mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 10:02:05 +00:00
feat(core): Harden native Python task runner (no-changelog) (#18826)
This commit is contained in:
@@ -23,6 +23,7 @@ OFFER_INTERVAL = 0.25 # 250ms
|
|||||||
OFFER_VALIDITY = 5000 # ms
|
OFFER_VALIDITY = 5000 # ms
|
||||||
OFFER_VALIDITY_MAX_JITTER = 500 # ms
|
OFFER_VALIDITY_MAX_JITTER = 500 # ms
|
||||||
OFFER_VALIDITY_LATENCY_BUFFER = 0.1 # 100ms
|
OFFER_VALIDITY_LATENCY_BUFFER = 0.1 # 100ms
|
||||||
|
MAX_VALIDATION_CACHE_SIZE = 500 # cached validation results
|
||||||
|
|
||||||
# Executor
|
# Executor
|
||||||
EXECUTOR_USER_OUTPUT_KEY = "__n8n_internal_user_output__"
|
EXECUTOR_USER_OUTPUT_KEY = "__n8n_internal_user_output__"
|
||||||
@@ -38,6 +39,9 @@ ENV_GRANT_TOKEN = "N8N_RUNNERS_GRANT_TOKEN"
|
|||||||
ENV_MAX_CONCURRENCY = "N8N_RUNNERS_MAX_CONCURRENCY"
|
ENV_MAX_CONCURRENCY = "N8N_RUNNERS_MAX_CONCURRENCY"
|
||||||
ENV_MAX_PAYLOAD_SIZE = "N8N_RUNNERS_MAX_PAYLOAD"
|
ENV_MAX_PAYLOAD_SIZE = "N8N_RUNNERS_MAX_PAYLOAD"
|
||||||
ENV_TASK_TIMEOUT = "N8N_RUNNERS_TASK_TIMEOUT"
|
ENV_TASK_TIMEOUT = "N8N_RUNNERS_TASK_TIMEOUT"
|
||||||
|
ENV_STDLIB_ALLOW = "N8N_RUNNERS_STDLIB_ALLOW"
|
||||||
|
ENV_EXTERNAL_ALLOW = "N8N_RUNNERS_EXTERNAL_ALLOW"
|
||||||
|
ENV_BUILTINS_DENY = "N8N_RUNNERS_BUILTINS_DENY"
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
LOG_FORMAT = "%(asctime)s.%(msecs)03d\t%(levelname)s\t%(message)s"
|
LOG_FORMAT = "%(asctime)s.%(msecs)03d\t%(levelname)s\t%(message)s"
|
||||||
@@ -51,3 +55,61 @@ TASK_REJECTED_REASON_OFFER_EXPIRED = (
|
|||||||
"Offer expired - not accepted within validity window"
|
"Offer expired - not accepted within validity window"
|
||||||
)
|
)
|
||||||
TASK_REJECTED_REASON_AT_CAPACITY = "No open task slots - runner already at capacity"
|
TASK_REJECTED_REASON_AT_CAPACITY = "No open task slots - runner already at capacity"
|
||||||
|
|
||||||
|
# Security
|
||||||
|
BUILTINS_DENY_DEFAULT = "eval,exec,compile,open,input,breakpoint,__import__,getattr,object,type,vars,setattr,delattr,hasattr,dir,memoryview,__build_class__"
|
||||||
|
ALWAYS_BLOCKED_ATTRIBUTES = {
|
||||||
|
"__subclasses__",
|
||||||
|
"__globals__",
|
||||||
|
"__builtins__",
|
||||||
|
"__traceback__",
|
||||||
|
"tb_frame",
|
||||||
|
"tb_next",
|
||||||
|
"f_back",
|
||||||
|
"f_globals",
|
||||||
|
"f_locals",
|
||||||
|
"f_code",
|
||||||
|
"f_builtins",
|
||||||
|
"__getattribute__",
|
||||||
|
"__qualname__",
|
||||||
|
"__module__",
|
||||||
|
"gi_frame",
|
||||||
|
"gi_code",
|
||||||
|
"gi_yieldfrom",
|
||||||
|
"cr_frame",
|
||||||
|
"cr_code",
|
||||||
|
"ag_frame",
|
||||||
|
"ag_code",
|
||||||
|
"__thisclass__",
|
||||||
|
"__self_class__",
|
||||||
|
}
|
||||||
|
# Attributes blocked only in certain contexts:
|
||||||
|
# - In attribute chains (e.g., x.__class__.__bases__)
|
||||||
|
# - On literals (e.g., "".__class__)
|
||||||
|
CONDITIONALLY_BLOCKED_ATTRIBUTES = {
|
||||||
|
"__class__",
|
||||||
|
"__bases__",
|
||||||
|
"__code__",
|
||||||
|
"__closure__",
|
||||||
|
"__loader__",
|
||||||
|
"__cached__",
|
||||||
|
"__dict__",
|
||||||
|
"__import__",
|
||||||
|
"__mro__",
|
||||||
|
"__init_subclass__",
|
||||||
|
"__getattr__",
|
||||||
|
"__setattr__",
|
||||||
|
"__delattr__",
|
||||||
|
"__self__",
|
||||||
|
"__func__",
|
||||||
|
"__wrapped__",
|
||||||
|
"__annotations__",
|
||||||
|
}
|
||||||
|
UNSAFE_ATTRIBUTES = ALWAYS_BLOCKED_ATTRIBUTES | CONDITIONALLY_BLOCKED_ATTRIBUTES
|
||||||
|
|
||||||
|
# errors
|
||||||
|
ERROR_RELATIVE_IMPORT = "Relative imports are disallowed."
|
||||||
|
ERROR_STDLIB_DISALLOWED = "Import of standard library module '{module}' is disallowed. Allowed stdlib modules: {allowed}"
|
||||||
|
ERROR_EXTERNAL_DISALLOWED = "Import of external package '{module}' is disallowed. Allowed external packages: {allowed}"
|
||||||
|
ERROR_DANGEROUS_ATTRIBUTE = "Access to attribute '{attr}' is disallowed, because it can be used to bypass security restrictions."
|
||||||
|
ERROR_SECURITY_VIOLATIONS = "Security violations detected:\n{violations}"
|
||||||
|
|||||||
76
packages/@n8n/task-runner-python/src/env.py
Normal file
76
packages/@n8n/task-runner-python/src/env.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import os
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
from src.constants import (
|
||||||
|
DEFAULT_MAX_CONCURRENCY,
|
||||||
|
DEFAULT_TASK_TIMEOUT,
|
||||||
|
DEFAULT_TASK_BROKER_URI,
|
||||||
|
DEFAULT_MAX_PAYLOAD_SIZE,
|
||||||
|
BUILTINS_DENY_DEFAULT,
|
||||||
|
ENV_MAX_CONCURRENCY,
|
||||||
|
ENV_MAX_PAYLOAD_SIZE,
|
||||||
|
ENV_TASK_BROKER_URI,
|
||||||
|
ENV_GRANT_TOKEN,
|
||||||
|
ENV_TASK_TIMEOUT,
|
||||||
|
ENV_BUILTINS_DENY,
|
||||||
|
ENV_STDLIB_ALLOW,
|
||||||
|
ENV_EXTERNAL_ALLOW,
|
||||||
|
)
|
||||||
|
from src.task_runner import TaskRunnerOpts
|
||||||
|
|
||||||
|
|
||||||
|
def parse_allowlist(allowlist_str: str, list_name: str) -> Set[str]:
|
||||||
|
if not allowlist_str:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
modules = {
|
||||||
|
module
|
||||||
|
for raw_module in allowlist_str.split(",")
|
||||||
|
if (module := raw_module.strip())
|
||||||
|
}
|
||||||
|
|
||||||
|
if "*" in modules and len(modules) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Wildcard '*' in {list_name} must be used alone, not with other modules. "
|
||||||
|
f"Got: {', '.join(sorted(modules))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return modules
|
||||||
|
|
||||||
|
|
||||||
|
def parse_denylist(denylist_str: str) -> Set[str]:
|
||||||
|
if not denylist_str:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
return {name for raw_name in denylist_str.split(",") if (name := raw_name.strip())}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_env_vars() -> TaskRunnerOpts:
|
||||||
|
grant_token = os.getenv(ENV_GRANT_TOKEN, "")
|
||||||
|
|
||||||
|
if not grant_token:
|
||||||
|
raise ValueError(f"{ENV_GRANT_TOKEN} environment variable is required")
|
||||||
|
|
||||||
|
builtins_deny_str = os.getenv(ENV_BUILTINS_DENY, BUILTINS_DENY_DEFAULT)
|
||||||
|
builtins_deny = parse_denylist(builtins_deny_str)
|
||||||
|
|
||||||
|
stdlib_allow_str = os.getenv(ENV_STDLIB_ALLOW, "")
|
||||||
|
stdlib_allow = parse_allowlist(stdlib_allow_str, "stdlib allowlist")
|
||||||
|
|
||||||
|
external_allow_str = os.getenv(ENV_EXTERNAL_ALLOW, "")
|
||||||
|
external_allow = parse_allowlist(external_allow_str, "external allowlist")
|
||||||
|
|
||||||
|
return TaskRunnerOpts(
|
||||||
|
grant_token=grant_token,
|
||||||
|
task_broker_uri=os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI),
|
||||||
|
max_concurrency=int(
|
||||||
|
os.getenv(ENV_MAX_CONCURRENCY) or str(DEFAULT_MAX_CONCURRENCY)
|
||||||
|
),
|
||||||
|
max_payload_size=int(
|
||||||
|
os.getenv(ENV_MAX_PAYLOAD_SIZE) or str(DEFAULT_MAX_PAYLOAD_SIZE)
|
||||||
|
),
|
||||||
|
task_timeout=int(os.getenv(ENV_TASK_TIMEOUT) or str(DEFAULT_TASK_TIMEOUT)),
|
||||||
|
stdlib_allow=stdlib_allow,
|
||||||
|
external_allow=external_allow,
|
||||||
|
builtins_deny=builtins_deny,
|
||||||
|
)
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from .security_violation_error import SecurityViolationError
|
||||||
from .task_missing_error import TaskMissingError
|
from .task_missing_error import TaskMissingError
|
||||||
from .task_result_missing_error import TaskResultMissingError
|
from .task_result_missing_error import TaskResultMissingError
|
||||||
from .task_process_exit_error import TaskProcessExitError
|
from .task_process_exit_error import TaskProcessExitError
|
||||||
@@ -6,6 +7,7 @@ from .task_timeout_error import TaskTimeoutError
|
|||||||
from .websocket_connection_error import WebsocketConnectionError
|
from .websocket_connection_error import WebsocketConnectionError
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"SecurityViolationError",
|
||||||
"TaskMissingError",
|
"TaskMissingError",
|
||||||
"TaskProcessExitError",
|
"TaskProcessExitError",
|
||||||
"TaskResultMissingError",
|
"TaskResultMissingError",
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
class SecurityViolationError(Exception):
|
||||||
|
"""Raised when code violates security policies, typically through use of disallowed modules or builtins."""
|
||||||
|
|
||||||
|
pass
|
||||||
@@ -1,23 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
os.environ["WEBSOCKETS_MAX_LOG_SIZE"] = "256"
|
from src.env import parse_env_vars
|
||||||
|
|
||||||
from src.constants import (
|
|
||||||
DEFAULT_MAX_CONCURRENCY,
|
|
||||||
DEFAULT_TASK_TIMEOUT,
|
|
||||||
ENV_MAX_CONCURRENCY,
|
|
||||||
ENV_MAX_PAYLOAD_SIZE,
|
|
||||||
ENV_TASK_BROKER_URI,
|
|
||||||
ENV_GRANT_TOKEN,
|
|
||||||
DEFAULT_TASK_BROKER_URI,
|
|
||||||
DEFAULT_MAX_PAYLOAD_SIZE,
|
|
||||||
ENV_TASK_TIMEOUT,
|
|
||||||
)
|
|
||||||
from src.logs import setup_logging
|
from src.logs import setup_logging
|
||||||
from src.task_runner import TaskRunner, TaskRunnerOpts
|
from src.task_runner import TaskRunner
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@@ -26,20 +13,12 @@ async def main():
|
|||||||
|
|
||||||
logger.info("Starting runner...")
|
logger.info("Starting runner...")
|
||||||
|
|
||||||
grant_token = os.getenv(ENV_GRANT_TOKEN, "")
|
try:
|
||||||
|
opts = parse_env_vars()
|
||||||
if not grant_token:
|
except ValueError as e:
|
||||||
logger.error(f"{ENV_GRANT_TOKEN} environment variable is required")
|
logger.error(str(e))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
opts = TaskRunnerOpts(
|
|
||||||
grant_token,
|
|
||||||
os.getenv(ENV_TASK_BROKER_URI, DEFAULT_TASK_BROKER_URI),
|
|
||||||
int(os.getenv(ENV_MAX_CONCURRENCY, DEFAULT_MAX_CONCURRENCY)),
|
|
||||||
int(os.getenv(ENV_MAX_PAYLOAD_SIZE, DEFAULT_MAX_PAYLOAD_SIZE)),
|
|
||||||
int(os.getenv(ENV_TASK_TIMEOUT, DEFAULT_TASK_TIMEOUT)),
|
|
||||||
)
|
|
||||||
|
|
||||||
task_runner = TaskRunner(opts)
|
task_runner = TaskRunner(opts)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
171
packages/@n8n/task-runner-python/src/task_analyzer.py
Normal file
171
packages/@n8n/task-runner-python/src/task_analyzer.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
import ast
|
||||||
|
import hashlib
|
||||||
|
import sys
|
||||||
|
from typing import Set, Tuple
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from src.errors import SecurityViolationError
|
||||||
|
from src.constants import (
|
||||||
|
MAX_VALIDATION_CACHE_SIZE,
|
||||||
|
ERROR_RELATIVE_IMPORT,
|
||||||
|
ERROR_STDLIB_DISALLOWED,
|
||||||
|
ERROR_EXTERNAL_DISALLOWED,
|
||||||
|
ERROR_DANGEROUS_ATTRIBUTE,
|
||||||
|
ERROR_SECURITY_VIOLATIONS,
|
||||||
|
ALWAYS_BLOCKED_ATTRIBUTES,
|
||||||
|
UNSAFE_ATTRIBUTES,
|
||||||
|
)
|
||||||
|
|
||||||
|
CacheKey = Tuple[str, Tuple] # (code_hash, allowlists_tuple)
|
||||||
|
CachedViolations = list[str]
|
||||||
|
ValidationCache = OrderedDict[CacheKey, CachedViolations]
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityValidator(ast.NodeVisitor):
|
||||||
|
"""AST visitor that enforces import allowlists and blocks dangerous attribute access."""
|
||||||
|
|
||||||
|
def __init__(self, stdlib_allow: Set[str], external_allow: Set[str]):
|
||||||
|
self.checked_modules: Set[str] = set()
|
||||||
|
self.violations: list[str] = []
|
||||||
|
|
||||||
|
self.stdlib_allow = stdlib_allow
|
||||||
|
self.external_allow = external_allow
|
||||||
|
self._stdlib_allowed_str = self._format_allowed(stdlib_allow)
|
||||||
|
self._external_allowed_str = self._format_allowed(external_allow)
|
||||||
|
self._stdlib_allow_all = "*" in stdlib_allow
|
||||||
|
self._external_allow_all = "*" in external_allow
|
||||||
|
|
||||||
|
# ========== Detection ==========
|
||||||
|
|
||||||
|
def visit_Import(self, node: ast.Import) -> None:
|
||||||
|
"""Detect bare import statements (e.g., import os), including aliased (e.g., import numpy as np)."""
|
||||||
|
|
||||||
|
for alias in node.names:
|
||||||
|
module_name = alias.name
|
||||||
|
self._validate_import(module_name, node.lineno)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||||
|
"""Detect from import statements (e.g., from os import path)."""
|
||||||
|
|
||||||
|
if node.level > 0:
|
||||||
|
self._add_violation(node.lineno, ERROR_RELATIVE_IMPORT)
|
||||||
|
elif node.module:
|
||||||
|
self._validate_import(node.module, node.lineno)
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_Attribute(self, node: ast.Attribute) -> None:
|
||||||
|
"""Detect access to unsafe attributes that could bypass security."""
|
||||||
|
|
||||||
|
if node.attr in UNSAFE_ATTRIBUTES:
|
||||||
|
# Block regardless of context
|
||||||
|
if node.attr in ALWAYS_BLOCKED_ATTRIBUTES:
|
||||||
|
self._add_violation(
|
||||||
|
node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=node.attr)
|
||||||
|
)
|
||||||
|
# Block in attribute chains (e.g., x.__class__.__bases__) or on literals (e.g., "".__class__)
|
||||||
|
elif isinstance(node.value, (ast.Attribute, ast.Constant)):
|
||||||
|
self._add_violation(
|
||||||
|
node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=node.attr)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
# ========== Validation ==========
|
||||||
|
|
||||||
|
def _validate_import(self, module_path: str, lineno: int) -> None:
|
||||||
|
"""Validate that a module import is allowed based on allowlists. Also disallow relative imports."""
|
||||||
|
|
||||||
|
if module_path.startswith("."):
|
||||||
|
self._add_violation(lineno, ERROR_RELATIVE_IMPORT)
|
||||||
|
return
|
||||||
|
|
||||||
|
module_name = module_path.split(".")[0] # e.g., os.path -> os
|
||||||
|
|
||||||
|
if module_name in self.checked_modules:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.checked_modules.add(module_name)
|
||||||
|
|
||||||
|
is_stdlib = module_name in sys.stdlib_module_names
|
||||||
|
is_external = not is_stdlib
|
||||||
|
|
||||||
|
if is_stdlib and (self._stdlib_allow_all or module_name in self.stdlib_allow):
|
||||||
|
return
|
||||||
|
|
||||||
|
if is_external and (
|
||||||
|
self._external_allow_all or module_name in self.external_allow
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
error, allowed_str = (
|
||||||
|
(ERROR_STDLIB_DISALLOWED, self._stdlib_allowed_str)
|
||||||
|
if is_stdlib
|
||||||
|
else (ERROR_EXTERNAL_DISALLOWED, self._external_allowed_str)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_violation(
|
||||||
|
lineno, error.format(module=module_path, allowed=allowed_str)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_allowed(self, allow_set: Set[str]) -> str:
|
||||||
|
return ", ".join(sorted(allow_set)) if allow_set else "none"
|
||||||
|
|
||||||
|
def _add_violation(self, lineno: int, message: str) -> None:
|
||||||
|
self.violations.append(f"Line {lineno}: {message}")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskAnalyzer:
|
||||||
|
_cache: ValidationCache = OrderedDict()
|
||||||
|
|
||||||
|
def __init__(self, stdlib_allow: Set[str], external_allow: Set[str]):
|
||||||
|
self._stdlib_allow = stdlib_allow
|
||||||
|
self._external_allow = external_allow
|
||||||
|
self._allowlists = (
|
||||||
|
tuple(sorted(stdlib_allow)),
|
||||||
|
tuple(sorted(external_allow)),
|
||||||
|
)
|
||||||
|
self._allow_all = "*" in stdlib_allow and "*" in external_allow
|
||||||
|
|
||||||
|
def validate(self, code: str) -> None:
|
||||||
|
if self._allow_all:
|
||||||
|
return
|
||||||
|
|
||||||
|
cache_key = self._to_cache_key(code)
|
||||||
|
cached_violations = self._cache.get(cache_key)
|
||||||
|
cache_hit = cached_violations is not None
|
||||||
|
|
||||||
|
if cache_hit:
|
||||||
|
self._cache.move_to_end(cache_key)
|
||||||
|
|
||||||
|
if len(cached_violations) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(cached_violations) > 0:
|
||||||
|
self._raise_security_error(cached_violations)
|
||||||
|
|
||||||
|
tree = ast.parse(code)
|
||||||
|
|
||||||
|
security_validator = SecurityValidator(self._stdlib_allow, self._external_allow)
|
||||||
|
security_validator.visit(tree)
|
||||||
|
|
||||||
|
self._set_in_cache(cache_key, security_validator.violations)
|
||||||
|
|
||||||
|
if security_validator.violations:
|
||||||
|
self._raise_security_error(security_validator.violations)
|
||||||
|
|
||||||
|
def _raise_security_error(self, violations: CachedViolations) -> None:
|
||||||
|
message = ERROR_SECURITY_VIOLATIONS.format(violations="\n".join(violations))
|
||||||
|
raise SecurityViolationError(message)
|
||||||
|
|
||||||
|
def _to_cache_key(self, code: str) -> CacheKey:
|
||||||
|
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||||
|
return (code_hash, self._allowlists)
|
||||||
|
|
||||||
|
def _set_in_cache(self, cache_key: CacheKey, violations: CachedViolations) -> None:
|
||||||
|
if len(self._cache) >= MAX_VALIDATION_CACHE_SIZE:
|
||||||
|
self._cache.popitem(last=False) # FIFO
|
||||||
|
|
||||||
|
self._cache[cache_key] = violations.copy()
|
||||||
|
self._cache.move_to_end(cache_key)
|
||||||
@@ -3,6 +3,8 @@ import multiprocessing
|
|||||||
import traceback
|
import traceback
|
||||||
import textwrap
|
import textwrap
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from src.errors import (
|
from src.errors import (
|
||||||
TaskResultMissingError,
|
TaskResultMissingError,
|
||||||
@@ -11,9 +13,9 @@ from src.errors import (
|
|||||||
TaskProcessExitError,
|
TaskProcessExitError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .message_types.broker import NodeMode, Items
|
from src.message_types.broker import NodeMode, Items
|
||||||
from .constants import EXECUTOR_CIRCULAR_REFERENCE_KEY, EXECUTOR_USER_OUTPUT_KEY
|
from src.constants import EXECUTOR_CIRCULAR_REFERENCE_KEY, EXECUTOR_USER_OUTPUT_KEY
|
||||||
from typing import Any
|
from typing import Any, Set
|
||||||
|
|
||||||
from multiprocessing.context import SpawnProcess
|
from multiprocessing.context import SpawnProcess
|
||||||
|
|
||||||
@@ -26,7 +28,14 @@ class TaskExecutor:
|
|||||||
"""Responsible for executing Python code tasks in isolated subprocesses."""
|
"""Responsible for executing Python code tasks in isolated subprocesses."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_process(code: str, node_mode: NodeMode, items: Items):
|
def create_process(
|
||||||
|
code: str,
|
||||||
|
node_mode: NodeMode,
|
||||||
|
items: Items,
|
||||||
|
stdlib_allow: Set[str],
|
||||||
|
external_allow: Set[str],
|
||||||
|
builtins_deny: set[str],
|
||||||
|
):
|
||||||
"""Create a subprocess for executing a Python code task and a queue for communication."""
|
"""Create a subprocess for executing a Python code task and a queue for communication."""
|
||||||
|
|
||||||
fn = (
|
fn = (
|
||||||
@@ -36,7 +45,10 @@ class TaskExecutor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
queue = MULTIPROCESSING_CONTEXT.Queue()
|
queue = MULTIPROCESSING_CONTEXT.Queue()
|
||||||
process = MULTIPROCESSING_CONTEXT.Process(target=fn, args=(code, items, queue))
|
process = MULTIPROCESSING_CONTEXT.Process(
|
||||||
|
target=fn,
|
||||||
|
args=(code, items, queue, stdlib_allow, external_allow, builtins_deny),
|
||||||
|
)
|
||||||
|
|
||||||
return process, queue
|
return process, queue
|
||||||
|
|
||||||
@@ -95,16 +107,27 @@ class TaskExecutor:
|
|||||||
process.kill()
|
process.kill()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _all_items(raw_code: str, items: Items, queue: multiprocessing.Queue):
|
def _all_items(
|
||||||
|
raw_code: str,
|
||||||
|
items: Items,
|
||||||
|
queue: multiprocessing.Queue,
|
||||||
|
stdlib_allow: Set[str],
|
||||||
|
external_allow: Set[str],
|
||||||
|
builtins_deny: set[str],
|
||||||
|
):
|
||||||
"""Execute a Python code task in all-items mode."""
|
"""Execute a Python code task in all-items mode."""
|
||||||
|
|
||||||
|
os.environ.clear()
|
||||||
|
|
||||||
|
TaskExecutor._sanitize_sys_modules(stdlib_allow, external_allow)
|
||||||
|
|
||||||
print_args: PrintArgs = []
|
print_args: PrintArgs = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
code = TaskExecutor._wrap_code(raw_code)
|
code = TaskExecutor._wrap_code(raw_code)
|
||||||
|
|
||||||
globals = {
|
globals = {
|
||||||
"__builtins__": __builtins__,
|
"__builtins__": TaskExecutor._filter_builtins(builtins_deny),
|
||||||
"_items": items,
|
"_items": items,
|
||||||
"print": TaskExecutor._create_custom_print(print_args),
|
"print": TaskExecutor._create_custom_print(print_args),
|
||||||
}
|
}
|
||||||
@@ -119,9 +142,20 @@ class TaskExecutor:
|
|||||||
TaskExecutor._put_error(queue, e, print_args)
|
TaskExecutor._put_error(queue, e, print_args)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _per_item(raw_code: str, items: Items, queue: multiprocessing.Queue):
|
def _per_item(
|
||||||
|
raw_code: str,
|
||||||
|
items: Items,
|
||||||
|
queue: multiprocessing.Queue,
|
||||||
|
stdlib_allow: Set[str],
|
||||||
|
external_allow: Set[str],
|
||||||
|
builtins_deny: set[str],
|
||||||
|
):
|
||||||
"""Execute a Python code task in per-item mode."""
|
"""Execute a Python code task in per-item mode."""
|
||||||
|
|
||||||
|
os.environ.clear()
|
||||||
|
|
||||||
|
TaskExecutor._sanitize_sys_modules(stdlib_allow, external_allow)
|
||||||
|
|
||||||
print_args: PrintArgs = []
|
print_args: PrintArgs = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -131,7 +165,7 @@ class TaskExecutor:
|
|||||||
result = []
|
result = []
|
||||||
for index, item in enumerate(items):
|
for index, item in enumerate(items):
|
||||||
globals = {
|
globals = {
|
||||||
"__builtins__": __builtins__,
|
"__builtins__": TaskExecutor._filter_builtins(builtins_deny),
|
||||||
"_item": item,
|
"_item": item,
|
||||||
"print": TaskExecutor._create_custom_print(print_args),
|
"print": TaskExecutor._create_custom_print(print_args),
|
||||||
}
|
}
|
||||||
@@ -195,7 +229,7 @@ class TaskExecutor:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_print_args(*args) -> list[str]:
|
def _format_print_args(*args) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Takes the arguments passed to a `print()` call in user code and converts them
|
Takes the args passed to a `print()` call in user code and converts them
|
||||||
to string representations suitable for display in a browser console.
|
to string representations suitable for display in a browser console.
|
||||||
|
|
||||||
Expects all args to be serializable.
|
Expects all args to be serializable.
|
||||||
@@ -217,3 +251,45 @@ class TaskExecutor:
|
|||||||
formatted.append(json.dumps(arg, default=str, ensure_ascii=False))
|
formatted.append(json.dumps(arg, default=str, ensure_ascii=False))
|
||||||
|
|
||||||
return formatted
|
return formatted
|
||||||
|
|
||||||
|
# ========== security ==========
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _filter_builtins(builtins_deny: set[str]):
|
||||||
|
"""Get __builtins__ with denied ones removed."""
|
||||||
|
|
||||||
|
if len(builtins_deny) == 0:
|
||||||
|
return __builtins__
|
||||||
|
|
||||||
|
return {k: v for k, v in __builtins__.items() if k not in builtins_deny}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_sys_modules(stdlib_allow: Set[str], external_allow: Set[str]):
|
||||||
|
safe_modules = {
|
||||||
|
"builtins",
|
||||||
|
"__main__",
|
||||||
|
"sys",
|
||||||
|
"traceback",
|
||||||
|
"linecache",
|
||||||
|
}
|
||||||
|
|
||||||
|
if "*" in stdlib_allow:
|
||||||
|
safe_modules.update(sys.stdlib_module_names)
|
||||||
|
else:
|
||||||
|
safe_modules.update(stdlib_allow)
|
||||||
|
|
||||||
|
if "*" in external_allow:
|
||||||
|
safe_modules.update(
|
||||||
|
name
|
||||||
|
for name in sys.modules.keys()
|
||||||
|
if name not in sys.stdlib_module_names
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
safe_modules.update(external_allow)
|
||||||
|
|
||||||
|
modules_to_remove = [
|
||||||
|
name for name in sys.modules.keys() if name not in safe_modules
|
||||||
|
]
|
||||||
|
|
||||||
|
for module_name in modules_to_remove:
|
||||||
|
del sys.modules[module_name]
|
||||||
|
|||||||
@@ -2,17 +2,20 @@ import asyncio
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Optional, Any
|
from typing import Dict, Optional, Any, Set
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import websockets
|
import websockets
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
from src.errors import WebsocketConnectionError, TaskMissingError
|
from src.errors import (
|
||||||
|
WebsocketConnectionError,
|
||||||
|
TaskMissingError,
|
||||||
|
)
|
||||||
from src.message_types.broker import TaskSettings
|
from src.message_types.broker import TaskSettings
|
||||||
from src.nanoid_utils import nanoid
|
from src.nanoid_utils import nanoid
|
||||||
|
|
||||||
from .constants import (
|
from src.constants import (
|
||||||
RUNNER_NAME,
|
RUNNER_NAME,
|
||||||
TASK_REJECTED_REASON_AT_CAPACITY,
|
TASK_REJECTED_REASON_AT_CAPACITY,
|
||||||
TASK_REJECTED_REASON_OFFER_EXPIRED,
|
TASK_REJECTED_REASON_OFFER_EXPIRED,
|
||||||
@@ -24,7 +27,7 @@ from .constants import (
|
|||||||
TASK_BROKER_WS_PATH,
|
TASK_BROKER_WS_PATH,
|
||||||
RPC_BROWSER_CONSOLE_LOG_METHOD,
|
RPC_BROWSER_CONSOLE_LOG_METHOD,
|
||||||
)
|
)
|
||||||
from .message_types import (
|
from src.message_types import (
|
||||||
BrokerMessage,
|
BrokerMessage,
|
||||||
RunnerMessage,
|
RunnerMessage,
|
||||||
BrokerInfoRequest,
|
BrokerInfoRequest,
|
||||||
@@ -41,9 +44,10 @@ from .message_types import (
|
|||||||
RunnerTaskError,
|
RunnerTaskError,
|
||||||
RunnerRpcCall,
|
RunnerRpcCall,
|
||||||
)
|
)
|
||||||
from .message_serde import MessageSerde
|
from src.message_serde import MessageSerde
|
||||||
from .task_state import TaskState, TaskStatus
|
from src.task_state import TaskState, TaskStatus
|
||||||
from .task_executor import TaskExecutor
|
from src.task_executor import TaskExecutor
|
||||||
|
from src.task_analyzer import TaskAnalyzer
|
||||||
|
|
||||||
|
|
||||||
class TaskOffer:
|
class TaskOffer:
|
||||||
@@ -56,13 +60,16 @@ class TaskOffer:
|
|||||||
return time.time() > self.valid_until
|
return time.time() > self.valid_until
|
||||||
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass
|
||||||
class TaskRunnerOpts:
|
class TaskRunnerOpts:
|
||||||
grant_token: str
|
grant_token: str
|
||||||
task_broker_uri: str
|
task_broker_uri: str
|
||||||
max_concurrency: int
|
max_concurrency: int
|
||||||
max_payload_size: int
|
max_payload_size: int
|
||||||
task_timeout: int
|
task_timeout: int
|
||||||
|
stdlib_allow: Set[str]
|
||||||
|
external_allow: Set[str]
|
||||||
|
builtins_deny: Set[str]
|
||||||
|
|
||||||
|
|
||||||
class TaskRunner:
|
class TaskRunner:
|
||||||
@@ -85,6 +92,7 @@ class TaskRunner:
|
|||||||
self.offers_coroutine: Optional[asyncio.Task] = None
|
self.offers_coroutine: Optional[asyncio.Task] = None
|
||||||
self.serde = MessageSerde()
|
self.serde = MessageSerde()
|
||||||
self.executor = TaskExecutor()
|
self.executor = TaskExecutor()
|
||||||
|
self.analyzer = TaskAnalyzer(opts.stdlib_allow, opts.external_allow)
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
self.task_broker_uri = opts.task_broker_uri
|
self.task_broker_uri = opts.task_broker_uri
|
||||||
@@ -207,14 +215,24 @@ class TaskRunner:
|
|||||||
if task_state is None:
|
if task_state is None:
|
||||||
raise TaskMissingError(task_id)
|
raise TaskMissingError(task_id)
|
||||||
|
|
||||||
|
self.analyzer.validate(task_settings.code)
|
||||||
|
|
||||||
process, queue = self.executor.create_process(
|
process, queue = self.executor.create_process(
|
||||||
task_settings.code, task_settings.node_mode, task_settings.items
|
code=task_settings.code,
|
||||||
|
node_mode=task_settings.node_mode,
|
||||||
|
items=task_settings.items,
|
||||||
|
stdlib_allow=self.opts.stdlib_allow,
|
||||||
|
external_allow=self.opts.external_allow,
|
||||||
|
builtins_deny=self.opts.builtins_deny,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_state.process = process
|
task_state.process = process
|
||||||
|
|
||||||
result, print_args = self.executor.execute_process(
|
result, print_args = self.executor.execute_process(
|
||||||
process, queue, self.opts.task_timeout, task_settings.continue_on_fail
|
process=process,
|
||||||
|
queue=queue,
|
||||||
|
task_timeout=self.opts.task_timeout,
|
||||||
|
continue_on_fail=task_settings.continue_on_fail,
|
||||||
)
|
)
|
||||||
|
|
||||||
for print_args_per_call in print_args:
|
for print_args_per_call in print_args:
|
||||||
|
|||||||
Reference in New Issue
Block a user