import asyncio import logging import time from typing import Dict, Optional, Any, Callable, Awaitable from urllib.parse import urlparse import websockets import random from src.config.task_runner_config import TaskRunnerConfig from src.errors import ( NoIdleTimeoutHandlerError, TaskMissingError, WebsocketConnectionError, ) from src.message_types.broker import TaskSettings from src.nanoid import nanoid from src.constants import ( RUNNER_NAME, TASK_REJECTED_REASON_AT_CAPACITY, TASK_REJECTED_REASON_OFFER_EXPIRED, TASK_TYPE_PYTHON, OFFER_INTERVAL, OFFER_VALIDITY, OFFER_VALIDITY_MAX_JITTER, OFFER_VALIDITY_LATENCY_BUFFER, TASK_BROKER_WS_PATH, RPC_BROWSER_CONSOLE_LOG_METHOD, LOG_TASK_COMPLETE, LOG_TASK_CANCEL, LOG_TASK_CANCEL_UNKNOWN, LOG_TASK_CANCEL_WAITING, ) from src.message_types import ( BrokerMessage, RunnerMessage, BrokerInfoRequest, BrokerRunnerRegistered, BrokerTaskOfferAccept, BrokerTaskSettings, BrokerTaskCancel, BrokerRpcResponse, RunnerInfo, RunnerTaskOffer, RunnerTaskAccepted, RunnerTaskRejected, RunnerTaskDone, RunnerTaskError, RunnerRpcCall, ) from src.message_serde import MessageSerde from src.task_state import TaskState, TaskStatus from src.task_executor import TaskExecutor from src.task_analyzer import TaskAnalyzer class TaskOffer: def __init__(self, offer_id: str, valid_until: float): self.offer_id = offer_id self.valid_until = valid_until @property def has_expired(self) -> bool: return time.time() > self.valid_until class TaskRunner: def __init__( self, config: TaskRunnerConfig, ): self.runner_id = nanoid() self.name = RUNNER_NAME self.config = config self.websocket_connection: Optional[Any] = None self.can_send_offers = False self.open_offers: Dict[str, TaskOffer] = {} self.running_tasks: Dict[str, TaskState] = {} self.offers_coroutine: Optional[asyncio.Task] = None self.serde = MessageSerde() self.executor = TaskExecutor() self.analyzer = TaskAnalyzer(config.stdlib_allow, config.external_allow) self.logger = logging.getLogger(__name__) self.idle_coroutine: Optional[asyncio.Task] = None self.on_idle_timeout: Optional[Callable[[], Awaitable[None]]] = None self.last_activity_time = time.time() self.is_shutting_down = False self.task_broker_uri = config.task_broker_uri websocket_host = urlparse(config.task_broker_uri).netloc self.websocket_url = ( f"ws://{websocket_host}{TASK_BROKER_WS_PATH}?id={self.runner_id}" ) async def start(self) -> None: if self.config.is_auto_shutdown_enabled and not self.on_idle_timeout: raise NoIdleTimeoutHandlerError(self.config.auto_shutdown_timeout) headers = {"Authorization": f"Bearer {self.config.grant_token}"} while not self.is_shutting_down: try: self.websocket_connection = await websockets.connect( self.websocket_url, additional_headers=headers, max_size=self.config.max_payload_size, ) self.logger.info("Connected to broker") await self._listen_for_messages() except Exception: raise WebsocketConnectionError(self.task_broker_uri) if not self.is_shutting_down: self.websocket_connection = None self.can_send_offers = False await self._cancel_coroutine(self.offers_coroutine) await self._cancel_coroutine(self.idle_coroutine) await asyncio.sleep(5) async def _cancel_coroutine(self, coroutine: Optional[asyncio.Task]) -> None: if coroutine and not coroutine.done(): coroutine.cancel() try: await coroutine except asyncio.CancelledError: pass # ========== Shutdown ========== async def stop(self) -> None: self.is_shutting_down = True self.can_send_offers = False await self._cancel_coroutine(self.offers_coroutine) await self._cancel_coroutine(self.idle_coroutine) await self._wait_for_tasks() await self._terminate_tasks() if self.websocket_connection: await self.websocket_connection.close() self.logger.info("Disconnected from broker") self.logger.info("Runner stopped") async def _wait_for_tasks(self): if not self.running_tasks: return timeout = self.config.graceful_shutdown_timeout self.logger.debug( f"Waiting for {len(self.running_tasks)} tasks to complete (timeout: {timeout}s)..." ) start_time = time.time() while self.running_tasks and (time.time() - start_time) < timeout: await asyncio.sleep(0.5) if self.running_tasks: self.logger.warning( f"Timed out waiting for {len(self.running_tasks)} tasks to complete" ) async def _terminate_tasks(self): if not self.running_tasks: return self.logger.warning(f"Terminating {len(self.running_tasks)} tasks...") tasks_to_terminate = [ asyncio.to_thread(self.executor.stop_process, task_state.process) for task_state in self.running_tasks.values() if task_state.process ] if tasks_to_terminate: await asyncio.gather(*tasks_to_terminate, return_exceptions=True) self.running_tasks.clear() self.logger.warning("Terminated tasks") # ========== Messages ========== async def _listen_for_messages(self) -> None: if self.websocket_connection is None: raise WebsocketConnectionError(self.task_broker_uri) async for raw_message in self.websocket_connection: try: message = self.serde.deserialize_broker_message(raw_message) await self._handle_message(message) except Exception as e: self.logger.error(f"Error handling message: {e}") async def _handle_message(self, message: BrokerMessage) -> None: match message: case BrokerInfoRequest(): await self._handle_info_request() case BrokerRunnerRegistered(): await self._handle_runner_registered() case BrokerTaskOfferAccept(): await self._handle_task_offer_accept(message) case BrokerTaskSettings(): await self._handle_task_settings(message) case BrokerTaskCancel(): await self._handle_task_cancel(message) case BrokerRpcResponse(): pass # currently only logging, already handled by browser case _: self.logger.warning(f"Unhandled message type: {type(message)}") async def _handle_info_request(self) -> None: response = RunnerInfo(name=self.name, types=[TASK_TYPE_PYTHON]) await self._send_message(response) async def _handle_runner_registered(self) -> None: self.can_send_offers = True self.offers_coroutine = asyncio.create_task(self._send_offers_loop()) self.logger.info("Registered with broker") self._reset_idle_timer() async def _handle_task_offer_accept(self, message: BrokerTaskOfferAccept) -> None: offer = self.open_offers.get(message.offer_id) if offer is None or offer.has_expired: response = RunnerTaskRejected( task_id=message.task_id, reason=TASK_REJECTED_REASON_OFFER_EXPIRED, ) await self._send_message(response) return if len(self.running_tasks) >= self.config.max_concurrency: response = RunnerTaskRejected( task_id=message.task_id, reason=TASK_REJECTED_REASON_AT_CAPACITY, ) await self._send_message(response) return del self.open_offers[message.offer_id] task_state = TaskState(message.task_id) self.running_tasks[message.task_id] = task_state response = RunnerTaskAccepted(task_id=message.task_id) await self._send_message(response) self.logger.info(f"Accepted task {message.task_id}") self._reset_idle_timer() async def _handle_task_settings(self, message: BrokerTaskSettings) -> None: task_state = self.running_tasks.get(message.task_id) if task_state is None: raise TaskMissingError(message.task_id) if task_state.status != TaskStatus.WAITING_FOR_SETTINGS: self.logger.warning( f"Received settings for task but it is already {task_state.status}. Discarding message." ) return task_state.workflow_name = message.settings.workflow_name task_state.workflow_id = message.settings.workflow_id task_state.node_name = message.settings.node_name task_state.node_id = message.settings.node_id task_state.status = TaskStatus.RUNNING asyncio.create_task(self._execute_task(message.task_id, message.settings)) self.logger.info(f"Received task {message.task_id}") async def _execute_task(self, task_id: str, task_settings: TaskSettings) -> None: start_time = time.time() try: task_state = self.running_tasks.get(task_id) if task_state is None: raise TaskMissingError(task_id) self.analyzer.validate(task_settings.code) process, queue = self.executor.create_process( code=task_settings.code, node_mode=task_settings.node_mode, items=task_settings.items, stdlib_allow=self.config.stdlib_allow, external_allow=self.config.external_allow, builtins_deny=self.config.builtins_deny, can_log=task_settings.can_log, ) task_state.process = process result, print_args = self.executor.execute_process( process=process, queue=queue, task_timeout=self.config.task_timeout, continue_on_fail=task_settings.continue_on_fail, ) for print_args_per_call in print_args: await self._send_rpc_message( task_id, RPC_BROWSER_CONSOLE_LOG_METHOD, print_args_per_call ) response = RunnerTaskDone(task_id=task_id, data={"result": result}) await self._send_message(response) self.logger.info( LOG_TASK_COMPLETE.format( task_id=task_id, duration=self._get_duration(start_time), **task_state.context(), ) ) except Exception as e: self.logger.error(f"Task {task_id} failed", exc_info=True) response = RunnerTaskError(task_id=task_id, error={"message": str(e)}) await self._send_message(response) finally: self.running_tasks.pop(task_id, None) self._reset_idle_timer() async def _handle_task_cancel(self, message: BrokerTaskCancel) -> None: task_id = message.task_id task_state = self.running_tasks.get(task_id) if task_state is None: self.logger.warning(LOG_TASK_CANCEL_UNKNOWN.format(task_id=task_id)) return if task_state.status == TaskStatus.WAITING_FOR_SETTINGS: self.running_tasks.pop(task_id, None) self.logger.info(LOG_TASK_CANCEL_WAITING.format(task_id=task_id)) await self._send_offers() return if task_state.status == TaskStatus.RUNNING: task_state.status = TaskStatus.ABORTING self.executor.stop_process(task_state.process) self.logger.info( LOG_TASK_CANCEL.format(task_id=task_id, **task_state.context()) ) async def _send_rpc_message(self, task_id: str, method_name: str, params: list): message = RunnerRpcCall( call_id=nanoid(), task_id=task_id, name=method_name, params=params ) await self._send_message(message) async def _send_message(self, message: RunnerMessage) -> None: if self.websocket_connection is None: raise WebsocketConnectionError(self.task_broker_uri) serialized = self.serde.serialize_runner_message(message) await self.websocket_connection.send(serialized) def _get_duration(self, start_time: float) -> str: elapsed = time.time() - start_time if elapsed < 1: return f"{int(elapsed * 1000)}ms" if elapsed < 60: return f"{int(elapsed)}s" return f"{int(elapsed) // 60}m" # ========== Offers ========== async def _send_offers_loop(self) -> None: while self.can_send_offers: try: await self._send_offers() await asyncio.sleep(OFFER_INTERVAL) except asyncio.CancelledError: break except Exception as e: self.logger.error(f"Error sending offers: {e}") async def _send_offers(self) -> None: if not self.can_send_offers: return expired_offer_ids = [ offer_id for offer_id, offer in self.open_offers.items() if offer.has_expired ] for offer_id in expired_offer_ids: self.open_offers.pop(offer_id, None) offers_to_send = self.config.max_concurrency - ( len(self.open_offers) + len(self.running_tasks) ) for _ in range(offers_to_send): offer_id = nanoid() valid_for_ms = OFFER_VALIDITY + random.randint(0, OFFER_VALIDITY_MAX_JITTER) valid_until = ( time.time() + (valid_for_ms / 1000) + OFFER_VALIDITY_LATENCY_BUFFER ) self.open_offers[offer_id] = TaskOffer(offer_id, valid_until) message = RunnerTaskOffer( offer_id=offer_id, task_type=TASK_TYPE_PYTHON, valid_for=valid_for_ms ) await self._send_message(message) # ========== Inactivity ========== def _reset_idle_timer(self): """Reset idle timer when key event occurs, namely runner registration, task acceptance, and task completion or failure.""" if not self.config.is_auto_shutdown_enabled: return self.last_activity_time = time.time() if self.idle_coroutine and not self.idle_coroutine.done(): self.idle_coroutine.cancel() self.idle_coroutine = asyncio.create_task(self._idle_timer_coroutine()) async def _idle_timer_coroutine(self): try: await asyncio.sleep(self.config.auto_shutdown_timeout) if len(self.running_tasks) > 0: return assert self.on_idle_timeout is not None # validated at start() await self.on_idle_timeout() except asyncio.CancelledError: pass