feat(core): Harden native Python task runner (no-changelog) (#18826)

This commit is contained in:
Iván Ovejero
2025-08-29 14:19:38 +02:00
committed by GitHub
parent e29ed1532a
commit 3b574306f3
8 changed files with 435 additions and 47 deletions

View File

@@ -23,6 +23,7 @@ OFFER_INTERVAL = 0.25 # 250ms
OFFER_VALIDITY = 5000 # ms
OFFER_VALIDITY_MAX_JITTER = 500 # ms
OFFER_VALIDITY_LATENCY_BUFFER = 0.1 # 100ms
MAX_VALIDATION_CACHE_SIZE = 500 # cached validation results
# Executor
EXECUTOR_USER_OUTPUT_KEY = "__n8n_internal_user_output__"
@@ -38,6 +39,9 @@ ENV_GRANT_TOKEN = "N8N_RUNNERS_GRANT_TOKEN"
ENV_MAX_CONCURRENCY = "N8N_RUNNERS_MAX_CONCURRENCY"
ENV_MAX_PAYLOAD_SIZE = "N8N_RUNNERS_MAX_PAYLOAD"
ENV_TASK_TIMEOUT = "N8N_RUNNERS_TASK_TIMEOUT"
ENV_STDLIB_ALLOW = "N8N_RUNNERS_STDLIB_ALLOW"
ENV_EXTERNAL_ALLOW = "N8N_RUNNERS_EXTERNAL_ALLOW"
ENV_BUILTINS_DENY = "N8N_RUNNERS_BUILTINS_DENY"
# Logging
LOG_FORMAT = "%(asctime)s.%(msecs)03d\t%(levelname)s\t%(message)s"
@@ -51,3 +55,61 @@ TASK_REJECTED_REASON_OFFER_EXPIRED = (
"Offer expired - not accepted within validity window"
)
TASK_REJECTED_REASON_AT_CAPACITY = "No open task slots - runner already at capacity"
# Security
BUILTINS_DENY_DEFAULT = "eval,exec,compile,open,input,breakpoint,__import__,getattr,object,type,vars,setattr,delattr,hasattr,dir,memoryview,__build_class__"
ALWAYS_BLOCKED_ATTRIBUTES = {
"__subclasses__",
"__globals__",
"__builtins__",
"__traceback__",
"tb_frame",
"tb_next",
"f_back",
"f_globals",
"f_locals",
"f_code",
"f_builtins",
"__getattribute__",
"__qualname__",
"__module__",
"gi_frame",
"gi_code",
"gi_yieldfrom",
"cr_frame",
"cr_code",
"ag_frame",
"ag_code",
"__thisclass__",
"__self_class__",
}
# Attributes blocked only in certain contexts:
# - In attribute chains (e.g., x.__class__.__bases__)
# - On literals (e.g., "".__class__)
CONDITIONALLY_BLOCKED_ATTRIBUTES = {
"__class__",
"__bases__",
"__code__",
"__closure__",
"__loader__",
"__cached__",
"__dict__",
"__import__",
"__mro__",
"__init_subclass__",
"__getattr__",
"__setattr__",
"__delattr__",
"__self__",
"__func__",
"__wrapped__",
"__annotations__",
}
UNSAFE_ATTRIBUTES = ALWAYS_BLOCKED_ATTRIBUTES | CONDITIONALLY_BLOCKED_ATTRIBUTES
# errors
ERROR_RELATIVE_IMPORT = "Relative imports are disallowed."
ERROR_STDLIB_DISALLOWED = "Import of standard library module '{module}' is disallowed. Allowed stdlib modules: {allowed}"
ERROR_EXTERNAL_DISALLOWED = "Import of external package '{module}' is disallowed. Allowed external packages: {allowed}"
ERROR_DANGEROUS_ATTRIBUTE = "Access to attribute '{attr}' is disallowed, because it can be used to bypass security restrictions."
ERROR_SECURITY_VIOLATIONS = "Security violations detected:\n{violations}"

View File

@@ -0,0 +1,76 @@
import os
from typing import Set
from src.constants import (
DEFAULT_MAX_CONCURRENCY,
DEFAULT_TASK_TIMEOUT,
DEFAULT_TASK_BROKER_URI,
DEFAULT_MAX_PAYLOAD_SIZE,
BUILTINS_DENY_DEFAULT,
ENV_MAX_CONCURRENCY,
ENV_MAX_PAYLOAD_SIZE,
ENV_TASK_BROKER_URI,
ENV_GRANT_TOKEN,
ENV_TASK_TIMEOUT,
ENV_BUILTINS_DENY,
ENV_STDLIB_ALLOW,
ENV_EXTERNAL_ALLOW,
)
from src.task_runner import TaskRunnerOpts
def parse_allowlist(allowlist_str: str, list_name: str) -> Set[str]:
if not allowlist_str:
return set()
modules = {
module
for raw_module in allowlist_str.split(",")
if (module := raw_module.strip())
}
if "*" in modules and len(modules) > 1:
raise ValueError(
f"Wildcard '*' in {list_name} must be used alone, not with other modules. "
f"Got: {', '.join(sorted(modules))}"
)
return modules
def parse_denylist(denylist_str: str) -> Set[str]:
if not denylist_str:
return set()
return {name for raw_name in denylist_str.split(",") if (name := raw_name.strip())}
def parse_env_vars() -> TaskRunnerOpts:
grant_token = os.getenv(ENV_GRANT_TOKEN, "")
if not grant_token:
raise ValueError(f"{ENV_GRANT_TOKEN} environment variable is required")
builtins_deny_str = os.getenv(ENV_BUILTINS_DENY, BUILTINS_DENY_DEFAULT)
builtins_deny = parse_denylist(builtins_deny_str)
stdlib_allow_str = os.getenv(ENV_STDLIB_ALLOW, "")
stdlib_allow = parse_allowlist(stdlib_allow_str, "stdlib allowlist")
external_allow_str = os.getenv(ENV_EXTERNAL_ALLOW, "")
external_allow = parse_allowlist(external_allow_str, "external allowlist")
return TaskRunnerOpts(
grant_token=grant_token,
task_broker_uri=os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI),
max_concurrency=int(
os.getenv(ENV_MAX_CONCURRENCY) or str(DEFAULT_MAX_CONCURRENCY)
),
max_payload_size=int(
os.getenv(ENV_MAX_PAYLOAD_SIZE) or str(DEFAULT_MAX_PAYLOAD_SIZE)
),
task_timeout=int(os.getenv(ENV_TASK_TIMEOUT) or str(DEFAULT_TASK_TIMEOUT)),
stdlib_allow=stdlib_allow,
external_allow=external_allow,
builtins_deny=builtins_deny,
)

