feat(core): Add execution flow to native Python runner (no-changelog) (#18485)

This commit is contained in:
Iván Ovejero
2025-08-22 12:00:42 +02:00
committed by GitHub
parent 67e75c389d
commit e3772c13d2
19 changed files with 677 additions and 105 deletions

View File

@@ -1,3 +1,8 @@
check:
just lint
just format-check
just typecheck
run:
uv run python -m src.main
@@ -13,8 +18,15 @@ lintfix:
format:
uv run ruff format
format-check:
uv run ruff format --check
test:
@echo "No tests yet"
typecheck:
uv run ty check src/
# For debugging only, start the runner with a manually fetched grant token.
debug:
GRANT_TOKEN=$(curl -s -X POST http://127.0.0.1:5679/runners/auth -H "Content-Type: application/json" -d '{"token":"test"}' | jq -r '.data.token') && N8N_RUNNERS_GRANT_TOKEN="$GRANT_TOKEN" N8N_RUNNERS_HIDE_TASK_OFFER_LOGS=true just run

View File

@@ -1,26 +1,48 @@
# Message Types
# Messages
BROKER_INFO_REQUEST = "broker:inforequest"
BROKER_RUNNER_REGISTERED = "broker:runnerregistered"
BROKER_TASK_OFFER_ACCEPT = "broker:taskofferaccept"
BROKER_TASK_SETTINGS = "broker:tasksettings"
BROKER_TASK_CANCEL = "broker:taskcancel"
RUNNER_INFO = "runner:info"
RUNNER_TASK_OFFER = "runner:taskoffer"
RUNNER_TASK_ACCEPTED = "runner:taskaccepted"
RUNNER_TASK_REJECTED = "runner:taskrejected"
RUNNER_TASK_DONE = "runner:taskdone"
RUNNER_TASK_ERROR = "runner:taskerror"
# Task Runner Defaults
# Runner
TASK_TYPE_PYTHON = "python"
DEFAULT_MAX_CONCURRENCY = 5
RUNNER_NAME = "Python Task Runner"
DEFAULT_MAX_CONCURRENCY = 5 # tasks
DEFAULT_MAX_PAYLOAD_SIZE = 1024 * 1024 * 1024 # 1 GiB
DEFAULT_TASK_TIMEOUT = 60 # seconds
OFFER_INTERVAL = 0.25 # 250ms
OFFER_VALIDITY = 5000 # ms
OFFER_VALIDITY_MAX_JITTER = 500 # ms
OFFER_VALIDITY_LATENCY_BUFFER = 0.1 # 100ms
DEFAULT_TASK_BROKER_URI = "http://127.0.0.1:5679"
# Environment Variables
# Executor
EXECUTOR_USER_OUTPUT_KEY = "__n8n_internal_user_output__"
# Broker
DEFAULT_TASK_BROKER_URI = "http://127.0.0.1:5679"
TASK_BROKER_WS_PATH = "/runners/_ws"
# Env vars
ENV_TASK_BROKER_URI = "N8N_RUNNERS_TASK_BROKER_URI"
ENV_GRANT_TOKEN = "N8N_RUNNERS_GRANT_TOKEN"
ENV_MAX_CONCURRENCY = "N8N_RUNNERS_MAX_CONCURRENCY"
ENV_MAX_PAYLOAD_SIZE = "N8N_RUNNERS_MAX_PAYLOAD"
ENV_TASK_TIMEOUT = "N8N_RUNNERS_TASK_TIMEOUT"
ENV_HIDE_TASK_OFFER_LOGS = "N8N_RUNNERS_HIDE_TASK_OFFER_LOGS"
# WebSocket Paths
WS_RUNNERS_PATH = "/runners/_ws"
# Logging
LOG_FORMAT = "%(asctime)s.%(msecs)03d\t%(levelname)s\t%(message)s"
LOG_TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S"
# Rejection reasons
TASK_REJECTED_REASON_OFFER_EXPIRED = (
"Offer expired - not accepted within validity window"
)
TASK_REJECTED_REASON_AT_CAPACITY = "No open task slots - runner already at capacity"

View File

@@ -0,0 +1,15 @@
from .task_missing_error import TaskMissingError
from .task_result_missing_error import TaskResultMissingError
from .task_process_exit_error import TaskProcessExitError
from .task_runtime_error import TaskRuntimeError
from .task_timeout_error import TaskTimeoutError
from .websocket_connection_error import WebsocketConnectionError
__all__ = [
"TaskMissingError",
"TaskProcessExitError",
"TaskResultMissingError",
"TaskRuntimeError",
"TaskTimeoutError",
"WebsocketConnectionError",
]

View File

@@ -0,0 +1,12 @@
class TaskMissingError(Exception):
"""Raised when attempting to operate on a task that does not exist.
This typically indicates an internal error where the task runner
received a message referencing a task ID that is not currently
being tracked in the runner's running tasks.
"""
def __init__(self, task_id: str):
super().__init__(
f"Failed to find task {task_id}. This is likely an internal error."
)

View File

@@ -0,0 +1,6 @@
class TaskProcessExitError(Exception):
"""Raised when a task subprocess exits with a non-zero exit code."""
def __init__(self, exit_code: int):
super().__init__(f"Process exited with code {exit_code}")
self.exit_code = exit_code

View File

@@ -0,0 +1,11 @@
class TaskResultMissingError(Exception):
"""Raised when a task subprocess exits successfully but returns no result.
This typically indicates an internal error where the subprocess did not
put any data in the communication queue.
"""
def __init__(self):
super().__init__(
"Process completed but returned no result. This is likely an internal error."
)

View File

@@ -0,0 +1,10 @@
from typing import Dict, Any
class TaskRuntimeError(Exception):
"""Raised when user code throws an exception during task execution."""
def __init__(self, error_dict: Dict[str, Any]):
message = error_dict["message"]
super().__init__(message)
self.stack_trace = error_dict.get("stack", "")

View File

@@ -0,0 +1,7 @@
class TaskTimeoutError(Exception):
def __init__(self, task_timeout: int):
"""Raised when a task execution takes longer than the timeout limit."""
message = f"Task execution timed out after {task_timeout} {'second' if task_timeout == 1 else 'seconds'}"
super().__init__(message)
self.task_timeout = task_timeout

View File

@@ -0,0 +1,10 @@
class WebsocketConnectionError(ConnectionError):
"""Raised when the task runner fails to establish a WebSocket connection to the broker.
Common causes include network issues, incorrect broker URI, or the broker service being unavailable.
"""
def __init__(self, broker_uri: str):
super().__init__(
f"Failed to connect to broker. Please check if broker is reachable at: {broker_uri}"
)

View File

@@ -0,0 +1,72 @@
import logging
import os
from .constants import LOG_FORMAT, LOG_TIMESTAMP_FORMAT, ENV_HIDE_TASK_OFFER_LOGS
COLORS = {
"DEBUG": "\033[34m", # blue
"INFO": "\033[32m", # green
"WARNING": "\033[33m", # yellow
"ERROR": "\033[31m", # red
"CRITICAL": "\033[31m", # red
}
RESET = "\033[0m"
class ColorFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_colors = os.getenv("NO_COLOR") is None
def format(self, record):
formatted = super().format(record)
if not self.use_colors:
return formatted
parts = formatted.split("\t")
if len(parts) >= 3:
timestamp = parts[0]
level = parts[1]
message = " ".join(parts[2:])
level_color = COLORS.get(record.levelname, "")
if level_color:
level = level_color + level + RESET
message = level_color + message + RESET
formatted = f"{timestamp} {level} {message}"
return formatted
class TaskOfferFilter(logging.Filter):
def __init__(self):
super().__init__()
self.hide_offers = os.getenv(ENV_HIDE_TASK_OFFER_LOGS, "").lower() == "true"
def filter(self, record):
"""Filter out task offers if N8N_RUNNERS_HIDE_TASK_OFFER_LOGS is 'true'."""
return not (self.hide_offers and self._is_task_offer_message(record))
def _is_task_offer_message(self, record):
return (
record.levelname == "DEBUG"
and "websockets" in record.name
and '"runner:taskoffer"' in record.getMessage()
)
def setup_logging():
logger = logging.getLogger()
logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(ColorFormatter(LOG_FORMAT, LOG_TIMESTAMP_FORMAT))
stream_handler.addFilter(TaskOfferFilter())
logger.addHandler(stream_handler)
logging.getLogger("websockets.client").setLevel(logging.DEBUG)

View File

@@ -3,35 +3,52 @@ import logging
import os
import sys
from .constants import ENV_TASK_BROKER_URI, ENV_GRANT_TOKEN, DEFAULT_TASK_BROKER_URI
from .task_runner import TaskRunner
os.environ["WEBSOCKETS_MAX_LOG_SIZE"] = "256"
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
from .constants import (
DEFAULT_MAX_CONCURRENCY,
DEFAULT_TASK_TIMEOUT,
ENV_MAX_CONCURRENCY,
ENV_MAX_PAYLOAD_SIZE,
ENV_TASK_BROKER_URI,
ENV_GRANT_TOKEN,
DEFAULT_TASK_BROKER_URI,
DEFAULT_MAX_PAYLOAD_SIZE,
ENV_TASK_TIMEOUT,
)
logger = logging.getLogger(__name__)
from .logs import setup_logging
from .task_runner import TaskRunner, TaskRunnerOpts
async def main():
task_broker_uri = os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI)
setup_logging()
logger = logging.getLogger(__name__)
logger.info("Starting runner...")
grant_token = os.getenv(ENV_GRANT_TOKEN, "")
if not grant_token:
logger.error(f"{ENV_GRANT_TOKEN} environment variable is required")
sys.exit(1)
runner = TaskRunner(
task_broker_uri=task_broker_uri,
grant_token=grant_token,
opts = TaskRunnerOpts(
grant_token,
os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI),
int(os.getenv(ENV_MAX_CONCURRENCY, DEFAULT_MAX_CONCURRENCY)),
int(os.getenv(ENV_MAX_PAYLOAD_SIZE, DEFAULT_MAX_PAYLOAD_SIZE)),
int(os.getenv(ENV_TASK_TIMEOUT, DEFAULT_TASK_TIMEOUT)),
)
task_runner = TaskRunner(opts)
try:
await runner.start()
except KeyboardInterrupt:
logger.info("Shutting down...")
await task_runner.start()
except (KeyboardInterrupt, asyncio.CancelledError):
logger.info("Shutting down runner...")
finally:
await runner.stop()
await task_runner.stop()
logger.info("Runner stopped")
if __name__ == "__main__":

View File

@@ -1,9 +1,14 @@
import json
from dataclasses import asdict
from typing import cast
from .message_types.broker import NodeMode, TaskSettings
from .constants import (
BROKER_INFO_REQUEST,
BROKER_RUNNER_REGISTERED,
BROKER_TASK_CANCEL,
BROKER_TASK_OFFER_ACCEPT,
BROKER_TASK_SETTINGS,
)
from .message_types import (
BrokerMessage,
@@ -11,30 +16,87 @@ from .message_types import (
BrokerInfoRequest,
BrokerRunnerRegistered,
BrokerTaskOfferAccept,
BrokerTaskSettings,
BrokerTaskCancel,
)
class MessageSerde:
"""Handles serialization and deserialization of broker and runner messages."""
NODE_MODE_MAP = {
"runOnceForAllItems": "all_items",
"runOnceForEachItem": "per_item",
}
MESSAGE_TYPE_MAP = {
BROKER_INFO_REQUEST: lambda _: BrokerInfoRequest(),
BROKER_RUNNER_REGISTERED: lambda _: BrokerRunnerRegistered(),
BROKER_TASK_OFFER_ACCEPT: lambda d: BrokerTaskOfferAccept(
task_id=d["taskId"], offer_id=d["offerId"]
def _get_node_mode(node_mode_str: str) -> NodeMode:
if node_mode_str not in NODE_MODE_MAP:
raise ValueError(f"Unknown nodeMode: {node_mode_str}")
return cast(NodeMode, NODE_MODE_MAP[node_mode_str])
def _parse_task_settings(d: dict) -> BrokerTaskSettings:
try:
task_id = d["taskId"]
settings_dict = d["settings"]
code = settings_dict["code"]
node_mode = _get_node_mode(settings_dict["nodeMode"])
continue_on_fail = settings_dict.get("continueOnFail", False)
items = settings_dict["items"]
except KeyError as e:
raise ValueError(f"Missing field in task settings message: {e}")
return BrokerTaskSettings(
task_id=task_id,
settings=TaskSettings(
code=code,
node_mode=node_mode,
continue_on_fail=continue_on_fail,
items=items,
),
}
)
def _parse_task_offer_accept(d: dict) -> BrokerTaskOfferAccept:
try:
task_id = d["taskId"]
offer_id = d["offerId"]
except KeyError as e:
raise ValueError(f"Missing field in task offer acceptance message: {e}")
return BrokerTaskOfferAccept(task_id=task_id, offer_id=offer_id)
def _parse_task_cancel(d: dict) -> BrokerTaskCancel:
try:
task_id = d["taskId"]
reason = d["reason"]
except KeyError as e:
raise ValueError(f"Missing field in task cancel message: {e}")
return BrokerTaskCancel(task_id=task_id, reason=reason)
MESSAGE_TYPE_MAP = {
BROKER_INFO_REQUEST: lambda _: BrokerInfoRequest(),
BROKER_RUNNER_REGISTERED: lambda _: BrokerRunnerRegistered(),
BROKER_TASK_OFFER_ACCEPT: _parse_task_offer_accept,
BROKER_TASK_SETTINGS: _parse_task_settings,
BROKER_TASK_CANCEL: _parse_task_cancel,
}
class MessageSerde:
"""Responsible for deserializing incoming messages and serializing outgoing messages."""
@staticmethod
def deserialize_broker_message(data: str) -> BrokerMessage:
message_dict = json.loads(data)
message_type = message_dict.get("type")
if message_type in MessageSerde.MESSAGE_TYPE_MAP:
return MessageSerde.MESSAGE_TYPE_MAP[message_type](message_dict)
else:
if message_type not in MESSAGE_TYPE_MAP:
raise ValueError(f"Unknown message type: {message_type}")
return MESSAGE_TYPE_MAP[message_type](message_dict)
@staticmethod
def serialize_runner_message(message: RunnerMessage) -> str:
data = asdict(message)

View File

@@ -3,6 +3,8 @@ from .broker import (
BrokerInfoRequest,
BrokerRunnerRegistered,
BrokerTaskOfferAccept,
BrokerTaskSettings,
BrokerTaskCancel,
)
from .runner import (
RunnerMessage,
@@ -10,6 +12,8 @@ from .runner import (
RunnerTaskOffer,
RunnerTaskAccepted,
RunnerTaskRejected,
RunnerTaskDone,
RunnerTaskError,
)
__all__ = [
@@ -17,9 +21,13 @@ __all__ = [
"BrokerInfoRequest",
"BrokerRunnerRegistered",
"BrokerTaskOfferAccept",
"BrokerTaskSettings",
"BrokerTaskCancel",
"RunnerMessage",
"RunnerInfo",
"RunnerTaskOffer",
"RunnerTaskAccepted",
"RunnerTaskRejected",
"RunnerTaskDone",
"RunnerTaskError",
]

View File

@@ -1,26 +1,63 @@
from dataclasses import dataclass
from typing import Literal, Union
from typing import Literal, Union, List, Dict, Any
from ..constants import (
BROKER_INFO_REQUEST,
BROKER_RUNNER_REGISTERED,
BROKER_TASK_CANCEL,
BROKER_TASK_OFFER_ACCEPT,
BROKER_TASK_SETTINGS,
)
@dataclass
class BrokerInfoRequest:
type: Literal["broker:inforequest"] = "broker:inforequest"
type: Literal["broker:inforequest"] = BROKER_INFO_REQUEST
@dataclass
class BrokerRunnerRegistered:
type: Literal["broker:runnerregistered"] = "broker:runnerregistered"
type: Literal["broker:runnerregistered"] = BROKER_RUNNER_REGISTERED
@dataclass
class BrokerTaskOfferAccept:
task_id: str
offer_id: str
type: Literal["broker:taskofferaccept"] = "broker:taskofferaccept"
type: Literal["broker:taskofferaccept"] = BROKER_TASK_OFFER_ACCEPT
NodeMode = Literal["all_items", "per_item"]
Items = List[Dict[str, Any]] # INodeExecutionData[]
@dataclass
class TaskSettings:
code: str
node_mode: NodeMode
continue_on_fail: bool
items: Items
@dataclass
class BrokerTaskSettings:
task_id: str
settings: TaskSettings
type: Literal["broker:tasksettings"] = BROKER_TASK_SETTINGS
@dataclass
class BrokerTaskCancel:
task_id: str
reason: str
type: Literal["broker:taskcancel"] = BROKER_TASK_CANCEL
BrokerMessage = Union[
BrokerInfoRequest,
BrokerRunnerRegistered,
BrokerTaskOfferAccept,
BrokerTaskSettings,
BrokerTaskCancel,
]

View File

@@ -1,12 +1,21 @@
from dataclasses import dataclass
from typing import List, Literal, Union
from typing import List, Literal, Union, Any, Dict
from ..constants import (
RUNNER_INFO,
RUNNER_TASK_ACCEPTED,
RUNNER_TASK_DONE,
RUNNER_TASK_ERROR,
RUNNER_TASK_OFFER,
RUNNER_TASK_REJECTED,
)
@dataclass
class RunnerInfo:
name: str
types: List[str]
type: Literal["runner:info"] = "runner:info"
type: Literal["runner:info"] = RUNNER_INFO
@dataclass
@@ -14,20 +23,34 @@ class RunnerTaskOffer:
offer_id: str
task_type: str
valid_for: int
type: Literal["runner:taskoffer"] = "runner:taskoffer"
type: Literal["runner:taskoffer"] = RUNNER_TASK_OFFER
@dataclass
class RunnerTaskAccepted:
task_id: str
type: Literal["runner:taskaccepted"] = "runner:taskaccepted"
type: Literal["runner:taskaccepted"] = RUNNER_TASK_ACCEPTED
@dataclass
class RunnerTaskRejected:
task_id: str
reason: str
type: Literal["runner:taskrejected"] = "runner:taskrejected"
type: Literal["runner:taskrejected"] = RUNNER_TASK_REJECTED
@dataclass
class RunnerTaskDone:
task_id: str
data: Dict[str, Any]
type: Literal["runner:taskdone"] = RUNNER_TASK_DONE
@dataclass
class RunnerTaskError:
task_id: str
error: Dict[str, Any]
type: Literal["runner:taskerror"] = RUNNER_TASK_ERROR
RunnerMessage = Union[
@@ -35,4 +58,6 @@ RunnerMessage = Union[
RunnerTaskOffer,
RunnerTaskAccepted,
RunnerTaskRejected,
RunnerTaskDone,
RunnerTaskError,
]

View File

@@ -0,0 +1,9 @@
from nanoid.generate import generate
import string
NANOID_CHARSET = string.ascii_uppercase + string.ascii_lowercase + string.digits
NANOID_LENGTH = 21
def nanoid() -> str:
return generate(NANOID_CHARSET, NANOID_LENGTH)

View File

@@ -0,0 +1,133 @@
from queue import Empty
import multiprocessing
import traceback
import textwrap
from .errors import (
TaskResultMissingError,
TaskRuntimeError,
TaskTimeoutError,
TaskProcessExitError,
)
from .message_types.broker import NodeMode, Items
from .constants import EXECUTOR_USER_OUTPUT_KEY
from multiprocessing.context import SpawnProcess
MULTIPROCESSING_CONTEXT = multiprocessing.get_context("spawn")
class TaskExecutor:
"""Responsible for executing Python code tasks in isolated subprocesses."""
@staticmethod
def create_process(code: str, node_mode: NodeMode, items: Items):
"""Create a subprocess for executing a Python code task and a queue for communication."""
fn = (
TaskExecutor._all_items
if node_mode == "all_items"
else TaskExecutor._per_item
)
queue = MULTIPROCESSING_CONTEXT.Queue()
process = MULTIPROCESSING_CONTEXT.Process(target=fn, args=(code, items, queue))
return process, queue
@staticmethod
def execute_process(
process: SpawnProcess,
queue: multiprocessing.Queue,
task_timeout: int,
continue_on_fail: bool,
):
"""Execute a subprocess for a Python code task."""
try:
process.start()
process.join(timeout=task_timeout)
if process.is_alive():
TaskExecutor.stop_process(process)
raise TaskTimeoutError(task_timeout)
if process.exitcode != 0:
assert process.exitcode is not None
raise TaskProcessExitError(process.exitcode)
try:
returned = queue.get_nowait()
except Empty:
raise TaskResultMissingError()
if "error" in returned:
raise TaskRuntimeError(returned["error"])
return returned["result"] or []
except Exception as e:
if continue_on_fail:
return [{"json": {"error": str(e)}}]
raise
@staticmethod
def stop_process(process: SpawnProcess | None):
"""Stop a running subprocess, gracefully else force-killing."""
if process is None or not process.is_alive():
return
process.terminate()
process.join(timeout=1) # 1s grace period
if process.is_alive():
process.kill()
@staticmethod
def _all_items(raw_code: str, items: Items, queue: multiprocessing.Queue):
"""Execute a Python code task in all-items mode."""
try:
code = TaskExecutor._wrap_code(raw_code)
globals = {"__builtins__": __builtins__, "_items": items}
exec(code, globals)
queue.put({"result": globals[EXECUTOR_USER_OUTPUT_KEY]})
except Exception as e:
TaskExecutor._put_error(queue, e)
@staticmethod
def _per_item(raw_code: str, items: Items, queue: multiprocessing.Queue):
"""Execute a Python code task in per-item mode."""
try:
wrapped_code = TaskExecutor._wrap_code(raw_code)
compiled_code = compile(wrapped_code, "<per_item_task_execution>", "exec")
result = []
for index, item in enumerate(items):
globals = {"__builtins__": __builtins__, "_item": item}
exec(compiled_code, globals)
user_output = globals[EXECUTOR_USER_OUTPUT_KEY]
if user_output is None:
continue
user_output["pairedItem"] = {"item": index}
result.append(user_output)
queue.put({"result": result})
except Exception as e:
TaskExecutor._put_error(queue, e)
@staticmethod
def _wrap_code(raw_code: str) -> str:
indented_code = textwrap.indent(raw_code, " ")
return f"def _user_function():\n{indented_code}\n\n{EXECUTOR_USER_OUTPUT_KEY} = _user_function()"
@staticmethod
def _put_error(queue: multiprocessing.Queue, e: Exception):
queue.put({"error": {"message": str(e), "stack": traceback.format_exc()}})

View File

@@ -1,23 +1,27 @@
import asyncio
from dataclasses import dataclass
import logging
import time
from typing import Dict, Optional
from typing import Dict, Optional, Any
from urllib.parse import urlparse
from typing import Any
import websockets
import random
from nanoid import generate as nanoid
from .errors import WebsocketConnectionError, TaskMissingError
from .message_types.broker import TaskSettings
from .nanoid import nanoid
from .constants import (
RUNNER_NAME,
TASK_REJECTED_REASON_AT_CAPACITY,
TASK_REJECTED_REASON_OFFER_EXPIRED,
TASK_TYPE_PYTHON,
DEFAULT_MAX_CONCURRENCY,
DEFAULT_MAX_PAYLOAD_SIZE,
OFFER_INTERVAL,
OFFER_VALIDITY,
OFFER_VALIDITY_MAX_JITTER,
OFFER_VALIDITY_LATENCY_BUFFER,
WS_RUNNERS_PATH,
TASK_BROKER_WS_PATH,
)
from .message_types import (
BrokerMessage,
@@ -25,17 +29,18 @@ from .message_types import (
BrokerInfoRequest,
BrokerRunnerRegistered,
BrokerTaskOfferAccept,
BrokerTaskSettings,
BrokerTaskCancel,
RunnerInfo,
RunnerTaskOffer,
RunnerTaskAccepted,
RunnerTaskRejected,
RunnerTaskDone,
RunnerTaskError,
)
from .message_serde import MessageSerde
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
from .task_state import TaskState, TaskStatus
from .task_executor import TaskExecutor
class TaskOffer:
@@ -48,82 +53,95 @@ class TaskOffer:
return time.time() > self.valid_until
@dataclass()
class TaskRunnerOpts:
grant_token: str
task_broker_uri: str
max_concurrency: int
max_payload_size: int
task_timeout: int
class TaskRunner:
def __init__(
self,
task_broker_uri: str = "http://127.0.0.1:5679",
grant_token: str = "",
opts: TaskRunnerOpts,
):
self.runner_id = nanoid()
self.name = RUNNER_NAME
self.task_broker_uri = task_broker_uri
self.grant_token = grant_token
self.name = "Python Task Runner"
self.max_concurrency = DEFAULT_MAX_CONCURRENCY
self.max_payload_size = DEFAULT_MAX_PAYLOAD_SIZE
self.grant_token = opts.grant_token
self.opts = opts
self.websocket: Optional[Any] = None
self.websocket_connection: Optional[Any] = None
self.can_send_offers = False
self.open_offers: Dict[str, TaskOffer] = {} # offer_id -> TaskOffer
self.running_tasks: Dict[str, str] = {} # task_id -> offer_id
self.open_offers: Dict[str, TaskOffer] = {}
self.running_tasks: Dict[str, TaskState] = {}
self.offers_coroutine: Optional[asyncio.Task] = None
self.serde = MessageSerde()
self.executor = TaskExecutor()
self.logger = logging.getLogger(__name__)
ws_host = urlparse(task_broker_uri).netloc
self.ws_url = f"ws://{ws_host}{WS_RUNNERS_PATH}?id={self.runner_id}"
self.task_broker_uri = opts.task_broker_uri
websocket_host = urlparse(opts.task_broker_uri).netloc
self.websocket_url = (
f"ws://{websocket_host}{TASK_BROKER_WS_PATH}?id={self.runner_id}"
)
async def start(self) -> None:
logger.info("Starting Python task runner...")
headers = {"Authorization": f"Bearer {self.grant_token}"}
try:
self.websocket = await websockets.connect(
self.ws_url,
self.websocket_connection = await websockets.connect(
self.websocket_url,
additional_headers=headers,
max_size=self.max_payload_size,
max_size=self.opts.max_payload_size,
)
logger.info(f"Connected to task broker at {self.ws_url}")
self.logger.info("Connected to broker")
await self._listen_for_messages()
except Exception as e:
logger.error(f"Failed to connect to task broker: {e}")
raise
except Exception:
raise WebsocketConnectionError(self.task_broker_uri)
async def stop(self) -> None:
logger.info("Stopping Python task runner...")
if self.offers_coroutine:
self.offers_coroutine.cancel()
if self.websocket:
await self.websocket.close()
if self.websocket_connection:
await self.websocket_connection.close()
self.logger.info("Disconnected from broker")
# ========== Messages ==========
async def _listen_for_messages(self) -> None:
if self.websocket is None:
raise RuntimeError("WebSocket not connected")
if self.websocket_connection is None:
raise WebsocketConnectionError(self.task_broker_uri)
async for raw_message in self.websocket:
async for raw_message in self.websocket_connection:
try:
message = MessageSerde.deserialize_broker_message(raw_message)
message = self.serde.deserialize_broker_message(raw_message)
await self._handle_message(message)
except Exception as e:
logger.error(f"Error handling message: {e}")
self.logger.error(f"Error handling message: {e}")
async def _handle_message(self, message: BrokerMessage) -> None:
if isinstance(message, BrokerInfoRequest):
await self._handle_info_request()
elif isinstance(message, BrokerRunnerRegistered):
await self._handle_runner_registered()
elif isinstance(message, BrokerTaskOfferAccept):
await self._handle_task_offer_accept(message)
else:
logger.warning(f"Unhandled message type: {type(message)}")
match message:
case BrokerInfoRequest():
await self._handle_info_request()
case BrokerRunnerRegistered():
await self._handle_runner_registered()
case BrokerTaskOfferAccept():
await self._handle_task_offer_accept(message)
case BrokerTaskSettings():
await self._handle_task_settings(message)
case BrokerTaskCancel():
await self._handle_task_cancel(message)
case _:
self.logger.warning(f"Unhandled message type: {type(message)}")
async def _handle_info_request(self) -> None:
response = RunnerInfo(name=self.name, types=[TASK_TYPE_PYTHON])
@@ -132,38 +150,103 @@ class TaskRunner:
async def _handle_runner_registered(self) -> None:
self.can_send_offers = True
self.offers_coroutine = asyncio.create_task(self._send_offers_loop())
self.logger.info("Registered with broker")
async def _handle_task_offer_accept(self, message: BrokerTaskOfferAccept) -> None:
offer = self.open_offers.get(message.offer_id)
if not offer or offer.has_expired:
if offer is None or offer.has_expired:
response = RunnerTaskRejected(
task_id=message.task_id,
reason="Offer expired - not accepted within validity window",
reason=TASK_REJECTED_REASON_OFFER_EXPIRED,
)
await self._send_message(response)
return
if len(self.running_tasks) >= self.max_concurrency:
if len(self.running_tasks) >= self.opts.max_concurrency:
response = RunnerTaskRejected(
task_id=message.task_id,
reason="No open task slots - runner already at capacity",
reason=TASK_REJECTED_REASON_AT_CAPACITY,
)
await self._send_message(response)
return
del self.open_offers[message.offer_id]
self.running_tasks[message.task_id] = message.offer_id
task_state = TaskState(message.task_id)
self.running_tasks[message.task_id] = task_state
response = RunnerTaskAccepted(task_id=message.task_id)
await self._send_message(response)
self.logger.info(f"Accepted task {message.task_id}")
async def _handle_task_settings(self, message: BrokerTaskSettings) -> None:
task_state = self.running_tasks.get(message.task_id)
if task_state is None:
raise TaskMissingError(message.task_id)
if task_state.status != TaskStatus.WAITING_FOR_SETTINGS:
self.logger.warning(
f"Received settings for task but it is already {task_state.status}. Discarding message."
)
return
task_state.status = TaskStatus.RUNNING
asyncio.create_task(self._execute_task(message.task_id, message.settings))
self.logger.info(f"Received task {message.task_id}")
async def _execute_task(self, task_id: str, task_settings: TaskSettings) -> None:
try:
task_state = self.running_tasks.get(task_id)
if task_state is None:
raise TaskMissingError(task_id)
process, queue = self.executor.create_process(
task_settings.code, task_settings.node_mode, task_settings.items
)
task_state.process = process
result = self.executor.execute_process(
process, queue, self.opts.task_timeout, task_settings.continue_on_fail
)
response = RunnerTaskDone(task_id=task_id, data={"result": result})
await self._send_message(response)
self.logger.info(f"Completed task {task_id}")
except Exception as e:
response = RunnerTaskError(task_id=task_id, error={"message": str(e)})
await self._send_message(response)
finally:
self.running_tasks.pop(task_id, None)
async def _handle_task_cancel(self, message: BrokerTaskCancel) -> None:
task_state = self.running_tasks.get(message.task_id)
if task_state is None:
self.logger.warning(
f"Received cancel for unknown task: {message.task_id}. Discarding message."
)
return
if task_state.status == TaskStatus.WAITING_FOR_SETTINGS:
self.running_tasks.pop(message.task_id, None)
await self._send_offers()
return
if task_state.status == TaskStatus.RUNNING:
task_state.status = TaskStatus.ABORTING
self.executor.stop_process(task_state.process)
async def _send_message(self, message: RunnerMessage) -> None:
if not self.websocket:
raise RuntimeError("WebSocket not connected")
if self.websocket_connection is None:
raise WebsocketConnectionError(self.task_broker_uri)
serialized = MessageSerde.serialize_runner_message(message)
await self.websocket.send(serialized)
serialized = self.serde.serialize_runner_message(message)
await self.websocket_connection.send(serialized)
# ========== Offers ==========
@@ -175,7 +258,7 @@ class TaskRunner:
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error sending offers: {e}")
self.logger.error(f"Error sending offers: {e}")
async def _send_offers(self) -> None:
if not self.can_send_offers:
@@ -188,9 +271,9 @@ class TaskRunner:
]
for offer_id in expired_offer_ids:
del self.open_offers[offer_id]
self.open_offers.pop(offer_id, None)
offers_to_send = self.max_concurrency - (
offers_to_send = self.opts.max_concurrency - (
len(self.open_offers) + len(self.running_tasks)
)
@@ -203,8 +286,7 @@ class TaskRunner:
time.time() + (valid_for_ms / 1000) + OFFER_VALIDITY_LATENCY_BUFFER
)
offer = TaskOffer(offer_id, valid_until)
self.open_offers[offer_id] = offer
self.open_offers[offer_id] = TaskOffer(offer_id, valid_until)
message = RunnerTaskOffer(
offer_id=offer_id, task_type=TASK_TYPE_PYTHON, valid_for=valid_for_ms

View File

@@ -0,0 +1,22 @@
from enum import Enum
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
from typing import Optional
class TaskStatus(Enum):
WAITING_FOR_SETTINGS = "waiting_for_settings"
RUNNING = "running"
ABORTING = "aborting"
@dataclass
class TaskState:
task_id: str
status: TaskStatus
process: Optional[SpawnProcess] = None
def __init__(self, task_id: str):
self.task_id = task_id
self.status = TaskStatus.WAITING_FOR_SETTINGS
self.process = None