From 54b54c61c0a44f3142998b31778b9ce470310cb4 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 2 Oct 2024 17:07:34 -0400 Subject: [PATCH] Improve type hints --- src/controlflow/agents/agent.py | 5 +++-- src/controlflow/flows/flow.py | 7 ++++--- src/controlflow/tasks/task.py | 6 ++++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 1a6912cd..7fa52d97 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -14,6 +14,7 @@ from langchain_core.language_models import BaseChatModel from pydantic import Field, field_serializer, field_validator +from typing_extensions import Self import controlflow from controlflow.agents.names import AGENT_NAMES @@ -183,11 +184,11 @@ def get_prompt(self) -> str: return template.render() @contextmanager - def create_context(self): + def create_context(self) -> Generator[Self, None, None]: with ctx(agent=self): yield self - def __enter__(self): + def __enter__(self) -> Self: self._cm_stack.append(self.create_context()) return self._cm_stack[-1].__enter__() diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index 6c0cc07c..8d51588e 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -1,9 +1,10 @@ import uuid from contextlib import contextmanager, nullcontext -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Union from prefect.context import FlowRunContext from pydantic import Field +from typing_extensions import Self import controlflow from controlflow.agents import Agent @@ -54,7 +55,7 @@ class Flow(ControlFlowModel): context: dict[str, Any] = {} _cm_stack: list[contextmanager] = [] - def __enter__(self): + def __enter__(self) -> Self: # use stack so we can enter the context multiple times cm = self.create_context() self._cm_stack.append(cm) @@ -111,7 +112,7 @@ def add_events(self, events: list[Event]): self.history.add_events(thread_id=self.thread_id, events=events) @contextmanager - def create_context(self, **prefect_kwargs): + def create_context(self, **prefect_kwargs) -> Generator[Self, None, None]: # create a new Prefect flow if we're not already in a flow run if FlowRunContext.get() is None: prefect_context = prefect_flow_context(**prefect_kwargs) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 22f1bfc7..097e0c21 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -7,6 +7,7 @@ TYPE_CHECKING, Any, Callable, + Generator, GenericAlias, Literal, Optional, @@ -27,6 +28,7 @@ field_serializer, field_validator, ) +from typing_extensions import Self import controlflow from controlflow.agents import Agent @@ -426,12 +428,12 @@ async def run_async( raise ValueError(f"{self.friendly_name()} failed: {self.result}") @contextmanager - def create_context(self): + def create_context(self) -> Generator[Self, None, None]: stack = ctx.get("tasks") or [] with ctx(tasks=stack + [self]): yield self - def __enter__(self): + def __enter__(self) -> Self: # use stack so we can enter the context multiple times self._cm_stack.append(ExitStack()) return self._cm_stack[-1].enter_context(self.create_context())