Skip to content

Commit

Permalink
Made it easier to access LLM names
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Jan 16, 2024
1 parent 39ba98c commit 59ed8d3
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 8 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,25 +100,24 @@ docs = Docs(llm='gpt-3.5-turbo')
or you can use any other model available in [langchain](https://github.com/hwchase17/langchain):

```py
from paperqa import Docs, LangchainLLMModel
from paperqa import Docs
from langchain_community.chat_models import ChatAnthropic
docs = Docs(llm_model=LangchainLLMModel(),
docs = Docs(llm="langchain",
client=ChatAnthropic())
```

Note we split the model into `LangchainLLMModel` (always empty) and `client` which is `ChatAnthropic`. This is because `client` stores the non-pickleable part and langchain LLMs are only sometimes serializable/pickleable. The paper-qa `Docs` must always serializable. Thus, we split the model into two parts.
Note we split the model into the wrapper and `client`, which is `ChatAnthropic` here. This is because `client` stores the non-pickleable part and langchain LLMs are only sometimes serializable/pickleable. The paper-qa `Docs` must always serializable. Thus, we split the model into two parts.

```py
import pickle
docs = Docs(llm_model=LangchainLLMModel(),
docs = Docs(llm="langchain",
client=ChatAnthropic())
model_str = pickle.dumps(docs)
docs = pickle.loads(model_str)
# but you have to set the client after loading
docs.set_client(ChatAnthropic())
```


#### Locally Hosted

You can use llama.cpp to be the LLM. Note that you should be using relatively large models, because paper-qa requires following a lot of instructions. You won't get good performance with 7B models.
Expand Down
24 changes: 23 additions & 1 deletion paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def __init__(self, **data):
super().__init__(**data)
self._client = client
self._embedding_client = embedding_client
# run this here (instead of automateically) so it has access to privates
# If I ever figure out a better way of validating privates
# I can move this back to the decorator
Docs.make_llm_names_consistent(self)

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -136,7 +140,6 @@ def setup_alias_models(cls, data: Any) -> Any:
raise ValueError(
f"Could not guess embedding model type for {data['embedding']}. "
)

return data

@model_validator(mode="after")
Expand All @@ -157,6 +160,24 @@ def config_summary_llm_config(cls, data: Any) -> Any:
data.summary_llm_model = data.llm_model
return data

@classmethod
def make_llm_names_consistent(cls, data: Any) -> Any:
if isinstance(data, Docs):
data.llm = data.llm_model.name
if data.llm == "langchain":
# from langchain models - kind of hacky
# langchain models cannot know type until
# it sees client
data.llm_model.infer_llm_type(data._client)
data.llm = data.llm_model.name
if data.summary_llm_model is not None:
if data.summary_llm == "langchain":
# from langchain models - kind of hacky
data.summary_llm_model.infer_llm_type(data._client)
data.summary_llm = data.summary_llm_model.name

return data

def clear_docs(self):
self.texts = []
self.docs = {}
Expand Down Expand Up @@ -193,6 +214,7 @@ def set_client(
else:
embedding_client = AsyncOpenAI()
self._embedding_client = embedding_client
Docs.make_llm_names_consistent(self)

def _get_unique_name(self, docname: str) -> str:
"""Create a unique name given proposed name"""
Expand Down
12 changes: 12 additions & 0 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa

class LLMModel(ABC, BaseModel):
llm_type: str | None = None
name: str
model_config = ConfigDict(extra="forbid")

async def acomplete(self, client: Any, prompt: str) -> str:
Expand Down Expand Up @@ -208,6 +209,7 @@ async def execute(

class OpenAILLMModel(LLMModel):
config: dict = Field(default=dict(model="gpt-3.5-turbo", temperature=0.1))
name: str = "gpt-3.5-turbo"

def _check_client(self, client: Any) -> AsyncOpenAI:
if client is None:
Expand All @@ -227,6 +229,13 @@ def guess_llm_type(cls, data: Any) -> Any:
m.llm_type = guess_model_type(m.config["model"])
return m

@model_validator(mode="after")
@classmethod
def set_model_name(cls, data: Any) -> Any:
m = cast(OpenAILLMModel, data)
m.name = m.config["model"]
return m

async def acomplete(self, client: Any, prompt: str) -> str:
aclient = self._check_client(client)
completion = await aclient.completions.create(
Expand Down Expand Up @@ -428,9 +437,12 @@ async def similarity_search(
class LangchainLLMModel(LLMModel):
"""A wrapper around the wrapper langchain"""

name: str = "langchain"

def infer_llm_type(self, client: Any) -> str:
from langchain_core.language_models.chat_models import BaseChatModel

self.name = client.model_name
if isinstance(client, BaseChatModel):
return "chat"
return "completion"
Expand Down
13 changes: 11 additions & 2 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,15 @@ def accum(x):


def test_docs():
llm_config = dict(temperature=0.1, model="text-ada-001", model_type="completion")
docs = Docs(llm_model=OpenAILLMModel(config=llm_config))
docs = Docs(llm="babbage-002")
docs.add_url(
"https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
)
assert docs.docs["test"].docname == "Wiki2023"
assert docs.llm == "babbage-002"
assert docs.summary_llm == "babbage-002"


def test_evidence():
Expand Down Expand Up @@ -486,6 +487,8 @@ async def embed_documents(self, client, texts):

def test_custom_llm():
class MyLLM(LLMModel):
name: str = "myllm"

async def acomplete(self, client, prompt):
assert client is None
return "Echo"
Expand All @@ -502,6 +505,8 @@ async def acomplete(self, client, prompt):

def test_custom_llm_stream():
class MyLLM(LLMModel):
name: str = "myllm"

async def acomplete_iter(self, client, prompt):
assert client is None
yield "Echo"
Expand All @@ -522,6 +527,8 @@ def test_langchain_llm():
from langchain_openai import ChatOpenAI, OpenAI

docs = Docs(llm="langchain", client=ChatOpenAI(model="gpt-3.5-turbo"))
assert docs.llm == "gpt-3.5-turbo"
assert docs.summary_llm == "gpt-3.5-turbo"
docs.add_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="WikiMedia Foundation, 2023, Accessed now",
Expand Down Expand Up @@ -567,7 +574,9 @@ def test_langchain_llm():
docs_pickle = pickle.dumps(docs)
docs2 = pickle.loads(docs_pickle)
assert docs2._client is None
assert docs2.llm == "babbage-002"
docs2.set_client(OpenAI(model="babbage-002"))
assert docs2.summary_llm == "babbage-002"
docs2.get_evidence(
Answer(question="What is Frederick Bates's greatest accomplishment?"),
get_callbacks=lambda x: [lambda y: print(y)],
Expand Down

0 comments on commit 59ed8d3

Please sign in to comment.