fix(core): Reconnect native Python runner on ws connection drop (no-changelog) (#19140)

This commit is contained in:
Iván Ovejero
2025-09-03 15:37:03 +02:00
committed by GitHub
parent a9a1c23a67
commit b9aa322c3d

View File

@@ -89,6 +89,7 @@ class TaskRunner:
self.idle_coroutine: Optional[asyncio.Task] = None self.idle_coroutine: Optional[asyncio.Task] = None
self.on_idle_timeout: Optional[Callable[[], Awaitable[None]]] = None self.on_idle_timeout: Optional[Callable[[], Awaitable[None]]] = None
self.last_activity_time = time.time() self.last_activity_time = time.time()
self.is_shutting_down = False
self.task_broker_uri = config.task_broker_uri self.task_broker_uri = config.task_broker_uri
websocket_host = urlparse(config.task_broker_uri).netloc websocket_host = urlparse(config.task_broker_uri).netloc
@@ -102,38 +103,42 @@ class TaskRunner:
headers = {"Authorization": f"Bearer {self.config.grant_token}"} headers = {"Authorization": f"Bearer {self.config.grant_token}"}
try: while not self.is_shutting_down:
self.websocket_connection = await websockets.connect( try:
self.websocket_url, self.websocket_connection = await websockets.connect(
additional_headers=headers, self.websocket_url,
max_size=self.config.max_payload_size, additional_headers=headers,
) max_size=self.config.max_payload_size,
)
self.logger.info("Connected to broker")
await self._listen_for_messages()
self.logger.info("Connected to broker") except Exception:
raise WebsocketConnectionError(self.task_broker_uri)
await self._listen_for_messages() if not self.is_shutting_down:
self.websocket_connection = None
self.can_send_offers = False
await self._cancel_coroutine(self.offers_coroutine)
await self._cancel_coroutine(self.idle_coroutine)
await asyncio.sleep(5)
except Exception: async def _cancel_coroutine(self, coroutine: Optional[asyncio.Task]) -> None:
raise WebsocketConnectionError(self.task_broker_uri) if coroutine and not coroutine.done():
coroutine.cancel()
try:
await coroutine
except asyncio.CancelledError:
pass
# ========== Shutdown ========== # ========== Shutdown ==========
async def stop(self) -> None: async def stop(self) -> None:
self.is_shutting_down = True
self.can_send_offers = False self.can_send_offers = False
if self.offers_coroutine and not self.offers_coroutine.done(): await self._cancel_coroutine(self.offers_coroutine)
self.offers_coroutine.cancel() await self._cancel_coroutine(self.idle_coroutine)
try:
await self.offers_coroutine
except asyncio.CancelledError:
pass
if self.idle_coroutine and not self.idle_coroutine.done():
self.idle_coroutine.cancel()
try:
await self.idle_coroutine
except asyncio.CancelledError:
pass
await self._wait_for_tasks() await self._wait_for_tasks()
await self._terminate_tasks() await self._terminate_tasks()
@@ -148,23 +153,25 @@ class TaskRunner:
if not self.running_tasks: if not self.running_tasks:
return return
self.logger.debug("Waiting for tasks to complete...") timeout = self.config.graceful_shutdown_timeout
self.logger.debug(
f"Waiting for {len(self.running_tasks)} tasks to complete (timeout: {timeout}s)..."
)
start_time = time.time() start_time = time.time()
while ( while self.running_tasks and (time.time() - start_time) < timeout:
self.running_tasks
and (time.time() - start_time) < self.config.graceful_shutdown_timeout
):
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
if self.running_tasks: if self.running_tasks:
self.logger.warning("Timed out waiting for tasks to complete") self.logger.warning(
f"Timed out waiting for {len(self.running_tasks)} tasks to complete"
)
async def _terminate_tasks(self): async def _terminate_tasks(self):
if not self.running_tasks: if not self.running_tasks:
return return
self.logger.warning("Terminating tasks...") self.logger.warning(f"Terminating {len(self.running_tasks)} tasks...")
tasks_to_terminate = [ tasks_to_terminate = [
asyncio.to_thread(self.executor.stop_process, task_state.process) asyncio.to_thread(self.executor.stop_process, task_state.process)