feat(core): Add execution flow to native Python runner (no-changelog) (#18485)

This commit is contained in:
Iván Ovejero
2025-08-22 12:00:42 +02:00
committed by GitHub
parent 67e75c389d
commit e3772c13d2
19 changed files with 677 additions and 105 deletions

View File

@@ -1,23 +1,27 @@
import asyncio
from dataclasses import dataclass
import logging
import time
from typing import Dict, Optional
from typing import Dict, Optional, Any
from urllib.parse import urlparse
from typing import Any
import websockets
import random
from nanoid import generate as nanoid
from .errors import WebsocketConnectionError, TaskMissingError
from .message_types.broker import TaskSettings
from .nanoid import nanoid
from .constants import (
RUNNER_NAME,
TASK_REJECTED_REASON_AT_CAPACITY,
TASK_REJECTED_REASON_OFFER_EXPIRED,
TASK_TYPE_PYTHON,
DEFAULT_MAX_CONCURRENCY,
DEFAULT_MAX_PAYLOAD_SIZE,
OFFER_INTERVAL,
OFFER_VALIDITY,
OFFER_VALIDITY_MAX_JITTER,
OFFER_VALIDITY_LATENCY_BUFFER,
WS_RUNNERS_PATH,
TASK_BROKER_WS_PATH,
)
from .message_types import (
BrokerMessage,
@@ -25,17 +29,18 @@ from .message_types import (
BrokerInfoRequest,
BrokerRunnerRegistered,
BrokerTaskOfferAccept,
BrokerTaskSettings,
BrokerTaskCancel,
RunnerInfo,
RunnerTaskOffer,
RunnerTaskAccepted,
RunnerTaskRejected,
RunnerTaskDone,
RunnerTaskError,
)
from .message_serde import MessageSerde
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
from .task_state import TaskState, TaskStatus
from .task_executor import TaskExecutor
class TaskOffer:
@@ -48,82 +53,95 @@ class TaskOffer:
return time.time() > self.valid_until
@dataclass()
class TaskRunnerOpts:
grant_token: str
task_broker_uri: str
max_concurrency: int
max_payload_size: int
task_timeout: int
class TaskRunner:
def __init__(
self,
task_broker_uri: str = "http://127.0.0.1:5679",
grant_token: str = "",
opts: TaskRunnerOpts,
):
self.runner_id = nanoid()
self.name = RUNNER_NAME
self.task_broker_uri = task_broker_uri
self.grant_token = grant_token
self.name = "Python Task Runner"
self.max_concurrency = DEFAULT_MAX_CONCURRENCY
self.max_payload_size = DEFAULT_MAX_PAYLOAD_SIZE
self.grant_token = opts.grant_token
self.opts = opts
self.websocket: Optional[Any] = None
self.websocket_connection: Optional[Any] = None
self.can_send_offers = False
self.open_offers: Dict[str, TaskOffer] = {} # offer_id -> TaskOffer
self.running_tasks: Dict[str, str] = {} # task_id -> offer_id
self.open_offers: Dict[str, TaskOffer] = {}
self.running_tasks: Dict[str, TaskState] = {}
self.offers_coroutine: Optional[asyncio.Task] = None
self.serde = MessageSerde()
self.executor = TaskExecutor()
self.logger = logging.getLogger(__name__)
ws_host = urlparse(task_broker_uri).netloc
self.ws_url = f"ws://{ws_host}{WS_RUNNERS_PATH}?id={self.runner_id}"
self.task_broker_uri = opts.task_broker_uri
websocket_host = urlparse(opts.task_broker_uri).netloc
self.websocket_url = (
f"ws://{websocket_host}{TASK_BROKER_WS_PATH}?id={self.runner_id}"
)
async def start(self) -> None:
logger.info("Starting Python task runner...")
headers = {"Authorization": f"Bearer {self.grant_token}"}
try:
self.websocket = await websockets.connect(
self.ws_url,
self.websocket_connection = await websockets.connect(
self.websocket_url,
additional_headers=headers,
max_size=self.max_payload_size,
max_size=self.opts.max_payload_size,
)
logger.info(f"Connected to task broker at {self.ws_url}")
self.logger.info("Connected to broker")
await self._listen_for_messages()
except Exception as e:
logger.error(f"Failed to connect to task broker: {e}")
raise
except Exception:
raise WebsocketConnectionError(self.task_broker_uri)
async def stop(self) -> None:
logger.info("Stopping Python task runner...")
if self.offers_coroutine:
self.offers_coroutine.cancel()
if self.websocket:
await self.websocket.close()
if self.websocket_connection:
await self.websocket_connection.close()
self.logger.info("Disconnected from broker")
# ========== Messages ==========
async def _listen_for_messages(self) -> None:
if self.websocket is None:
raise RuntimeError("WebSocket not connected")
if self.websocket_connection is None:
raise WebsocketConnectionError(self.task_broker_uri)
async for raw_message in self.websocket:
async for raw_message in self.websocket_connection:
try:
message = MessageSerde.deserialize_broker_message(raw_message)
message = self.serde.deserialize_broker_message(raw_message)
await self._handle_message(message)
except Exception as e:
logger.error(f"Error handling message: {e}")
self.logger.error(f"Error handling message: {e}")
async def _handle_message(self, message: BrokerMessage) -> None:
if isinstance(message, BrokerInfoRequest):
await self._handle_info_request()
elif isinstance(message, BrokerRunnerRegistered):
await self._handle_runner_registered()
elif isinstance(message, BrokerTaskOfferAccept):
await self._handle_task_offer_accept(message)
else:
logger.warning(f"Unhandled message type: {type(message)}")
match message:
case BrokerInfoRequest():
await self._handle_info_request()
case BrokerRunnerRegistered():
await self._handle_runner_registered()
case BrokerTaskOfferAccept():
await self._handle_task_offer_accept(message)
case BrokerTaskSettings():
await self._handle_task_settings(message)
case BrokerTaskCancel():
await self._handle_task_cancel(message)
case _:
self.logger.warning(f"Unhandled message type: {type(message)}")
async def _handle_info_request(self) -> None:
response = RunnerInfo(name=self.name, types=[TASK_TYPE_PYTHON])
@@ -132,38 +150,103 @@ class TaskRunner:
async def _handle_runner_registered(self) -> None:
self.can_send_offers = True
self.offers_coroutine = asyncio.create_task(self._send_offers_loop())
self.logger.info("Registered with broker")
async def _handle_task_offer_accept(self, message: BrokerTaskOfferAccept) -> None:
offer = self.open_offers.get(message.offer_id)
if not offer or offer.has_expired:
if offer is None or offer.has_expired:
response = RunnerTaskRejected(
task_id=message.task_id,
reason="Offer expired - not accepted within validity window",
reason=TASK_REJECTED_REASON_OFFER_EXPIRED,
)
await self._send_message(response)
return
if len(self.running_tasks) >= self.max_concurrency:
if len(self.running_tasks) >= self.opts.max_concurrency:
response = RunnerTaskRejected(
task_id=message.task_id,
reason="No open task slots - runner already at capacity",
reason=TASK_REJECTED_REASON_AT_CAPACITY,
)
await self._send_message(response)
return
del self.open_offers[message.offer_id]
self.running_tasks[message.task_id] = message.offer_id
task_state = TaskState(message.task_id)
self.running_tasks[message.task_id] = task_state
response = RunnerTaskAccepted(task_id=message.task_id)
await self._send_message(response)
self.logger.info(f"Accepted task {message.task_id}")
async def _handle_task_settings(self, message: BrokerTaskSettings) -> None:
task_state = self.running_tasks.get(message.task_id)
if task_state is None:
raise TaskMissingError(message.task_id)
if task_state.status != TaskStatus.WAITING_FOR_SETTINGS:
self.logger.warning(
f"Received settings for task but it is already {task_state.status}. Discarding message."
)
return
task_state.status = TaskStatus.RUNNING
asyncio.create_task(self._execute_task(message.task_id, message.settings))
self.logger.info(f"Received task {message.task_id}")
async def _execute_task(self, task_id: str, task_settings: TaskSettings) -> None:
try:
task_state = self.running_tasks.get(task_id)
if task_state is None:
raise TaskMissingError(task_id)
process, queue = self.executor.create_process(
task_settings.code, task_settings.node_mode, task_settings.items
)
task_state.process = process
result = self.executor.execute_process(
process, queue, self.opts.task_timeout, task_settings.continue_on_fail
)
response = RunnerTaskDone(task_id=task_id, data={"result": result})
await self._send_message(response)
self.logger.info(f"Completed task {task_id}")
except Exception as e:
response = RunnerTaskError(task_id=task_id, error={"message": str(e)})
await self._send_message(response)
finally:
self.running_tasks.pop(task_id, None)
async def _handle_task_cancel(self, message: BrokerTaskCancel) -> None:
task_state = self.running_tasks.get(message.task_id)
if task_state is None:
self.logger.warning(
f"Received cancel for unknown task: {message.task_id}. Discarding message."
)
return
if task_state.status == TaskStatus.WAITING_FOR_SETTINGS:
self.running_tasks.pop(message.task_id, None)
await self._send_offers()
return
if task_state.status == TaskStatus.RUNNING:
task_state.status = TaskStatus.ABORTING
self.executor.stop_process(task_state.process)
async def _send_message(self, message: RunnerMessage) -> None:
if not self.websocket:
raise RuntimeError("WebSocket not connected")
if self.websocket_connection is None:
raise WebsocketConnectionError(self.task_broker_uri)
serialized = MessageSerde.serialize_runner_message(message)
await self.websocket.send(serialized)
serialized = self.serde.serialize_runner_message(message)
await self.websocket_connection.send(serialized)
# ========== Offers ==========
@@ -175,7 +258,7 @@ class TaskRunner:
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error sending offers: {e}")
self.logger.error(f"Error sending offers: {e}")
async def _send_offers(self) -> None:
if not self.can_send_offers:
@@ -188,9 +271,9 @@ class TaskRunner:
]
for offer_id in expired_offer_ids:
del self.open_offers[offer_id]
self.open_offers.pop(offer_id, None)
offers_to_send = self.max_concurrency - (
offers_to_send = self.opts.max_concurrency - (
len(self.open_offers) + len(self.running_tasks)
)
@@ -203,8 +286,7 @@ class TaskRunner:
time.time() + (valid_for_ms / 1000) + OFFER_VALIDITY_LATENCY_BUFFER
)
offer = TaskOffer(offer_id, valid_until)
self.open_offers[offer_id] = offer
self.open_offers[offer_id] = TaskOffer(offer_id, valid_until)
message = RunnerTaskOffer(
offer_id=offer_id, task_type=TASK_TYPE_PYTHON, valid_for=valid_for_ms