mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
feat(core): Harden native Python task runner (no-changelog) (#18826)
This commit is contained in:
@@ -23,6 +23,7 @@ OFFER_INTERVAL = 0.25 # 250ms
|
||||
OFFER_VALIDITY = 5000 # ms
|
||||
OFFER_VALIDITY_MAX_JITTER = 500 # ms
|
||||
OFFER_VALIDITY_LATENCY_BUFFER = 0.1 # 100ms
|
||||
MAX_VALIDATION_CACHE_SIZE = 500 # cached validation results
|
||||
|
||||
# Executor
|
||||
EXECUTOR_USER_OUTPUT_KEY = "__n8n_internal_user_output__"
|
||||
@@ -38,6 +39,9 @@ ENV_GRANT_TOKEN = "N8N_RUNNERS_GRANT_TOKEN"
|
||||
ENV_MAX_CONCURRENCY = "N8N_RUNNERS_MAX_CONCURRENCY"
|
||||
ENV_MAX_PAYLOAD_SIZE = "N8N_RUNNERS_MAX_PAYLOAD"
|
||||
ENV_TASK_TIMEOUT = "N8N_RUNNERS_TASK_TIMEOUT"
|
||||
ENV_STDLIB_ALLOW = "N8N_RUNNERS_STDLIB_ALLOW"
|
||||
ENV_EXTERNAL_ALLOW = "N8N_RUNNERS_EXTERNAL_ALLOW"
|
||||
ENV_BUILTINS_DENY = "N8N_RUNNERS_BUILTINS_DENY"
|
||||
|
||||
# Logging
|
||||
LOG_FORMAT = "%(asctime)s.%(msecs)03d\t%(levelname)s\t%(message)s"
|
||||
@@ -51,3 +55,61 @@ TASK_REJECTED_REASON_OFFER_EXPIRED = (
|
||||
"Offer expired - not accepted within validity window"
|
||||
)
|
||||
TASK_REJECTED_REASON_AT_CAPACITY = "No open task slots - runner already at capacity"
|
||||
|
||||
# Security
|
||||
BUILTINS_DENY_DEFAULT = "eval,exec,compile,open,input,breakpoint,__import__,getattr,object,type,vars,setattr,delattr,hasattr,dir,memoryview,__build_class__"
|
||||
ALWAYS_BLOCKED_ATTRIBUTES = {
|
||||
"__subclasses__",
|
||||
"__globals__",
|
||||
"__builtins__",
|
||||
"__traceback__",
|
||||
"tb_frame",
|
||||
"tb_next",
|
||||
"f_back",
|
||||
"f_globals",
|
||||
"f_locals",
|
||||
"f_code",
|
||||
"f_builtins",
|
||||
"__getattribute__",
|
||||
"__qualname__",
|
||||
"__module__",
|
||||
"gi_frame",
|
||||
"gi_code",
|
||||
"gi_yieldfrom",
|
||||
"cr_frame",
|
||||
"cr_code",
|
||||
"ag_frame",
|
||||
"ag_code",
|
||||
"__thisclass__",
|
||||
"__self_class__",
|
||||
}
|
||||
# Attributes blocked only in certain contexts:
|
||||
# - In attribute chains (e.g., x.__class__.__bases__)
|
||||
# - On literals (e.g., "".__class__)
|
||||
CONDITIONALLY_BLOCKED_ATTRIBUTES = {
|
||||
"__class__",
|
||||
"__bases__",
|
||||
"__code__",
|
||||
"__closure__",
|
||||
"__loader__",
|
||||
"__cached__",
|
||||
"__dict__",
|
||||
"__import__",
|
||||
"__mro__",
|
||||
"__init_subclass__",
|
||||
"__getattr__",
|
||||
"__setattr__",
|
||||
"__delattr__",
|
||||
"__self__",
|
||||
"__func__",
|
||||
"__wrapped__",
|
||||
"__annotations__",
|
||||
}
|
||||
UNSAFE_ATTRIBUTES = ALWAYS_BLOCKED_ATTRIBUTES | CONDITIONALLY_BLOCKED_ATTRIBUTES
|
||||
|
||||
# errors
|
||||
ERROR_RELATIVE_IMPORT = "Relative imports are disallowed."
|
||||
ERROR_STDLIB_DISALLOWED = "Import of standard library module '{module}' is disallowed. Allowed stdlib modules: {allowed}"
|
||||
ERROR_EXTERNAL_DISALLOWED = "Import of external package '{module}' is disallowed. Allowed external packages: {allowed}"
|
||||
ERROR_DANGEROUS_ATTRIBUTE = "Access to attribute '{attr}' is disallowed, because it can be used to bypass security restrictions."
|
||||
ERROR_SECURITY_VIOLATIONS = "Security violations detected:\n{violations}"
|
||||
|
||||
76
packages/@n8n/task-runner-python/src/env.py
Normal file
76
packages/@n8n/task-runner-python/src/env.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
from typing import Set
|
||||
|
||||
from src.constants import (
|
||||
DEFAULT_MAX_CONCURRENCY,
|
||||
DEFAULT_TASK_TIMEOUT,
|
||||
DEFAULT_TASK_BROKER_URI,
|
||||
DEFAULT_MAX_PAYLOAD_SIZE,
|
||||
BUILTINS_DENY_DEFAULT,
|
||||
ENV_MAX_CONCURRENCY,
|
||||
ENV_MAX_PAYLOAD_SIZE,
|
||||
ENV_TASK_BROKER_URI,
|
||||
ENV_GRANT_TOKEN,
|
||||
ENV_TASK_TIMEOUT,
|
||||
ENV_BUILTINS_DENY,
|
||||
ENV_STDLIB_ALLOW,
|
||||
ENV_EXTERNAL_ALLOW,
|
||||
)
|
||||
from src.task_runner import TaskRunnerOpts
|
||||
|
||||
|
||||
def parse_allowlist(allowlist_str: str, list_name: str) -> Set[str]:
|
||||
if not allowlist_str:
|
||||
return set()
|
||||
|
||||
modules = {
|
||||
module
|
||||
for raw_module in allowlist_str.split(",")
|
||||
if (module := raw_module.strip())
|
||||
}
|
||||
|
||||
if "*" in modules and len(modules) > 1:
|
||||
raise ValueError(
|
||||
f"Wildcard '*' in {list_name} must be used alone, not with other modules. "
|
||||
f"Got: {', '.join(sorted(modules))}"
|
||||
)
|
||||
|
||||
return modules
|
||||
|
||||
|
||||
def parse_denylist(denylist_str: str) -> Set[str]:
|
||||
if not denylist_str:
|
||||
return set()
|
||||
|
||||
return {name for raw_name in denylist_str.split(",") if (name := raw_name.strip())}
|
||||
|
||||
|
||||
def parse_env_vars() -> TaskRunnerOpts:
|
||||
grant_token = os.getenv(ENV_GRANT_TOKEN, "")
|
||||
|
||||
if not grant_token:
|
||||
raise ValueError(f"{ENV_GRANT_TOKEN} environment variable is required")
|
||||
|
||||
builtins_deny_str = os.getenv(ENV_BUILTINS_DENY, BUILTINS_DENY_DEFAULT)
|
||||
builtins_deny = parse_denylist(builtins_deny_str)
|
||||
|
||||
stdlib_allow_str = os.getenv(ENV_STDLIB_ALLOW, "")
|
||||
stdlib_allow = parse_allowlist(stdlib_allow_str, "stdlib allowlist")
|
||||
|
||||
external_allow_str = os.getenv(ENV_EXTERNAL_ALLOW, "")
|
||||
external_allow = parse_allowlist(external_allow_str, "external allowlist")
|
||||
|
||||
return TaskRunnerOpts(
|
||||
grant_token=grant_token,
|
||||
task_broker_uri=os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI),
|
||||
max_concurrency=int(
|
||||
os.getenv(ENV_MAX_CONCURRENCY) or str(DEFAULT_MAX_CONCURRENCY)
|
||||
),
|
||||
max_payload_size=int(
|
||||
os.getenv(ENV_MAX_PAYLOAD_SIZE) or str(DEFAULT_MAX_PAYLOAD_SIZE)
|
||||
),
|
||||
task_timeout=int(os.getenv(ENV_TASK_TIMEOUT) or str(DEFAULT_TASK_TIMEOUT)),
|
||||
stdlib_allow=stdlib_allow,
|
||||
external_allow=external_allow,
|
||||
builtins_deny=builtins_deny,
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
from .security_violation_error import SecurityViolationError
|
||||
from .task_missing_error import TaskMissingError
|
||||
from .task_result_missing_error import TaskResultMissingError
|
||||
from .task_process_exit_error import TaskProcessExitError
|
||||
@@ -6,6 +7,7 @@ from .task_timeout_error import TaskTimeoutError
|
||||
from .websocket_connection_error import WebsocketConnectionError
|
||||
|
||||
__all__ = [
|
||||
"SecurityViolationError",
|
||||
"TaskMissingError",
|
||||
"TaskProcessExitError",
|
||||
"TaskResultMissingError",
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
class SecurityViolationError(Exception):
|
||||
"""Raised when code violates security policies, typically through use of disallowed modules or builtins."""
|
||||
|
||||
pass
|
||||
@@ -1,23 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.environ["WEBSOCKETS_MAX_LOG_SIZE"] = "256"
|
||||
|
||||
from src.constants import (
|
||||
DEFAULT_MAX_CONCURRENCY,
|
||||
DEFAULT_TASK_TIMEOUT,
|
||||
ENV_MAX_CONCURRENCY,
|
||||
ENV_MAX_PAYLOAD_SIZE,
|
||||
ENV_TASK_BROKER_URI,
|
||||
ENV_GRANT_TOKEN,
|
||||
DEFAULT_TASK_BROKER_URI,
|
||||
DEFAULT_MAX_PAYLOAD_SIZE,
|
||||
ENV_TASK_TIMEOUT,
|
||||
)
|
||||
from src.env import parse_env_vars
|
||||
from src.logs import setup_logging
|
||||
from src.task_runner import TaskRunner, TaskRunnerOpts
|
||||
from src.task_runner import TaskRunner
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -26,20 +13,12 @@ async def main():
|
||||
|
||||
logger.info("Starting runner...")
|
||||
|
||||
grant_token = os.getenv(ENV_GRANT_TOKEN, "")
|
||||
|
||||
if not grant_token:
|
||||
logger.error(f"{ENV_GRANT_TOKEN} environment variable is required")
|
||||
try:
|
||||
opts = parse_env_vars()
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
sys.exit(1)
|
||||
|
||||
opts = TaskRunnerOpts(
|
||||
grant_token,
|
||||
os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI),
|
||||
int(os.getenv(ENV_MAX_CONCURRENCY, DEFAULT_MAX_CONCURRENCY)),
|
||||
int(os.getenv(ENV_MAX_PAYLOAD_SIZE, DEFAULT_MAX_PAYLOAD_SIZE)),
|
||||
int(os.getenv(ENV_TASK_TIMEOUT, DEFAULT_TASK_TIMEOUT)),
|
||||
)
|
||||
|
||||
task_runner = TaskRunner(opts)
|
||||
|
||||
try:
|
||||
|
||||
171
packages/@n8n/task-runner-python/src/task_analyzer.py
Normal file
171
packages/@n8n/task-runner-python/src/task_analyzer.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import ast
|
||||
import hashlib
|
||||
import sys
|
||||
from typing import Set, Tuple
|
||||
from collections import OrderedDict
|
||||
|
||||
from src.errors import SecurityViolationError
|
||||
from src.constants import (
|
||||
MAX_VALIDATION_CACHE_SIZE,
|
||||
ERROR_RELATIVE_IMPORT,
|
||||
ERROR_STDLIB_DISALLOWED,
|
||||
ERROR_EXTERNAL_DISALLOWED,
|
||||
ERROR_DANGEROUS_ATTRIBUTE,
|
||||
ERROR_SECURITY_VIOLATIONS,
|
||||
ALWAYS_BLOCKED_ATTRIBUTES,
|
||||
UNSAFE_ATTRIBUTES,
|
||||
)
|
||||
|
||||
CacheKey = Tuple[str, Tuple] # (code_hash, allowlists_tuple)
|
||||
CachedViolations = list[str]
|
||||
ValidationCache = OrderedDict[CacheKey, CachedViolations]
|
||||
|
||||
|
||||
class SecurityValidator(ast.NodeVisitor):
|
||||
"""AST visitor that enforces import allowlists and blocks dangerous attribute access."""
|
||||
|
||||
def __init__(self, stdlib_allow: Set[str], external_allow: Set[str]):
|
||||
self.checked_modules: Set[str] = set()
|
||||
self.violations: list[str] = []
|
||||
|
||||
self.stdlib_allow = stdlib_allow
|
||||
self.external_allow = external_allow
|
||||
self._stdlib_allowed_str = self._format_allowed(stdlib_allow)
|
||||
self._external_allowed_str = self._format_allowed(external_allow)
|
||||
self._stdlib_allow_all = "*" in stdlib_allow
|
||||
self._external_allow_all = "*" in external_allow
|
||||
|
||||
# ========== Detection ==========
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
"""Detect bare import statements (e.g., import os), including aliased (e.g., import numpy as np)."""
|
||||
|
||||
for alias in node.names:
|
||||
module_name = alias.name
|
||||
self._validate_import(module_name, node.lineno)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
"""Detect from import statements (e.g., from os import path)."""
|
||||
|
||||
if node.level > 0:
|
||||
self._add_violation(node.lineno, ERROR_RELATIVE_IMPORT)
|
||||
elif node.module:
|
||||
self._validate_import(node.module, node.lineno)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute) -> None:
|
||||
"""Detect access to unsafe attributes that could bypass security."""
|
||||
|
||||
if node.attr in UNSAFE_ATTRIBUTES:
|
||||
# Block regardless of context
|
||||
if node.attr in ALWAYS_BLOCKED_ATTRIBUTES:
|
||||
self._add_violation(
|
||||
node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=node.attr)
|
||||
)
|
||||
# Block in attribute chains (e.g., x.__class__.__bases__) or on literals (e.g., "".__class__)
|
||||
elif isinstance(node.value, (ast.Attribute, ast.Constant)):
|
||||
self._add_violation(
|
||||
node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=node.attr)
|
||||
)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
# ========== Validation ==========
|
||||
|
||||
def _validate_import(self, module_path: str, lineno: int) -> None:
|
||||
"""Validate that a module import is allowed based on allowlists. Also disallow relative imports."""
|
||||
|
||||
if module_path.startswith("."):
|
||||
self._add_violation(lineno, ERROR_RELATIVE_IMPORT)
|
||||
return
|
||||
|
||||
module_name = module_path.split(".")[0] # e.g., os.path -> os
|
||||
|
||||
if module_name in self.checked_modules:
|
||||
return
|
||||
|
||||
self.checked_modules.add(module_name)
|
||||
|
||||
is_stdlib = module_name in sys.stdlib_module_names
|
||||
is_external = not is_stdlib
|
||||
|
||||
if is_stdlib and (self._stdlib_allow_all or module_name in self.stdlib_allow):
|
||||
return
|
||||
|
||||
if is_external and (
|
||||
self._external_allow_all or module_name in self.external_allow
|
||||
):
|
||||
return
|
||||
|
||||
error, allowed_str = (
|
||||
(ERROR_STDLIB_DISALLOWED, self._stdlib_allowed_str)
|
||||
if is_stdlib
|
||||
else (ERROR_EXTERNAL_DISALLOWED, self._external_allowed_str)
|
||||
)
|
||||
|
||||
self._add_violation(
|
||||
lineno, error.format(module=module_path, allowed=allowed_str)
|
||||
)
|
||||
|
||||
def _format_allowed(self, allow_set: Set[str]) -> str:
|
||||
return ", ".join(sorted(allow_set)) if allow_set else "none"
|
||||
|
||||
def _add_violation(self, lineno: int, message: str) -> None:
|
||||
self.violations.append(f"Line {lineno}: {message}")
|
||||
|
||||
|
||||
class TaskAnalyzer:
|
||||
_cache: ValidationCache = OrderedDict()
|
||||
|
||||
def __init__(self, stdlib_allow: Set[str], external_allow: Set[str]):
|
||||
self._stdlib_allow = stdlib_allow
|
||||
self._external_allow = external_allow
|
||||
self._allowlists = (
|
||||
tuple(sorted(stdlib_allow)),
|
||||
tuple(sorted(external_allow)),
|
||||
)
|
||||
self._allow_all = "*" in stdlib_allow and "*" in external_allow
|
||||
|
||||
def validate(self, code: str) -> None:
|
||||
if self._allow_all:
|
||||
return
|
||||
|
||||
cache_key = self._to_cache_key(code)
|
||||
cached_violations = self._cache.get(cache_key)
|
||||
cache_hit = cached_violations is not None
|
||||
|
||||
if cache_hit:
|
||||
self._cache.move_to_end(cache_key)
|
||||
|
||||
if len(cached_violations) == 0:
|
||||
return
|
||||
|
||||
if len(cached_violations) > 0:
|
||||
self._raise_security_error(cached_violations)
|
||||
|
||||
tree = ast.parse(code)
|
||||
|
||||
security_validator = SecurityValidator(self._stdlib_allow, self._external_allow)
|
||||
security_validator.visit(tree)
|
||||
|
||||
self._set_in_cache(cache_key, security_validator.violations)
|
||||
|
||||
if security_validator.violations:
|
||||
self._raise_security_error(security_validator.violations)
|
||||
|
||||
def _raise_security_error(self, violations: CachedViolations) -> None:
|
||||
message = ERROR_SECURITY_VIOLATIONS.format(violations="\n".join(violations))
|
||||
raise SecurityViolationError(message)
|
||||
|
||||
def _to_cache_key(self, code: str) -> CacheKey:
|
||||
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||
return (code_hash, self._allowlists)
|
||||
|
||||
def _set_in_cache(self, cache_key: CacheKey, violations: CachedViolations) -> None:
|
||||
if len(self._cache) >= MAX_VALIDATION_CACHE_SIZE:
|
||||
self._cache.popitem(last=False) # FIFO
|
||||
|
||||
self._cache[cache_key] = violations.copy()
|
||||
self._cache.move_to_end(cache_key)
|
||||
@@ -3,6 +3,8 @@ import multiprocessing
|
||||
import traceback
|
||||
import textwrap
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from src.errors import (
|
||||
TaskResultMissingError,
|
||||
@@ -11,9 +13,9 @@ from src.errors import (
|
||||
TaskProcessExitError,
|
||||
)
|
||||
|
||||
from .message_types.broker import NodeMode, Items
|
||||
from .constants import EXECUTOR_CIRCULAR_REFERENCE_KEY, EXECUTOR_USER_OUTPUT_KEY
|
||||
from typing import Any
|
||||
from src.message_types.broker import NodeMode, Items
|
||||
from src.constants import EXECUTOR_CIRCULAR_REFERENCE_KEY, EXECUTOR_USER_OUTPUT_KEY
|
||||
from typing import Any, Set
|
||||
|
||||
from multiprocessing.context import SpawnProcess
|
||||
|
||||
@@ -26,7 +28,14 @@ class TaskExecutor:
|
||||
"""Responsible for executing Python code tasks in isolated subprocesses."""
|
||||
|
||||
@staticmethod
|
||||
def create_process(code: str, node_mode: NodeMode, items: Items):
|
||||
def create_process(
|
||||
code: str,
|
||||
node_mode: NodeMode,
|
||||
items: Items,
|
||||
stdlib_allow: Set[str],
|
||||
external_allow: Set[str],
|
||||
builtins_deny: set[str],
|
||||
):
|
||||
"""Create a subprocess for executing a Python code task and a queue for communication."""
|
||||
|
||||
fn = (
|
||||
@@ -36,7 +45,10 @@ class TaskExecutor:
|
||||
)
|
||||
|
||||
queue = MULTIPROCESSING_CONTEXT.Queue()
|
||||
process = MULTIPROCESSING_CONTEXT.Process(target=fn, args=(code, items, queue))
|
||||
process = MULTIPROCESSING_CONTEXT.Process(
|
||||
target=fn,
|
||||
args=(code, items, queue, stdlib_allow, external_allow, builtins_deny),
|
||||
)
|
||||
|
||||
return process, queue
|
||||
|
||||
@@ -95,16 +107,27 @@ class TaskExecutor:
|
||||
process.kill()
|
||||
|
||||
@staticmethod
|
||||
def _all_items(raw_code: str, items: Items, queue: multiprocessing.Queue):
|
||||
def _all_items(
|
||||
raw_code: str,
|
||||
items: Items,
|
||||
queue: multiprocessing.Queue,
|
||||
stdlib_allow: Set[str],
|
||||
external_allow: Set[str],
|
||||
builtins_deny: set[str],
|
||||
):
|
||||
"""Execute a Python code task in all-items mode."""
|
||||
|
||||
os.environ.clear()
|
||||
|
||||
TaskExecutor._sanitize_sys_modules(stdlib_allow, external_allow)
|
||||
|
||||
print_args: PrintArgs = []
|
||||
|
||||
try:
|
||||
code = TaskExecutor._wrap_code(raw_code)
|
||||
|
||||
globals = {
|
||||
"__builtins__": __builtins__,
|
||||
"__builtins__": TaskExecutor._filter_builtins(builtins_deny),
|
||||
"_items": items,
|
||||
"print": TaskExecutor._create_custom_print(print_args),
|
||||
}
|
||||
@@ -119,9 +142,20 @@ class TaskExecutor:
|
||||
TaskExecutor._put_error(queue, e, print_args)
|
||||
|
||||
@staticmethod
|
||||
def _per_item(raw_code: str, items: Items, queue: multiprocessing.Queue):
|
||||
def _per_item(
|
||||
raw_code: str,
|
||||
items: Items,
|
||||
queue: multiprocessing.Queue,
|
||||
stdlib_allow: Set[str],
|
||||
external_allow: Set[str],
|
||||
builtins_deny: set[str],
|
||||
):
|
||||
"""Execute a Python code task in per-item mode."""
|
||||
|
||||
os.environ.clear()
|
||||
|
||||
TaskExecutor._sanitize_sys_modules(stdlib_allow, external_allow)
|
||||
|
||||
print_args: PrintArgs = []
|
||||
|
||||
try:
|
||||
@@ -131,7 +165,7 @@ class TaskExecutor:
|
||||
result = []
|
||||
for index, item in enumerate(items):
|
||||
globals = {
|
||||
"__builtins__": __builtins__,
|
||||
"__builtins__": TaskExecutor._filter_builtins(builtins_deny),
|
||||
"_item": item,
|
||||
"print": TaskExecutor._create_custom_print(print_args),
|
||||
}
|
||||
@@ -195,7 +229,7 @@ class TaskExecutor:
|
||||
@staticmethod
|
||||
def _format_print_args(*args) -> list[str]:
|
||||
"""
|
||||
Takes the arguments passed to a `print()` call in user code and converts them
|
||||
Takes the args passed to a `print()` call in user code and converts them
|
||||
to string representations suitable for display in a browser console.
|
||||
|
||||
Expects all args to be serializable.
|
||||
@@ -217,3 +251,45 @@ class TaskExecutor:
|
||||
formatted.append(json.dumps(arg, default=str, ensure_ascii=False))
|
||||
|
||||
return formatted
|
||||
|
||||
# ========== security ==========
|
||||
|
||||
@staticmethod
|
||||
def _filter_builtins(builtins_deny: set[str]):
|
||||
"""Get __builtins__ with denied ones removed."""
|
||||
|
||||
if len(builtins_deny) == 0:
|
||||
return __builtins__
|
||||
|
||||
return {k: v for k, v in __builtins__.items() if k not in builtins_deny}
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_sys_modules(stdlib_allow: Set[str], external_allow: Set[str]):
|
||||
safe_modules = {
|
||||
"builtins",
|
||||
"__main__",
|
||||
"sys",
|
||||
"traceback",
|
||||
"linecache",
|
||||
}
|
||||
|
||||
if "*" in stdlib_allow:
|
||||
safe_modules.update(sys.stdlib_module_names)
|
||||
else:
|
||||
safe_modules.update(stdlib_allow)
|
||||
|
||||
if "*" in external_allow:
|
||||
safe_modules.update(
|
||||
name
|
||||
for name in sys.modules.keys()
|
||||
if name not in sys.stdlib_module_names
|
||||
)
|
||||
else:
|
||||
safe_modules.update(external_allow)
|
||||
|
||||
modules_to_remove = [
|
||||
name for name in sys.modules.keys() if name not in safe_modules
|
||||
]
|
||||
|
||||
for module_name in modules_to_remove:
|
||||
del sys.modules[module_name]
|
||||
|
||||
@@ -2,17 +2,20 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Dict, Optional, Any, Set
|
||||
from urllib.parse import urlparse
|
||||
import websockets
|
||||
import random
|
||||
|
||||
|
||||
from src.errors import WebsocketConnectionError, TaskMissingError
|
||||
from src.errors import (
|
||||
WebsocketConnectionError,
|
||||
TaskMissingError,
|
||||
)
|
||||
from src.message_types.broker import TaskSettings
|
||||
from src.nanoid_utils import nanoid
|
||||
|
||||
from .constants import (
|
||||
from src.constants import (
|
||||
RUNNER_NAME,
|
||||
TASK_REJECTED_REASON_AT_CAPACITY,
|
||||
TASK_REJECTED_REASON_OFFER_EXPIRED,
|
||||
@@ -24,7 +27,7 @@ from .constants import (
|
||||
TASK_BROKER_WS_PATH,
|
||||
RPC_BROWSER_CONSOLE_LOG_METHOD,
|
||||
)
|
||||
from .message_types import (
|
||||
from src.message_types import (
|
||||
BrokerMessage,
|
||||
RunnerMessage,
|
||||
BrokerInfoRequest,
|
||||
@@ -41,9 +44,10 @@ from .message_types import (
|
||||
RunnerTaskError,
|
||||
RunnerRpcCall,
|
||||
)
|
||||
from .message_serde import MessageSerde
|
||||
from .task_state import TaskState, TaskStatus
|
||||
from .task_executor import TaskExecutor
|
||||
from src.message_serde import MessageSerde
|
||||
from src.task_state import TaskState, TaskStatus
|
||||
from src.task_executor import TaskExecutor
|
||||
from src.task_analyzer import TaskAnalyzer
|
||||
|
||||
|
||||
class TaskOffer:
|
||||
@@ -56,13 +60,16 @@ class TaskOffer:
|
||||
return time.time() > self.valid_until
|
||||
|
||||
|
||||
@dataclass()
|
||||
@dataclass
|
||||
class TaskRunnerOpts:
|
||||
grant_token: str
|
||||
task_broker_uri: str
|
||||
max_concurrency: int
|
||||
max_payload_size: int
|
||||
task_timeout: int
|
||||
stdlib_allow: Set[str]
|
||||
external_allow: Set[str]
|
||||
builtins_deny: Set[str]
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
@@ -85,6 +92,7 @@ class TaskRunner:
|
||||
self.offers_coroutine: Optional[asyncio.Task] = None
|
||||
self.serde = MessageSerde()
|
||||
self.executor = TaskExecutor()
|
||||
self.analyzer = TaskAnalyzer(opts.stdlib_allow, opts.external_allow)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
self.task_broker_uri = opts.task_broker_uri
|
||||
@@ -207,14 +215,24 @@ class TaskRunner:
|
||||
if task_state is None:
|
||||
raise TaskMissingError(task_id)
|
||||
|
||||
self.analyzer.validate(task_settings.code)
|
||||
|
||||
process, queue = self.executor.create_process(
|
||||
task_settings.code, task_settings.node_mode, task_settings.items
|
||||
code=task_settings.code,
|
||||
node_mode=task_settings.node_mode,
|
||||
items=task_settings.items,
|
||||
stdlib_allow=self.opts.stdlib_allow,
|
||||
external_allow=self.opts.external_allow,
|
||||
builtins_deny=self.opts.builtins_deny,
|
||||
)
|
||||
|
||||
task_state.process = process
|
||||
|
||||
result, print_args = self.executor.execute_process(
|
||||
process, queue, self.opts.task_timeout, task_settings.continue_on_fail
|
||||
process=process,
|
||||
queue=queue,
|
||||
task_timeout=self.opts.task_timeout,
|
||||
continue_on_fail=task_settings.continue_on_fail,
|
||||
)
|
||||
|
||||
for print_args_per_call in print_args:
|
||||
|
||||
Reference in New Issue
Block a user