diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1417b98f..96ad4660 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -422,114 +422,96 @@ def _apply_cancel_workflow( def _apply_do_update( self, job: temporalio.bridge.proto.workflow_activation.DoUpdate ): - acceptance_command = self._add_command() - acceptance_command.update_response.protocol_instance_id = ( - job.protocol_instance_id - ) - try: - defn = self._updates.get(job.name) or self._updates.get(None) - if not defn: - raise RuntimeError( - f"Update handler for '{job.name}' expected but not found, and there is no dynamic handler" + # Run the validator & handler in a task. Everything, including looking up the update definition, needs to be + # inside the task, since the update may not be defined until after we have started the workflow - for example + # if an update is in the first WFT & is also registered dynamically at the top of workflow code. + async def run_update() -> None: + command = self._add_command() + command.update_response.protocol_instance_id = job.protocol_instance_id + try: + defn = self._updates.get(job.name) or self._updates.get(None) + if not defn: + raise RuntimeError( + f"Update handler for '{job.name}' expected but not found, and there is no dynamic handler" + ) + args = self._process_handler_args( + job.name, + job.input, + defn.name, + defn.arg_types, + defn.dynamic_vararg, + ) + handler_input = HandleUpdateInput( + # TODO: update id vs proto instance id + id=job.protocol_instance_id, + update=job.name, + args=args, + headers=job.headers, ) - args = self._process_handler_args( - job.name, - job.input, - defn.name, - defn.arg_types, - defn.dynamic_vararg, - ) - handler_input = HandleUpdateInput( - # TODO: update id vs proto instance id - id=job.protocol_instance_id, - update=job.name, - args=args, - headers=job.headers, - ) - # Run the validator & handler in a task. Validator needs to be in here since the interceptor might be async. - async def run_update( - accpetance_command: temporalio.bridge.proto.workflow_commands.WorkflowCommand, - ) -> None: - command = accpetance_command - assert defn is not None - try: - if defn.validator is not None: - # Run the validator - with self._as_read_only(): - await self._inbound.handle_update_validator(handler_input) + if defn.validator is not None: + # Run the validator + with self._as_read_only(): + await self._inbound.handle_update_validator(handler_input) - # Accept the update - command.update_response.accepted.SetInParent() - command = None # type: ignore + # Accept the update + command.update_response.accepted.SetInParent() + command = None # type: ignore - # Run the handler - success = await self._inbound.handle_update_handler(handler_input) - result_payloads = self._payload_converter.to_payloads([success]) - if len(result_payloads) != 1: - raise ValueError( - f"Expected 1 result payload, got {len(result_payloads)}" - ) + # Run the handler + success = await self._inbound.handle_update_handler(handler_input) + result_payloads = self._payload_converter.to_payloads([success]) + if len(result_payloads) != 1: + raise ValueError( + f"Expected 1 result payload, got {len(result_payloads)}" + ) + command = self._add_command() + command.update_response.protocol_instance_id = job.protocol_instance_id + command.update_response.completed.CopyFrom(result_payloads[0]) + except (Exception, asyncio.CancelledError) as err: + logger.debug( + f"Update raised failure with run ID {self._info.run_id}", + exc_info=True, + ) + # All asyncio cancelled errors become Temporal cancelled errors + if isinstance(err, asyncio.CancelledError): + err = temporalio.exceptions.CancelledError( + f"Cancellation raised within update {err}" + ) + # Read-only issues during validation should fail the task + if isinstance(err, temporalio.workflow.ReadOnlyContextError): + self._current_activation_error = err + return + # All other errors fail the update + if command is None: command = self._add_command() command.update_response.protocol_instance_id = ( job.protocol_instance_id ) - command.update_response.completed.CopyFrom(result_payloads[0]) - except (Exception, asyncio.CancelledError) as err: - logger.debug( - f"Update raised failure with run ID {self._info.run_id}", - exc_info=True, - ) - # All asyncio cancelled errors become Temporal cancelled errors - if isinstance(err, asyncio.CancelledError): - err = temporalio.exceptions.CancelledError( - f"Cancellation raised within update {err}" - ) - # Read-only issues during validation should fail the task - if isinstance(err, temporalio.workflow.ReadOnlyContextError): - self._current_activation_error = err - return - # All other errors fail the update - if command is None: - command = self._add_command() - command.update_response.protocol_instance_id = ( - job.protocol_instance_id - ) - self._failure_converter.to_failure( - err, - self._payload_converter, - command.update_response.rejected.cause, - ) - except BaseException as err: - # During tear down, generator exit and no-runtime exceptions can appear - if not self._deleting: - raise - if not isinstance( - err, - ( - GeneratorExit, - temporalio.workflow._NotInWorkflowEventLoopError, - ), - ): - logger.debug( - "Ignoring exception while deleting workflow", exc_info=True - ) - - self.create_task( - run_update(acceptance_command), - name=f"update: {job.name}", - ) - - except Exception as err: - # If we failed here we had some issue deserializing or finding the update handlers, so reject it. - try: self._failure_converter.to_failure( err, self._payload_converter, - acceptance_command.update_response.rejected.cause, + command.update_response.rejected.cause, ) - except Exception as inner_err: - raise ValueError("Failed converting application error") from inner_err + except BaseException as err: + # During tear down, generator exit and no-runtime exceptions can appear + if not self._deleting: + raise + if not isinstance( + err, + ( + GeneratorExit, + temporalio.workflow._NotInWorkflowEventLoopError, + ), + ): + logger.debug( + "Ignoring exception while deleting workflow", exc_info=True + ) + + self.create_task( + run_update(), + name=f"update: {job.name}", + ) def _apply_fire_timer( self, job: temporalio.bridge.proto.workflow_activation.FireTimer @@ -1017,7 +999,21 @@ def workflow_set_signal_handler( else: self._signals.pop(name, None) - # TODO: Set update handler? + def workflow_set_update_handler( + self, name: Optional[str], handler: Optional[Callable] + ) -> None: + self._assert_not_read_only("set update handler") + if handler: + defn = temporalio.workflow._UpdateDefinition( + name=name, fn=handler, is_method=False + ) + self._updates[name] = defn + if defn.dynamic_vararg: + raise RuntimeError( + "Dynamic updates do not support a vararg third param, use Sequence[RawValue]", + ) + else: + self._updates.pop(name, None) def workflow_start_activity( self, diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 6b8acbd9..3e82e5e9 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -505,6 +505,12 @@ def workflow_set_signal_handler( ) -> None: ... + @abstractmethod + def workflow_set_update_handler( + self, name: Optional[str], handler: Optional[Callable] + ) -> None: + ... + @abstractmethod def workflow_start_activity( self, @@ -4093,6 +4099,63 @@ def set_dynamic_query_handler(handler: Optional[Callable]) -> None: _Runtime.current().workflow_set_query_handler(None, handler) +def get_update_handler(name: str) -> Optional[Callable]: + """Get the update handler for the given name if any. + + This includes handlers created via the ``@workflow.update`` decorator. + + Args: + name: Name of the update. + + Returns: + Callable for the update if any. If a handler is not found for the name, + this will not return the dynamic handler even if there is one. + """ + return _Runtime.current().workflow_get_update_handler(name) + + +def set_update_handler(name: str, handler: Optional[Callable]) -> None: + """Set or unset the update handler for the given name. + + This overrides any existing handlers for the given name, including handlers + created via the ``@workflow.update`` decorator. + + When set, all unhandled past signals for the given name are immediately sent + to the handler. + + Args: + name: Name of the update. + handler: Callable to set or None to unset. + """ + _Runtime.current().workflow_set_update_handler(name, handler) + + +def get_dynamic_update_handler() -> Optional[Callable]: + """Get the dynamic update handler if any. + + This includes dynamic handlers created via the ``@workflow.update`` + decorator. + + Returns: + Callable for the dynamic update handler if any. + """ + return _Runtime.current().workflow_get_update_handler(None) + + +def set_dynamic_update_handler(handler: Optional[Callable]) -> None: + """Set or unset the dynamic update handler. + + This overrides the existing dynamic handler even if it was created via the + ``@workflow.update`` decorator. + + When set, all unhandled past signals are immediately sent to the handler. + + Args: + handler: Callable to set or None to unset. + """ + _Runtime.current().workflow_set_update_handler(None, handler) + + def _is_unbound_method_on_cls(fn: Callable[..., Any], cls: Type) -> bool: # Python 3 does not make this easy, ref https://stackoverflow.com/questions/3589311 return ( diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 0a8c68a9..ea9d1a90 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3516,6 +3516,12 @@ def __init__(self) -> None: @workflow.run async def run(self) -> None: + workflow.set_update_handler("first_task_update", lambda: "worked") + + def dynahandler(name: str, _args: Sequence[RawValue]) -> str: + return "dynahandler - " + name + + workflow.set_dynamic_update_handler(dynahandler) # Wait forever await asyncio.Future() @@ -3565,23 +3571,8 @@ def bad_validator_validator(self) -> None: say_hello, "boo", schedule_to_close_timeout=timedelta(seconds=5) ) - # @workflow.signal - # def set_signal_handler(self, signal_name: str) -> None: - # def new_handler(arg: str) -> None: - # self._last_event = f"signal {signal_name}: {arg}" - # - # workflow.set_signal_handler(signal_name, new_handler) - # - # @workflow.signal - # def set_dynamic_signal_handler(self) -> None: - # def new_handler(name: str, args: Sequence[RawValue]) -> None: - # arg = workflow.payload_converter().from_payload(args[0].payload, str) - # self._last_event = f"signal dynamic {name}: {arg}" - # - # workflow.set_dynamic_signal_handler(new_handler) - - -async def test_workflow_update_handlers(client: Client): + +async def test_workflow_update_handlers_happy(client: Client): async with new_worker( client, UpdateHandlersWorkflow, activities=[say_hello] ) as worker: @@ -3591,6 +3582,9 @@ async def test_workflow_update_handlers(client: Client): task_queue=worker.task_queue, ) + # Dynamically registered and used in first task + assert "worked" == await handle.update("first_task_update") + # Normal handling last_event = await handle.update(UpdateHandlersWorkflow.last_event, "val2") assert "" == last_event @@ -3600,30 +3594,9 @@ async def test_workflow_update_handlers(client: Client): UpdateHandlersWorkflow.last_event_async, "val3" ) assert "val2" == last_event - # # Dynamic signal handling buffered and new - # await handle.signal("unknown_signal2", "val3") - # await handle.signal(UpdateHandlersWorkflow.set_dynamic_signal_handler) - # assert "signal dynamic unknown_signal2: val3" == await handle.query( - # UpdateHandlersWorkflow.last_event - # ) - # await handle.signal("unknown_signal3", "val4") - # assert "signal dynamic unknown_signal3: val4" == await handle.query( - # UpdateHandlersWorkflow.last_event - # ) - # - # # Normal query handling - # await handle.signal( - # UpdateHandlersWorkflow.set_query_handler, "unknown_query1" - # ) - # assert "query unknown_query1: val5" == await handle.query( - # "unknown_query1", "val5" - # ) - # - # # Dynamic query handling - # await handle.signal(UpdateHandlersWorkflow.set_dynamic_query_handler) - # assert "query dynamic unknown_query2: val6" == await handle.query( - # "unknown_query2", "val6" - # ) + + # Dynamic handler + assert "dynahandler - made_up" == await handle.update("made_up") async def test_workflow_update_handlers_unhappy(client: Client):