View File

@@ -1,3 +1,4 @@
from .security_violation_error import SecurityViolationError
from .task_missing_error import TaskMissingError
from .task_result_missing_error import TaskResultMissingError
from .task_process_exit_error import TaskProcessExitError
@@ -6,6 +7,7 @@ from .task_timeout_error import TaskTimeoutError
from .websocket_connection_error import WebsocketConnectionError
__all__ = [
"SecurityViolationError",
"TaskMissingError",
"TaskProcessExitError",
"TaskResultMissingError",

View File

@@ -0,0 +1,4 @@
class SecurityViolationError(Exception):
"""Raised when code violates security policies, typically through use of disallowed modules or builtins."""
pass

View File

@@ -1,23 +1,10 @@
import asyncio
import logging
import os
import sys
os.environ["WEBSOCKETS_MAX_LOG_SIZE"] = "256"
from src.constants import (
DEFAULT_MAX_CONCURRENCY,
DEFAULT_TASK_TIMEOUT,
ENV_MAX_CONCURRENCY,
ENV_MAX_PAYLOAD_SIZE,
ENV_TASK_BROKER_URI,
ENV_GRANT_TOKEN,
DEFAULT_TASK_BROKER_URI,
DEFAULT_MAX_PAYLOAD_SIZE,
ENV_TASK_TIMEOUT,
)
from src.env import parse_env_vars
from src.logs import setup_logging
from src.task_runner import TaskRunner, TaskRunnerOpts
from src.task_runner import TaskRunner
async def main():
@@ -26,20 +13,12 @@ async def main():
logger.info("Starting runner...")
grant_token = os.getenv(ENV_GRANT_TOKEN, "")
if not grant_token:
logger.error(f"{ENV_GRANT_TOKEN} environment variable is required")
try:
opts = parse_env_vars()
except ValueError as e:
logger.error(str(e))
sys.exit(1)
opts = TaskRunnerOpts(
grant_token,
os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI),
int(os.getenv(ENV_MAX_CONCURRENCY, DEFAULT_MAX_CONCURRENCY)),
int(os.getenv(ENV_MAX_PAYLOAD_SIZE, DEFAULT_MAX_PAYLOAD_SIZE)),
int(os.getenv(ENV_TASK_TIMEOUT, DEFAULT_TASK_TIMEOUT)),
)
task_runner = TaskRunner(opts)
try:

