From c810b1d878982829e31a7304b136106fc23e0706 Mon Sep 17 00:00:00 2001 From: Hao Date: Thu, 6 Jun 2024 16:27:01 -0400 Subject: [PATCH] cherry picked generate.py from #69 --- poetry.lock | 138 ++------------------------- pyproject.toml | 1 - sotopia/generation_utils/generate.py | 103 +++++++++++--------- 3 files changed, 65 insertions(+), 177 deletions(-) diff --git a/poetry.lock b/poetry.lock index 3fb7d9cc2..2b4785693 100644 --- a/poetry.lock +++ b/poetry.lock @@ -278,10 +278,7 @@ files = [ [package.dependencies] jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" -urllib3 = [ - {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, - {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, -] +urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} [package.extras] crt = ["awscrt (==0.20.9)"] @@ -973,7 +970,7 @@ files = [ name = "fsspec" version = "2024.3.1" description = "File-system specification" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, @@ -1348,7 +1345,6 @@ files = [ [package.dependencies] cloudpickle = ">=1.2.0" farama-notifications = ">=0.0.1" -importlib-metadata = {version = ">=4.8.0", markers = "python_version < \"3.10\""} numpy = ">=1.21.0" typing-extensions = ">=4.3.0" @@ -1568,7 +1564,7 @@ files = [ name = "huggingface-hub" version = "0.23.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -optional = false +optional = true python-versions = ">=3.8.0" files = [ {file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"}, @@ -1623,25 +1619,6 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] -[[package]] -name = "importlib-metadata" -version = "7.1.0" -description = "Read metadata from Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, - {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, -] - -[package.dependencies] -zipp = ">=0.5" - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] - [[package]] name = "iniconfig" version = "2.0.0" @@ -1722,7 +1699,6 @@ prompt-toolkit = ">=3.0.41,<3.1.0" pygments = ">=2.4.0" stack-data = "*" traitlets = ">=5" -typing-extensions = {version = "*", markers = "python_version < \"3.10\""} [package.extras] all = ["black", "curio", "docrepr", "exceptiongroup", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio (<0.22)", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] @@ -1760,7 +1736,7 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.4" description = "A very fast and expressive template engine." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, @@ -1926,7 +1902,6 @@ files = [ ] [package.dependencies] -importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" @@ -2079,32 +2054,6 @@ files = [ pydantic = ">=1,<3" requests = ">=2,<3" -[[package]] -name = "litellm" -version = "1.38.8" -description = "Library to easily interface with LLM API providers" -optional = false -python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" -files = [ - {file = "litellm-1.38.8-py3-none-any.whl", hash = "sha256:8155eb346dda0d766a401bcfb7b89c0e174a1c324d048b801e0cacd3579b1019"}, - {file = "litellm-1.38.8.tar.gz", hash = "sha256:7ea342b6dd4d409e9bf7642cc0fabe9f0e6fe61613505d36e7dbf9e4128baf72"}, -] - -[package.dependencies] -aiohttp = "*" -click = "*" -importlib-metadata = ">=6.8.0" -jinja2 = ">=3.1.2,<4.0.0" -openai = ">=1.27.0" -python-dotenv = ">=0.2.0" -requests = ">=2.31.0,<3.0.0" -tiktoken = ">=0.4.0" -tokenizers = "*" - -[package.extras] -extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "resend (>=0.8.0,<0.9.0)"] -proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "cryptography (>=42.0.5,<43.0.0)", "fastapi (>=0.111.0,<0.112.0)", "fastapi-sso (>=0.10.0,<0.11.0)", "gunicorn (>=22.0.0,<23.0.0)", "orjson (>=3.9.7,<4.0.0)", "python-multipart (>=0.0.9,<0.0.10)", "pyyaml (>=6.0.1,<7.0.0)", "rq", "uvicorn (>=0.22.0,<0.23.0)"] - [[package]] name = "lxml" version = "4.9.4" @@ -2241,7 +2190,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.5" description = "Safely add untrusted strings to HTML/XML markup." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, @@ -3425,20 +3374,6 @@ files = [ [package.dependencies] six = ">=1.5" -[[package]] -name = "python-dotenv" -version = "1.0.1" -description = "Read key-value pairs from a .env file and set them as environment variables" -optional = false -python-versions = ">=3.8" -files = [ - {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, - {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, -] - -[package.extras] -cli = ["click (>=5.0)"] - [[package]] name = "python-ulid" version = "1.1.0" @@ -4321,7 +4256,6 @@ files = [ [package.dependencies] anyio = ">=3.4.0,<5" -typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] @@ -4456,7 +4390,7 @@ typer = ">=0.9.0,<0.10.0" name = "tokenizers" version = "0.19.1" description = "" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "tokenizers-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:952078130b3d101e05ecfc7fc3640282d74ed26bcf691400f872563fca15ac97"}, @@ -4854,20 +4788,6 @@ files = [ cryptography = ">=35.0.0" types-pyOpenSSL = "*" -[[package]] -name = "types-requests" -version = "2.31.0.6" -description = "Typing stubs for requests" -optional = true -python-versions = ">=3.7" -files = [ - {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, - {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, -] - -[package.dependencies] -types-urllib3 = "*" - [[package]] name = "types-requests" version = "2.32.0.20240523" @@ -4904,17 +4824,6 @@ files = [ {file = "types_tqdm-4.66.0.20240417-py3-none-any.whl", hash = "sha256:248aef1f9986b7b8c2c12b3cb4399fc17dba0a29e7e3f3f9cd704babb879383d"}, ] -[[package]] -name = "types-urllib3" -version = "1.26.25.14" -description = "Typing stubs for urllib3" -optional = true -python-versions = "*" -files = [ - {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, - {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, -] - [[package]] name = "typing-extensions" version = "4.12.0" @@ -4963,22 +4872,6 @@ files = [ {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, ] -[[package]] -name = "urllib3" -version = "1.26.18" -description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" -files = [ - {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, - {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, -] - -[package.extras] -brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] -secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] -socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] - [[package]] name = "urllib3" version = "2.2.1" @@ -5258,21 +5151,6 @@ files = [ idna = ">=2.0" multidict = ">=4.0" -[[package]] -name = "zipp" -version = "3.19.0" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.8" -files = [ - {file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"}, - {file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] - [extras] anthropic = ["anthropic"] chat = ["fastapi"] @@ -5283,5 +5161,5 @@ groq = ["groq"] [metadata] lock-version = "2.0" -python-versions = ">=3.9, <=3.12" -content-hash = "8ce0edb923a09683941fc9f9869e3322b21de89259e2615d28ba1f62dfbb9277" +python-versions = ">=3.10, <3.13" +content-hash = "ccc118453c73d909b3ea1bcb9c0a1d06b02894f89a0bda6097b4c76fef3bb1a5" diff --git a/pyproject.toml b/pyproject.toml index 102a4161c..b439c1135 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ together = "^0.2.4" pydantic = "1.10.12" beartype = "^0.14.0" langchain-openai = "^0.0.5" -litellm = "~1.23.12" # dependency versions for extras fastapi = { version = "^0.109.2", optional = true } diff --git a/sotopia/generation_utils/generate.py b/sotopia/generation_utils/generate.py index 7e8ec5e0c..36e77d07d 100644 --- a/sotopia/generation_utils/generate.py +++ b/sotopia/generation_utils/generate.py @@ -1,11 +1,12 @@ import logging +import os import re -from typing import Any, TypeVar +from typing import TypeVar import gin from beartype import beartype from beartype.typing import Type -from langchain.chains import LLMChain +from langchain.chains.llm import LLMChain from langchain.output_parsers import PydanticOutputParser from langchain.prompts import ( ChatPromptTemplate, @@ -13,7 +14,7 @@ PromptTemplate, ) from langchain.schema import BaseOutputParser, OutputParserException -from langchain_community.chat_models import ChatLiteLLM +from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field from rich import print from typing_extensions import Literal @@ -32,51 +33,24 @@ logging_handler = LoggingCallbackHandler("langchain") LLM_Name = Literal[ - "togethercomputer/llama-2-7b-chat", - "togethercomputer/llama-2-70b-chat", - "togethercomputer/mpt-30b-chat", + "together_ai/meta-llama/Llama-2-7b-chat-hf", + "together_ai/meta-llama/Llama-2-70b-chat-hf", + "together_ai/mistralai/Mixtral-8x22B-Instruct-v0.1", + "together_ai/meta-llama/Llama-3-8b-chat-hf", + "together_ai/meta-llama/Llama-3-70b-chat-hf", "gpt-3.5-turbo", "gpt-3.5-turbo-finetuned", "gpt-3.5-turbo-ft-MF", - "text-davinci-003", "gpt-4", "gpt-4-turbo", "human", "redis", - "mistralai/Mixtral-8x7B-Instruct-v0.1", - "together_ai/togethercomputer/llama-2-7b-chat", - "together_ai/togethercomputer/falcon-7b-instruct", "groq/llama3-70b-8192", ] OutputType = TypeVar("OutputType", bound=object) -class PatchedChatLiteLLM(ChatLiteLLM): - max_tokens: int | None = None # type: ignore - - @property - def _default_params(self) -> dict[str, Any]: - """Get the default parameters for calling OpenAI API.""" - set_model_value = self.model - if self.model_name is not None: - set_model_value = self.model_name - - params = { - "model": set_model_value, - "force_timeout": self.request_timeout, - "stream": self.streaming, - "n": self.n, - "temperature": self.temperature, - "custom_llm_provider": self.custom_llm_provider, - **self.model_kwargs, - } - if self.max_tokens is not None: - params["max_tokens"] = self.max_tokens - - return params - - class EnvResponse(BaseModel): reasoning: str = Field( description="first reiterate agents' social goals and then reason about what agents say/do and whether that aligns with their goals." @@ -326,17 +300,54 @@ def obtain_chain( Using langchain to sample profiles for participants """ model_name = _return_fixed_model_version(model_name) - chat = PatchedChatLiteLLM( - model=model_name, - temperature=temperature, - max_retries=max_retries, - ) - human_message_prompt = HumanMessagePromptTemplate( - prompt=PromptTemplate(template=template, input_variables=input_variables) - ) - chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) - chain = LLMChain(llm=chat, prompt=chat_prompt_template) - return chain + if "together_ai" in model_name: + model_name = "/".join(model_name.split("/")[1:]) + human_message_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template=template, + input_variables=input_variables, + ) + ) + chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) + chat_openai = ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_retries=max_retries, + openai_api_base="https://api.together.xyz/v1", + openai_api_key=os.environ.get("TOGETHER_API_KEY"), + ) + chain = LLMChain(llm=chat_openai, prompt=chat_prompt_template) + return chain + elif "groq" in model_name: + model_name = "/".join(model_name.split("/")[1:]) + human_message_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template=template, + input_variables=input_variables, + ) + ) + chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) + chat_openai = ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_retries=max_retries, + openai_api_base="https://api.groq.com/openai/v1", + openai_api_key=os.environ.get("GROQ_API_KEY"), + ) + chain = LLMChain(llm=chat_openai, prompt=chat_prompt_template) + return chain + else: + chat = ChatOpenAI( + model=model_name, + temperature=temperature, + max_retries=max_retries, + ) + human_message_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate(template=template, input_variables=input_variables) + ) + chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) + chain = LLMChain(llm=chat, prompt=chat_prompt_template) + return chain @beartype