-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
366 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Context Propagation Interceptor Sample | ||
|
||
This sample shows how to use an interceptor to propagate contextual information through workflows and activities. For | ||
this example, [contextvars](https://docs.python.org/3/library/contextvars.html) holds the contextual information. | ||
|
||
To run, first see [README.md](../README.md) for prerequisites. Then, run the following from this directory to start the | ||
worker: | ||
|
||
poetry run python worker.py | ||
|
||
This will start the worker. Then, in another terminal, run the following to execute the workflow: | ||
|
||
poetry run python starter.py | ||
|
||
The starter terminal should complete with the hello result and the worker terminal should show the logs with the | ||
propagated user ID contextual information flowing through the workflows/activities. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from temporalio import activity | ||
|
||
from context_propagation import shared | ||
|
||
|
||
@activity.defn | ||
async def say_hello_activity(name: str) -> str: | ||
activity.logger.info(f"Activity called by user {shared.user_id.get()}") | ||
return f"Hello, {name}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from __future__ import annotations | ||
|
||
from contextlib import contextmanager | ||
from typing import Any, Mapping, Protocol, Type | ||
|
||
import temporalio.activity | ||
import temporalio.api.common.v1 | ||
import temporalio.client | ||
import temporalio.converter | ||
import temporalio.worker | ||
import temporalio.workflow | ||
|
||
with temporalio.workflow.unsafe.imports_passed_through(): | ||
from context_propagation.shared import HEADER_KEY, user_id | ||
|
||
|
||
class _InputWithHeaders(Protocol): | ||
headers: Mapping[str, temporalio.api.common.v1.Payload] | ||
|
||
|
||
def set_header_from_context( | ||
input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter | ||
) -> None: | ||
user_id_val = user_id.get() | ||
if user_id_val: | ||
input.headers = { | ||
**input.headers, | ||
HEADER_KEY: payload_converter.to_payload(user_id_val), | ||
} | ||
|
||
|
||
@contextmanager | ||
def context_from_header( | ||
input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter | ||
): | ||
payload = input.headers.get(HEADER_KEY) | ||
token = ( | ||
user_id.set(payload_converter.from_payload(payload, str)) if payload else None | ||
) | ||
try: | ||
yield | ||
finally: | ||
if token: | ||
user_id.reset(token) | ||
|
||
|
||
class ContextPropagationInterceptor( | ||
temporalio.client.Interceptor, temporalio.worker.Interceptor | ||
): | ||
"""Interceptor that can serialize/deserialize contexts.""" | ||
|
||
def __init__( | ||
self, | ||
payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter, | ||
) -> None: | ||
self._payload_converter = payload_converter | ||
|
||
def intercept_client( | ||
self, next: temporalio.client.OutboundInterceptor | ||
) -> temporalio.client.OutboundInterceptor: | ||
return _ContextPropagationClientOutboundInterceptor( | ||
next, self._payload_converter | ||
) | ||
|
||
def intercept_activity( | ||
self, next: temporalio.worker.ActivityInboundInterceptor | ||
) -> temporalio.worker.ActivityInboundInterceptor: | ||
return _ContextPropagationActivityInboundInterceptor(next) | ||
|
||
def workflow_interceptor_class( | ||
self, input: temporalio.worker.WorkflowInterceptorClassInput | ||
) -> Type[_ContextPropagationWorkflowInboundInterceptor]: | ||
return _ContextPropagationWorkflowInboundInterceptor | ||
|
||
|
||
class _ContextPropagationClientOutboundInterceptor( | ||
temporalio.client.OutboundInterceptor | ||
): | ||
def __init__( | ||
self, | ||
next: temporalio.client.OutboundInterceptor, | ||
payload_converter: temporalio.converter.PayloadConverter, | ||
) -> None: | ||
super().__init__(next) | ||
self._payload_converter = payload_converter | ||
|
||
async def start_workflow( | ||
self, input: temporalio.client.StartWorkflowInput | ||
) -> temporalio.client.WorkflowHandle[Any, Any]: | ||
set_header_from_context(input, self._payload_converter) | ||
return await super().start_workflow(input) | ||
|
||
async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> Any: | ||
set_header_from_context(input, self._payload_converter) | ||
return await super().query_workflow(input) | ||
|
||
async def signal_workflow( | ||
self, input: temporalio.client.SignalWorkflowInput | ||
) -> None: | ||
set_header_from_context(input, self._payload_converter) | ||
await super().signal_workflow(input) | ||
|
||
async def start_workflow_update( | ||
self, input: temporalio.client.StartWorkflowUpdateInput | ||
) -> temporalio.client.WorkflowUpdateHandle[Any]: | ||
set_header_from_context(input, self._payload_converter) | ||
return await self.next.start_workflow_update(input) | ||
|
||
|
||
class _ContextPropagationActivityInboundInterceptor( | ||
temporalio.worker.ActivityInboundInterceptor | ||
): | ||
async def execute_activity( | ||
self, input: temporalio.worker.ExecuteActivityInput | ||
) -> Any: | ||
with context_from_header(input, temporalio.activity.payload_converter()): | ||
return await self.next.execute_activity(input) | ||
|
||
|
||
class _ContextPropagationWorkflowInboundInterceptor( | ||
temporalio.worker.WorkflowInboundInterceptor | ||
): | ||
def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None: | ||
self.next.init(_ContextPropagationWorkflowOutboundInterceptor(outbound)) | ||
|
||
async def execute_workflow( | ||
self, input: temporalio.worker.ExecuteWorkflowInput | ||
) -> Any: | ||
with context_from_header(input, temporalio.workflow.payload_converter()): | ||
return await self.next.execute_workflow(input) | ||
|
||
async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: | ||
with context_from_header(input, temporalio.workflow.payload_converter()): | ||
return await self.next.handle_signal(input) | ||
|
||
async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: | ||
with context_from_header(input, temporalio.workflow.payload_converter()): | ||
return await self.next.handle_query(input) | ||
|
||
def handle_update_validator( | ||
self, input: temporalio.worker.HandleUpdateInput | ||
) -> None: | ||
with context_from_header(input, temporalio.workflow.payload_converter()): | ||
self.next.handle_update_validator(input) | ||
|
||
async def handle_update_handler( | ||
self, input: temporalio.worker.HandleUpdateInput | ||
) -> Any: | ||
with context_from_header(input, temporalio.workflow.payload_converter()): | ||
return await self.next.handle_update_handler(input) | ||
|
||
|
||
class _ContextPropagationWorkflowOutboundInterceptor( | ||
temporalio.worker.WorkflowOutboundInterceptor | ||
): | ||
async def signal_child_workflow( | ||
self, input: temporalio.worker.SignalChildWorkflowInput | ||
) -> None: | ||
set_header_from_context(input, temporalio.workflow.payload_converter()) | ||
return await self.next.signal_child_workflow(input) | ||
|
||
async def signal_external_workflow( | ||
self, input: temporalio.worker.SignalExternalWorkflowInput | ||
) -> None: | ||
set_header_from_context(input, temporalio.workflow.payload_converter()) | ||
return await self.next.signal_external_workflow(input) | ||
|
||
def start_activity( | ||
self, input: temporalio.worker.StartActivityInput | ||
) -> temporalio.workflow.ActivityHandle: | ||
set_header_from_context(input, temporalio.workflow.payload_converter()) | ||
return self.next.start_activity(input) | ||
|
||
async def start_child_workflow( | ||
self, input: temporalio.worker.StartChildWorkflowInput | ||
) -> temporalio.workflow.ChildWorkflowHandle: | ||
set_header_from_context(input, temporalio.workflow.payload_converter()) | ||
return await self.next.start_child_workflow(input) | ||
|
||
def start_local_activity( | ||
self, input: temporalio.worker.StartLocalActivityInput | ||
) -> temporalio.workflow.ActivityHandle: | ||
set_header_from_context(input, temporalio.workflow.payload_converter()) | ||
return self.next.start_local_activity(input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from contextvars import ContextVar | ||
from typing import Optional | ||
|
||
HEADER_KEY = "__my_user_id" | ||
|
||
user_id: ContextVar[Optional[str]] = ContextVar("user_id", default=None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import asyncio | ||
import logging | ||
|
||
from temporalio.client import Client | ||
|
||
from context_propagation import interceptor, shared, workflows | ||
|
||
|
||
async def main(): | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
# Set the user ID | ||
shared.user_id.set("some-user") | ||
|
||
# Connect client | ||
client = await Client.connect( | ||
"localhost:7233", | ||
# Use our interceptor | ||
interceptors=[interceptor.ContextPropagationInterceptor()], | ||
) | ||
|
||
# Start workflow, send signal, wait for completion, issue query | ||
handle = await client.start_workflow( | ||
workflows.SayHelloWorkflow.run, | ||
"Temporal", | ||
id=f"context-propagation-workflow-id", | ||
task_queue="context-propagation-task-queue", | ||
) | ||
await handle.signal(workflows.SayHelloWorkflow.signal_complete) | ||
result = await handle.result() | ||
logging.info(f"Workflow result: {result}") | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import asyncio | ||
import logging | ||
|
||
from temporalio.client import Client | ||
from temporalio.worker import Worker | ||
|
||
from context_propagation import activities, interceptor, workflows | ||
|
||
interrupt_event = asyncio.Event() | ||
|
||
|
||
async def main(): | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
# Connect client | ||
client = await Client.connect( | ||
"localhost:7233", | ||
# Use our interceptor | ||
interceptors=[interceptor.ContextPropagationInterceptor()], | ||
) | ||
|
||
# Run a worker for the workflow | ||
async with Worker( | ||
client, | ||
task_queue="context-propagation-task-queue", | ||
activities=[activities.say_hello_activity], | ||
workflows=[workflows.SayHelloWorkflow], | ||
): | ||
# Wait until interrupted | ||
logging.info("Worker started, ctrl+c to exit") | ||
await interrupt_event.wait() | ||
logging.info("Shutting down") | ||
|
||
|
||
if __name__ == "__main__": | ||
loop = asyncio.new_event_loop() | ||
try: | ||
loop.run_until_complete(main()) | ||
except KeyboardInterrupt: | ||
interrupt_event.set() | ||
loop.run_until_complete(loop.shutdown_asyncgens()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from datetime import timedelta | ||
|
||
from temporalio import workflow | ||
|
||
with workflow.unsafe.imports_passed_through(): | ||
from context_propagation.activities import say_hello_activity | ||
from context_propagation.shared import user_id | ||
|
||
|
||
@workflow.defn | ||
class SayHelloWorkflow: | ||
def __init__(self) -> None: | ||
self._complete = False | ||
|
||
@workflow.run | ||
async def run(self, name: str) -> str: | ||
workflow.logger.info(f"Workflow called by user {user_id.get()}") | ||
|
||
# Wait for signal then run activity | ||
await workflow.wait_condition(lambda: self._complete) | ||
return await workflow.execute_activity( | ||
say_hello_activity, name, start_to_close_timeout=timedelta(minutes=5) | ||
) | ||
|
||
@workflow.signal | ||
async def signal_complete(self) -> None: | ||
workflow.logger.info(f"Signal called by user {user_id.get()}") | ||
self._complete = True |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import uuid | ||
|
||
from temporalio import activity | ||
from temporalio.client import Client | ||
from temporalio.exceptions import ApplicationError | ||
from temporalio.worker import Worker | ||
|
||
from context_propagation.interceptor import ContextPropagationInterceptor | ||
from context_propagation.shared import user_id | ||
from context_propagation.workflows import SayHelloWorkflow | ||
|
||
|
||
async def test_workflow_with_context_propagator(client: Client): | ||
# Mock out the activity to assert the context value | ||
@activity.defn(name="say_hello_activity") | ||
async def say_hello_activity_mock(name: str) -> str: | ||
try: | ||
assert user_id.get() == "test-user" | ||
except Exception as err: | ||
raise ApplicationError("Assertion fail", non_retryable=True) from err | ||
return f"Mock for {name}" | ||
|
||
# Replace interceptors in client | ||
new_config = client.config() | ||
new_config["interceptors"] = [ContextPropagationInterceptor()] | ||
client = Client(**new_config) | ||
task_queue = f"tq-{uuid.uuid4()}" | ||
|
||
async with Worker( | ||
client, | ||
task_queue=task_queue, | ||
activities=[say_hello_activity_mock], | ||
workflows=[SayHelloWorkflow], | ||
): | ||
# Set the user during start/signal, but unset after | ||
token = user_id.set("test-user") | ||
handle = await client.start_workflow( | ||
SayHelloWorkflow.run, | ||
"some-name", | ||
id=f"wf-{uuid.uuid4()}", | ||
task_queue=task_queue, | ||
) | ||
await handle.signal(SayHelloWorkflow.signal_complete) | ||
user_id.reset(token) | ||
result = await handle.result() | ||
assert result == "Mock for some-name" |