Skip to content

Commit

Permalink
Dynamic update handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Oct 17, 2023
1 parent abe36c3 commit 818f6f2
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 138 deletions.
190 changes: 93 additions & 97 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,114 +422,96 @@ def _apply_cancel_workflow(
def _apply_do_update(
self, job: temporalio.bridge.proto.workflow_activation.DoUpdate
):
acceptance_command = self._add_command()
acceptance_command.update_response.protocol_instance_id = (
job.protocol_instance_id
)
try:
defn = self._updates.get(job.name) or self._updates.get(None)
if not defn:
raise RuntimeError(
f"Update handler for '{job.name}' expected but not found, and there is no dynamic handler"
# Run the validator & handler in a task. Everything, including looking up the update definition, needs to be
# inside the task, since the update may not be defined until after we have started the workflow - for example
# if an update is in the first WFT & is also registered dynamically at the top of workflow code.
async def run_update() -> None:
command = self._add_command()
command.update_response.protocol_instance_id = job.protocol_instance_id
try:
defn = self._updates.get(job.name) or self._updates.get(None)
if not defn:
raise RuntimeError(
f"Update handler for '{job.name}' expected but not found, and there is no dynamic handler"
)
args = self._process_handler_args(
job.name,
job.input,
defn.name,
defn.arg_types,
defn.dynamic_vararg,
)
handler_input = HandleUpdateInput(
# TODO: update id vs proto instance id
id=job.protocol_instance_id,
update=job.name,
args=args,
headers=job.headers,
)
args = self._process_handler_args(
job.name,
job.input,
defn.name,
defn.arg_types,
defn.dynamic_vararg,
)
handler_input = HandleUpdateInput(
# TODO: update id vs proto instance id
id=job.protocol_instance_id,
update=job.name,
args=args,
headers=job.headers,
)

# Run the validator & handler in a task. Validator needs to be in here since the interceptor might be async.
async def run_update(
accpetance_command: temporalio.bridge.proto.workflow_commands.WorkflowCommand,
) -> None:
command = accpetance_command
assert defn is not None
try:
if defn.validator is not None:
# Run the validator
with self._as_read_only():
await self._inbound.handle_update_validator(handler_input)
if defn.validator is not None:
# Run the validator
with self._as_read_only():
await self._inbound.handle_update_validator(handler_input)

# Accept the update
command.update_response.accepted.SetInParent()
command = None # type: ignore
# Accept the update
command.update_response.accepted.SetInParent()
command = None # type: ignore

# Run the handler
success = await self._inbound.handle_update_handler(handler_input)
result_payloads = self._payload_converter.to_payloads([success])
if len(result_payloads) != 1:
raise ValueError(
f"Expected 1 result payload, got {len(result_payloads)}"
)
# Run the handler
success = await self._inbound.handle_update_handler(handler_input)
result_payloads = self._payload_converter.to_payloads([success])
if len(result_payloads) != 1:
raise ValueError(
f"Expected 1 result payload, got {len(result_payloads)}"
)
command = self._add_command()
command.update_response.protocol_instance_id = job.protocol_instance_id
command.update_response.completed.CopyFrom(result_payloads[0])
except (Exception, asyncio.CancelledError) as err:
logger.debug(
f"Update raised failure with run ID {self._info.run_id}",
exc_info=True,
)
# All asyncio cancelled errors become Temporal cancelled errors
if isinstance(err, asyncio.CancelledError):
err = temporalio.exceptions.CancelledError(
f"Cancellation raised within update {err}"
)
# Read-only issues during validation should fail the task
if isinstance(err, temporalio.workflow.ReadOnlyContextError):
self._current_activation_error = err
return
# All other errors fail the update
if command is None:
command = self._add_command()
command.update_response.protocol_instance_id = (
job.protocol_instance_id
)
command.update_response.completed.CopyFrom(result_payloads[0])
except (Exception, asyncio.CancelledError) as err:
logger.debug(
f"Update raised failure with run ID {self._info.run_id}",
exc_info=True,
)
# All asyncio cancelled errors become Temporal cancelled errors
if isinstance(err, asyncio.CancelledError):
err = temporalio.exceptions.CancelledError(
f"Cancellation raised within update {err}"
)
# Read-only issues during validation should fail the task
if isinstance(err, temporalio.workflow.ReadOnlyContextError):
self._current_activation_error = err
return
# All other errors fail the update
if command is None:
command = self._add_command()
command.update_response.protocol_instance_id = (
job.protocol_instance_id
)
self._failure_converter.to_failure(
err,
self._payload_converter,
command.update_response.rejected.cause,
)
except BaseException as err:
# During tear down, generator exit and no-runtime exceptions can appear
if not self._deleting:
raise
if not isinstance(
err,
(
GeneratorExit,
temporalio.workflow._NotInWorkflowEventLoopError,
),
):
logger.debug(
"Ignoring exception while deleting workflow", exc_info=True
)

self.create_task(
run_update(acceptance_command),
name=f"update: {job.name}",
)

except Exception as err:
# If we failed here we had some issue deserializing or finding the update handlers, so reject it.
try:
self._failure_converter.to_failure(
err,
self._payload_converter,
acceptance_command.update_response.rejected.cause,
command.update_response.rejected.cause,
)
except Exception as inner_err:
raise ValueError("Failed converting application error") from inner_err
except BaseException as err:
# During tear down, generator exit and no-runtime exceptions can appear
if not self._deleting:
raise
if not isinstance(
err,
(
GeneratorExit,
temporalio.workflow._NotInWorkflowEventLoopError,
),
):
logger.debug(
"Ignoring exception while deleting workflow", exc_info=True
)

self.create_task(
run_update(),
name=f"update: {job.name}",
)

def _apply_fire_timer(
self, job: temporalio.bridge.proto.workflow_activation.FireTimer
Expand Down Expand Up @@ -1017,7 +999,21 @@ def workflow_set_signal_handler(
else:
self._signals.pop(name, None)

# TODO: Set update handler?
def workflow_set_update_handler(
self, name: Optional[str], handler: Optional[Callable]
) -> None:
self._assert_not_read_only("set update handler")
if handler:
defn = temporalio.workflow._UpdateDefinition(
name=name, fn=handler, is_method=False
)
self._updates[name] = defn
if defn.dynamic_vararg:
raise RuntimeError(
"Dynamic updates do not support a vararg third param, use Sequence[RawValue]",
)
else:
self._updates.pop(name, None)

def workflow_start_activity(
self,
Expand Down
63 changes: 63 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,12 @@ def workflow_set_signal_handler(
) -> None:
...

@abstractmethod
def workflow_set_update_handler(
self, name: Optional[str], handler: Optional[Callable]
) -> None:
...

@abstractmethod
def workflow_start_activity(
self,
Expand Down Expand Up @@ -4093,6 +4099,63 @@ def set_dynamic_query_handler(handler: Optional[Callable]) -> None:
_Runtime.current().workflow_set_query_handler(None, handler)


def get_update_handler(name: str) -> Optional[Callable]:
"""Get the update handler for the given name if any.
This includes handlers created via the ``@workflow.update`` decorator.
Args:
name: Name of the update.
Returns:
Callable for the update if any. If a handler is not found for the name,
this will not return the dynamic handler even if there is one.
"""
return _Runtime.current().workflow_get_update_handler(name)


def set_update_handler(name: str, handler: Optional[Callable]) -> None:
"""Set or unset the update handler for the given name.
This overrides any existing handlers for the given name, including handlers
created via the ``@workflow.update`` decorator.
When set, all unhandled past signals for the given name are immediately sent
to the handler.
Args:
name: Name of the update.
handler: Callable to set or None to unset.
"""
_Runtime.current().workflow_set_update_handler(name, handler)


def get_dynamic_update_handler() -> Optional[Callable]:
"""Get the dynamic update handler if any.
This includes dynamic handlers created via the ``@workflow.update``
decorator.
Returns:
Callable for the dynamic update handler if any.
"""
return _Runtime.current().workflow_get_update_handler(None)


def set_dynamic_update_handler(handler: Optional[Callable]) -> None:
"""Set or unset the dynamic update handler.
This overrides the existing dynamic handler even if it was created via the
``@workflow.update`` decorator.
When set, all unhandled past signals are immediately sent to the handler.
Args:
handler: Callable to set or None to unset.
"""
_Runtime.current().workflow_set_update_handler(None, handler)


def _is_unbound_method_on_cls(fn: Callable[..., Any], cls: Type) -> bool:
# Python 3 does not make this easy, ref https://stackoverflow.com/questions/3589311
return (
Expand Down
55 changes: 14 additions & 41 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3516,6 +3516,12 @@ def __init__(self) -> None:

@workflow.run
async def run(self) -> None:
workflow.set_update_handler("first_task_update", lambda: "worked")

def dynahandler(name: str, _args: Sequence[RawValue]) -> str:
return "dynahandler - " + name

workflow.set_dynamic_update_handler(dynahandler)
# Wait forever
await asyncio.Future()

Expand Down Expand Up @@ -3565,23 +3571,8 @@ def bad_validator_validator(self) -> None:
say_hello, "boo", schedule_to_close_timeout=timedelta(seconds=5)
)

# @workflow.signal
# def set_signal_handler(self, signal_name: str) -> None:
# def new_handler(arg: str) -> None:
# self._last_event = f"signal {signal_name}: {arg}"
#
# workflow.set_signal_handler(signal_name, new_handler)
#
# @workflow.signal
# def set_dynamic_signal_handler(self) -> None:
# def new_handler(name: str, args: Sequence[RawValue]) -> None:
# arg = workflow.payload_converter().from_payload(args[0].payload, str)
# self._last_event = f"signal dynamic {name}: {arg}"
#
# workflow.set_dynamic_signal_handler(new_handler)


async def test_workflow_update_handlers(client: Client):

async def test_workflow_update_handlers_happy(client: Client):
async with new_worker(
client, UpdateHandlersWorkflow, activities=[say_hello]
) as worker:
Expand All @@ -3591,6 +3582,9 @@ async def test_workflow_update_handlers(client: Client):
task_queue=worker.task_queue,
)

# Dynamically registered and used in first task
assert "worked" == await handle.update("first_task_update")

# Normal handling
last_event = await handle.update(UpdateHandlersWorkflow.last_event, "val2")
assert "<no event>" == last_event
Expand All @@ -3600,30 +3594,9 @@ async def test_workflow_update_handlers(client: Client):
UpdateHandlersWorkflow.last_event_async, "val3"
)
assert "val2" == last_event
# # Dynamic signal handling buffered and new
# await handle.signal("unknown_signal2", "val3")
# await handle.signal(UpdateHandlersWorkflow.set_dynamic_signal_handler)
# assert "signal dynamic unknown_signal2: val3" == await handle.query(
# UpdateHandlersWorkflow.last_event
# )
# await handle.signal("unknown_signal3", "val4")
# assert "signal dynamic unknown_signal3: val4" == await handle.query(
# UpdateHandlersWorkflow.last_event
# )
#
# # Normal query handling
# await handle.signal(
# UpdateHandlersWorkflow.set_query_handler, "unknown_query1"
# )
# assert "query unknown_query1: val5" == await handle.query(
# "unknown_query1", "val5"
# )
#
# # Dynamic query handling
# await handle.signal(UpdateHandlersWorkflow.set_dynamic_query_handler)
# assert "query dynamic unknown_query2: val6" == await handle.query(
# "unknown_query2", "val6"
# )

# Dynamic handler
assert "dynahandler - made_up" == await handle.update("made_up")


async def test_workflow_update_handlers_unhappy(client: Client):
Expand Down

0 comments on commit 818f6f2

Please sign in to comment.