diff --git a/packages/@n8n/task-runner-python/src/task_runner.py b/packages/@n8n/task-runner-python/src/task_runner.py index 8cbc4c5a59..3981917bc4 100644 --- a/packages/@n8n/task-runner-python/src/task_runner.py +++ b/packages/@n8n/task-runner-python/src/task_runner.py @@ -89,6 +89,7 @@ class TaskRunner: self.idle_coroutine: Optional[asyncio.Task] = None self.on_idle_timeout: Optional[Callable[[], Awaitable[None]]] = None self.last_activity_time = time.time() + self.is_shutting_down = False self.task_broker_uri = config.task_broker_uri websocket_host = urlparse(config.task_broker_uri).netloc @@ -102,38 +103,42 @@ class TaskRunner: headers = {"Authorization": f"Bearer {self.config.grant_token}"} - try: - self.websocket_connection = await websockets.connect( - self.websocket_url, - additional_headers=headers, - max_size=self.config.max_payload_size, - ) + while not self.is_shutting_down: + try: + self.websocket_connection = await websockets.connect( + self.websocket_url, + additional_headers=headers, + max_size=self.config.max_payload_size, + ) + self.logger.info("Connected to broker") + await self._listen_for_messages() - self.logger.info("Connected to broker") + except Exception: + raise WebsocketConnectionError(self.task_broker_uri) - await self._listen_for_messages() + if not self.is_shutting_down: + self.websocket_connection = None + self.can_send_offers = False + await self._cancel_coroutine(self.offers_coroutine) + await self._cancel_coroutine(self.idle_coroutine) + await asyncio.sleep(5) - except Exception: - raise WebsocketConnectionError(self.task_broker_uri) + async def _cancel_coroutine(self, coroutine: Optional[asyncio.Task]) -> None: + if coroutine and not coroutine.done(): + coroutine.cancel() + try: + await coroutine + except asyncio.CancelledError: + pass # ========== Shutdown ========== async def stop(self) -> None: + self.is_shutting_down = True self.can_send_offers = False - if self.offers_coroutine and not self.offers_coroutine.done(): - self.offers_coroutine.cancel() - try: - await self.offers_coroutine - except asyncio.CancelledError: - pass - - if self.idle_coroutine and not self.idle_coroutine.done(): - self.idle_coroutine.cancel() - try: - await self.idle_coroutine - except asyncio.CancelledError: - pass + await self._cancel_coroutine(self.offers_coroutine) + await self._cancel_coroutine(self.idle_coroutine) await self._wait_for_tasks() await self._terminate_tasks() @@ -148,23 +153,25 @@ class TaskRunner: if not self.running_tasks: return - self.logger.debug("Waiting for tasks to complete...") + timeout = self.config.graceful_shutdown_timeout + self.logger.debug( + f"Waiting for {len(self.running_tasks)} tasks to complete (timeout: {timeout}s)..." + ) start_time = time.time() - while ( - self.running_tasks - and (time.time() - start_time) < self.config.graceful_shutdown_timeout - ): + while self.running_tasks and (time.time() - start_time) < timeout: await asyncio.sleep(0.5) if self.running_tasks: - self.logger.warning("Timed out waiting for tasks to complete") + self.logger.warning( + f"Timed out waiting for {len(self.running_tasks)} tasks to complete" + ) async def _terminate_tasks(self): if not self.running_tasks: return - self.logger.warning("Terminating tasks...") + self.logger.warning(f"Terminating {len(self.running_tasks)} tasks...") tasks_to_terminate = [ asyncio.to_thread(self.executor.stop_process, task_state.process)