Skip to content

Commit

Permalink
Linting / mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Oct 16, 2023
1 parent d1e0681 commit 713c847
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 42 deletions.
77 changes: 44 additions & 33 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect
import json
import re
import sys
import uuid
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -1668,7 +1669,7 @@ async def start_update(
*,
args: Sequence[Any] = [],
id: Optional[str] = None,
wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED,
wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED,
result_type: Optional[Type] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
Expand Down Expand Up @@ -2815,8 +2816,8 @@ class ScheduleActionStartWorkflow(ScheduleAction):

@staticmethod
def _from_proto(
info: temporalio.api.workflow.v1.NewWorkflowExecutionInfo,
) -> ScheduleActionStartWorkflow: # type: ignore[override]
info: temporalio.api.workflow.v1.NewWorkflowExecutionInfo, # type: ignore[override]
) -> ScheduleActionStartWorkflow:
return ScheduleActionStartWorkflow("<unset>", raw_info=info)

# Overload for no-param workflow
Expand Down Expand Up @@ -3797,13 +3798,18 @@ def __init__(
run_id: Optional[str] = None,
result_type: Optional[Type] = None,
):
"""Create a workflow update handle.
Users should not create this directly, but rather use
:py:meth:`Client.start_workflow_update`.
"""
self._client = client
self._id = id
self._name = name
self._workflow_id = workflow_id
self._run_id = run_id
self._result_type = result_type
self._known_result = None
self._known_result: Optional[temporalio.api.update.v1.Outcome] = None

