mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 10:02:05 +00:00
feat: Add graceful shutdown and idle timeout to native Python runner (no-changelog) (#19125)
This commit is contained in:
@@ -8,6 +8,8 @@ from src.constants import (
|
|||||||
DEFAULT_MAX_PAYLOAD_SIZE,
|
DEFAULT_MAX_PAYLOAD_SIZE,
|
||||||
DEFAULT_TASK_BROKER_URI,
|
DEFAULT_TASK_BROKER_URI,
|
||||||
DEFAULT_TASK_TIMEOUT,
|
DEFAULT_TASK_TIMEOUT,
|
||||||
|
DEFAULT_AUTO_SHUTDOWN_TIMEOUT,
|
||||||
|
DEFAULT_SHUTDOWN_TIMEOUT,
|
||||||
ENV_BUILTINS_DENY,
|
ENV_BUILTINS_DENY,
|
||||||
ENV_EXTERNAL_ALLOW,
|
ENV_EXTERNAL_ALLOW,
|
||||||
ENV_GRANT_TOKEN,
|
ENV_GRANT_TOKEN,
|
||||||
@@ -16,6 +18,8 @@ from src.constants import (
|
|||||||
ENV_STDLIB_ALLOW,
|
ENV_STDLIB_ALLOW,
|
||||||
ENV_TASK_BROKER_URI,
|
ENV_TASK_BROKER_URI,
|
||||||
ENV_TASK_TIMEOUT,
|
ENV_TASK_TIMEOUT,
|
||||||
|
ENV_AUTO_SHUTDOWN_TIMEOUT,
|
||||||
|
ENV_GRACEFUL_SHUTDOWN_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -45,30 +49,54 @@ class TaskRunnerConfig:
|
|||||||
max_concurrency: int
|
max_concurrency: int
|
||||||
max_payload_size: int
|
max_payload_size: int
|
||||||
task_timeout: int
|
task_timeout: int
|
||||||
|
auto_shutdown_timeout: int
|
||||||
|
graceful_shutdown_timeout: int
|
||||||
stdlib_allow: Set[str]
|
stdlib_allow: Set[str]
|
||||||
external_allow: Set[str]
|
external_allow: Set[str]
|
||||||
builtins_deny: Set[str]
|
builtins_deny: Set[str]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_auto_shutdown_enabled(self) -> bool:
|
||||||
|
return self.auto_shutdown_timeout > 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_env(cls):
|
def from_env(cls):
|
||||||
grant_token = os.getenv(ENV_GRANT_TOKEN, "")
|
grant_token = os.getenv(ENV_GRANT_TOKEN, "")
|
||||||
if not grant_token:
|
if not grant_token:
|
||||||
raise ValueError("Environment variable N8N_RUNNERS_GRANT_TOKEN is required")
|
raise ValueError("Environment variable N8N_RUNNERS_GRANT_TOKEN is required")
|
||||||
|
|
||||||
task_timeout = int(os.getenv(ENV_TASK_TIMEOUT, str(DEFAULT_TASK_TIMEOUT)))
|
task_timeout = int(os.getenv(ENV_TASK_TIMEOUT, DEFAULT_TASK_TIMEOUT))
|
||||||
if task_timeout <= 0:
|
if task_timeout <= 0:
|
||||||
raise ValueError(f"Task timeout must be positive, got {task_timeout}")
|
raise ValueError(f"Task timeout must be positive, got {task_timeout}")
|
||||||
|
|
||||||
|
auto_shutdown_timeout = int(
|
||||||
|
os.getenv(ENV_AUTO_SHUTDOWN_TIMEOUT, DEFAULT_AUTO_SHUTDOWN_TIMEOUT)
|
||||||
|
)
|
||||||
|
if auto_shutdown_timeout < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Auto shutdown timeout must be non-negative, got {auto_shutdown_timeout}"
|
||||||
|
)
|
||||||
|
|
||||||
|
graceful_shutdown_timeout = int(
|
||||||
|
os.getenv(ENV_GRACEFUL_SHUTDOWN_TIMEOUT, DEFAULT_SHUTDOWN_TIMEOUT)
|
||||||
|
)
|
||||||
|
if graceful_shutdown_timeout <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Graceful shutdown timeout must be positive, got {graceful_shutdown_timeout}"
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
grant_token=grant_token,
|
grant_token=grant_token,
|
||||||
task_broker_uri=os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI),
|
task_broker_uri=os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI),
|
||||||
max_concurrency=int(
|
max_concurrency=int(
|
||||||
os.getenv(ENV_MAX_CONCURRENCY, str(DEFAULT_MAX_CONCURRENCY))
|
os.getenv(ENV_MAX_CONCURRENCY, DEFAULT_MAX_CONCURRENCY)
|
||||||
),
|
),
|
||||||
max_payload_size=int(
|
max_payload_size=int(
|
||||||
os.getenv(ENV_MAX_PAYLOAD_SIZE, str(DEFAULT_MAX_PAYLOAD_SIZE))
|
os.getenv(ENV_MAX_PAYLOAD_SIZE, DEFAULT_MAX_PAYLOAD_SIZE)
|
||||||
),
|
),
|
||||||
task_timeout=task_timeout,
|
task_timeout=task_timeout,
|
||||||
|
auto_shutdown_timeout=auto_shutdown_timeout,
|
||||||
|
graceful_shutdown_timeout=graceful_shutdown_timeout,
|
||||||
stdlib_allow=parse_allowlist(
|
stdlib_allow=parse_allowlist(
|
||||||
os.getenv(ENV_STDLIB_ALLOW, ""), ENV_STDLIB_ALLOW
|
os.getenv(ENV_STDLIB_ALLOW, ""), ENV_STDLIB_ALLOW
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ RUNNER_NAME = "Python Task Runner"
|
|||||||
DEFAULT_MAX_CONCURRENCY = 5 # tasks
|
DEFAULT_MAX_CONCURRENCY = 5 # tasks
|
||||||
DEFAULT_MAX_PAYLOAD_SIZE = 1024 * 1024 * 1024 # 1 GiB
|
DEFAULT_MAX_PAYLOAD_SIZE = 1024 * 1024 * 1024 # 1 GiB
|
||||||
DEFAULT_TASK_TIMEOUT = 60 # seconds
|
DEFAULT_TASK_TIMEOUT = 60 # seconds
|
||||||
|
DEFAULT_AUTO_SHUTDOWN_TIMEOUT = 15 # seconds
|
||||||
|
DEFAULT_SHUTDOWN_TIMEOUT = 10 # seconds
|
||||||
OFFER_INTERVAL = 0.25 # 250ms
|
OFFER_INTERVAL = 0.25 # 250ms
|
||||||
OFFER_VALIDITY = 5000 # ms
|
OFFER_VALIDITY = 5000 # ms
|
||||||
OFFER_VALIDITY_MAX_JITTER = 500 # ms
|
OFFER_VALIDITY_MAX_JITTER = 500 # ms
|
||||||
@@ -46,6 +48,8 @@ ENV_GRANT_TOKEN = "N8N_RUNNERS_GRANT_TOKEN"
|
|||||||
ENV_MAX_CONCURRENCY = "N8N_RUNNERS_MAX_CONCURRENCY"
|
ENV_MAX_CONCURRENCY = "N8N_RUNNERS_MAX_CONCURRENCY"
|
||||||
ENV_MAX_PAYLOAD_SIZE = "N8N_RUNNERS_MAX_PAYLOAD"
|
ENV_MAX_PAYLOAD_SIZE = "N8N_RUNNERS_MAX_PAYLOAD"
|
||||||
ENV_TASK_TIMEOUT = "N8N_RUNNERS_TASK_TIMEOUT"
|
ENV_TASK_TIMEOUT = "N8N_RUNNERS_TASK_TIMEOUT"
|
||||||
|
ENV_AUTO_SHUTDOWN_TIMEOUT = "N8N_RUNNERS_AUTO_SHUTDOWN_TIMEOUT"
|
||||||
|
ENV_GRACEFUL_SHUTDOWN_TIMEOUT = "N8N_RUNNERS_GRACEFUL_SHUTDOWN_TIMEOUT"
|
||||||
ENV_STDLIB_ALLOW = "N8N_RUNNERS_STDLIB_ALLOW"
|
ENV_STDLIB_ALLOW = "N8N_RUNNERS_STDLIB_ALLOW"
|
||||||
ENV_EXTERNAL_ALLOW = "N8N_RUNNERS_EXTERNAL_ALLOW"
|
ENV_EXTERNAL_ALLOW = "N8N_RUNNERS_EXTERNAL_ALLOW"
|
||||||
ENV_BUILTINS_DENY = "N8N_RUNNERS_BUILTINS_DENY"
|
ENV_BUILTINS_DENY = "N8N_RUNNERS_BUILTINS_DENY"
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from .no_idle_timeout_handler_error import NoIdleTimeoutHandlerError
|
||||||
from .security_violation_error import SecurityViolationError
|
from .security_violation_error import SecurityViolationError
|
||||||
from .task_missing_error import TaskMissingError
|
from .task_missing_error import TaskMissingError
|
||||||
from .task_result_missing_error import TaskResultMissingError
|
from .task_result_missing_error import TaskResultMissingError
|
||||||
@@ -7,6 +8,7 @@ from .task_timeout_error import TaskTimeoutError
|
|||||||
from .websocket_connection_error import WebsocketConnectionError
|
from .websocket_connection_error import WebsocketConnectionError
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"NoIdleTimeoutHandlerError",
|
||||||
"SecurityViolationError",
|
"SecurityViolationError",
|
||||||
"TaskMissingError",
|
"TaskMissingError",
|
||||||
"TaskProcessExitError",
|
"TaskProcessExitError",
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
class NoIdleTimeoutHandlerError(Exception):
|
||||||
|
"""Raised when idle timeout is reached but no shutdown handler is configured."""
|
||||||
|
|
||||||
|
def __init__(self, timeout: int):
|
||||||
|
super().__init__(
|
||||||
|
f"Idle timeout is configured ({timeout}s) but no handler is set. "
|
||||||
|
"Set task_runner.on_idle_timeout before calling task_runner.start(). "
|
||||||
|
"This is an internal error."
|
||||||
|
)
|
||||||
@@ -8,6 +8,7 @@ from src.config.sentry_config import SentryConfig
|
|||||||
from src.config.task_runner_config import TaskRunnerConfig
|
from src.config.task_runner_config import TaskRunnerConfig
|
||||||
from src.logs import setup_logging
|
from src.logs import setup_logging
|
||||||
from src.task_runner import TaskRunner
|
from src.task_runner import TaskRunner
|
||||||
|
from src.shutdown import Shutdown
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@@ -48,21 +49,17 @@ async def main():
|
|||||||
task_runner = TaskRunner(task_runner_config)
|
task_runner = TaskRunner(task_runner_config)
|
||||||
logger.info("Starting runner...")
|
logger.info("Starting runner...")
|
||||||
|
|
||||||
|
shutdown = Shutdown(task_runner, health_check_server, sentry)
|
||||||
|
task_runner.on_idle_timeout = shutdown.start_auto_shutdown
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await task_runner.start()
|
await task_runner.start()
|
||||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
|
||||||
logger.info("Shutting down runner...")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Unexpected error", exc_info=True)
|
logger.error("Unexpected error", exc_info=True)
|
||||||
raise
|
await shutdown.start_shutdown()
|
||||||
finally:
|
|
||||||
await task_runner.stop()
|
|
||||||
|
|
||||||
if health_check_server:
|
exit_code = await shutdown.wait_for_shutdown()
|
||||||
await health_check_server.stop()
|
sys.exit(exit_code)
|
||||||
|
|
||||||
if sentry:
|
|
||||||
sentry.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
86
packages/@n8n/task-runner-python/src/shutdown.py
Normal file
86
packages/@n8n/task-runner-python/src/shutdown.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.task_runner import TaskRunner
|
||||||
|
from src.health_check_server import HealthCheckServer
|
||||||
|
from src.sentry import TaskRunnerSentry
|
||||||
|
|
||||||
|
|
||||||
|
class Shutdown:
|
||||||
|
"""Responsible for managing the shutdown routine of the task runner."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_runner: "TaskRunner",
|
||||||
|
health_check_server: Optional["HealthCheckServer"] = None,
|
||||||
|
sentry: Optional["TaskRunnerSentry"] = None,
|
||||||
|
):
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self.is_shutting_down = False
|
||||||
|
self.shutdown_complete = asyncio.Event()
|
||||||
|
self.exit_code = 0
|
||||||
|
|
||||||
|
self.task_runner = task_runner
|
||||||
|
self.health_check_server = health_check_server
|
||||||
|
self.sentry = sentry
|
||||||
|
|
||||||
|
self._register_handler(signal.SIGINT)
|
||||||
|
self._register_handler(signal.SIGTERM)
|
||||||
|
|
||||||
|
async def start_shutdown(self, custom_timeout: Optional[int] = None):
|
||||||
|
if self.is_shutting_down:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_shutting_down = True
|
||||||
|
|
||||||
|
timeout = (
|
||||||
|
custom_timeout
|
||||||
|
if custom_timeout is not None
|
||||||
|
else self.task_runner.config.graceful_shutdown_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._perform_shutdown(), timeout=timeout)
|
||||||
|
self.exit_code = 0
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.logger.warning(f"Shutdown timed out after {timeout}s, forcing exit...")
|
||||||
|
self.exit_code = 1
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error during shutdown: {e}", exc_info=True)
|
||||||
|
self.exit_code = 1
|
||||||
|
finally:
|
||||||
|
self.shutdown_complete.set()
|
||||||
|
|
||||||
|
async def wait_for_shutdown(self) -> int:
|
||||||
|
await self.shutdown_complete.wait()
|
||||||
|
return self.exit_code
|
||||||
|
|
||||||
|
def _register_handler(self, sig: signal.Signals):
|
||||||
|
async def handler():
|
||||||
|
self.logger.info(f"Received {sig.name} signal, starting shutdown...")
|
||||||
|
await self.start_shutdown()
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop().add_signal_handler(
|
||||||
|
sig, lambda: asyncio.create_task(handler())
|
||||||
|
)
|
||||||
|
except NotImplementedError:
|
||||||
|
self.logger.warning(
|
||||||
|
f"Signal handler for {sig.name} not supported on this platform"
|
||||||
|
) # e.g. Windows
|
||||||
|
|
||||||
|
async def start_auto_shutdown(self):
|
||||||
|
self.logger.info("Reached idle timeout, starting shutdown...")
|
||||||
|
await self.start_shutdown(3) # no tasks so no grace period
|
||||||
|
|
||||||
|
async def _perform_shutdown(self):
|
||||||
|
await self.task_runner.stop()
|
||||||
|
|
||||||
|
if self.health_check_server:
|
||||||
|
await self.health_check_server.stop()
|
||||||
|
|
||||||
|
if self.sentry:
|
||||||
|
self.sentry.shutdown()
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Optional, Any
|
from typing import Dict, Optional, Any, Callable, Awaitable
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import websockets
|
import websockets
|
||||||
import random
|
import random
|
||||||
@@ -9,8 +9,9 @@ import random
|
|||||||
|
|
||||||
from src.config.task_runner_config import TaskRunnerConfig
|
from src.config.task_runner_config import TaskRunnerConfig
|
||||||
from src.errors import (
|
from src.errors import (
|
||||||
WebsocketConnectionError,
|
NoIdleTimeoutHandlerError,
|
||||||
TaskMissingError,
|
TaskMissingError,
|
||||||
|
WebsocketConnectionError,
|
||||||
)
|
)
|
||||||
from src.message_types.broker import TaskSettings
|
from src.message_types.broker import TaskSettings
|
||||||
from src.nanoid import nanoid
|
from src.nanoid import nanoid
|
||||||
@@ -85,6 +86,10 @@ class TaskRunner:
|
|||||||
self.analyzer = TaskAnalyzer(config.stdlib_allow, config.external_allow)
|
self.analyzer = TaskAnalyzer(config.stdlib_allow, config.external_allow)
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
self.idle_coroutine: Optional[asyncio.Task] = None
|
||||||
|
self.on_idle_timeout: Optional[Callable[[], Awaitable[None]]] = None
|
||||||
|
self.last_activity_time = time.time()
|
||||||
|
|
||||||
self.task_broker_uri = config.task_broker_uri
|
self.task_broker_uri = config.task_broker_uri
|
||||||
websocket_host = urlparse(config.task_broker_uri).netloc
|
websocket_host = urlparse(config.task_broker_uri).netloc
|
||||||
self.websocket_url = (
|
self.websocket_url = (
|
||||||
@@ -92,6 +97,9 @@ class TaskRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
|
if self.config.is_auto_shutdown_enabled and not self.on_idle_timeout:
|
||||||
|
raise NoIdleTimeoutHandlerError(self.config.auto_shutdown_timeout)
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {self.config.grant_token}"}
|
headers = {"Authorization": f"Bearer {self.config.grant_token}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -108,9 +116,27 @@ class TaskRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise WebsocketConnectionError(self.task_broker_uri)
|
raise WebsocketConnectionError(self.task_broker_uri)
|
||||||
|
|
||||||
|
# ========== Shutdown ==========
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
if self.offers_coroutine:
|
self.can_send_offers = False
|
||||||
|
|
||||||
|
if self.offers_coroutine and not self.offers_coroutine.done():
|
||||||
self.offers_coroutine.cancel()
|
self.offers_coroutine.cancel()
|
||||||
|
try:
|
||||||
|
await self.offers_coroutine
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.idle_coroutine and not self.idle_coroutine.done():
|
||||||
|
self.idle_coroutine.cancel()
|
||||||
|
try:
|
||||||
|
await self.idle_coroutine
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await self._wait_for_tasks()
|
||||||
|
await self._terminate_tasks()
|
||||||
|
|
||||||
if self.websocket_connection:
|
if self.websocket_connection:
|
||||||
await self.websocket_connection.close()
|
await self.websocket_connection.close()
|
||||||
@@ -118,6 +144,41 @@ class TaskRunner:
|
|||||||
|
|
||||||
self.logger.info("Runner stopped")
|
self.logger.info("Runner stopped")
|
||||||
|
|
||||||
|
async def _wait_for_tasks(self):
|
||||||
|
if not self.running_tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.debug("Waiting for tasks to complete...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while (
|
||||||
|
self.running_tasks
|
||||||
|
and (time.time() - start_time) < self.config.graceful_shutdown_timeout
|
||||||
|
):
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
if self.running_tasks:
|
||||||
|
self.logger.warning("Timed out waiting for tasks to complete")
|
||||||
|
|
||||||
|
async def _terminate_tasks(self):
|
||||||
|
if not self.running_tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.warning("Terminating tasks...")
|
||||||
|
|
||||||
|
tasks_to_terminate = [
|
||||||
|
asyncio.to_thread(self.executor.stop_process, task_state.process)
|
||||||
|
for task_state in self.running_tasks.values()
|
||||||
|
if task_state.process
|
||||||
|
]
|
||||||
|
|
||||||
|
if tasks_to_terminate:
|
||||||
|
await asyncio.gather(*tasks_to_terminate, return_exceptions=True)
|
||||||
|
|
||||||
|
self.running_tasks.clear()
|
||||||
|
|
||||||
|
self.logger.warning("Terminated tasks")
|
||||||
|
|
||||||
# ========== Messages ==========
|
# ========== Messages ==========
|
||||||
|
|
||||||
async def _listen_for_messages(self) -> None:
|
async def _listen_for_messages(self) -> None:
|
||||||
@@ -156,6 +217,7 @@ class TaskRunner:
|
|||||||
self.can_send_offers = True
|
self.can_send_offers = True
|
||||||
self.offers_coroutine = asyncio.create_task(self._send_offers_loop())
|
self.offers_coroutine = asyncio.create_task(self._send_offers_loop())
|
||||||
self.logger.info("Registered with broker")
|
self.logger.info("Registered with broker")
|
||||||
|
self._reset_idle_timer()
|
||||||
|
|
||||||
async def _handle_task_offer_accept(self, message: BrokerTaskOfferAccept) -> None:
|
async def _handle_task_offer_accept(self, message: BrokerTaskOfferAccept) -> None:
|
||||||
offer = self.open_offers.get(message.offer_id)
|
offer = self.open_offers.get(message.offer_id)
|
||||||
@@ -184,6 +246,7 @@ class TaskRunner:
|
|||||||
response = RunnerTaskAccepted(task_id=message.task_id)
|
response = RunnerTaskAccepted(task_id=message.task_id)
|
||||||
await self._send_message(response)
|
await self._send_message(response)
|
||||||
self.logger.info(f"Accepted task {message.task_id}")
|
self.logger.info(f"Accepted task {message.task_id}")
|
||||||
|
self._reset_idle_timer()
|
||||||
|
|
||||||
async def _handle_task_settings(self, message: BrokerTaskSettings) -> None:
|
async def _handle_task_settings(self, message: BrokerTaskSettings) -> None:
|
||||||
task_state = self.running_tasks.get(message.task_id)
|
task_state = self.running_tasks.get(message.task_id)
|
||||||
@@ -258,6 +321,7 @@ class TaskRunner:
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.running_tasks.pop(task_id, None)
|
self.running_tasks.pop(task_id, None)
|
||||||
|
self._reset_idle_timer()
|
||||||
|
|
||||||
async def _handle_task_cancel(self, message: BrokerTaskCancel) -> None:
|
async def _handle_task_cancel(self, message: BrokerTaskCancel) -> None:
|
||||||
task_id = message.task_id
|
task_id = message.task_id
|
||||||
@@ -351,3 +415,31 @@ class TaskRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self._send_message(message)
|
await self._send_message(message)
|
||||||
|
|
||||||
|
# ========== Inactivity ==========
|
||||||
|
|
||||||
|
def _reset_idle_timer(self):
|
||||||
|
"""Reset idle timer when key event occurs, namely runner registration, task acceptance, and task completion or failure."""
|
||||||
|
|
||||||
|
if not self.config.is_auto_shutdown_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.last_activity_time = time.time()
|
||||||
|
|
||||||
|
if self.idle_coroutine and not self.idle_coroutine.done():
|
||||||
|
self.idle_coroutine.cancel()
|
||||||
|
|
||||||
|
self.idle_coroutine = asyncio.create_task(self._idle_timer_coroutine())
|
||||||
|
|
||||||
|
async def _idle_timer_coroutine(self):
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.config.auto_shutdown_timeout)
|
||||||
|
|
||||||
|
if len(self.running_tasks) > 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert self.on_idle_timeout is not None # validated at start()
|
||||||
|
|
||||||
|
await self.on_idle_timeout()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user