mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-20 19:32:15 +00:00
test: Add integration tests for execution flows in native Python runner (#19198)
This commit is contained in:
186
packages/@n8n/task-runner-python/tests/fixtures/local_task_broker.py
vendored
Normal file
186
packages/@n8n/task-runner-python/tests/fixtures/local_task_broker.py
vendored
Normal file
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web, web_ws
|
||||
from src.nanoid import nanoid
|
||||
|
||||
from tests.fixtures.test_constants import (
|
||||
TASK_RESPONSE_WAIT,
|
||||
LOCAL_TASK_BROKER_PORT,
|
||||
LOCAL_TASK_BROKER_WS_PATH,
|
||||
)
|
||||
|
||||
TaskId = str
|
||||
TaskSettings = dict[str, Any]
|
||||
WebsocketMessage = dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
settings: TaskSettings
|
||||
|
||||
|
||||
class LocalTaskBroker:
|
||||
def __init__(self):
|
||||
self.port = LOCAL_TASK_BROKER_PORT
|
||||
self.app = web.Application()
|
||||
self.runner: web.AppRunner | None = None
|
||||
self.site: web.TCPSite | None = None
|
||||
self.connections: dict[str, web_ws.WebSocketResponse] = {}
|
||||
self.pending_messages: dict[str, asyncio.Queue[WebsocketMessage]] = {}
|
||||
self.received_messages: list[WebsocketMessage] = []
|
||||
self.active_tasks: dict[TaskId, ActiveTask] = {}
|
||||
self.task_settings: dict[TaskId, TaskSettings] = {}
|
||||
self.rpc_messages: dict[TaskId, list[dict]] = {}
|
||||
self.app.router.add_get(LOCAL_TASK_BROKER_WS_PATH, self.websocket_handler)
|
||||
|
||||
async def start(self) -> None:
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
self.site = web.TCPSite(self.runner, "localhost", self.port)
|
||||
await self.site.start()
|
||||
print(f"Local task broker started on port {self.port}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
for ws in self.connections.values():
|
||||
await ws.close()
|
||||
self.connections.clear()
|
||||
|
||||
if self.site:
|
||||
await self.site.stop()
|
||||
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
|
||||
async def websocket_handler(self, request: web.Request) -> web_ws.WebSocketResponse:
|
||||
print(f"WebSocket connection request from {request.remote}")
|
||||
ws = web_ws.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
connection_id = nanoid()
|
||||
self.connections[connection_id] = ws
|
||||
self.pending_messages[connection_id] = asyncio.Queue()
|
||||
|
||||
sender_coroutine = asyncio.create_task(self._message_sender(connection_id, ws))
|
||||
|
||||
try:
|
||||
await self.send_to_connection(connection_id, {"type": "broker:inforequest"})
|
||||
|
||||
async for message in ws:
|
||||
if message.type == web_ws.WSMsgType.TEXT:
|
||||
json_message = json.loads(message.data)
|
||||
self.received_messages.append(json_message)
|
||||
await self._handle_message(connection_id, json_message)
|
||||
finally:
|
||||
sender_coroutine.cancel()
|
||||
try:
|
||||
await sender_coroutine
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
del self.connections[connection_id]
|
||||
del self.pending_messages[connection_id]
|
||||
|
||||
return ws
|
||||
|
||||
async def _message_sender(self, connection_id: str, ws: web_ws.WebSocketResponse):
|
||||
while True:
|
||||
message = await self.pending_messages[connection_id].get()
|
||||
await ws.send_str(json.dumps(message))
|
||||
|
||||
async def _handle_message(self, connection_id: str, message: WebsocketMessage):
|
||||
match message.get("type"):
|
||||
case "runner:info":
|
||||
await self.send_to_connection(
|
||||
connection_id, {"type": "broker:runnerregistered"}
|
||||
)
|
||||
|
||||
case "runner:taskoffer":
|
||||
pass # Handled by send_task() which waits for them
|
||||
|
||||
case "runner:taskaccepted":
|
||||
task_id = message.get("taskId")
|
||||
if task_id in self.task_settings:
|
||||
await self.send_to_connection(
|
||||
connection_id,
|
||||
{
|
||||
"type": "broker:tasksettings",
|
||||
"taskId": task_id,
|
||||
"settings": self.task_settings[task_id],
|
||||
},
|
||||
)
|
||||
|
||||
case "runner:taskdone" | "runner:taskerror":
|
||||
task_id = message.get("taskId")
|
||||
if task_id in self.active_tasks:
|
||||
del self.active_tasks[task_id]
|
||||
|
||||
case "runner:rpc":
|
||||
task_id = message.get("taskId")
|
||||
if task_id:
|
||||
if task_id not in self.rpc_messages:
|
||||
self.rpc_messages[task_id] = []
|
||||
self.rpc_messages[task_id].append(
|
||||
{"method": message.get("name"), "params": message.get("params")}
|
||||
)
|
||||
|
||||
async def send_to_connection(self, connection_id: str, message: WebsocketMessage):
|
||||
if connection_id in self.pending_messages:
|
||||
await self.pending_messages[connection_id].put(message)
|
||||
|
||||
async def send_task(
|
||||
self,
|
||||
task_id: TaskId,
|
||||
task_settings: TaskSettings,
|
||||
):
|
||||
self.active_tasks[task_id] = ActiveTask(task_settings)
|
||||
self.task_settings[task_id] = task_settings
|
||||
|
||||
offer = await self.wait_for_msg("runner:taskoffer", timeout=2.0)
|
||||
|
||||
if offer:
|
||||
accept = {
|
||||
"type": "broker:taskofferaccept",
|
||||
"taskId": task_id,
|
||||
"offerId": offer.get("offerId"),
|
||||
}
|
||||
|
||||
if self.connections:
|
||||
connection = next(iter(self.connections.keys()))
|
||||
await self.send_to_connection(connection, accept)
|
||||
|
||||
async def cancel_task(self, task_id: TaskId, reason: str):
|
||||
cancel_message = {
|
||||
"type": "broker:taskcancel",
|
||||
"taskId": task_id,
|
||||
"reason": reason,
|
||||
}
|
||||
|
||||
for connection_id in self.connections:
|
||||
await self.send_to_connection(connection_id, cancel_message)
|
||||
|
||||
async def wait_for_msg(
|
||||
self,
|
||||
msg_type: str,
|
||||
timeout: float = TASK_RESPONSE_WAIT,
|
||||
predicate: Callable[[WebsocketMessage], bool] | None = None,
|
||||
) -> WebsocketMessage | None:
|
||||
start_time = asyncio.get_running_loop().time()
|
||||
|
||||
while asyncio.get_running_loop().time() - start_time < timeout:
|
||||
for msg in self.received_messages:
|
||||
if msg.get("type") == msg_type:
|
||||
if predicate is None or predicate(msg):
|
||||
return msg
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return None
|
||||
|
||||
def get_messages_of_type(self, msg_type: str) -> list[WebsocketMessage]:
|
||||
return [msg for msg in self.received_messages if msg.get("type") == msg_type]
|
||||
|
||||
def get_task_rpc_messages(self, task_id: TaskId) -> list[dict]:
|
||||
return self.rpc_messages.get(task_id, [])
|
||||
Reference in New Issue
Block a user