Skip to content

Commit

Permalink
Add polling
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Oct 16, 2023
1 parent 84b909c commit d1e0681
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 27 deletions.
148 changes: 122 additions & 26 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import copy
import dataclasses
import inspect
Expand Down Expand Up @@ -1716,7 +1717,7 @@ async def start_update(

return await self._client._impl.start_workflow_update(
UpdateWorkflowInput(
id=self._id,
workflow_id=self._id,
run_id=self._run_id,
update_id=id or "",
update=update_name,
Expand Down Expand Up @@ -3829,31 +3830,41 @@ async def result(
*,
timeout: Optional[timedelta] = None,
rpc_metadata: Mapping[str, str] = None,
rpc_timeout: Optional[timedelta] = None,
) -> Any:
"""Wait for and return the result of the update. The result may already be known in which case no call is made.
Otherwise the result will be polled for until returned, or until the provided timeout is reached, if specified.
Args:
timeout: Optional timeout specifying maximum wait time for the result.
rpc_metadata: Headers used on the RPC call. Keys here override client-level RPC metadata keys.
rpc_timeout: Optional RPC deadline to set for the RPC call. If this elapses, the poll is retried until the
overall timeout has been reached.
"""
outcome: temporalio.api.update.v1.Outcome
if self._known_result is not None:
outcome = self._known_result
else:
# TODO: This
raise NotImplementedError

if outcome.HasField("failure"):
raise WorkflowUpdateFailedError(
return await _update_outcome_to_result(
outcome,
self.id,
self.name,
await self._client.data_converter.decode_failure(outcome.failure.cause),
self._client.data_converter,
self._result_type,
)
else:
return await self._client._impl.poll_workflow_update(
PollUpdateWorkflowInput(
self.workflow_id,
self.run_id,
self.id,
self.name,
timeout,
{},
self._result_type,
rpc_metadata,
rpc_timeout,
)
)
if not outcome.success.payloads:
return None
type_hints = [self._result_type] if self._result_type else None
results = await self._client.data_converter.decode(
outcome.success.payloads, type_hints
)
if not results:
return None
elif len(results) > 1:
warnings.warn(f"Expected single update result, got {len(results)}")
return results[0]

def _set_known_result(self, result: temporalio.api.update.v1.Outcome) -> None:
self._known_result = result
Expand Down Expand Up @@ -4065,9 +4076,9 @@ class TerminateWorkflowInput:

@dataclass
class UpdateWorkflowInput:
"""Input for :py:meth:`OutboundInterceptor.update_workflow`."""
"""Input for :py:meth:`OutboundInterceptor.start_workflow_update`."""

id: str
workflow_id: str
run_id: Optional[str]
update_id: str
update: str
Expand All @@ -4076,7 +4087,21 @@ class UpdateWorkflowInput:
temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage
]
headers: Mapping[str, temporalio.api.common.v1.Payload]
# Type may be absent
ret_type: Optional[Type]
rpc_metadata: Mapping[str, str]
rpc_timeout: Optional[timedelta]


@dataclass
class PollUpdateWorkflowInput:
"""Input for :py:meth:`OutboundInterceptor.poll_workflow_update`."""

workflow_id: str
run_id: Optional[str]
update_id: str
update: str
timeout: Optional[timedelta]
headers: Mapping[str, temporalio.api.common.v1.Payload]
ret_type: Optional[Type]
rpc_metadata: Mapping[str, str]
rpc_timeout: Optional[timedelta]
Expand Down Expand Up @@ -4329,9 +4354,13 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
async def start_workflow_update(
self, input: UpdateWorkflowInput
) -> WorkflowUpdateHandle:
"""Called for every :py:meth:`WorkflowHandle.signal` call."""
"""Called for every :py:meth:`WorkflowHandle.update` and :py:meth:`WorkflowHandle.start_update` call."""
return await self.next.start_workflow_update(input)

async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any:
"""May be called when calling :py:math:`WorkflowUpdateHandle.result`."""
return await self.next.poll_workflow_update(input)

### Async activity calls

async def heartbeat_async_activity(
Expand Down Expand Up @@ -4665,7 +4694,7 @@ async def start_workflow_update(
req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest(
namespace=self._client.namespace,
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
workflow_id=input.id,
workflow_id=input.workflow_id,
run_id=input.run_id or "",
),
request=temporalio.api.update.v1.Request(
Expand Down Expand Up @@ -4695,15 +4724,17 @@ async def start_workflow_update(
# If the status is INVALID_ARGUMENT, we can assume it's an update
# failed error
if err.status == RPCStatusCode.INVALID_ARGUMENT:
raise WorkflowUpdateFailedError(input.id, input.update, err.cause)
raise WorkflowUpdateFailedError(
input.workflow_id, input.update, err.cause
)
else:
raise

update_handle = WorkflowUpdateHandle(
client=self._client,
id=input.update_id,
name=input.update,
workflow_id=input.id,
workflow_id=input.workflow_id,
run_id=input.run_id,
result_type=input.ret_type,
)
Expand All @@ -4712,6 +4743,47 @@ async def start_workflow_update(

return update_handle

async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any:
req = temporalio.api.workflowservice.v1.PollWorkflowExecutionUpdateRequest(
namespace=self._client.namespace,
update_ref=temporalio.api.update.v1.UpdateRef(
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
workflow_id=input.workflow_id,
run_id=input.run_id or "",
),
update_id=input.update_id,
),
identity=self._client.identity,
wait_policy=temporalio.api.update.v1.WaitPolicy(
lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED
),
)
try:
# Wait for at most the *overall* timeout
async with asyncio.timeout(input.timeout.total_seconds()):
# Continue polling as long as we have either an empty response, or an *rpc* timeout
while True:
try:
res = await self._client.workflow_service.poll_workflow_execution_update(
req,
retry=True,
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
if res.HasField("outcome"):
return await _update_outcome_to_result(
res.outcome,
input.update_id,
input.update,
self._client.data_converter,
input.ret_type,
)
except RPCError as err:
if err.status == RPCStatusCode.DEADLINE_EXCEEDED:
continue
except TimeoutError:
pass

### Async activity calls

async def heartbeat_async_activity(
Expand Down Expand Up @@ -5240,6 +5312,30 @@ def _fix_history_enum(prefix: str, parent: Dict[str, Any], *attrs: str) -> None:
_fix_history_enum(prefix, child_item, *attrs[1:])


async def _update_outcome_to_result(
outcome: temporalio.api.update.v1.Outcome,
id: str,
name: str,
converter: temporalio.converter.DataConverter,
rtype: Optional[Type],
) -> Any:
if outcome.HasField("failure"):
raise WorkflowUpdateFailedError(
id,
name,
await converter.decode_failure(outcome.failure.cause),
)
if not outcome.success.payloads:
return None
type_hints = [rtype] if rtype else None
results = await converter.decode(outcome.success.payloads, type_hints)
if not results:
return None
elif len(results) > 1:
warnings.warn(f"Expected single update result, got {len(results)}")
return results[0]


@dataclass(frozen=True)
class WorkerBuildIdVersionSets:
"""Represents the sets of compatible Build ID versions associated with some Task Queue, as
Expand Down
2 changes: 1 addition & 1 deletion temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ async def update_workflow(
) -> Any:
with self.root._start_as_current_span(
f"UpdateWorkflow:{input.update}",
attributes={"temporalWorkflowID": input.id},
attributes={"temporalWorkflowID": input.workflow_id},
input=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Client,
Interceptor,
OutboundInterceptor,
PollUpdateWorkflowInput,
QueryWorkflowInput,
RPCError,
RPCStatusCode,
Expand Down Expand Up @@ -408,6 +409,12 @@ async def start_workflow_update(
self._parent.traces.append(("start_workflow_update", input))
return await super().start_workflow_update(input)

async def poll_workflow_update(
self, input: PollUpdateWorkflowInput
) -> WorkflowUpdateHandle:
self._parent.traces.append(("poll_workflow_update", input))
return await super().poll_workflow_update(input)


async def test_interceptor(client: Client, worker: ExternalWorker):
# Create new client from existing client but with a tracing interceptor
Expand Down

0 comments on commit d1e0681

Please sign in to comment.