@property
def id(self) -> str:
Expand All @@ -3829,7 +3835,7 @@ async def result(
self,
*,
timeout: Optional[timedelta] = None,
rpc_metadata: Mapping[str, str] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> Any:
"""Wait for and return the result of the update. The result may already be known in which case no call is made.
Expand All @@ -3840,6 +3846,10 @@ async def result(
rpc_metadata: Headers used on the RPC call. Keys here override client-level RPC metadata keys.
rpc_timeout: Optional RPC deadline to set for the RPC call. If this elapses, the poll is retried until the
overall timeout has been reached.
Raises:
TimeoutError: The specified timeout was reached when waiting for the update result.
RPCError: Update result could not be fetched for some other reason.
"""
outcome: temporalio.api.update.v1.Outcome
if self._known_result is not None:
Expand Down Expand Up @@ -4084,7 +4094,7 @@ class UpdateWorkflowInput:
update: str
args: Sequence[Any]
wait_for_stage: Optional[
temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage
temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType
]
headers: Mapping[str, temporalio.api.common.v1.Payload]
ret_type: Optional[Type]
Expand Down Expand Up @@ -4724,9 +4734,7 @@ async def start_workflow_update(
# If the status is INVALID_ARGUMENT, we can assume it's an update
# failed error
if err.status == RPCStatusCode.INVALID_ARGUMENT:
raise WorkflowUpdateFailedError(
input.workflow_id, input.update, err.cause
)
raise WorkflowUpdateFailedError(input.workflow_id, input.update, err)
else:
raise

Expand Down Expand Up @@ -4758,31 +4766,34 @@ async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any:
lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED
),
)
try:
# Wait for at most the *overall* timeout
async with asyncio.timeout(input.timeout.total_seconds()):
# Continue polling as long as we have either an empty response, or an *rpc* timeout
while True:
try:
res = await self._client.workflow_service.poll_workflow_execution_update(
req,
retry=True,
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,

async def poll_loop():
# Continue polling as long as we have either an empty response, or an *rpc* timeout
while True:
try:
res = await self._client.workflow_service.poll_workflow_execution_update(
req,
retry=True,
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
if res.HasField("outcome"):
return await _update_outcome_to_result(
res.outcome,
input.update_id,
input.update,
self._client.data_converter,
input.ret_type,
)
if res.HasField("outcome"):
return await _update_outcome_to_result(
res.outcome,
input.update_id,
input.update,
self._client.data_converter,
input.ret_type,
)
except RPCError as err:
if err.status == RPCStatusCode.DEADLINE_EXCEEDED:
continue
except TimeoutError:
pass
except RPCError as err:
if err.status == RPCStatusCode.DEADLINE_EXCEEDED:
continue

# Wait for at most the *overall* timeout
return await asyncio.wait_for(
poll_loop(),
input.timeout.total_seconds() if input.timeout else sys.float_info.max,
)

### Async activity calls

Expand Down
17 changes: 14 additions & 3 deletions temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,27 @@ async def signal_workflow(
):
return await super().signal_workflow(input)

async def update_workflow(
async def start_workflow_update(
self, input: temporalio.client.UpdateWorkflowInput
) -> temporalio.client.WorkflowUpdateHandle:
with self.root._start_as_current_span(
f"StartWorkflowUpdate:{input.update}",
attributes={"temporalWorkflowID": input.workflow_id},
input=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
return await super().start_workflow_update(input)

async def poll_workflow_update(
self, input: temporalio.client.PollUpdateWorkflowInput
) -> Any:
with self.root._start_as_current_span(
f"UpdateWorkflow:{input.update}",
f"PollWorkflowUpdate:{input.update}",
attributes={"temporalWorkflowID": input.workflow_id},
input=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
return await super().update_workflow(input)
return await super().poll_workflow_update(input)


class _TracingActivityInboundInterceptor(temporalio.worker.ActivityInboundInterceptor):
Expand Down
3 changes: 2 additions & 1 deletion temporalio/worker/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ class HandleQueryInput:
@dataclass
class HandleUpdateInput:
"""Input for :py:meth:`WorkflowInboundInterceptor.handle_update_validator`
and :py:meth:`WorkflowInboundInterceptor.handle_update_handler`."""
and :py:meth:`WorkflowInboundInterceptor.handle_update_handler`.
"""

id: str
update: str
Expand Down
3 changes: 2 additions & 1 deletion temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,15 @@ async def run_update(
accpetance_command: temporalio.bridge.proto.workflow_commands.WorkflowCommand,
) -> None:
command = accpetance_command
assert defn is not None
try:
if defn.validator is not None:
# Run the validator
await self._inbound.handle_update_validator(handler_input)

# Accept the update
command.update_response.accepted.SetInParent()
command = None
command = None # type: ignore

# Run the handler
success = await self._inbound.handle_update_handler(handler_input)
Expand Down
7 changes: 5 additions & 2 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,8 @@ def _update_validator(
update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None
):
"""Decorator for a workflow update validator method."""
update_def.set_validator(fn)
if fn is not None:
update_def.set_validator(fn)


def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None:
Expand Down Expand Up @@ -1375,7 +1376,9 @@ def bind_fn(self, obj: Any) -> Callable[..., Any]:
return _bind_method(obj, self.fn)

def bind_validator(self, obj: Any) -> Callable[..., Any]:
return _bind_method(obj, self.validator)
if self.validator is not None:
return _bind_method(obj, self.validator)
return lambda *args, **kwargs: None

def set_validator(self, validator: Callable[..., None]) -> None:
# TODO: Verify arg types are the same
Expand Down
2 changes: 2 additions & 0 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def test_workflow_defn_good():
name="base_query", fn=GoodDefnBase.base_query, is_method=True
),
},
# TODO: Add
updates={},
sandboxed=True,
)

Expand Down
6 changes: 4 additions & 2 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3541,7 +3541,7 @@ async def last_event_async(self, an_arg: str) -> str:

@workflow.update
async def runs_activity(self, name: str) -> str:
act = workflow.start_activity_method(
act = workflow.start_activity(
say_hello, name, schedule_to_close_timeout=timedelta(seconds=5)
)
act.cancel()
Expand All @@ -3565,7 +3565,9 @@ async def runs_activity(self, name: str) -> str:


async def test_workflow_update_handlers(client: Client):
async with new_worker(client, UpdateHandlersWorkflow) as worker:
async with new_worker(
client, UpdateHandlersWorkflow, activities=[say_hello]
) as worker:
handle = await client.start_workflow(
UpdateHandlersWorkflow.run,
id=f"update-handlers-workflow-{uuid.uuid4()}",
Expand Down

0 comments on commit 713c847

Please sign in to comment.