diff --git a/README.md b/README.md index 821d9740..aef520bc 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ Some examples require extra dependencies. See each sample's directory for specif * [activity_worker](activity_worker) - Use Python activities from a workflow in another language. * [cloud_export_to_parquet](cloud_export_to_parquet) - Set up schedule workflow to process exported files on an hourly basis +* [context_propagation](context_propagation) - Context propagation through workflows/activities via interceptor. * [custom_converter](custom_converter) - Use a custom payload converter to handle custom types. * [custom_decorator](custom_decorator) - Custom decorator to auto-heartbeat a long-running activity. * [dsl](dsl) - DSL workflow that executes steps defined in a YAML file. diff --git a/context_propagation/README.md b/context_propagation/README.md new file mode 100644 index 00000000..bbf47ac0 --- /dev/null +++ b/context_propagation/README.md @@ -0,0 +1,16 @@ +# Context Propagation Interceptor Sample + +This sample shows how to use an interceptor to propagate contextual information through workflows and activities. For +this example, [contextvars](https://docs.python.org/3/library/contextvars.html) holds the contextual information. + +To run, first see [README.md](../README.md) for prerequisites. Then, run the following from this directory to start the +worker: + + poetry run python worker.py + +This will start the worker. Then, in another terminal, run the following to execute the workflow: + + poetry run python starter.py + +The starter terminal should complete with the hello result and the worker terminal should show the logs with the +propagated user ID contextual information flowing through the workflows/activities. \ No newline at end of file diff --git a/context_propagation/__init__.py b/context_propagation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/context_propagation/activities.py b/context_propagation/activities.py new file mode 100644 index 00000000..77cde015 --- /dev/null +++ b/context_propagation/activities.py @@ -0,0 +1,9 @@ +from temporalio import activity + +from context_propagation import shared + + +@activity.defn +async def say_hello_activity(name: str) -> str: + activity.logger.info(f"Activity called by user {shared.user_id.get()}") + return f"Hello, {name}" diff --git a/context_propagation/interceptor.py b/context_propagation/interceptor.py new file mode 100644 index 00000000..f8dca8e5 --- /dev/null +++ b/context_propagation/interceptor.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any, Mapping, Protocol, Type + +import temporalio.activity +import temporalio.api.common.v1 +import temporalio.client +import temporalio.converter +import temporalio.worker +import temporalio.workflow + +with temporalio.workflow.unsafe.imports_passed_through(): + from context_propagation.shared import HEADER_KEY, user_id + + +class _InputWithHeaders(Protocol): + headers: Mapping[str, temporalio.api.common.v1.Payload] + + +def set_header_from_context( + input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter +) -> None: + user_id_val = user_id.get() + if user_id_val: + input.headers = { + **input.headers, + HEADER_KEY: payload_converter.to_payload(user_id_val), + } + + +@contextmanager +def context_from_header( + input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter +): + payload = input.headers.get(HEADER_KEY) + token = ( + user_id.set(payload_converter.from_payload(payload, str)) if payload else None + ) + try: + yield + finally: + if token: + user_id.reset(token) + + +class ContextPropagationInterceptor( + temporalio.client.Interceptor, temporalio.worker.Interceptor +): + """Interceptor that can serialize/deserialize contexts.""" + + def __init__( + self, + payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter, + ) -> None: + self._payload_converter = payload_converter + + def intercept_client( + self, next: temporalio.client.OutboundInterceptor + ) -> temporalio.client.OutboundInterceptor: + return _ContextPropagationClientOutboundInterceptor( + next, self._payload_converter + ) + + def intercept_activity( + self, next: temporalio.worker.ActivityInboundInterceptor + ) -> temporalio.worker.ActivityInboundInterceptor: + return _ContextPropagationActivityInboundInterceptor(next) + + def workflow_interceptor_class( + self, input: temporalio.worker.WorkflowInterceptorClassInput + ) -> Type[_ContextPropagationWorkflowInboundInterceptor]: + return _ContextPropagationWorkflowInboundInterceptor + + +class _ContextPropagationClientOutboundInterceptor( + temporalio.client.OutboundInterceptor +): + def __init__( + self, + next: temporalio.client.OutboundInterceptor, + payload_converter: temporalio.converter.PayloadConverter, + ) -> None: + super().__init__(next) + self._payload_converter = payload_converter + + async def start_workflow( + self, input: temporalio.client.StartWorkflowInput + ) -> temporalio.client.WorkflowHandle[Any, Any]: + set_header_from_context(input, self._payload_converter) + return await super().start_workflow(input) + + async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> Any: + set_header_from_context(input, self._payload_converter) + return await super().query_workflow(input) + + async def signal_workflow( + self, input: temporalio.client.SignalWorkflowInput + ) -> None: + set_header_from_context(input, self._payload_converter) + await super().signal_workflow(input) + + async def start_workflow_update( + self, input: temporalio.client.StartWorkflowUpdateInput + ) -> temporalio.client.WorkflowUpdateHandle[Any]: + set_header_from_context(input, self._payload_converter) + return await self.next.start_workflow_update(input) + + +class _ContextPropagationActivityInboundInterceptor( + temporalio.worker.ActivityInboundInterceptor +): + async def execute_activity( + self, input: temporalio.worker.ExecuteActivityInput + ) -> Any: + with context_from_header(input, temporalio.activity.payload_converter()): + return await self.next.execute_activity(input) + + +class _ContextPropagationWorkflowInboundInterceptor( + temporalio.worker.WorkflowInboundInterceptor +): + def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None: + self.next.init(_ContextPropagationWorkflowOutboundInterceptor(outbound)) + + async def execute_workflow( + self, input: temporalio.worker.ExecuteWorkflowInput + ) -> Any: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.execute_workflow(input) + + async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.handle_signal(input) + + async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.handle_query(input) + + def handle_update_validator( + self, input: temporalio.worker.HandleUpdateInput + ) -> None: + with context_from_header(input, temporalio.workflow.payload_converter()): + self.next.handle_update_validator(input) + + async def handle_update_handler( + self, input: temporalio.worker.HandleUpdateInput + ) -> Any: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.handle_update_handler(input) + + +class _ContextPropagationWorkflowOutboundInterceptor( + temporalio.worker.WorkflowOutboundInterceptor +): + async def signal_child_workflow( + self, input: temporalio.worker.SignalChildWorkflowInput + ) -> None: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.signal_child_workflow(input) + + async def signal_external_workflow( + self, input: temporalio.worker.SignalExternalWorkflowInput + ) -> None: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.signal_external_workflow(input) + + def start_activity( + self, input: temporalio.worker.StartActivityInput + ) -> temporalio.workflow.ActivityHandle: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return self.next.start_activity(input) + + async def start_child_workflow( + self, input: temporalio.worker.StartChildWorkflowInput + ) -> temporalio.workflow.ChildWorkflowHandle: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.start_child_workflow(input) + + def start_local_activity( + self, input: temporalio.worker.StartLocalActivityInput + ) -> temporalio.workflow.ActivityHandle: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return self.next.start_local_activity(input) diff --git a/context_propagation/shared.py b/context_propagation/shared.py new file mode 100644 index 00000000..faae59d8 --- /dev/null +++ b/context_propagation/shared.py @@ -0,0 +1,6 @@ +from contextvars import ContextVar +from typing import Optional + +HEADER_KEY = "__my_user_id" + +user_id: ContextVar[Optional[str]] = ContextVar("user_id", default=None) diff --git a/context_propagation/starter.py b/context_propagation/starter.py new file mode 100644 index 00000000..2865eee2 --- /dev/null +++ b/context_propagation/starter.py @@ -0,0 +1,35 @@ +import asyncio +import logging + +from temporalio.client import Client + +from context_propagation import interceptor, shared, workflows + + +async def main(): + logging.basicConfig(level=logging.INFO) + + # Set the user ID + shared.user_id.set("some-user") + + # Connect client + client = await Client.connect( + "localhost:7233", + # Use our interceptor + interceptors=[interceptor.ContextPropagationInterceptor()], + ) + + # Start workflow, send signal, wait for completion, issue query + handle = await client.start_workflow( + workflows.SayHelloWorkflow.run, + "Temporal", + id=f"context-propagation-workflow-id", + task_queue="context-propagation-task-queue", + ) + await handle.signal(workflows.SayHelloWorkflow.signal_complete) + result = await handle.result() + logging.info(f"Workflow result: {result}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/context_propagation/worker.py b/context_propagation/worker.py new file mode 100644 index 00000000..14d954da --- /dev/null +++ b/context_propagation/worker.py @@ -0,0 +1,41 @@ +import asyncio +import logging + +from temporalio.client import Client +from temporalio.worker import Worker + +from context_propagation import activities, interceptor, workflows + +interrupt_event = asyncio.Event() + + +async def main(): + logging.basicConfig(level=logging.INFO) + + # Connect client + client = await Client.connect( + "localhost:7233", + # Use our interceptor + interceptors=[interceptor.ContextPropagationInterceptor()], + ) + + # Run a worker for the workflow + async with Worker( + client, + task_queue="context-propagation-task-queue", + activities=[activities.say_hello_activity], + workflows=[workflows.SayHelloWorkflow], + ): + # Wait until interrupted + logging.info("Worker started, ctrl+c to exit") + await interrupt_event.wait() + logging.info("Shutting down") + + +if __name__ == "__main__": + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(main()) + except KeyboardInterrupt: + interrupt_event.set() + loop.run_until_complete(loop.shutdown_asyncgens()) diff --git a/context_propagation/workflows.py b/context_propagation/workflows.py new file mode 100644 index 00000000..e9c120b3 --- /dev/null +++ b/context_propagation/workflows.py @@ -0,0 +1,28 @@ +from datetime import timedelta + +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + from context_propagation.activities import say_hello_activity + from context_propagation.shared import user_id + + +@workflow.defn +class SayHelloWorkflow: + def __init__(self) -> None: + self._complete = False + + @workflow.run + async def run(self, name: str) -> str: + workflow.logger.info(f"Workflow called by user {user_id.get()}") + + # Wait for signal then run activity + await workflow.wait_condition(lambda: self._complete) + return await workflow.execute_activity( + say_hello_activity, name, start_to_close_timeout=timedelta(minutes=5) + ) + + @workflow.signal + async def signal_complete(self) -> None: + workflow.logger.info(f"Signal called by user {user_id.get()}") + self._complete = True diff --git a/tests/context_propagation/__init__.py b/tests/context_propagation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/context_propagation/workflow_test.py b/tests/context_propagation/workflow_test.py new file mode 100644 index 00000000..8de2120d --- /dev/null +++ b/tests/context_propagation/workflow_test.py @@ -0,0 +1,46 @@ +import uuid + +from temporalio import activity +from temporalio.client import Client +from temporalio.exceptions import ApplicationError +from temporalio.worker import Worker + +from context_propagation.interceptor import ContextPropagationInterceptor +from context_propagation.shared import user_id +from context_propagation.workflows import SayHelloWorkflow + + +async def test_workflow_with_context_propagator(client: Client): + # Mock out the activity to assert the context value + @activity.defn(name="say_hello_activity") + async def say_hello_activity_mock(name: str) -> str: + try: + assert user_id.get() == "test-user" + except Exception as err: + raise ApplicationError("Assertion fail", non_retryable=True) from err + return f"Mock for {name}" + + # Replace interceptors in client + new_config = client.config() + new_config["interceptors"] = [ContextPropagationInterceptor()] + client = Client(**new_config) + task_queue = f"tq-{uuid.uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + activities=[say_hello_activity_mock], + workflows=[SayHelloWorkflow], + ): + # Set the user during start/signal, but unset after + token = user_id.set("test-user") + handle = await client.start_workflow( + SayHelloWorkflow.run, + "some-name", + id=f"wf-{uuid.uuid4()}", + task_queue=task_queue, + ) + await handle.signal(SayHelloWorkflow.signal_complete) + user_id.reset(token) + result = await handle.result() + assert result == "Mock for some-name"