feat(core): Add auth and offers flow to native Python runner (no-changelog) (#18354)

This commit is contained in:
Iván Ovejero
2025-08-15 14:36:42 +02:00
committed by GitHub
parent 3848673921
commit a1280b6bf4
13 changed files with 462 additions and 5 deletions

View File

@@ -0,0 +1,213 @@
import asyncio
import logging
import time
from typing import Dict, Optional
from urllib.parse import urlparse
from typing import Any
import websockets
import random
from nanoid import generate as nanoid
from .constants import (
TASK_TYPE_PYTHON,
DEFAULT_MAX_CONCURRENCY,
DEFAULT_MAX_PAYLOAD_SIZE,
OFFER_INTERVAL,
OFFER_VALIDITY,
OFFER_VALIDITY_MAX_JITTER,
OFFER_VALIDITY_LATENCY_BUFFER,
WS_RUNNERS_PATH,
)
from .message_types import (
BrokerMessage,
RunnerMessage,
BrokerInfoRequest,
BrokerRunnerRegistered,
BrokerTaskOfferAccept,
RunnerInfo,
RunnerTaskOffer,
RunnerTaskAccepted,
RunnerTaskRejected,
)
from .message_serde import MessageSerde
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class TaskOffer:
def __init__(self, offer_id: str, valid_until: float):
self.offer_id = offer_id
self.valid_until = valid_until
@property
def has_expired(self) -> bool:
return time.time() > self.valid_until
class TaskRunner:
def __init__(
self,
task_broker_uri: str = "http://127.0.0.1:5679",
grant_token: str = "",
):
self.runner_id = nanoid()
self.task_broker_uri = task_broker_uri
self.grant_token = grant_token
self.name = "Python Task Runner"
self.max_concurrency = DEFAULT_MAX_CONCURRENCY
self.max_payload_size = DEFAULT_MAX_PAYLOAD_SIZE
self.websocket: Optional[Any] = None
self.can_send_offers = False
self.open_offers: Dict[str, TaskOffer] = {} # offer_id -> TaskOffer
self.running_tasks: Dict[str, str] = {} # task_id -> offer_id
self.offers_coroutine: Optional[asyncio.Task] = None
ws_host = urlparse(task_broker_uri).netloc
self.ws_url = f"ws://{ws_host}{WS_RUNNERS_PATH}?id={self.runner_id}"
async def start(self) -> None:
logger.info("Starting Python task runner...")
headers = {"Authorization": f"Bearer {self.grant_token}"}
try:
self.websocket = await websockets.connect(
self.ws_url,
additional_headers=headers,
max_size=self.max_payload_size,
)
logger.info(f"Connected to task broker at {self.ws_url}")
await self._listen_for_messages()
except Exception as e:
logger.error(f"Failed to connect to task broker: {e}")
raise
async def stop(self) -> None:
logger.info("Stopping Python task runner...")
if self.offers_coroutine:
self.offers_coroutine.cancel()
if self.websocket:
await self.websocket.close()
# ========== Messages ==========
async def _listen_for_messages(self) -> None:
if self.websocket is None:
raise RuntimeError("WebSocket not connected")
async for raw_message in self.websocket:
try:
message = MessageSerde.deserialize_broker_message(raw_message)
await self._handle_message(message)
except Exception as e:
logger.error(f"Error handling message: {e}")
async def _handle_message(self, message: BrokerMessage) -> None:
if isinstance(message, BrokerInfoRequest):
await self._handle_info_request()
elif isinstance(message, BrokerRunnerRegistered):
await self._handle_runner_registered()
elif isinstance(message, BrokerTaskOfferAccept):
await self._handle_task_offer_accept(message)
else:
logger.warning(f"Unhandled message type: {type(message)}")
async def _handle_info_request(self) -> None:
response = RunnerInfo(name=self.name, types=[TASK_TYPE_PYTHON])
await self._send_message(response)
async def _handle_runner_registered(self) -> None:
self.can_send_offers = True
self.offers_coroutine = asyncio.create_task(self._send_offers_loop())
async def _handle_task_offer_accept(self, message: BrokerTaskOfferAccept) -> None:
offer = self.open_offers.get(message.offer_id)
if not offer or offer.has_expired:
response = RunnerTaskRejected(
task_id=message.task_id,
reason="Offer expired - not accepted within validity window",
)
await self._send_message(response)
return
if len(self.running_tasks) >= self.max_concurrency:
response = RunnerTaskRejected(
task_id=message.task_id,
reason="No open task slots - runner already at capacity",
)
await self._send_message(response)
return
del self.open_offers[message.offer_id]
self.running_tasks[message.task_id] = message.offer_id
response = RunnerTaskAccepted(task_id=message.task_id)
await self._send_message(response)
async def _send_message(self, message: RunnerMessage) -> None:
if not self.websocket:
raise RuntimeError("WebSocket not connected")
serialized = MessageSerde.serialize_runner_message(message)
await self.websocket.send(serialized)
# ========== Offers ==========
async def _send_offers_loop(self) -> None:
while self.can_send_offers:
try:
await self._send_offers()
await asyncio.sleep(OFFER_INTERVAL)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error sending offers: {e}")
async def _send_offers(self) -> None:
if not self.can_send_offers:
return
expired_offer_ids = [
offer_id
for offer_id, offer in self.open_offers.items()
if offer.has_expired
]
for offer_id in expired_offer_ids:
del self.open_offers[offer_id]
offers_to_send = self.max_concurrency - (
len(self.open_offers) + len(self.running_tasks)
)
for _ in range(offers_to_send):
offer_id = nanoid()
valid_for_ms = OFFER_VALIDITY + random.randint(0, OFFER_VALIDITY_MAX_JITTER)
valid_until = (
time.time() + (valid_for_ms / 1000) + OFFER_VALIDITY_LATENCY_BUFFER
)
offer = TaskOffer(offer_id, valid_until)
self.open_offers[offer_id] = offer
message = RunnerTaskOffer(
offer_id=offer_id, task_type=TASK_TYPE_PYTHON, valid_for=valid_for_ms
)
await self._send_message(message)