diff --git a/docs/concepts/flows.mdx b/docs/concepts/flows.mdx index e152a31b..ef7346f2 100644 --- a/docs/concepts/flows.mdx +++ b/docs/concepts/flows.mdx @@ -119,19 +119,18 @@ The following flow properties are inferred from the decorated function: | ------------- | ------------- | | `name` | The function's name | | `description` | The function's docstring | -| `context` | The function's arguments (keyed by argument name) | +| `context` | The function's arguments, if specified as `context_kwargs` (keyed by argument name) | -Additional properties can be set by passing keyword arguments directly to the `@flow` decorator or to the `flow_kwargs` parameter when calling the decorated function. - - -You may not want the arguments to your flow function to be used as context. In that case, you can set `args_as_context=False` when decorating or calling the function: +To automatically put some of your flow's arguments into the global context that all agents can see, specify `context_kwargs` when decorating your flow: ```python -@cf.flow(args_as_context=False) -def my_flow(secret_var: str): +@cf.flow(context_kwargs=["x"]) +def my_flow(x: int, y: int): + # x will be automatically added to a global, agent-visible context ... ``` - + +Additional properties can be set by passing keyword arguments directly to the `@flow` decorator or to the `flow_kwargs` parameter when calling the decorated function. ### The `Flow` object and context manager diff --git a/src/controlflow/decorators.py b/src/controlflow/decorators.py index 5f175948..752da59c 100644 --- a/src/controlflow/decorators.py +++ b/src/controlflow/decorators.py @@ -21,12 +21,12 @@ def flow( thread: Optional[str] = None, instructions: Optional[str] = None, tools: Optional[list[Callable[..., Any]]] = None, - default_agent: Optional[Agent] = None, # Changed from 'agents' + default_agent: Optional[Agent] = None, retries: Optional[int] = None, retry_delay_seconds: Optional[Union[float, int]] = None, timeout_seconds: Optional[Union[float, int]] = None, prefect_kwargs: Optional[dict[str, Any]] = None, - args_as_context: Optional[bool] = True, + context_kwargs: Optional[list[str]] = None, **kwargs: Optional[dict[str, Any]], ): """ @@ -46,67 +46,69 @@ def flow( instructions (str, optional): Instructions for the flow. Defaults to None. tools (list[Callable], optional): List of tools to be used in the flow. Defaults to None. default_agent (Agent, optional): The default agent to be used in the flow. Defaults to None. - args_as_context (bool, optional): Whether to pass the arguments as context to the flow. Defaults to True. + context_kwargs (list[str], optional): List of argument names to be added to the flow context. + Defaults to None. Returns: callable: The wrapped function or a new flow decorator if `fn` is not provided. """ - ... - if fn is None: return functools.partial( flow, thread=thread, instructions=instructions, tools=tools, - default_agent=default_agent, # Changed from 'agents' + default_agent=default_agent, retries=retries, retry_delay_seconds=retry_delay_seconds, timeout_seconds=timeout_seconds, - args_as_context=args_as_context, + context_kwargs=context_kwargs, **kwargs, ) sig = inspect.signature(fn) - def _inner_wrapper(*wrapper_args, flow_kwargs: dict = None, **wrapper_kwargs): - # first process callargs - bound = sig.bind(*wrapper_args, **wrapper_kwargs) - bound.apply_defaults() - - flow_kwargs = kwargs | (flow_kwargs or {}) - + def create_flow_context(bound_args): + flow_kwargs = kwargs.copy() if thread is not None: flow_kwargs.setdefault("thread_id", thread) if tools is not None: flow_kwargs.setdefault("tools", tools) - if default_agent is not None: # Changed from 'agents' - flow_kwargs.setdefault( - "default_agent", default_agent - ) # Changed from 'agents' - - context = bound.arguments if args_as_context else {} - - with ( - Flow( - name=fn.__name__, - description=fn.__doc__, - context=context, - **flow_kwargs, - ), - controlflow.instructions(instructions), - ): - return fn(*wrapper_args, **wrapper_kwargs) + if default_agent is not None: + flow_kwargs.setdefault("default_agent", default_agent) + + context = {} + if context_kwargs: + context = {k: bound_args[k] for k in context_kwargs if k in bound_args} + + return Flow( + name=fn.__name__, + description=fn.__doc__, + context=context, + **flow_kwargs, + ) if asyncio.iscoroutinefunction(fn): @functools.wraps(fn) async def wrapper(*wrapper_args, **wrapper_kwargs): - return await _inner_wrapper(*wrapper_args, **wrapper_kwargs) + bound = sig.bind(*wrapper_args, **wrapper_kwargs) + bound.apply_defaults() + with ( + create_flow_context(bound.arguments), + controlflow.instructions(instructions), + ): + return await fn(*wrapper_args, **wrapper_kwargs) else: @functools.wraps(fn) def wrapper(*wrapper_args, **wrapper_kwargs): - return _inner_wrapper(*wrapper_args, **wrapper_kwargs) + bound = sig.bind(*wrapper_args, **wrapper_kwargs) + bound.apply_defaults() + with ( + create_flow_context(bound.arguments), + controlflow.instructions(instructions), + ): + return fn(*wrapper_args, **wrapper_kwargs) wrapper = prefect_flow( timeout_seconds=timeout_seconds, diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 468b652f..c4076f49 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -103,6 +103,33 @@ def partial_flow(): result = partial_flow() assert result == 10 + def test_flow_decorator_with_context_kwargs(self): + @controlflow.flow(context_kwargs=["x", "z"]) + def flow_with_context(x: int, y: int, z: str): + flow = controlflow.flows.get_flow() + return flow.context + + result = flow_with_context(1, 2, "test") + assert result == {"x": 1, "z": "test"} + + def test_flow_decorator_without_context_kwargs(self): + @controlflow.flow + def flow_without_context(x: int, y: int, z: str): + flow = controlflow.flows.get_flow() + return flow.context + + result = flow_without_context(1, 2, "test") + assert result == {} + + async def test_async_flow_decorator_with_context_kwargs(self): + @controlflow.flow(context_kwargs=["a", "b"]) + async def async_flow_with_context(a: int, b: str, c: float): + flow = controlflow.flows.get_flow() + return flow.context + + result = await async_flow_with_context(10, "hello", 3.14) + assert result == {"a": 10, "b": "hello"} + class TestTaskDecorator: def test_task_decorator_sync_as_task(self):