diff --git a/sotopia/generation_utils/generate.py b/sotopia/generation_utils/generate.py index 39a0d1548..4d9d710ef 100644 --- a/sotopia/generation_utils/generate.py +++ b/sotopia/generation_utils/generate.py @@ -1,12 +1,13 @@ import logging import os import re -from typing import TypeVar +from typing import TypeVar, Any import gin from beartype import beartype from beartype.typing import Type -from langchain.chains.llm import LLMChain +from langchain_core.runnables.base import RunnableSerializable +from langchain_core.messages.base import BaseMessage from langchain.output_parsers import PydanticOutputParser from langchain.prompts import ( ChatPromptTemplate, @@ -14,7 +15,7 @@ PromptTemplate, ) from langchain.schema import BaseOutputParser, OutputParserException -from langchain_openai import ChatOpenAI +from langchain_openai import ChatOpenAI, AzureChatOpenAI from pydantic import BaseModel, Field from rich import print from typing_extensions import Literal @@ -295,7 +296,7 @@ def obtain_chain( input_variables: list[str], temperature: float = 0.7, max_retries: int = 6, -) -> LLMChain: +) -> RunnableSerializable[dict[Any, Any], BaseMessage]: """ Using langchain to sample profiles for participants """ @@ -316,7 +317,7 @@ def obtain_chain( openai_api_base="https://api.together.xyz/v1", openai_api_key=os.environ.get("TOGETHER_API_KEY"), ) - chain = LLMChain(llm=chat_openai, prompt=chat_prompt_template) + chain = chat_prompt_template | chat_openai return chain elif "groq" in model_name: model_name = "/".join(model_name.split("/")[1:]) @@ -334,7 +335,31 @@ def obtain_chain( openai_api_base="https://api.groq.com/openai/v1", openai_api_key=os.environ.get("GROQ_API_KEY"), ) - chain = LLMChain(llm=chat_openai, prompt=chat_prompt_template) + chain = chat_prompt_template | chat_openai + return chain + elif "azure" in model_name: + # azure/resource_name/deployment_name/version + azure_credentials = model_name.split("/")[1:] + resource_name, deployment_name, azure_version = ( + azure_credentials[0], + azure_credentials[1], + azure_credentials[2], + ) + human_message_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template=template, + input_variables=input_variables, + ) + ) + chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) + chat_azure_openai = AzureChatOpenAI( + azure_deployment=deployment_name, + openai_api_version=azure_version, + azure_endpoint=f"https://{resource_name}.openai.azure.com", + temperature=temperature, + max_retries=max_retries, + ) + chain = chat_prompt_template | chat_azure_openai return chain else: chat = ChatOpenAI( @@ -346,7 +371,7 @@ def obtain_chain( prompt=PromptTemplate(template=template, input_variables=input_variables) ) chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) - chain = LLMChain(llm=chat, prompt=chat_prompt_template) + chain = chat_prompt_template | chat return chain @@ -356,7 +381,7 @@ def format_bad_output_for_script( format_instructions: str, agents: list[str], model_name: str = "gpt-3.5-turbo", -) -> str: +) -> BaseMessage: template = """ Given the string that can not be parsed by a parser, reformat it to a string that can be parsed by the parser which uses the following format instructions. Do not add or delete any information. Small tip: for every round of conversation, first determine the name and the case, and whether this line contains errors. Correct it if necessary. @@ -380,17 +405,17 @@ def format_bad_output_for_script( "format_instructions": format_instructions, "agents": agents, } - reformat = chain.predict([logging_handler], **input_values) + reformat = chain.invoke(input_values, config={"callbacks": [logging_handler]}) log.info(f"Reformated output: {reformat}") return reformat @beartype def format_bad_output( - ill_formed_output: str, + ill_formed_output: BaseMessage, format_instructions: str, model_name: str = "gpt-3.5-turbo", -) -> str: +) -> BaseMessage: template = """ Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser. Original string: {ill_formed_output} @@ -405,10 +430,10 @@ def format_bad_output( input_variables=re.findall(r"{(.*?)}", template), ) input_values = { - "ill_formed_output": ill_formed_output, + "ill_formed_output": ill_formed_output.content, "format_instructions": format_instructions, } - reformat = chain.predict([logging_handler], **input_values) + reformat = chain.invoke(input_values, config={"callbacks": [logging_handler]}) log.info(f"Reformated output: {reformat}") return reformat @@ -437,9 +462,9 @@ async def agenerate( ) if "format_instructions" not in input_values: input_values["format_instructions"] = output_parser.get_format_instructions() - result = await chain.apredict([logging_handler], **input_values) + result = await chain.ainvoke(input_values, config={"callbacks": [logging_handler]}) try: - parsed_result = output_parser.parse(result) + parsed_result = output_parser.invoke(result) except Exception as e: if isinstance(output_parser, ScriptOutputParser): raise e # the problem has been handled in the parser @@ -450,7 +475,7 @@ async def agenerate( reformat_parsed_result = format_bad_output( result, format_instructions=output_parser.get_format_instructions() ) - parsed_result = output_parser.parse(reformat_parsed_result) + parsed_result = output_parser.invoke(reformat_parsed_result) log.info(f"Generated result: {parsed_result}") return parsed_result diff --git a/sotopia_conf/run_async_server_in_batch.gin b/sotopia_conf/run_async_server_in_batch.gin index 2d0f88026..544d0679a 100644 --- a/sotopia_conf/run_async_server_in_batch.gin +++ b/sotopia_conf/run_async_server_in_batch.gin @@ -2,12 +2,13 @@ from __gin__ import dynamic_registration import __main__ as main_script BATCH_SIZE=10 +ENV_MODEL="gpt-4" AGENT1_MODEL="gpt-3.5-turbo" AGENT2_MODEL="gpt-3.5-turbo" VERBOSE=False TAG_TO_CHECK_EXISTING_EPISODES=None -MODEL_NAMES={"env": "gpt-4", "agent1": %AGENT1_MODEL, "agent2": %AGENT2_MODEL} +MODEL_NAMES={"env": %ENV_MODEL, "agent1": %AGENT1_MODEL, "agent2": %AGENT2_MODEL} ENV_IDS=%gin.REQUIRED