Skip to content

Commit

Permalink
Langchain V0.2 and Sync API deprecated (#116)
Browse files Browse the repository at this point in the history
* upgrade langchain and remove sync apis

* remove return prompt from agenerate

* remove prompt from agenerate

* add sync apis
  • Loading branch information
ProKil authored Jun 20, 2024
1 parent bf0321b commit 8d9b9be
Show file tree
Hide file tree
Showing 12 changed files with 299 additions and 517 deletions.
6 changes: 3 additions & 3 deletions examples/generate_specific_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
from datasets import DatasetDict, load_dataset

from sotopia.generation_utils.generate import StrOutputParser, generate
from sotopia.generation_utils.generate import StrOutputParser, agenerate


async def generate_mutual_friend_envs() -> tuple[str, list[str]]:
Expand Down Expand Up @@ -78,7 +78,7 @@ async def generate_craigslist_bargains_envs() -> tuple[str, list[str]]:
all_data = craigslist_bargains_dataset["train"]
# sample one datum from all data
datum = np.random.choice(all_data)
scenario = generate(
scenario = await agenerate(
model_name="gpt-4",
template="The following sentence is automatically generated with the following"
'template: "One person is selling <item> for <price>, another person is'
Expand All @@ -100,7 +100,7 @@ async def generate_craigslist_bargains_envs() -> tuple[str, list[str]]:
datum["agent_info"]["Target"][i] = datum["items"]["Price"][0] / (
1 + markup_ratio
)
goal = generate(
goal = await agenerate(
model_name="gpt-4",
template="The following sentence is automatically generated with the following"
'template: "You want to <role> this item. Your target price '
Expand Down
435 changes: 217 additions & 218 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ packages = [{include = "sotopia"}]
python = ">=3.10, <3.13"
lxml = ">=4.9.3,<6.0.0"
openai = "^1.11.0"
langchain = "0.1.5"
langchain = "~0.2.5"
rich = "^13.6.0"
PettingZoo = "1.24.3"
redis-om = "^0.2.1"
Expand All @@ -21,7 +21,7 @@ absl-py = "^2.0.0"
together = "^0.2.4"
pydantic = "1.10.12"
beartype = "^0.14.0"
langchain-openai = ">=0.0.5,<0.0.7"
langchain-openai = "~0.1.8"
litellm = "~1.23.12"

# dependency versions for extras
Expand Down
2 changes: 0 additions & 2 deletions sotopia/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
HumanAgent,
LLMAgent,
ScriptWritingAgent,
SpeakAgent,
)
from .redis_agent import RedisAgent

Expand All @@ -17,7 +16,6 @@
"LLMAgent",
"Agents",
"HumanAgent",
"SpeakAgent",
"generate_background",
"generate_background_conversation",
"RedisAgent",
Expand Down
9 changes: 3 additions & 6 deletions sotopia/agents/generate_agent_background.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
import os
from typing import Callable

from sotopia.generation_utils.generate import (
convert_narratives,
generate_init_profile,
)
from sotopia.generation_utils.generate import convert_narratives, agenerate_init_profile
from sotopia.messages import Message, ScriptBackground


def generate_background(
async def generate_background(
info_json_file: str, basic_info: dict[str, str]
) -> tuple[str, str, str, str, list[dict[str, str]]]:
if os.path.isfile(info_json_file):
Expand All @@ -22,7 +19,7 @@ def generate_background(
previous_messages = info_dict["messages"]
else:
initial_profile = str(basic_info)
profile = generate_init_profile(
profile = await agenerate_init_profile(
model_name="gpt-3.5-turbo", basic_info=basic_info
)
first_narrative = convert_narratives(
Expand Down
53 changes: 14 additions & 39 deletions sotopia/agents/llm_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, cast
from typing import cast

from sotopia.agents import BaseAgent
from sotopia.database import AgentProfile
from sotopia.generation_utils.generate import (
LLM_Name,
agenerate_action,
agenerate_goal,
agenerate_script,
generate_action,
generate_action_speak,
generate_goal,
)
from sotopia.messages import AgentAction, Observation
from sotopia.messages.message_classes import ScriptBackground
Expand Down Expand Up @@ -44,48 +42,34 @@ def __init__(
def goal(self) -> str:
if self._goal is not None:
return self._goal
assert (
len(self.inbox) > 0
), "attribute goal has to be called after at least one step"
goal = generate_goal(
self.model_name,
background=self.inbox[0][
1
].to_natural_language(), # Only consider the first message for now
)
return goal
else:
raise Exception("Goal is not set.")

@goal.setter
def goal(self, goal: str) -> None:
self._goal = goal

def act(
self,
obs: Observation,
gen_func: Callable[..., AgentAction] = generate_action,
_obs: Observation,
) -> AgentAction:
raise Exception("Sync act method is deprecated. Use aact instead.")

async def aact(self, obs: Observation) -> AgentAction:
self.recv_message("Environment", obs)

if len(obs.available_actions) == 1 and "none" in obs.available_actions:
return AgentAction(action_type="none", argument="")
else:
action = gen_func(
if self._goal is None:
self._goal = await agenerate_goal(
self.model_name,
history="\n".join(f"{y.to_natural_language()}" for x, y in self.inbox),
turn_number=obs.turn_number,
action_types=obs.available_actions,
agent=self.agent_name,
goal=self.goal,
background=self.inbox[0][
1
].to_natural_language(), # Only consider the first message for now
)
return action

async def aact(self, obs: Observation) -> AgentAction:
self.recv_message("Environment", obs)

if len(obs.available_actions) == 1 and "none" in obs.available_actions:
return AgentAction(action_type="none", argument="")
else:
action, prompt = await agenerate_action(
action = await agenerate_action(
self.model_name,
history="\n".join(f"{y.to_natural_language()}" for x, y in self.inbox),
turn_number=obs.turn_number,
Expand Down Expand Up @@ -147,15 +131,6 @@ async def aact(self, obs: Observation) -> AgentAction:
return returned_action


class SpeakAgent(LLMAgent):
def act(
self,
obs: Observation,
gen_func: Callable[..., AgentAction] = generate_action_speak,
) -> AgentAction:
return super().act(obs, gen_func=gen_func)


class HumanAgent(BaseAgent[Observation, AgentAction]):
"""
A human agent that takes input from the command line.
Expand Down
3 changes: 1 addition & 2 deletions sotopia/envs/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ async def __acall__(
response: (
EnvResponsePlus | EnvResponse
) # fix type error from langchain 0.0.264. we don't need this line for langchain 0.0.263
response, prompt = await agenerate(
response = await agenerate(
model_name=self.model_name,
template="""{history},
Based on previous interactions, evaluate how well participants achieve their goals.
Expand All @@ -286,7 +286,6 @@ async def __acall__(
),
temperature=temperature,
)
self.prompt = prompt
response_list = []
# TODO: multiple agents
for dimension in response.agent_1_evaluation.dict().keys():
Expand Down
15 changes: 11 additions & 4 deletions sotopia/generation_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
EnvResponse,
LLM_Name,
agenerate_env_profile,
fill_in_background,
generate_goal,
agenerate,
agenerate_action,
)

from .sync import (
generate,
generate_action,
)

__all__ = [
"EnvResponse",
"agenerate_env_profile",
"LLM_Name",
"fill_in_background",
"generate_goal",
"agenerate",
"agenerate_action",
"generate",
"generate_action",
]
Loading

0 comments on commit 8d9b9be

Please sign in to comment.