diff --git a/README.md b/README.md index c14263bd..e979aa04 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,9 @@ The Python SDK is under development. There are no compatibility guarantees at th - [Sandbox is not Secure](#sandbox-is-not-secure) - [Sandbox Performance](#sandbox-performance) - [Extending Restricted Classes](#extending-restricted-classes) + - [Certain Standard Library Calls on Restricted Objects](#certain-standard-library-calls-on-restricted-objects) - [is_subclass of ABC-based Restricted Classes](#is_subclass-of-abc-based-restricted-classes) + - [Compiled Pydantic Sometimes Using Wrong Types](#compiled-pydantic-sometimes-using-wrong-types) - [Activities](#activities) - [Definition](#definition-1) - [Types of Activities](#types-of-activities) @@ -672,6 +674,10 @@ workflow will not progress until fixed. The sandbox is not foolproof and non-determinism can still occur. It is simply a best-effort way to catch bad code early. Users are encouraged to define their workflows in files with no other side effects. +The sandbox offers a mechanism to pass through modules from outside the sandbox. By default this already includes all +standard library modules and Temporal modules. **For performance and behavior reasons, users are encouraged to pass +through all third party modules whose calls will be deterministic.** See "Passthrough Modules" below on how to do this. + ##### How the Sandbox Works The sandbox is made up of two components that work closely together: @@ -718,23 +724,46 @@ future releases. When creating the `Worker`, the `workflow_runner` is defaulted to `temporalio.worker.workflow_sandbox.SandboxedWorkflowRunner()`. The `SandboxedWorkflowRunner`'s init accepts a `restrictions` keyword argument that is defaulted to `SandboxRestrictions.default`. The `SandboxRestrictions` dataclass -is immutable and contains three fields that can be customized, but only two have notable value +is immutable and contains three fields that can be customized, but only two have notable value. See below. ###### Passthrough Modules -To make the sandbox quicker and use less memory when importing known third party libraries, they can be added to the -`SandboxRestrictions.passthrough_modules` set like so: +By default the sandbox completely reloads non-standard-library and non-Temporal modules for every workflow run. To make +the sandbox quicker and use less memory when importing known-side-effect-free third party modules, they can be marked +as passthrough modules. + +**For performance and behavior reasons, users are encouraged to pass through all third party modules whose calls will be +deterministic.** + +One way to pass through a module is at import time in the workflow file using the `imports_passed_through` context +manager like so: ```python -my_restrictions = dataclasses.replace( - SandboxRestrictions.default, - passthrough_modules=SandboxRestrictions.passthrough_modules_default | SandboxMatcher(access={"pydantic"}), +# my_workflow_file.py + +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + import pydantic + +@workflow.defn +class MyWorkflow: + ... +``` + +Alternatively, this can be done at worker creation time by customizing the runner's restrictions. For example: + +```python +my_worker = Worker( + ..., + workflow_runner=SandboxedWorkflowRunner( + restrictions=SandboxRestrictions.default.with_passthrough_modules("pydantic") + ) ) -my_worker = Worker(..., runner=SandboxedWorkflowRunner(restrictions=my_restrictions)) ``` -If an "access" match succeeds for an import, it will simply be forwarded from outside of the sandbox. See the API for -more details on exact fields and their meaning. +In both of these cases, now the `pydantic` module will be passed through from outside of the sandbox instead of +being reloaded for every workflow run. ###### Invalid Module Members @@ -749,7 +778,7 @@ my_restrictions = dataclasses.replace( "datetime", "date", "today", ), ) -my_worker = Worker(..., runner=SandboxedWorkflowRunner(restrictions=my_restrictions)) +my_worker = Worker(..., workflow_runner=SandboxedWorkflowRunner(restrictions=my_restrictions)) ``` Restrictions can also be added by `|`'ing together matchers, for example to restrict the `datetime.date` class from @@ -762,7 +791,7 @@ my_restrictions = dataclasses.replace( children={"datetime": SandboxMatcher(use={"date"})}, ), ) -my_worker = Worker(..., runner=SandboxedWorkflowRunner(restrictions=my_restrictions)) +my_worker = Worker(..., workflow_runner=SandboxedWorkflowRunner(restrictions=my_restrictions)) ``` See the API for more details on exact fields and their meaning. @@ -802,9 +831,22 @@ To mitigate this, users should: ###### Extending Restricted Classes -Currently, extending classes marked as restricted causes an issue with their `__init__` parameters. This does not affect -most users, but if there is a dependency that is, say, extending `zipfile.ZipFile` an error may occur and the module -will have to be marked as pass through. +Extending a restricted class causes Python to instantiate the restricted metaclass which is unsupported. Therefore if +you attempt to use a class in the sandbox that extends a restricted class, it will fail. For example, if you have a +`class MyZipFile(zipfile.ZipFile)` and try to use that class inside a workflow, it will fail. + +Classes used inside the workflow should not extend restricted classes. For situations where third-party modules need to +at import time, they should be marked as pass through modules. + +###### Certain Standard Library Calls on Restricted Objects + +If an object is restricted, internal C Python validation may fail in some cases. For example, running +`dict.items(os.__dict__)` will fail with: + +> descriptor 'items' for 'dict' objects doesn't apply to a '_RestrictedProxy' object + +This is a low-level check that cannot be subverted. The solution is to not use restricted objects inside the sandbox. +For situations where third-party modules need to at import time, they should be marked as pass through modules. ###### is_subclass of ABC-based Restricted Classes @@ -812,6 +854,16 @@ Due to [https://bugs.python.org/issue44847](https://bugs.python.org/issue44847), checked to see if they are subclasses of another via `is_subclass` may fail (see also [this wrapt issue](https://github.com/GrahamDumpleton/wrapt/issues/130)). +###### Compiled Pydantic Sometimes Using Wrong Types + +If the Pydantic dependency is in compiled form (the default) and you are using a Pydantic model inside a workflow +sandbox that uses a `datetime` type, it will grab the wrong validator and use `date` instead. This is because our +patched form of `issubclass` is bypassed by compiled Pydantic. + +To work around, either don't use `datetime`-based Pydantic model fields in workflows, or mark `datetime` library as +passthrough (means you lose protection against calling the non-deterministic `now()`), or use non-compiled Pydantic +dependency. + ### Activities #### Definition diff --git a/temporalio/worker/workflow_sandbox/_importer.py b/temporalio/worker/workflow_sandbox/_importer.py index 31bd3a6a..62f314c6 100644 --- a/temporalio/worker/workflow_sandbox/_importer.py +++ b/temporalio/worker/workflow_sandbox/_importer.py @@ -14,6 +14,7 @@ import sys import threading import types +import warnings from contextlib import ExitStack, contextmanager from typing import ( Any, @@ -29,6 +30,7 @@ Set, Tuple, TypeVar, + no_type_check, ) from typing_extensions import ParamSpec @@ -107,6 +109,33 @@ def restrict_built_in(name: str, orig: Any, *args, **kwargs): ) ) + # Need to unwrap params for isinstance and issubclass. We have + # chosen to do it this way instead of customize __instancecheck__ + # and __subclasscheck__ because we may have proxied the second + # parameter which does not have a way to override. It is unfortunate + # we have to change these globals for everybody. + def unwrap_second_param(orig: Any, a: Any, b: Any) -> Any: + a = RestrictionContext.unwrap_if_proxied(a) + b = RestrictionContext.unwrap_if_proxied(b) + return orig(a, b) + + thread_local_is_inst = _get_thread_local_builtin("isinstance") + self.restricted_builtins.append( + ( + "isinstance", + thread_local_is_inst, + functools.partial(unwrap_second_param, thread_local_is_inst.orig), + ) + ) + thread_local_is_sub = _get_thread_local_builtin("issubclass") + self.restricted_builtins.append( + ( + "issubclass", + thread_local_is_sub, + functools.partial(unwrap_second_param, thread_local_is_sub.orig), + ) + ) + @contextmanager def applied(self) -> Iterator[None]: """Context manager to apply this restrictive import. @@ -153,17 +182,21 @@ def _import( fromlist: Sequence[str] = (), level: int = 0, ) -> types.ModuleType: + # We have to resolve the full name, it can be relative at different + # levels + full_name = _resolve_module_name(name, globals, level) + # Check module restrictions and passthrough modules - if name not in sys.modules: + if full_name not in sys.modules: # Make sure not an entirely invalid module - self._assert_valid_module(name) + self._assert_valid_module(full_name) # Check if passthrough - passthrough_mod = self._maybe_passthrough_module(name) + passthrough_mod = self._maybe_passthrough_module(full_name) if passthrough_mod: # Load all parents. Usually Python does this for us, but not on # passthrough. - parent, _, child = name.rpartition(".") + parent, _, child = full_name.rpartition(".") if parent and parent not in sys.modules: _trace( "Importing parent module %s before passing through %s", @@ -174,17 +207,17 @@ def _import( # Set the passthrough on the parent setattr(sys.modules[parent], child, passthrough_mod) # Set the passthrough on sys.modules and on the parent - sys.modules[name] = passthrough_mod + sys.modules[full_name] = passthrough_mod # Put it on the parent if parent: - setattr(sys.modules[parent], child, sys.modules[name]) + setattr(sys.modules[parent], child, sys.modules[full_name]) # If the module is __temporal_main__ and not already in sys.modules, # we load it from whatever file __main__ was originally in - if name == "__temporal_main__": + if full_name == "__temporal_main__": orig_mod = _thread_local_sys_modules.orig["__main__"] new_spec = importlib.util.spec_from_file_location( - name, orig_mod.__file__ + full_name, orig_mod.__file__ ) if not new_spec: raise ImportError( @@ -195,7 +228,7 @@ def _import( f"Spec for __main__ file at {orig_mod.__file__} has no loader" ) new_mod = importlib.util.module_from_spec(new_spec) - sys.modules[name] = new_mod + sys.modules[full_name] = new_mod new_spec.loader.exec_module(new_mod) mod = importlib.__import__(name, globals, locals, fromlist, level) @@ -219,10 +252,20 @@ def _assert_valid_module(self, name: str) -> None: raise RestrictedWorkflowAccessError(name) def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]: - if not self.restrictions.passthrough_modules.match_access( - self.restriction_context, *name.split(".") + # If imports not passed through and name not in passthrough modules, + # check parents + if ( + not temporalio.workflow.unsafe.is_imports_passed_through() + and name not in self.restrictions.passthrough_modules ): - return None + end_dot = -1 + while True: + end_dot = name.find(".", end_dot + 1) + if end_dot == -1: + return None + elif name[:end_dot] in self.restrictions.passthrough_modules: + break + # Do the pass through with self._unapplied(): _trace("Passing module %s through from host", name) global _trace_depth @@ -409,3 +452,50 @@ def _get_thread_local_builtin(name: str) -> _ThreadLocalCallable: ret = _ThreadLocalCallable(getattr(builtins, name)) _thread_local_builtins[name] = ret return ret + + +def _resolve_module_name( + name: str, globals: Optional[Mapping[str, object]], level: int +) -> str: + if level == 0: + return name + # Calc the package from globals + package = _calc___package__(globals or {}) + # Logic taken from importlib._resolve_name + bits = package.rsplit(".", level - 1) + if len(bits) < level: + raise ImportError("Attempted relative import beyond top-level package") + base = bits[0] + return f"{base}.{name}" if name else base + + +# Copied from importlib._calc__package__ +@no_type_check +def _calc___package__(globals: Mapping[str, object]) -> str: + """Calculate what __package__ should be. + __package__ is not guaranteed to be defined or could be set to None + to represent that its proper value is unknown. + """ + package = globals.get("__package__") + spec = globals.get("__spec__") + if package is not None: + if spec is not None and package != spec.parent: + warnings.warn( + "__package__ != __spec__.parent " f"({package!r} != {spec.parent!r})", + DeprecationWarning, + stacklevel=3, + ) + return package + elif spec is not None: + return spec.parent + else: + warnings.warn( + "can't resolve package from __spec__ or __package__, " + "falling back on __name__ and __path__", + ImportWarning, + stacklevel=3, + ) + package = globals["__name__"] + if "__path__" not in globals: + package = package.rpartition(".")[0] + return package diff --git a/temporalio/worker/workflow_sandbox/_restrictions.py b/temporalio/worker/workflow_sandbox/_restrictions.py index 5b07eb77..43ed186d 100644 --- a/temporalio/worker/workflow_sandbox/_restrictions.py +++ b/temporalio/worker/workflow_sandbox/_restrictions.py @@ -59,13 +59,12 @@ def __init__(self, qualified_name: str) -> None: class SandboxRestrictions: """Set of restrictions that can be applied to a sandbox.""" - passthrough_modules: SandboxMatcher + passthrough_modules: Set[str] """ Modules which pass through because we know they are side-effect free (or the side-effecting pieces are restricted). These modules will not be reloaded, - but instead will just be forwarded from outside of the sandbox. The check - whether a module matches here is an access match using the fully qualified - module name. + but instead will just be forwarded from outside of the sandbox. Any module + listed will apply to all children. """ invalid_modules: SandboxMatcher @@ -84,19 +83,19 @@ class methods (including __init__, etc). The check compares the against the fully qualified path to the item. """ - passthrough_modules_minimum: ClassVar[SandboxMatcher] + passthrough_modules_minimum: ClassVar[Set[str]] """Set of modules that must be passed through at the minimum.""" - passthrough_modules_with_temporal: ClassVar[SandboxMatcher] + passthrough_modules_with_temporal: ClassVar[Set[str]] """Minimum modules that must be passed through and the Temporal modules.""" - passthrough_modules_maximum: ClassVar[SandboxMatcher] + passthrough_modules_maximum: ClassVar[Set[str]] """ All modules that can be passed through. This includes all standard library modules. """ - passthrough_modules_default: ClassVar[SandboxMatcher] + passthrough_modules_default: ClassVar[Set[str]] """Same as :py:attr:`passthrough_modules_maximum`.""" invalid_module_members_default: ClassVar[SandboxMatcher] @@ -110,6 +109,14 @@ class methods (including __init__, etc). The check compares the against the Combination of :py:attr:`passthrough_modules_default`, :py:attr:`invalid_module_members_default`, and no invalid modules.""" + def with_passthrough_modules(self, *modules: str) -> SandboxRestrictions: + """Create a new restriction set with the given modules added to the + :py:attr:`passthrough_modules` set. + """ + return dataclasses.replace( + self, passthrough_modules=self.passthrough_modules | set(modules) + ) + # We intentionally use specific fields instead of generic "matcher" callbacks # for optimization reasons. @@ -321,44 +328,40 @@ def with_child_unrestricted(self, *child_path: str) -> SandboxMatcher: SandboxMatcher.all_uses = SandboxMatcher(use={"*"}) SandboxMatcher.all_uses_runtime = SandboxMatcher(use={"*"}, only_runtime=True) -SandboxRestrictions.passthrough_modules_minimum = SandboxMatcher( - access={ - "grpc", - # Due to some side-effecting calls made on import, these need to be - # allowed - "pathlib", - "importlib", - # Python 3.7 libs often import this to support things in older Python. - # This does not work natively due to an issue with extending - # zipfile.ZipFile. TODO(cretz): Fix when subclassing restricted classes - # is fixed. - "importlib_metadata", - }, - # Required due to https://github.com/protocolbuffers/protobuf/issues/10143. - # This unfortunately means that for now, everyone using Python protos has to - # pass their module through :-( - children={ - "google": SandboxMatcher(access={"protobuf"}), - "temporalio": SandboxMatcher( - access={"api"}, children={"bridge": SandboxMatcher(access={"proto"})} - ), - }, -) - -SandboxRestrictions.passthrough_modules_with_temporal = SandboxRestrictions.passthrough_modules_minimum | SandboxMatcher( - access={ - # is_subclass is broken in sandbox due to Python bug on ABC C extension. - # So we have to include all things that might extend an ABC class and - # do a subclass check. See https://bugs.python.org/issue44847 and - # https://wrapt.readthedocs.io/en/latest/issues.html#using-issubclass-on-abstract-classes - "asyncio", - "abc", - "temporalio", - # Due to pkg_resources use of base classes caused by the ABC issue - # above, and Otel's use of pkg_resources, we pass it through - "pkg_resources", - } -) +SandboxRestrictions.passthrough_modules_minimum = { + "grpc", + # Due to some side-effecting calls made on import, these need to be + # allowed + "pathlib", + "importlib", + # Python 3.7 libs often import this to support things in older Python. + # This does not work natively due to an issue with extending + # zipfile.ZipFile. TODO(cretz): Fix when subclassing restricted classes + # is fixed. + "importlib_metadata", + # Required due to https://github.com/protocolbuffers/protobuf/issues/10143 + # for older versions. This unfortunately means that on those versions, + # everyone using Python protos has to pass their module through. + "google.protobuf", + "temporalio.api", + "temporalio.bridge.proto", +} + +SandboxRestrictions.passthrough_modules_with_temporal = SandboxRestrictions.passthrough_modules_minimum | { + # is_subclass is broken in sandbox due to Python bug on ABC C extension. + # So we have to include all things that might extend an ABC class and + # do a subclass check. See https://bugs.python.org/issue44847 and + # https://wrapt.readthedocs.io/en/latest/issues.html#using-issubclass-on-abstract-classes + "asyncio", + "abc", + "temporalio", + # Due to pkg_resources use of base classes caused by the ABC issue + # above, and Otel's use of pkg_resources, we pass it through + "pkg_resources", + # Due to how Pydantic is importing lazily inside of some classes, we choose + # to always pass it through + "pydantic", +} # sys.stdlib_module_names is only available on 3.10+, so we hardcode here. A # test will fail if this list doesn't match the latest Python version it was @@ -398,17 +401,15 @@ def with_child_unrestricted(self, *child_path: str) -> SandboxMatcher: "xdrlib,xml,xmlrpc,zipapp,zipfile,zipimport,zlib,zoneinfo" ) -SandboxRestrictions.passthrough_modules_maximum = SandboxRestrictions.passthrough_modules_with_temporal | SandboxMatcher( - access={ - # All stdlib modules except "sys" and their children. Children are not - # listed in stdlib names but we need them because in some cases (e.g. os - # manually setting sys.modules["os.path"]) they have certain child - # expectations. - v - for v in _stdlib_module_names.split(",") - if v != "sys" - } -) +SandboxRestrictions.passthrough_modules_maximum = SandboxRestrictions.passthrough_modules_with_temporal | { + # All stdlib modules except "sys" and their children. Children are not + # listed in stdlib names but we need them because in some cases (e.g. os + # manually setting sys.modules["os.path"]) they have certain child + # expectations. + v + for v in _stdlib_module_names.split(",") + if v != "sys" +} SandboxRestrictions.passthrough_modules_default = ( SandboxRestrictions.passthrough_modules_maximum @@ -640,6 +641,13 @@ class RestrictionContext: if we're just importing it. """ + @staticmethod + def unwrap_if_proxied(v: Any) -> Any: + """Unwrap a proxy object if proxied.""" + if type(v) is _RestrictedProxy: + v = _RestrictionState.from_proxy(v).obj + return v + def __init__(self) -> None: """Create a restriction context.""" self.is_runtime = False @@ -651,7 +659,14 @@ class _RestrictionState: def from_proxy(v: _RestrictedProxy) -> _RestrictionState: # To prevent recursion, must use __getattribute__ on object to get the # restriction state - return object.__getattribute__(v, "__temporal_state") + try: + return object.__getattribute__(v, "__temporal_state") + except AttributeError: + # This mostly occurs when accessing a field of an extended class on + # that has been proxied + raise RuntimeError( + "Restriction state not present. Using subclasses of proxied objects is unsupported." + ) from None name: str obj: object @@ -787,13 +802,27 @@ def _is_restrictable(v: Any) -> bool: class _RestrictedProxy: - def __init__( - self, name: str, obj: Any, context: RestrictionContext, matcher: SandboxMatcher - ) -> None: - _trace("__init__ on %s", name) - _RestrictionState( - name=name, obj=obj, context=context, matcher=matcher - ).set_on_proxy(self) + def __init__(self, *args, **kwargs) -> None: + # When we instantiate this class, we have the signature of: + # __init__( + # self, + # name: str, + # obj: Any, + # context: RestrictionContext, + # matcher: SandboxMatcher + # ) + # However when Python subclasses a class, it calls metaclass() on the + # class object which doesn't match these args. For now, we'll just + # ignore inits on these metadata classes. + # TODO(cretz): Properly support subclassing restricted classes in + # sandbox + if isinstance(args[2], RestrictionContext): + _trace("__init__ on %s", args[0]) + _RestrictionState( + name=args[0], obj=args[1], context=args[2], matcher=args[3] + ).set_on_proxy(self) + else: + _trace("__init__ unrecognized with args %s", args) def __getattribute__(self, __name: str) -> Any: state = _RestrictionState.from_proxy(self) @@ -802,7 +831,7 @@ def __getattribute__(self, __name: str) -> Any: ret = object.__getattribute__(self, "__getattr__")(__name) # Since Python 3.11, the importer references __spec__ on module, so we - # allow that through + # allow that through. if __name != "__spec__": child_matcher = state.matcher.child_matcher(__name) if child_matcher and _is_restrictable(ret): @@ -876,7 +905,6 @@ def __getitem__(self, key: Any) -> Any: # __slots__ used by proxy itself # __dict__ (__getattr__) # __weakref__ (__getattr__) - # __init_subclass__ (proxying metaclass not supported) # __prepare__ (metaclass) __class__ = _RestrictedProxyLookup( # type: ignore fallback_func=lambda self: type(self), is_attr=True diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 2511dbcd..bfd70bbe 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -714,6 +714,7 @@ async def wait_condition( _sandbox_unrestricted = threading.local() _in_sandbox = threading.local() +_imports_passed_through = threading.local() class unsafe: @@ -773,6 +774,35 @@ def sandbox_unrestricted() -> Iterator[None]: finally: _sandbox_unrestricted.value = False + @staticmethod + def is_imports_passed_through() -> bool: + """Whether the current block of code is in + :py:meth:imports_passed_through. + + Returns: + True if the current code's imports will be passed through + """ + # See comment in is_sandbox_unrestricted for why we allow unset instead + # of just global false. + return getattr(_imports_passed_through, "value", False) + + @staticmethod + @contextmanager + def imports_passed_through() -> Iterator[None]: + """Context manager to mark all imports that occur within it as passed + through (meaning not reloaded by the sandbox). + """ + # Only apply if not already applied. Nested calls just continue + # passed through. + if unsafe.is_imports_passed_through(): + yield None + return + _imports_passed_through.value = True + try: + yield None + finally: + _imports_passed_through.value = False + class LoggerAdapter(logging.LoggerAdapter): """Adapter that adds details to the log about the running workflow. diff --git a/tests/worker/workflow_sandbox/test_importer.py b/tests/worker/workflow_sandbox/test_importer.py index 11647132..07a305b3 100644 --- a/tests/worker/workflow_sandbox/test_importer.py +++ b/tests/worker/workflow_sandbox/test_importer.py @@ -2,6 +2,7 @@ import pytest +from temporalio import workflow from temporalio.worker.workflow_sandbox._importer import ( Importer, _thread_local_sys_modules, @@ -33,22 +34,36 @@ def test_workflow_sandbox_importer_passthrough_module(): assert outside1.module_state == ["module orig"] assert outside2.module_state == ["module orig"] - # Now import both via importer + # Now import via importer with Importer(restrictions, RestrictionContext()).applied(): import tests.worker.workflow_sandbox.testmodules.passthrough_module as inside1 import tests.worker.workflow_sandbox.testmodules.stateful_module as inside2 + from .testmodules import stateful_module as inside_relative2 + # Now if we alter inside1, it's passthrough so it affects outside1 inside1.module_state = ["another val"] assert outside1.module_state == ["another val"] assert id(inside1) == id(outside1) + # Confirm relative is same as non-relative + assert id(inside2) == id(inside_relative2) + # But if we alter non-passthrough inside2 it does not affect outside2 inside2.module_state = ["another val"] assert outside2.module_state != ["another val"] assert id(inside2) != id(outside2) +def test_workflow_sandbox_importer_passthough_context_manager(): + import tests.worker.workflow_sandbox.testmodules.stateful_module as outside + + with Importer(restrictions, RestrictionContext()).applied(): + with workflow.unsafe.imports_passed_through(): + import tests.worker.workflow_sandbox.testmodules.stateful_module as inside + assert id(outside) == id(inside) + + def test_workflow_sandbox_importer_invalid_module_members(): importer = Importer(restrictions, RestrictionContext()) # Can access the function, no problem diff --git a/tests/worker/workflow_sandbox/test_restrictions.py b/tests/worker/workflow_sandbox/test_restrictions.py index e983de52..a00bcf0a 100644 --- a/tests/worker/workflow_sandbox/test_restrictions.py +++ b/tests/worker/workflow_sandbox/test_restrictions.py @@ -10,6 +10,7 @@ RestrictedWorkflowAccessError, RestrictionContext, SandboxMatcher, + SandboxRestrictions, _RestrictedProxy, _stdlib_module_names, ) @@ -34,6 +35,14 @@ def test_workflow_sandbox_stdlib_module_names(): ), f"Expecting names as {actual_names}. In code as:\n{code}" +def test_workflow_sandbox_restrictions_add_passthrough_modules(): + updated = SandboxRestrictions.default.with_passthrough_modules("module1", "module2") + assert ( + "module1" in updated.passthrough_modules + and "module2" in updated.passthrough_modules + ) + + @dataclass class RestrictableObject: foo: Optional[RestrictableObject] = None diff --git a/tests/worker/workflow_sandbox/test_runner.py b/tests/worker/workflow_sandbox/test_runner.py index dfe47f6b..cbc1c732 100644 --- a/tests/worker/workflow_sandbox/test_runner.py +++ b/tests/worker/workflow_sandbox/test_runner.py @@ -7,10 +7,11 @@ import time import uuid from dataclasses import dataclass -from datetime import date, timedelta +from datetime import date, datetime, timedelta from enum import IntEnum -from typing import Callable, Dict, List, Optional, Sequence, Type +from typing import Callable, Dict, List, Optional, Sequence, Set, Type +import pydantic import pytest import temporalio.worker.workflow_sandbox._restrictions @@ -33,9 +34,16 @@ # runtime only _ = os.name -# TODO(cretz): This fails because you can't subclass a restricted class -# class MyZipFile(zipfile.ZipFile): -# pass +# This used to fail because our __init__ couldn't handle metaclass init +import zipfile + + +class MyZipFile(zipfile.ZipFile): + pass + + +# This used to fail because subclass checks were inaccurate +assert issubclass(datetime, datetime) @dataclass @@ -76,14 +84,13 @@ def state(self) -> Dict[str, List[str]]: [ # TODO(cretz): Disabling until https://github.com/protocolbuffers/upb/pull/804 # SandboxRestrictions.passthrough_modules_minimum, - # TODO(cretz): See what is failing here SandboxRestrictions.passthrough_modules_with_temporal, SandboxRestrictions.passthrough_modules_maximum, ], ) async def test_workflow_sandbox_global_state( client: Client, - sandboxed_passthrough_modules: SandboxMatcher, + sandboxed_passthrough_modules: Set[str], ): global global_state async with new_worker( @@ -365,7 +372,6 @@ async def test_workflow_sandbox_instance_check(client: Client): class UseProtoWorkflow: @workflow.run async def run(self, param: SomeMessage) -> SomeMessage: - print("In: ", id(SomeMessage)) assert isinstance(param, SomeMessage) return param @@ -383,12 +389,53 @@ async def test_workflow_sandbox_with_proto(client: Client): assert result is not param and result == param +class PydanticMessage(pydantic.BaseModel): + content: datetime + + +@workflow.defn +class KnownIssuesWorkflow: + @workflow.run + async def run(self) -> None: + # Calling an internally defined method passing in self that is a + # restricted proxy object will fail + try: + dict.items(os.__dict__) + raise ApplicationError("Expected failure") + except TypeError: + pass + + # Using a subclass of a proxied class is unsupported + try: + MyZipFile("some/path") + raise ApplicationError("Expected failure") + except RuntimeError as err: + assert "Restriction state not present" in str(err) + + # Using a datetime in binary-compiled Pydantic skips our issubclass when + # building their validators causing it to use date instead + # TODO(cretz): https://github.com/temporalio/sdk-python/issues/207 + if pydantic.compiled: + assert isinstance(PydanticMessage(content=workflow.now()).content, date) + else: + assert isinstance(PydanticMessage(content=workflow.now()).content, datetime) + + +async def test_workflow_sandbox_known_issues(client: Client): + async with new_worker(client, KnownIssuesWorkflow) as worker: + await client.execute_workflow( + KnownIssuesWorkflow.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + def new_worker( client: Client, *workflows: Type, activities: Sequence[Callable] = [], task_queue: Optional[str] = None, - sandboxed_passthrough_modules: Optional[SandboxMatcher] = None, + sandboxed_passthrough_modules: Set[str] = set(), sandboxed_invalid_module_members: Optional[SandboxMatcher] = None, ) -> Worker: restrictions = SandboxRestrictions.default diff --git a/tests/worker/workflow_sandbox/testmodules/__init__.py b/tests/worker/workflow_sandbox/testmodules/__init__.py index 0a9b313f..9d1688fb 100644 --- a/tests/worker/workflow_sandbox/testmodules/__init__.py +++ b/tests/worker/workflow_sandbox/testmodules/__init__.py @@ -5,10 +5,7 @@ restrictions = SandboxRestrictions( passthrough_modules=SandboxRestrictions.passthrough_modules_minimum - | SandboxMatcher.nested_child( - "tests.worker.workflow_sandbox.testmodules".split("."), - SandboxMatcher(access={"passthrough_module"}), - ), + | {"tests.worker.workflow_sandbox.testmodules.passthrough_module"}, invalid_modules=SandboxMatcher.nested_child( "tests.worker.workflow_sandbox.testmodules".split("."), SandboxMatcher(access={"invalid_module"}),