From 25759b5e0c008d9649307b2734e3382a64e4cb1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Albert=20=C3=96rwall?= Date: Wed, 31 Jul 2024 14:46:53 +0200 Subject: [PATCH] Readd retry_message --- moatless/loop.py | 5 +++-- moatless/state.py | 2 +- moatless/trajectory.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/moatless/loop.py b/moatless/loop.py index 7c9dff95..91495852 100644 --- a/moatless/loop.py +++ b/moatless/loop.py @@ -267,7 +267,7 @@ def _set_state_loop(self, state: AgenticState): def retries(self) -> int: retries = 0 for action in reversed(self._current_transition.actions): - if action.retry_message: + if action.trigger == "retry": retries += 1 else: return retries @@ -281,7 +281,7 @@ def retry_messages(self, state: AgenticState) -> list[Message]: return messages for action in self._current_transition.actions: - if action.retry_message: + if action.trigger == "retry": if isinstance(action.action, Content): messages.append( AssistantMessage( @@ -504,6 +504,7 @@ def _run(self): TrajectoryAction( action=action, trigger=response.trigger, + retry_message=response.retry_message, completion_cost=cost, input_tokens=input_tokens, output_tokens=output_tokens, diff --git a/moatless/state.py b/moatless/state.py index e3b4d79b..866761ad 100644 --- a/moatless/state.py +++ b/moatless/state.py @@ -100,7 +100,7 @@ def required_fields(cls) -> set[str]: def retries(self) -> int: retries = 0 for action in reversed(self.loop._current_transition.actions): - if action.retry_message: + if action.trigger == "retry": retries += 1 else: return retries diff --git a/moatless/trajectory.py b/moatless/trajectory.py index 1d0d6e77..1aaeee1d 100644 --- a/moatless/trajectory.py +++ b/moatless/trajectory.py @@ -17,6 +17,7 @@ class TrajectoryAction(BaseModel): action: ActionRequest trigger: Optional[str] + retry_message: Optional[str] completion_cost: Optional[float] = None input_tokens: Optional[int] = None output_tokens: Optional[int] = None