Skip to content

Commit

Permalink
Resurrect potentially useful additional sample with test
Browse files Browse the repository at this point in the history
  • Loading branch information
dandavison committed May 31, 2024
1 parent 569eed1 commit 139100d
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 0 deletions.
95 changes: 95 additions & 0 deletions tests/update/serialized_handling_of_n_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import asyncio
import logging
import uuid
from dataclasses import dataclass
from unittest.mock import patch

import temporalio.api.common.v1
import temporalio.api.enums.v1
import temporalio.api.update.v1
import temporalio.api.workflowservice.v1
from temporalio.client import Client, WorkflowHandle
from temporalio.worker import Worker
from temporalio.workflow import UpdateMethodMultiParam

from update.serialized_handling_of_n_messages import (
MessageProcessor,
Result,
get_current_time,
)


async def test_continue_as_new_doesnt_lose_updates(client: Client):
with patch(
"temporalio.workflow.Info.is_continue_as_new_suggested", return_value=True
):
tq = str(uuid.uuid4())
wf = await client.start_workflow(
MessageProcessor.run, id=str(uuid.uuid4()), task_queue=tq
)
update_requests = [
UpdateRequest(wf, MessageProcessor.process_message, i) for i in range(10)
]
for req in update_requests:
await req.wait_until_admitted()

async with Worker(
client,
task_queue=tq,
workflows=[MessageProcessor],
activities=[get_current_time],
):
for req in update_requests:
update_result = await req.task
assert update_result.startswith(req.expected_result_prefix())


@dataclass
class UpdateRequest:
wf_handle: WorkflowHandle
update: UpdateMethodMultiParam
sequence_number: int

def __post_init__(self):
self.task = asyncio.Task[Result](
self.wf_handle.execute_update(self.update, args=[self.arg], id=self.id)
)

async def wait_until_admitted(self):
while True:
try:
return await self._poll_update_non_blocking()
except Exception as err:
logging.warning(err)

async def _poll_update_non_blocking(self):
req = temporalio.api.workflowservice.v1.PollWorkflowExecutionUpdateRequest(
namespace=self.wf_handle._client.namespace,
update_ref=temporalio.api.update.v1.UpdateRef(
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
workflow_id=self.wf_handle.id,
run_id="",
),
update_id=self.id,
),
identity=self.wf_handle._client.identity,
)
res = await self.wf_handle._client.workflow_service.poll_workflow_execution_update(
req
)
# TODO: @cretz how do we work with these raw proto objects?
assert "stage: UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED" in str(res)

@property
def arg(self) -> str:
return str(self.sequence_number)

@property
def id(self) -> str:
return str(self.sequence_number)

def expected_result_prefix(self) -> str:
# TODO: Currently the server does not send updates to the worker in order of admission When
# this is fixed (https://github.com/temporalio/temporal/pull/5831), we can make a stronger
# assertion about the activity numbers used to construct each result.
return f"{self.arg}-result"
114 changes: 114 additions & 0 deletions update/serialized_handling_of_n_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import asyncio
import logging
from asyncio import Future
from collections import deque
from datetime import timedelta

from temporalio import activity, common, workflow
from temporalio.client import Client, WorkflowHandle
from temporalio.worker import Worker

Arg = str
Result = str

# Problem:
# -------
# - Your workflow receives an unbounded number of updates.
# - Each update must be processed by calling two activities.
# - The next update may not start processing until the previous has completed.

# Solution:
# --------
# Enqueue updates, and process items from the queue in a single coroutine (the main workflow
# coroutine).

# Discussion:
# ----------
# The queue is used because Temporal's async update & signal handlers will interleave if they
# contain multiple yield points. An alternative would be to use standard async handler functions,
# with handling being done with an asyncio.Lock held. The queue approach would be necessary if we
# need to process in an order other than arrival.


@workflow.defn
class MessageProcessor:

def __init__(self):
self.queue = deque[tuple[Arg, Future[Result]]]()

@workflow.run
async def run(self):
while True:
await workflow.wait_condition(lambda: len(self.queue) > 0)
while self.queue:
arg, fut = self.queue.popleft()
fut.set_result(await self.execute_processing_task(arg))
if workflow.info().is_continue_as_new_suggested():
# Footgun: If we don't let the event loop tick, then CAN will end the workflow
# before the update handler is notified that the result future has completed.
# See https://github.com/temporalio/features/issues/481
await asyncio.sleep(0) # Let update handler complete
print("CAN")
return workflow.continue_as_new()

# Note: handler must be async if we are both enqueuing, and returning an update result
# => We could add SDK APIs to manually complete updates.
@workflow.update
async def process_message(self, arg: Arg) -> Result:
# Footgun: handler may need to wait for workflow initialization after CAN
# See https://github.com/temporalio/features/issues/400
# await workflow.wait_condition(lambda: hasattr(self, "queue"))
fut = Future[Result]()
self.queue.append((arg, fut)) # Note: update validation gates enqueue
return await fut

async def execute_processing_task(self, arg: Arg) -> Result:
# The purpose of the two activities and the result string format is to permit checks that
# the activities of different tasks do not interleave.
t1, t2 = [
await workflow.execute_activity(
get_current_time, start_to_close_timeout=timedelta(seconds=10)
)
for _ in range(2)
]
return f"{arg}-result-{t1}-{t2}"


time = 0


@activity.defn
async def get_current_time() -> int:
global time
time += 1
return time


async def app(wf: WorkflowHandle):
for i in range(20):
print(f"app(): sending update {i}")
result = await wf.execute_update(MessageProcessor.process_message, f"arg {i}")
print(f"app(): {result}")


async def main():
client = await Client.connect("localhost:7233")

async with Worker(
client,
task_queue="tq",
workflows=[MessageProcessor],
activities=[get_current_time],
):
wf = await client.start_workflow(
MessageProcessor.run,
id="wid",
task_queue="tq",
id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING,
)
await asyncio.gather(app(wf), wf.result())


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())

0 comments on commit 139100d

Please sign in to comment.