View File

@@ -0,0 +1,171 @@
import ast
import hashlib
import sys
from typing import Set, Tuple
from collections import OrderedDict
from src.errors import SecurityViolationError
from src.constants import (
MAX_VALIDATION_CACHE_SIZE,
ERROR_RELATIVE_IMPORT,
ERROR_STDLIB_DISALLOWED,
ERROR_EXTERNAL_DISALLOWED,
ERROR_DANGEROUS_ATTRIBUTE,
ERROR_SECURITY_VIOLATIONS,
ALWAYS_BLOCKED_ATTRIBUTES,
UNSAFE_ATTRIBUTES,
)
CacheKey = Tuple[str, Tuple] # (code_hash, allowlists_tuple)
CachedViolations = list[str]
ValidationCache = OrderedDict[CacheKey, CachedViolations]
class SecurityValidator(ast.NodeVisitor):
"""AST visitor that enforces import allowlists and blocks dangerous attribute access."""
def __init__(self, stdlib_allow: Set[str], external_allow: Set[str]):
self.checked_modules: Set[str] = set()
self.violations: list[str] = []
self.stdlib_allow = stdlib_allow
self.external_allow = external_allow
self._stdlib_allowed_str = self._format_allowed(stdlib_allow)
self._external_allowed_str = self._format_allowed(external_allow)
self._stdlib_allow_all = "*" in stdlib_allow
self._external_allow_all = "*" in external_allow
# ========== Detection ==========
def visit_Import(self, node: ast.Import) -> None:
"""Detect bare import statements (e.g., import os), including aliased (e.g., import numpy as np)."""
for alias in node.names:
module_name = alias.name
self._validate_import(module_name, node.lineno)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Detect from import statements (e.g., from os import path)."""
if node.level > 0:
self._add_violation(node.lineno, ERROR_RELATIVE_IMPORT)
elif node.module:
self._validate_import(node.module, node.lineno)
self.generic_visit(node)
def visit_Attribute(self, node: ast.Attribute) -> None:
"""Detect access to unsafe attributes that could bypass security."""
if node.attr in UNSAFE_ATTRIBUTES:
# Block regardless of context
if node.attr in ALWAYS_BLOCKED_ATTRIBUTES:
self._add_violation(
node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=node.attr)
)
# Block in attribute chains (e.g., x.__class__.__bases__) or on literals (e.g., "".__class__)
elif isinstance(node.value, (ast.Attribute, ast.Constant)):
self._add_violation(
node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=node.attr)
)
self.generic_visit(node)
# ========== Validation ==========
def _validate_import(self, module_path: str, lineno: int) -> None:
"""Validate that a module import is allowed based on allowlists. Also disallow relative imports."""
if module_path.startswith("."):
self._add_violation(lineno, ERROR_RELATIVE_IMPORT)
return
module_name = module_path.split(".")[0] # e.g., os.path -> os
if module_name in self.checked_modules:
return
self.checked_modules.add(module_name)
is_stdlib = module_name in sys.stdlib_module_names
is_external = not is_stdlib
if is_stdlib and (self._stdlib_allow_all or module_name in self.stdlib_allow):
return
if is_external and (
self._external_allow_all or module_name in self.external_allow
):
return
error, allowed_str = (
(ERROR_STDLIB_DISALLOWED, self._stdlib_allowed_str)
if is_stdlib
else (ERROR_EXTERNAL_DISALLOWED, self._external_allowed_str)
)
self._add_violation(
lineno, error.format(module=module_path, allowed=allowed_str)
)
def _format_allowed(self, allow_set: Set[str]) -> str:
return ", ".join(sorted(allow_set)) if allow_set else "none"
def _add_violation(self, lineno: int, message: str) -> None:
self.violations.append(f"Line {lineno}: {message}")
class TaskAnalyzer:
_cache: ValidationCache = OrderedDict()
def __init__(self, stdlib_allow: Set[str], external_allow: Set[str]):
self._stdlib_allow = stdlib_allow
self._external_allow = external_allow
self._allowlists = (
tuple(sorted(stdlib_allow)),
tuple(sorted(external_allow)),
)
self._allow_all = "*" in stdlib_allow and "*" in external_allow
def validate(self, code: str) -> None:
if self._allow_all:
return
cache_key = self._to_cache_key(code)
cached_violations = self._cache.get(cache_key)
cache_hit = cached_violations is not None
if cache_hit:
self._cache.move_to_end(cache_key)
if len(cached_violations) == 0:
return
if len(cached_violations) > 0:
self._raise_security_error(cached_violations)
tree = ast.parse(code)
security_validator = SecurityValidator(self._stdlib_allow, self._external_allow)
security_validator.visit(tree)
self._set_in_cache(cache_key, security_validator.violations)
if security_validator.violations:
self._raise_security_error(security_validator.violations)
def _raise_security_error(self, violations: CachedViolations) -> None:
message = ERROR_SECURITY_VIOLATIONS.format(violations="\n".join(violations))
raise SecurityViolationError(message)
def _to_cache_key(self, code: str) -> CacheKey:
code_hash = hashlib.sha256(code.encode()).hexdigest()
return (code_hash, self._allowlists)
def _set_in_cache(self, cache_key: CacheKey, violations: CachedViolations) -> None:
if len(self._cache) >= MAX_VALIDATION_CACHE_SIZE:
self._cache.popitem(last=False) # FIFO
self._cache[cache_key] = violations.copy()
self._cache.move_to_end(cache_key)

View File

@@ -3,6 +3,8 @@ import multiprocessing
import traceback
import textwrap
import json
import os
import sys
from src.errors import (
TaskResultMissingError,
@@ -11,9 +13,9 @@ from src.errors import (
TaskProcessExitError,
)
from .message_types.broker import NodeMode, Items
from .constants import EXECUTOR_CIRCULAR_REFERENCE_KEY, EXECUTOR_USER_OUTPUT_KEY
from typing import Any
from src.message_types.broker import NodeMode, Items
from src.constants import EXECUTOR_CIRCULAR_REFERENCE_KEY, EXECUTOR_USER_OUTPUT_KEY
from typing import Any, Set
from multiprocessing.context import SpawnProcess
@@ -26,7 +28,14 @@ class TaskExecutor:
"""Responsible for executing Python code tasks in isolated subprocesses."""
@staticmethod
def create_process(code: str, node_mode: NodeMode, items: Items):
def create_process(
code: str,
node_mode: NodeMode,
items: Items,
stdlib_allow: Set[str],
external_allow: Set[str],
builtins_deny: set[str],
):
"""Create a subprocess for executing a Python code task and a queue for communication."""
fn = (
@@ -36,7 +45,10 @@ class TaskExecutor:
)
queue = MULTIPROCESSING_CONTEXT.Queue()
process = MULTIPROCESSING_CONTEXT.Process(target=fn, args=(code, items, queue))
process = MULTIPROCESSING_CONTEXT.Process(
target=fn,
args=(code, items, queue, stdlib_allow, external_allow, builtins_deny),
)
return process, queue
@@ -95,16 +107,27 @@ class TaskExecutor:
process.kill()
@staticmethod
def _all_items(raw_code: str, items: Items, queue: multiprocessing.Queue):
def _all_items(
raw_code: str,
items: Items,
queue: multiprocessing.Queue,
stdlib_allow: Set[str],
external_allow: Set[str],
builtins_deny: set[str],
):
"""Execute a Python code task in all-items mode."""
os.environ.clear()
TaskExecutor._sanitize_sys_modules(stdlib_allow, external_allow)
print_args: PrintArgs = []
try:
code = TaskExecutor._wrap_code(raw_code)
globals = {
"__builtins__": __builtins__,
"__builtins__": TaskExecutor._filter_builtins(builtins_deny),
"_items": items,
"print": TaskExecutor._create_custom_print(print_args),
}
@@ -119,9 +142,20 @@ class TaskExecutor:
TaskExecutor._put_error(queue, e, print_args)
@staticmethod
def _per_item(raw_code: str, items: Items, queue: multiprocessing.Queue):
def _per_item(
raw_code: str,
items: Items,
queue: multiprocessing.Queue,
stdlib_allow: Set[str],
external_allow: Set[str],
builtins_deny: set[str],
):
"""Execute a Python code task in per-item mode."""
os.environ.clear()
TaskExecutor._sanitize_sys_modules(stdlib_allow, external_allow)
print_args: PrintArgs = []
try:
@@ -131,7 +165,7 @@ class TaskExecutor:
result = []
for index, item in enumerate(items):
globals = {
"__builtins__": __builtins__,
"__builtins__": TaskExecutor._filter_builtins(builtins_deny),
"_item": item,
"print": TaskExecutor._create_custom_print(print_args),
}
@@ -195,7 +229,7 @@ class TaskExecutor:
@staticmethod
def _format_print_args(*args) -> list[str]:
"""
Takes the arguments passed to a `print()` call in user code and converts them
Takes the args passed to a `print()` call in user code and converts them
to string representations suitable for display in a browser console.
Expects all args to be serializable.
@@ -217,3 +251,45 @@ class TaskExecutor:
formatted.append(json.dumps(arg, default=str, ensure_ascii=False))
return formatted
# ========== security ==========
@staticmethod
def _filter_builtins(builtins_deny: set[str]):
"""Get __builtins__ with denied ones removed."""
if len(builtins_deny) == 0:
return __builtins__
return {k: v for k, v in __builtins__.items() if k not in builtins_deny}
@staticmethod
def _sanitize_sys_modules(stdlib_allow: Set[str], external_allow: Set[str]):
safe_modules = {
"builtins",
"__main__",
"sys",
"traceback",
"linecache",
}
if "*" in stdlib_allow:
safe_modules.update(sys.stdlib_module_names)
else:
safe_modules.update(stdlib_allow)
if "*" in external_allow:
safe_modules.update(
name
for name in sys.modules.keys()
if name not in sys.stdlib_module_names
)
else:
safe_modules.update(external_allow)
modules_to_remove = [
name for name in sys.modules.keys() if name not in safe_modules
]
for module_name in modules_to_remove:
del sys.modules[module_name]

