Skip to content

Commit

Permalink
Context propagation sample (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
cretz authored Jun 4, 2024
1 parent e299047 commit a7c04d4
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Some examples require extra dependencies. See each sample's directory for specif
<!-- Keep this list in alphabetical order -->
* [activity_worker](activity_worker) - Use Python activities from a workflow in another language.
* [cloud_export_to_parquet](cloud_export_to_parquet) - Set up schedule workflow to process exported files on an hourly basis
* [context_propagation](context_propagation) - Context propagation through workflows/activities via interceptor.
* [custom_converter](custom_converter) - Use a custom payload converter to handle custom types.
* [custom_decorator](custom_decorator) - Custom decorator to auto-heartbeat a long-running activity.
* [dsl](dsl) - DSL workflow that executes steps defined in a YAML file.
Expand Down
16 changes: 16 additions & 0 deletions context_propagation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Context Propagation Interceptor Sample

This sample shows how to use an interceptor to propagate contextual information through workflows and activities. For
this example, [contextvars](https://docs.python.org/3/library/contextvars.html) holds the contextual information.

To run, first see [README.md](../README.md) for prerequisites. Then, run the following from this directory to start the
worker:

poetry run python worker.py

This will start the worker. Then, in another terminal, run the following to execute the workflow:

poetry run python starter.py

The starter terminal should complete with the hello result and the worker terminal should show the logs with the
propagated user ID contextual information flowing through the workflows/activities.
Empty file added context_propagation/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions context_propagation/activities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from temporalio import activity

from context_propagation import shared


@activity.defn
async def say_hello_activity(name: str) -> str:
activity.logger.info(f"Activity called by user {shared.user_id.get()}")
return f"Hello, {name}"
184 changes: 184 additions & 0 deletions context_propagation/interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Any, Mapping, Protocol, Type

import temporalio.activity
import temporalio.api.common.v1
import temporalio.client
import temporalio.converter
import temporalio.worker
import temporalio.workflow

with temporalio.workflow.unsafe.imports_passed_through():
from context_propagation.shared import HEADER_KEY, user_id


class _InputWithHeaders(Protocol):
headers: Mapping[str, temporalio.api.common.v1.Payload]


def set_header_from_context(
input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter
) -> None:
user_id_val = user_id.get()
if user_id_val:
input.headers = {
**input.headers,
HEADER_KEY: payload_converter.to_payload(user_id_val),
}


@contextmanager
def context_from_header(
input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter
):
payload = input.headers.get(HEADER_KEY)
token = (
user_id.set(payload_converter.from_payload(payload, str)) if payload else None
)
try:
yield
finally:
if token:
user_id.reset(token)


class ContextPropagationInterceptor(
temporalio.client.Interceptor, temporalio.worker.Interceptor
):
"""Interceptor that can serialize/deserialize contexts."""

def __init__(
self,
payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter,
) -> None:
self._payload_converter = payload_converter

def intercept_client(
self, next: temporalio.client.OutboundInterceptor
) -> temporalio.client.OutboundInterceptor:
return _ContextPropagationClientOutboundInterceptor(
next, self._payload_converter
)

def intercept_activity(
self, next: temporalio.worker.ActivityInboundInterceptor
) -> temporalio.worker.ActivityInboundInterceptor:
return _ContextPropagationActivityInboundInterceptor(next)

def workflow_interceptor_class(
self, input: temporalio.worker.WorkflowInterceptorClassInput
) -> Type[_ContextPropagationWorkflowInboundInterceptor]:
return _ContextPropagationWorkflowInboundInterceptor


class _ContextPropagationClientOutboundInterceptor(
temporalio.client.OutboundInterceptor
):
def __init__(
self,
next: temporalio.client.OutboundInterceptor,
payload_converter: temporalio.converter.PayloadConverter,
) -> None:
super().__init__(next)
self._payload_converter = payload_converter

async def start_workflow(
self, input: temporalio.client.StartWorkflowInput
) -> temporalio.client.WorkflowHandle[Any, Any]:
set_header_from_context(input, self._payload_converter)
return await super().start_workflow(input)

async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> Any:
set_header_from_context(input, self._payload_converter)
return await super().query_workflow(input)

async def signal_workflow(
self, input: temporalio.client.SignalWorkflowInput
) -> None:
set_header_from_context(input, self._payload_converter)
await super().signal_workflow(input)

async def start_workflow_update(
self, input: temporalio.client.StartWorkflowUpdateInput
) -> temporalio.client.WorkflowUpdateHandle[Any]:
set_header_from_context(input, self._payload_converter)
return await self.next.start_workflow_update(input)


class _ContextPropagationActivityInboundInterceptor(
temporalio.worker.ActivityInboundInterceptor
):
async def execute_activity(
self, input: temporalio.worker.ExecuteActivityInput
) -> Any:
with context_from_header(input, temporalio.activity.payload_converter()):
return await self.next.execute_activity(input)


class _ContextPropagationWorkflowInboundInterceptor(
temporalio.worker.WorkflowInboundInterceptor
):
def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None:
self.next.init(_ContextPropagationWorkflowOutboundInterceptor(outbound))

async def execute_workflow(
self, input: temporalio.worker.ExecuteWorkflowInput
) -> Any:
with context_from_header(input, temporalio.workflow.payload_converter()):
return await self.next.execute_workflow(input)

async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None:
with context_from_header(input, temporalio.workflow.payload_converter()):
return await self.next.handle_signal(input)

async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
with context_from_header(input, temporalio.workflow.payload_converter()):
return await self.next.handle_query(input)

def handle_update_validator(
self, input: temporalio.worker.HandleUpdateInput
) -> None:
with context_from_header(input, temporalio.workflow.payload_converter()):
self.next.handle_update_validator(input)

async def handle_update_handler(
self, input: temporalio.worker.HandleUpdateInput
) -> Any:
with context_from_header(input, temporalio.workflow.payload_converter()):
return await self.next.handle_update_handler(input)


class _ContextPropagationWorkflowOutboundInterceptor(
temporalio.worker.WorkflowOutboundInterceptor
):
async def signal_child_workflow(
self, input: temporalio.worker.SignalChildWorkflowInput
) -> None:
set_header_from_context(input, temporalio.workflow.payload_converter())
return await self.next.signal_child_workflow(input)

async def signal_external_workflow(
self, input: temporalio.worker.SignalExternalWorkflowInput
) -> None:
set_header_from_context(input, temporalio.workflow.payload_converter())
return await self.next.signal_external_workflow(input)

def start_activity(
self, input: temporalio.worker.StartActivityInput
) -> temporalio.workflow.ActivityHandle:
set_header_from_context(input, temporalio.workflow.payload_converter())
return self.next.start_activity(input)

async def start_child_workflow(
self, input: temporalio.worker.StartChildWorkflowInput
) -> temporalio.workflow.ChildWorkflowHandle:
set_header_from_context(input, temporalio.workflow.payload_converter())
return await self.next.start_child_workflow(input)

def start_local_activity(
self, input: temporalio.worker.StartLocalActivityInput
) -> temporalio.workflow.ActivityHandle:
set_header_from_context(input, temporalio.workflow.payload_converter())
return self.next.start_local_activity(input)
6 changes: 6 additions & 0 deletions context_propagation/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from contextvars import ContextVar
from typing import Optional

HEADER_KEY = "__my_user_id"

user_id: ContextVar[Optional[str]] = ContextVar("user_id", default=None)
35 changes: 35 additions & 0 deletions context_propagation/starter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
import logging

from temporalio.client import Client

from context_propagation import interceptor, shared, workflows


async def main():
logging.basicConfig(level=logging.INFO)

# Set the user ID
shared.user_id.set("some-user")

# Connect client
client = await Client.connect(
"localhost:7233",
# Use our interceptor
interceptors=[interceptor.ContextPropagationInterceptor()],
)

# Start workflow, send signal, wait for completion, issue query
handle = await client.start_workflow(
workflows.SayHelloWorkflow.run,
"Temporal",
id=f"context-propagation-workflow-id",
task_queue="context-propagation-task-queue",
)
await handle.signal(workflows.SayHelloWorkflow.signal_complete)
result = await handle.result()
logging.info(f"Workflow result: {result}")


if __name__ == "__main__":
asyncio.run(main())
41 changes: 41 additions & 0 deletions context_propagation/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import asyncio
import logging

from temporalio.client import Client
from temporalio.worker import Worker

from context_propagation import activities, interceptor, workflows

interrupt_event = asyncio.Event()


async def main():
logging.basicConfig(level=logging.INFO)

# Connect client
client = await Client.connect(
"localhost:7233",
# Use our interceptor
interceptors=[interceptor.ContextPropagationInterceptor()],
)

# Run a worker for the workflow
async with Worker(
client,
task_queue="context-propagation-task-queue",
activities=[activities.say_hello_activity],
workflows=[workflows.SayHelloWorkflow],
):
# Wait until interrupted
logging.info("Worker started, ctrl+c to exit")
await interrupt_event.wait()
logging.info("Shutting down")


if __name__ == "__main__":
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
interrupt_event.set()
loop.run_until_complete(loop.shutdown_asyncgens())
28 changes: 28 additions & 0 deletions context_propagation/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from datetime import timedelta

from temporalio import workflow

with workflow.unsafe.imports_passed_through():
from context_propagation.activities import say_hello_activity
from context_propagation.shared import user_id


@workflow.defn
class SayHelloWorkflow:
def __init__(self) -> None:
self._complete = False

@workflow.run
async def run(self, name: str) -> str:
workflow.logger.info(f"Workflow called by user {user_id.get()}")

# Wait for signal then run activity
await workflow.wait_condition(lambda: self._complete)
return await workflow.execute_activity(
say_hello_activity, name, start_to_close_timeout=timedelta(minutes=5)
)

@workflow.signal
async def signal_complete(self) -> None:
workflow.logger.info(f"Signal called by user {user_id.get()}")
self._complete = True
Empty file.
46 changes: 46 additions & 0 deletions tests/context_propagation/workflow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import uuid

from temporalio import activity
from temporalio.client import Client
from temporalio.exceptions import ApplicationError
from temporalio.worker import Worker

from context_propagation.interceptor import ContextPropagationInterceptor
from context_propagation.shared import user_id
from context_propagation.workflows import SayHelloWorkflow


async def test_workflow_with_context_propagator(client: Client):
# Mock out the activity to assert the context value
@activity.defn(name="say_hello_activity")
async def say_hello_activity_mock(name: str) -> str:
try:
assert user_id.get() == "test-user"
except Exception as err:
raise ApplicationError("Assertion fail", non_retryable=True) from err
return f"Mock for {name}"

# Replace interceptors in client
new_config = client.config()
new_config["interceptors"] = [ContextPropagationInterceptor()]
client = Client(**new_config)
task_queue = f"tq-{uuid.uuid4()}"

async with Worker(
client,
task_queue=task_queue,
activities=[say_hello_activity_mock],
workflows=[SayHelloWorkflow],
):
# Set the user during start/signal, but unset after
token = user_id.set("test-user")
handle = await client.start_workflow(
SayHelloWorkflow.run,
"some-name",
id=f"wf-{uuid.uuid4()}",
task_queue=task_queue,
)
await handle.signal(SayHelloWorkflow.signal_complete)
user_id.reset(token)
result = await handle.result()
assert result == "Mock for some-name"

0 comments on commit a7c04d4

Please sign in to comment.