mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 01:56:46 +00:00
366 lines
12 KiB
Python
366 lines
12 KiB
Python
import asyncio
|
|
from dataclasses import dataclass
|
|
import logging
|
|
import time
|
|
from typing import Dict, Optional, Any, Set
|
|
from urllib.parse import urlparse
|
|
import websockets
|
|
import random
|
|
|
|
|
|
from src.errors import (
|
|
WebsocketConnectionError,
|
|
TaskMissingError,
|
|
)
|
|
from src.message_types.broker import TaskSettings
|
|
from src.nanoid import nanoid
|
|
|
|
from src.constants import (
|
|
RUNNER_NAME,
|
|
TASK_REJECTED_REASON_AT_CAPACITY,
|
|
TASK_REJECTED_REASON_OFFER_EXPIRED,
|
|
TASK_TYPE_PYTHON,
|
|
OFFER_INTERVAL,
|
|
OFFER_VALIDITY,
|
|
OFFER_VALIDITY_MAX_JITTER,
|
|
OFFER_VALIDITY_LATENCY_BUFFER,
|
|
TASK_BROKER_WS_PATH,
|
|
RPC_BROWSER_CONSOLE_LOG_METHOD,
|
|
LOG_TASK_COMPLETE,
|
|
LOG_TASK_CANCEL,
|
|
LOG_TASK_CANCEL_UNKNOWN,
|
|
LOG_TASK_CANCEL_WAITING,
|
|
)
|
|
from src.message_types import (
|
|
BrokerMessage,
|
|
RunnerMessage,
|
|
BrokerInfoRequest,
|
|
BrokerRunnerRegistered,
|
|
BrokerTaskOfferAccept,
|
|
BrokerTaskSettings,
|
|
BrokerTaskCancel,
|
|
BrokerRpcResponse,
|
|
RunnerInfo,
|
|
RunnerTaskOffer,
|
|
RunnerTaskAccepted,
|
|
RunnerTaskRejected,
|
|
RunnerTaskDone,
|
|
RunnerTaskError,
|
|
RunnerRpcCall,
|
|
)
|
|
from src.message_serde import MessageSerde
|
|
from src.task_state import TaskState, TaskStatus
|
|
from src.task_executor import TaskExecutor
|
|
from src.task_analyzer import TaskAnalyzer
|
|
|
|
|
|
class TaskOffer:
|
|
def __init__(self, offer_id: str, valid_until: float):
|
|
self.offer_id = offer_id
|
|
self.valid_until = valid_until
|
|
|
|
@property
|
|
def has_expired(self) -> bool:
|
|
return time.time() > self.valid_until
|
|
|
|
|
|
@dataclass
|
|
class TaskRunnerOpts:
|
|
grant_token: str
|
|
task_broker_uri: str
|
|
max_concurrency: int
|
|
max_payload_size: int
|
|
task_timeout: int
|
|
stdlib_allow: Set[str]
|
|
external_allow: Set[str]
|
|
builtins_deny: Set[str]
|
|
|
|
|
|
class TaskRunner:
|
|
def __init__(
|
|
self,
|
|
opts: TaskRunnerOpts,
|
|
):
|
|
self.runner_id = nanoid()
|
|
self.name = RUNNER_NAME
|
|
|
|
self.grant_token = opts.grant_token
|
|
self.opts = opts
|
|
|
|
self.websocket_connection: Optional[Any] = None
|
|
self.can_send_offers = False
|
|
|
|
self.open_offers: Dict[str, TaskOffer] = {}
|
|
self.running_tasks: Dict[str, TaskState] = {}
|
|
|
|
self.offers_coroutine: Optional[asyncio.Task] = None
|
|
self.serde = MessageSerde()
|
|
self.executor = TaskExecutor()
|
|
self.analyzer = TaskAnalyzer(opts.stdlib_allow, opts.external_allow)
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
self.task_broker_uri = opts.task_broker_uri
|
|
websocket_host = urlparse(opts.task_broker_uri).netloc
|
|
self.websocket_url = (
|
|
f"ws://{websocket_host}{TASK_BROKER_WS_PATH}?id={self.runner_id}"
|
|
)
|
|
|
|
async def start(self) -> None:
|
|
headers = {"Authorization": f"Bearer {self.grant_token}"}
|
|
|
|
try:
|
|
self.websocket_connection = await websockets.connect(
|
|
self.websocket_url,
|
|
additional_headers=headers,
|
|
max_size=self.opts.max_payload_size,
|
|
)
|
|
|
|
self.logger.info("Connected to broker")
|
|
|
|
await self._listen_for_messages()
|
|
|
|
except Exception:
|
|
raise WebsocketConnectionError(self.task_broker_uri)
|
|
|
|
async def stop(self) -> None:
|
|
if self.offers_coroutine:
|
|
self.offers_coroutine.cancel()
|
|
|
|
if self.websocket_connection:
|
|
await self.websocket_connection.close()
|
|
self.logger.info("Disconnected from broker")
|
|
|
|
self.logger.info("Runner stopped")
|
|
|
|
# ========== Messages ==========
|
|
|
|
async def _listen_for_messages(self) -> None:
|
|
if self.websocket_connection is None:
|
|
raise WebsocketConnectionError(self.task_broker_uri)
|
|
|
|
async for raw_message in self.websocket_connection:
|
|
try:
|
|
message = self.serde.deserialize_broker_message(raw_message)
|
|
await self._handle_message(message)
|
|
except Exception as e:
|
|
self.logger.error(f"Error handling message: {e}")
|
|
|
|
async def _handle_message(self, message: BrokerMessage) -> None:
|
|
match message:
|
|
case BrokerInfoRequest():
|
|
await self._handle_info_request()
|
|
case BrokerRunnerRegistered():
|
|
await self._handle_runner_registered()
|
|
case BrokerTaskOfferAccept():
|
|
await self._handle_task_offer_accept(message)
|
|
case BrokerTaskSettings():
|
|
await self._handle_task_settings(message)
|
|
case BrokerTaskCancel():
|
|
await self._handle_task_cancel(message)
|
|
case BrokerRpcResponse():
|
|
pass # currently only logging, already handled by browser
|
|
case _:
|
|
self.logger.warning(f"Unhandled message type: {type(message)}")
|
|
|
|
async def _handle_info_request(self) -> None:
|
|
response = RunnerInfo(name=self.name, types=[TASK_TYPE_PYTHON])
|
|
await self._send_message(response)
|
|
|
|
async def _handle_runner_registered(self) -> None:
|
|
self.can_send_offers = True
|
|
self.offers_coroutine = asyncio.create_task(self._send_offers_loop())
|
|
self.logger.info("Registered with broker")
|
|
|
|
async def _handle_task_offer_accept(self, message: BrokerTaskOfferAccept) -> None:
|
|
offer = self.open_offers.get(message.offer_id)
|
|
|
|
if offer is None or offer.has_expired:
|
|
response = RunnerTaskRejected(
|
|
task_id=message.task_id,
|
|
reason=TASK_REJECTED_REASON_OFFER_EXPIRED,
|
|
)
|
|
await self._send_message(response)
|
|
return
|
|
|
|
if len(self.running_tasks) >= self.opts.max_concurrency:
|
|
response = RunnerTaskRejected(
|
|
task_id=message.task_id,
|
|
reason=TASK_REJECTED_REASON_AT_CAPACITY,
|
|
)
|
|
await self._send_message(response)
|
|
return
|
|
|
|
del self.open_offers[message.offer_id]
|
|
|
|
task_state = TaskState(message.task_id)
|
|
self.running_tasks[message.task_id] = task_state
|
|
|
|
response = RunnerTaskAccepted(task_id=message.task_id)
|
|
await self._send_message(response)
|
|
self.logger.info(f"Accepted task {message.task_id}")
|
|
|
|
async def _handle_task_settings(self, message: BrokerTaskSettings) -> None:
|
|
task_state = self.running_tasks.get(message.task_id)
|
|
if task_state is None:
|
|
raise TaskMissingError(message.task_id)
|
|
|
|
if task_state.status != TaskStatus.WAITING_FOR_SETTINGS:
|
|
self.logger.warning(
|
|
f"Received settings for task but it is already {task_state.status}. Discarding message."
|
|
)
|
|
return
|
|
|
|
task_state.workflow_name = message.settings.workflow_name
|
|
task_state.workflow_id = message.settings.workflow_id
|
|
task_state.node_name = message.settings.node_name
|
|
task_state.node_id = message.settings.node_id
|
|
|
|
task_state.status = TaskStatus.RUNNING
|
|
asyncio.create_task(self._execute_task(message.task_id, message.settings))
|
|
self.logger.info(f"Received task {message.task_id}")
|
|
|
|
async def _execute_task(self, task_id: str, task_settings: TaskSettings) -> None:
|
|
start_time = time.time()
|
|
|
|
try:
|
|
task_state = self.running_tasks.get(task_id)
|
|
|
|
if task_state is None:
|
|
raise TaskMissingError(task_id)
|
|
|
|
self.analyzer.validate(task_settings.code)
|
|
|
|
process, queue = self.executor.create_process(
|
|
code=task_settings.code,
|
|
node_mode=task_settings.node_mode,
|
|
items=task_settings.items,
|
|
stdlib_allow=self.opts.stdlib_allow,
|
|
external_allow=self.opts.external_allow,
|
|
builtins_deny=self.opts.builtins_deny,
|
|
)
|
|
|
|
task_state.process = process
|
|
|
|
result, print_args = self.executor.execute_process(
|
|
process=process,
|
|
queue=queue,
|
|
task_timeout=self.opts.task_timeout,
|
|
continue_on_fail=task_settings.continue_on_fail,
|
|
)
|
|
|
|
for print_args_per_call in print_args:
|
|
await self._send_rpc_message(
|
|
task_id, RPC_BROWSER_CONSOLE_LOG_METHOD, print_args_per_call
|
|
)
|
|
|
|
response = RunnerTaskDone(task_id=task_id, data={"result": result})
|
|
await self._send_message(response)
|
|
|
|
self.logger.info(
|
|
LOG_TASK_COMPLETE.format(
|
|
task_id=task_id,
|
|
duration=self._get_duration(start_time),
|
|
**task_state.context(),
|
|
)
|
|
)
|
|
|
|
except Exception as e:
|
|
response = RunnerTaskError(task_id=task_id, error={"message": str(e)})
|
|
await self._send_message(response)
|
|
|
|
finally:
|
|
self.running_tasks.pop(task_id, None)
|
|
|
|
async def _handle_task_cancel(self, message: BrokerTaskCancel) -> None:
|
|
task_id = message.task_id
|
|
task_state = self.running_tasks.get(task_id)
|
|
|
|
if task_state is None:
|
|
self.logger.warning(LOG_TASK_CANCEL_UNKNOWN.format(task_id=task_id))
|
|
return
|
|
|
|
if task_state.status == TaskStatus.WAITING_FOR_SETTINGS:
|
|
self.running_tasks.pop(task_id, None)
|
|
self.logger.info(LOG_TASK_CANCEL_WAITING.format(task_id=task_id))
|
|
await self._send_offers()
|
|
return
|
|
|
|
if task_state.status == TaskStatus.RUNNING:
|
|
task_state.status = TaskStatus.ABORTING
|
|
self.executor.stop_process(task_state.process)
|
|
|
|
self.logger.info(
|
|
LOG_TASK_CANCEL.format(task_id=task_id, **task_state.context())
|
|
)
|
|
|
|
async def _send_rpc_message(self, task_id: str, method_name: str, params: list):
|
|
message = RunnerRpcCall(
|
|
call_id=nanoid(), task_id=task_id, name=method_name, params=params
|
|
)
|
|
|
|
await self._send_message(message)
|
|
|
|
async def _send_message(self, message: RunnerMessage) -> None:
|
|
if self.websocket_connection is None:
|
|
raise WebsocketConnectionError(self.task_broker_uri)
|
|
|
|
serialized = self.serde.serialize_runner_message(message)
|
|
await self.websocket_connection.send(serialized)
|
|
|
|
def _get_duration(self, start_time: float) -> str:
|
|
elapsed = time.time() - start_time
|
|
|
|
if elapsed < 1:
|
|
return f"{int(elapsed * 1000)}ms"
|
|
|
|
if elapsed < 60:
|
|
return f"{int(elapsed)}s"
|
|
|
|
return f"{int(elapsed) // 60}m"
|
|
|
|
# ========== Offers ==========
|
|
|
|
async def _send_offers_loop(self) -> None:
|
|
while self.can_send_offers:
|
|
try:
|
|
await self._send_offers()
|
|
await asyncio.sleep(OFFER_INTERVAL)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
self.logger.error(f"Error sending offers: {e}")
|
|
|
|
async def _send_offers(self) -> None:
|
|
if not self.can_send_offers:
|
|
return
|
|
|
|
expired_offer_ids = [
|
|
offer_id
|
|
for offer_id, offer in self.open_offers.items()
|
|
if offer.has_expired
|
|
]
|
|
|
|
for offer_id in expired_offer_ids:
|
|
self.open_offers.pop(offer_id, None)
|
|
|
|
offers_to_send = self.opts.max_concurrency - (
|
|
len(self.open_offers) + len(self.running_tasks)
|
|
)
|
|
|
|
for _ in range(offers_to_send):
|
|
offer_id = nanoid()
|
|
|
|
valid_for_ms = OFFER_VALIDITY + random.randint(0, OFFER_VALIDITY_MAX_JITTER)
|
|
|
|
valid_until = (
|
|
time.time() + (valid_for_ms / 1000) + OFFER_VALIDITY_LATENCY_BUFFER
|
|
)
|
|
|
|
self.open_offers[offer_id] = TaskOffer(offer_id, valid_until)
|
|
|
|
message = RunnerTaskOffer(
|
|
offer_id=offer_id, task_type=TASK_TYPE_PYTHON, valid_for=valid_for_ms
|
|
)
|
|
|
|
await self._send_message(message)
|