Skip to content

Commit

Permalink
Support Azure API for agent and env models; Fix updates on langchain …
Browse files Browse the repository at this point in the history
…V0.2 (#132)

* cherry picked generate.py from #69

* Add AzureOpenAI for agent and env models and Update to langchain V0.2 runnable interface

* Delete azure api key input

* Fix mypy errors

---------
  • Loading branch information
ruiyiw authored Jul 9, 2024
1 parent f6bb5ba commit f53472a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 17 deletions.
57 changes: 41 additions & 16 deletions sotopia/generation_utils/generate.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import logging
import os
import re
from typing import TypeVar
from typing import TypeVar, Any

import gin
from beartype import beartype
from beartype.typing import Type
from langchain.chains.llm import LLMChain
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.messages.base import BaseMessage
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
)
from langchain.schema import BaseOutputParser, OutputParserException
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from pydantic import BaseModel, Field
from rich import print
from typing_extensions import Literal
Expand Down Expand Up @@ -295,7 +296,7 @@ def obtain_chain(
input_variables: list[str],
temperature: float = 0.7,
max_retries: int = 6,
) -> LLMChain:
) -> RunnableSerializable[dict[Any, Any], BaseMessage]:
"""
Using langchain to sample profiles for participants
"""
Expand All @@ -316,7 +317,7 @@ def obtain_chain(
openai_api_base="https://api.together.xyz/v1",
openai_api_key=os.environ.get("TOGETHER_API_KEY"),
)
chain = LLMChain(llm=chat_openai, prompt=chat_prompt_template)
chain = chat_prompt_template | chat_openai
return chain
elif "groq" in model_name:
model_name = "/".join(model_name.split("/")[1:])
Expand All @@ -334,7 +335,31 @@ def obtain_chain(
openai_api_base="https://api.groq.com/openai/v1",
openai_api_key=os.environ.get("GROQ_API_KEY"),
)
chain = LLMChain(llm=chat_openai, prompt=chat_prompt_template)
chain = chat_prompt_template | chat_openai
return chain
elif "azure" in model_name:
# azure/resource_name/deployment_name/version
azure_credentials = model_name.split("/")[1:]
resource_name, deployment_name, azure_version = (
azure_credentials[0],
azure_credentials[1],
azure_credentials[2],
)
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template=template,
input_variables=input_variables,
)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
chat_azure_openai = AzureChatOpenAI(
azure_deployment=deployment_name,
openai_api_version=azure_version,
azure_endpoint=f"https://{resource_name}.openai.azure.com",
temperature=temperature,
max_retries=max_retries,
)
chain = chat_prompt_template | chat_azure_openai
return chain
else:
chat = ChatOpenAI(
Expand All @@ -346,7 +371,7 @@ def obtain_chain(
prompt=PromptTemplate(template=template, input_variables=input_variables)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
chain = chat_prompt_template | chat
return chain


Expand All @@ -356,7 +381,7 @@ def format_bad_output_for_script(
format_instructions: str,
agents: list[str],
model_name: str = "gpt-3.5-turbo",
) -> str:
) -> BaseMessage:
template = """
Given the string that can not be parsed by a parser, reformat it to a string that can be parsed by the parser which uses the following format instructions. Do not add or delete any information.
Small tip: for every round of conversation, first determine the name and the case, and whether this line contains errors. Correct it if necessary.
Expand All @@ -380,17 +405,17 @@ def format_bad_output_for_script(
"format_instructions": format_instructions,
"agents": agents,
}
reformat = chain.predict([logging_handler], **input_values)
reformat = chain.invoke(input_values, config={"callbacks": [logging_handler]})
log.info(f"Reformated output: {reformat}")
return reformat


@beartype
def format_bad_output(
ill_formed_output: str,
ill_formed_output: BaseMessage,
format_instructions: str,
model_name: str = "gpt-3.5-turbo",
) -> str:
) -> BaseMessage:
template = """
Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
Original string: {ill_formed_output}
Expand All @@ -405,10 +430,10 @@ def format_bad_output(
input_variables=re.findall(r"{(.*?)}", template),
)
input_values = {
"ill_formed_output": ill_formed_output,
"ill_formed_output": ill_formed_output.content,
"format_instructions": format_instructions,
}
reformat = chain.predict([logging_handler], **input_values)
reformat = chain.invoke(input_values, config={"callbacks": [logging_handler]})
log.info(f"Reformated output: {reformat}")
return reformat

Expand Down Expand Up @@ -437,9 +462,9 @@ async def agenerate(
)
if "format_instructions" not in input_values:
input_values["format_instructions"] = output_parser.get_format_instructions()
result = await chain.apredict([logging_handler], **input_values)
result = await chain.ainvoke(input_values, config={"callbacks": [logging_handler]})
try:
parsed_result = output_parser.parse(result)
parsed_result = output_parser.invoke(result)
except Exception as e:
if isinstance(output_parser, ScriptOutputParser):
raise e # the problem has been handled in the parser
Expand All @@ -450,7 +475,7 @@ async def agenerate(
reformat_parsed_result = format_bad_output(
result, format_instructions=output_parser.get_format_instructions()
)
parsed_result = output_parser.parse(reformat_parsed_result)
parsed_result = output_parser.invoke(reformat_parsed_result)
log.info(f"Generated result: {parsed_result}")
return parsed_result

Expand Down
3 changes: 2 additions & 1 deletion sotopia_conf/run_async_server_in_batch.gin
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ from __gin__ import dynamic_registration
import __main__ as main_script

BATCH_SIZE=10
ENV_MODEL="gpt-4"
AGENT1_MODEL="gpt-3.5-turbo"
AGENT2_MODEL="gpt-3.5-turbo"
VERBOSE=False
TAG_TO_CHECK_EXISTING_EPISODES=None

MODEL_NAMES={"env": "gpt-4", "agent1": %AGENT1_MODEL, "agent2": %AGENT2_MODEL}
MODEL_NAMES={"env": %ENV_MODEL, "agent1": %AGENT1_MODEL, "agent2": %AGENT2_MODEL}
ENV_IDS=%gin.REQUIRED


Expand Down

0 comments on commit f53472a

Please sign in to comment.