test: Add integration tests for execution flows in native Python runner (#19198)

This commit is contained in:
Iván Ovejero
2025-09-05 10:49:45 +02:00
committed by GitHub
parent 2001397387
commit 36958e3ffa
14 changed files with 1228 additions and 16 deletions

View File

@@ -0,0 +1,186 @@
import asyncio
import json
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from aiohttp import web, web_ws
from src.nanoid import nanoid
from tests.fixtures.test_constants import (
TASK_RESPONSE_WAIT,
LOCAL_TASK_BROKER_PORT,
LOCAL_TASK_BROKER_WS_PATH,
)
TaskId = str
TaskSettings = dict[str, Any]
WebsocketMessage = dict[str, Any]
@dataclass
class ActiveTask:
settings: TaskSettings
class LocalTaskBroker:
def __init__(self):
self.port = LOCAL_TASK_BROKER_PORT
self.app = web.Application()
self.runner: web.AppRunner | None = None
self.site: web.TCPSite | None = None
self.connections: dict[str, web_ws.WebSocketResponse] = {}
self.pending_messages: dict[str, asyncio.Queue[WebsocketMessage]] = {}
self.received_messages: list[WebsocketMessage] = []
self.active_tasks: dict[TaskId, ActiveTask] = {}
self.task_settings: dict[TaskId, TaskSettings] = {}
self.rpc_messages: dict[TaskId, list[dict]] = {}
self.app.router.add_get(LOCAL_TASK_BROKER_WS_PATH, self.websocket_handler)
async def start(self) -> None:
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", self.port)
await self.site.start()
print(f"Local task broker started on port {self.port}")
async def stop(self) -> None:
for ws in self.connections.values():
await ws.close()
self.connections.clear()
if self.site:
await self.site.stop()
if self.runner:
await self.runner.cleanup()
async def websocket_handler(self, request: web.Request) -> web_ws.WebSocketResponse:
print(f"WebSocket connection request from {request.remote}")
ws = web_ws.WebSocketResponse()
await ws.prepare(request)
connection_id = nanoid()
self.connections[connection_id] = ws
self.pending_messages[connection_id] = asyncio.Queue()
sender_coroutine = asyncio.create_task(self._message_sender(connection_id, ws))
try:
await self.send_to_connection(connection_id, {"type": "broker:inforequest"})
async for message in ws:
if message.type == web_ws.WSMsgType.TEXT:
json_message = json.loads(message.data)
self.received_messages.append(json_message)
await self._handle_message(connection_id, json_message)
finally:
sender_coroutine.cancel()
try:
await sender_coroutine
except asyncio.CancelledError:
pass
del self.connections[connection_id]
del self.pending_messages[connection_id]
return ws
async def _message_sender(self, connection_id: str, ws: web_ws.WebSocketResponse):
while True:
message = await self.pending_messages[connection_id].get()
await ws.send_str(json.dumps(message))
async def _handle_message(self, connection_id: str, message: WebsocketMessage):
match message.get("type"):
case "runner:info":
await self.send_to_connection(
connection_id, {"type": "broker:runnerregistered"}
)
case "runner:taskoffer":
pass # Handled by send_task() which waits for them
case "runner:taskaccepted":
task_id = message.get("taskId")
if task_id in self.task_settings:
await self.send_to_connection(
connection_id,
{
"type": "broker:tasksettings",
"taskId": task_id,
"settings": self.task_settings[task_id],
},
)
case "runner:taskdone" | "runner:taskerror":
task_id = message.get("taskId")
if task_id in self.active_tasks:
del self.active_tasks[task_id]
case "runner:rpc":
task_id = message.get("taskId")
if task_id:
if task_id not in self.rpc_messages:
self.rpc_messages[task_id] = []
self.rpc_messages[task_id].append(
{"method": message.get("name"), "params": message.get("params")}
)
async def send_to_connection(self, connection_id: str, message: WebsocketMessage):
if connection_id in self.pending_messages:
await self.pending_messages[connection_id].put(message)
async def send_task(
self,
task_id: TaskId,
task_settings: TaskSettings,
):
self.active_tasks[task_id] = ActiveTask(task_settings)
self.task_settings[task_id] = task_settings
offer = await self.wait_for_msg("runner:taskoffer", timeout=2.0)
if offer:
accept = {
"type": "broker:taskofferaccept",
"taskId": task_id,
"offerId": offer.get("offerId"),
}
if self.connections:
connection = next(iter(self.connections.keys()))
await self.send_to_connection(connection, accept)
async def cancel_task(self, task_id: TaskId, reason: str):
cancel_message = {
"type": "broker:taskcancel",
"taskId": task_id,
"reason": reason,
}
for connection_id in self.connections:
await self.send_to_connection(connection_id, cancel_message)
async def wait_for_msg(
self,
msg_type: str,
timeout: float = TASK_RESPONSE_WAIT,
predicate: Callable[[WebsocketMessage], bool] | None = None,
) -> WebsocketMessage | None:
start_time = asyncio.get_running_loop().time()
while asyncio.get_running_loop().time() - start_time < timeout:
for msg in self.received_messages:
if msg.get("type") == msg_type:
if predicate is None or predicate(msg):
return msg
await asyncio.sleep(0.1)
return None
def get_messages_of_type(self, msg_type: str) -> list[WebsocketMessage]:
return [msg for msg in self.received_messages if msg.get("type") == msg_type]
def get_task_rpc_messages(self, task_id: TaskId) -> list[dict]:
return self.rpc_messages.get(task_id, [])