From 1da053f7497a8d72af6a7c8d65159bedd9ec6396 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Mon, 16 Oct 2023 11:28:47 -0700 Subject: [PATCH] Linting / mypy --- temporalio/client.py | 77 ++++++++++++++----------- temporalio/contrib/opentelemetry.py | 18 ++++-- temporalio/worker/_interceptor.py | 3 +- temporalio/worker/_workflow_instance.py | 3 +- temporalio/workflow.py | 7 ++- tests/test_workflow.py | 2 + tests/worker/test_workflow.py | 6 +- 7 files changed, 73 insertions(+), 43 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 95c18143..ccde4f80 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -8,6 +8,7 @@ import inspect import json import re +import sys import uuid import warnings from abc import ABC, abstractmethod @@ -1668,7 +1669,7 @@ async def start_update( *, args: Sequence[Any] = [], id: Optional[str] = None, - wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED, + wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED, result_type: Optional[Type] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, @@ -2815,8 +2816,8 @@ class ScheduleActionStartWorkflow(ScheduleAction): @staticmethod def _from_proto( - info: temporalio.api.workflow.v1.NewWorkflowExecutionInfo, - ) -> ScheduleActionStartWorkflow: # type: ignore[override] + info: temporalio.api.workflow.v1.NewWorkflowExecutionInfo, # type: ignore[override] + ) -> ScheduleActionStartWorkflow: return ScheduleActionStartWorkflow("", raw_info=info) # Overload for no-param workflow @@ -3797,13 +3798,18 @@ def __init__( run_id: Optional[str] = None, result_type: Optional[Type] = None, ): + """Create a workflow update handle. + + Users should not create this directly, but rather use + :py:meth:`Client.start_workflow_update`. + """ self._client = client self._id = id self._name = name self._workflow_id = workflow_id self._run_id = run_id self._result_type = result_type - self._known_result = None + self._known_result: Optional[temporalio.api.update.v1.Outcome] = None @property def id(self) -> str: @@ -3829,7 +3835,7 @@ async def result( self, *, timeout: Optional[timedelta] = None, - rpc_metadata: Mapping[str, str] = None, + rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, ) -> Any: """Wait for and return the result of the update. The result may already be known in which case no call is made. @@ -3840,6 +3846,10 @@ async def result( rpc_metadata: Headers used on the RPC call. Keys here override client-level RPC metadata keys. rpc_timeout: Optional RPC deadline to set for the RPC call. If this elapses, the poll is retried until the overall timeout has been reached. + + Raises: + TimeoutError: The specified timeout was reached when waiting for the update result. + RPCError: Update result could not be fetched for some other reason. """ outcome: temporalio.api.update.v1.Outcome if self._known_result is not None: @@ -4084,7 +4094,7 @@ class UpdateWorkflowInput: update: str args: Sequence[Any] wait_for_stage: Optional[ - temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage + temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType ] headers: Mapping[str, temporalio.api.common.v1.Payload] ret_type: Optional[Type] @@ -4724,9 +4734,7 @@ async def start_workflow_update( # If the status is INVALID_ARGUMENT, we can assume it's an update # failed error if err.status == RPCStatusCode.INVALID_ARGUMENT: - raise WorkflowUpdateFailedError( - input.workflow_id, input.update, err.cause - ) + raise WorkflowUpdateFailedError(input.workflow_id, input.update, err) else: raise @@ -4758,31 +4766,34 @@ async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any: lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED ), ) - try: - # Wait for at most the *overall* timeout - async with asyncio.timeout(input.timeout.total_seconds()): - # Continue polling as long as we have either an empty response, or an *rpc* timeout - while True: - try: - res = await self._client.workflow_service.poll_workflow_execution_update( - req, - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + + async def poll_loop(): + # Continue polling as long as we have either an empty response, or an *rpc* timeout + while True: + try: + res = await self._client.workflow_service.poll_workflow_execution_update( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + if res.HasField("outcome"): + return await _update_outcome_to_result( + res.outcome, + input.update_id, + input.update, + self._client.data_converter, + input.ret_type, ) - if res.HasField("outcome"): - return await _update_outcome_to_result( - res.outcome, - input.update_id, - input.update, - self._client.data_converter, - input.ret_type, - ) - except RPCError as err: - if err.status == RPCStatusCode.DEADLINE_EXCEEDED: - continue - except TimeoutError: - pass + except RPCError as err: + if err.status == RPCStatusCode.DEADLINE_EXCEEDED: + continue + + # Wait for at most the *overall* timeout + return await asyncio.wait_for( + poll_loop(), + input.timeout.total_seconds() if input.timeout else sys.float_info.max, + ) ### Async activity calls diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index b1e1df6b..bf963ca9 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -25,6 +25,7 @@ import opentelemetry.trace import opentelemetry.trace.propagation.tracecontext import opentelemetry.util.types +from client import PollUpdateWorkflowInput, WorkflowUpdateHandle from typing_extensions import Protocol, TypeAlias, TypedDict import temporalio.activity @@ -244,16 +245,25 @@ async def signal_workflow( ): return await super().signal_workflow(input) - async def update_workflow( + async def start_workflow_update( self, input: temporalio.client.UpdateWorkflowInput - ) -> Any: + ) -> WorkflowUpdateHandle: + with self.root._start_as_current_span( + f"StartWorkflowUpdate:{input.update}", + attributes={"temporalWorkflowID": input.workflow_id}, + input=input, + kind=opentelemetry.trace.SpanKind.CLIENT, + ): + return await super().start_workflow_update(input) + + async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any: with self.root._start_as_current_span( - f"UpdateWorkflow:{input.update}", + f"PollWorkflowUpdate:{input.update}", attributes={"temporalWorkflowID": input.workflow_id}, input=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): - return await super().update_workflow(input) + return await super().poll_workflow_update(input) class _TracingActivityInboundInterceptor(temporalio.worker.ActivityInboundInterceptor): diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index bbfec877..5d2cc685 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -193,7 +193,8 @@ class HandleQueryInput: @dataclass class HandleUpdateInput: """Input for :py:meth:`WorkflowInboundInterceptor.handle_update_validator` - and :py:meth:`WorkflowInboundInterceptor.handle_update_handler`.""" + and :py:meth:`WorkflowInboundInterceptor.handle_update_handler`. + """ id: str update: str diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 20db27f3..46279b1b 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -452,6 +452,7 @@ async def run_update( accpetance_command: temporalio.bridge.proto.workflow_commands.WorkflowCommand, ) -> None: command = accpetance_command + assert defn is not None try: if defn.validator is not None: # Run the validator @@ -459,7 +460,7 @@ async def run_update( # Accept the update command.update_response.accepted.SetInParent() - command = None + command = None # type: ignore # Run the handler success = await self._inbound.handle_update_handler(handler_input) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 5380babf..4e1d448b 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -811,7 +811,8 @@ def _update_validator( update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None ): """Decorator for a workflow update validator method.""" - update_def.set_validator(fn) + if fn is not None: + update_def.set_validator(fn) def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None: @@ -1375,7 +1376,9 @@ def bind_fn(self, obj: Any) -> Callable[..., Any]: return _bind_method(obj, self.fn) def bind_validator(self, obj: Any) -> Callable[..., Any]: - return _bind_method(obj, self.validator) + if self.validator is not None: + return _bind_method(obj, self.validator) + return lambda *args, **kwargs: None def set_validator(self, validator: Callable[..., None]) -> None: # TODO: Verify arg types are the same diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 3f24d530..e9851445 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -87,6 +87,8 @@ def test_workflow_defn_good(): name="base_query", fn=GoodDefnBase.base_query, is_method=True ), }, + # TODO: Add + updates={}, sandboxed=True, ) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 1ae41a5f..6fded7d5 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3541,7 +3541,7 @@ async def last_event_async(self, an_arg: str) -> str: @workflow.update async def runs_activity(self, name: str) -> str: - act = workflow.start_activity_method( + act = workflow.start_activity( say_hello, name, schedule_to_close_timeout=timedelta(seconds=5) ) act.cancel() @@ -3565,7 +3565,9 @@ async def runs_activity(self, name: str) -> str: async def test_workflow_update_handlers(client: Client): - async with new_worker(client, UpdateHandlersWorkflow) as worker: + async with new_worker( + client, UpdateHandlersWorkflow, activities=[say_hello] + ) as worker: handle = await client.start_workflow( UpdateHandlersWorkflow.run, id=f"update-handlers-workflow-{uuid.uuid4()}",