View File

@@ -2,17 +2,20 @@ import asyncio
from dataclasses import dataclass
import logging
import time
from typing import Dict, Optional, Any
from typing import Dict, Optional, Any, Set
from urllib.parse import urlparse
import websockets
import random
from src.errors import WebsocketConnectionError, TaskMissingError
from src.errors import (
WebsocketConnectionError,
TaskMissingError,
)
from src.message_types.broker import TaskSettings
from src.nanoid_utils import nanoid
from .constants import (
from src.constants import (
RUNNER_NAME,
TASK_REJECTED_REASON_AT_CAPACITY,
TASK_REJECTED_REASON_OFFER_EXPIRED,
@@ -24,7 +27,7 @@ from .constants import (
TASK_BROKER_WS_PATH,
RPC_BROWSER_CONSOLE_LOG_METHOD,
)
from .message_types import (
from src.message_types import (
BrokerMessage,
RunnerMessage,
BrokerInfoRequest,
@@ -41,9 +44,10 @@ from .message_types import (
RunnerTaskError,
RunnerRpcCall,
)
from .message_serde import MessageSerde
from .task_state import TaskState, TaskStatus
from .task_executor import TaskExecutor
from src.message_serde import MessageSerde
from src.task_state import TaskState, TaskStatus
from src.task_executor import TaskExecutor
from src.task_analyzer import TaskAnalyzer
class TaskOffer:
@@ -56,13 +60,16 @@ class TaskOffer:
return time.time() > self.valid_until
@dataclass()
@dataclass
class TaskRunnerOpts:
grant_token: str
task_broker_uri: str
max_concurrency: int
max_payload_size: int
task_timeout: int
stdlib_allow: Set[str]
external_allow: Set[str]
builtins_deny: Set[str]
class TaskRunner:
@@ -85,6 +92,7 @@ class TaskRunner:
self.offers_coroutine: Optional[asyncio.Task] = None
self.serde = MessageSerde()
self.executor = TaskExecutor()
self.analyzer = TaskAnalyzer(opts.stdlib_allow, opts.external_allow)
self.logger = logging.getLogger(__name__)
self.task_broker_uri = opts.task_broker_uri
@@ -207,14 +215,24 @@ class TaskRunner:
if task_state is None:
raise TaskMissingError(task_id)
self.analyzer.validate(task_settings.code)
process, queue = self.executor.create_process(
task_settings.code, task_settings.node_mode, task_settings.items
code=task_settings.code,
node_mode=task_settings.node_mode,
items=task_settings.items,
stdlib_allow=self.opts.stdlib_allow,
external_allow=self.opts.external_allow,
builtins_deny=self.opts.builtins_deny,
)
task_state.process = process
result, print_args = self.executor.execute_process(
process, queue, self.opts.task_timeout, task_settings.continue_on_fail
process=process,
queue=queue,
task_timeout=self.opts.task_timeout,
continue_on_fail=task_settings.continue_on_fail,
)
for print_args_per_call in print_args: