-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LangChain Integration #60
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments. In case it's useful, here there is some code I used once when dealing with the LangChain API.
libs/manubot_ai_editor/models.py
Outdated
if self.endpoint == "edits": | ||
# FIXME: what's the "edits" equivalent in langchain? | ||
client_cls = OpenAI | ||
elif self.endpoint == "chat": | ||
client_cls = ChatOpenAI | ||
else: | ||
client_cls = OpenAI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to take care of this anymore. Before, there were a "completion" and "edits" endpoints, but now we only have a "chat" endpoint I believe. Let's research a little bit, but I think we only need the ChatOpenAI
class here.
libs/manubot_ai_editor/models.py
Outdated
# FIXME: 'params' contains a lot of fields that we're not | ||
# currently passing to the langchain client. i need to figure | ||
# out where they're supposed to be given, e.g. in the client | ||
# init or with each request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are those fields in params
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at it again, "a lot" is an overstatement, sorry. On top of the model_parameters
dict that gets merged into it and aside from prompt
(or the other variants based on whether it's a "chat" or "edits" model) GPT3CompletionModel.get_params()
introduces just:
n
: I assume this is the number of responses you want the API to generate- it seems that it's always 1, and it LangChain's
invoke()
returns a single response anyway, so I assume we can ignore this one
- it seems that it's always 1, and it LangChain's
stop
: despite beingNone
all the time and probably not necessary to include ininvoke()
- this one's easy to integrate, since
invoke()
takesstop
as an argument; I'll just go ahead and add it
- this one's easy to integrate, since
max_tokens
: it seems this is taken at client initialization in LangChain- I'll see if there's a way to provide it for each
invoke()
call, or to change its value prior to the call
- I'll see if there's a way to provide it for each
Correct me if I'm wrong, but since model_parameters
is already used to initialize the client and since AFAICT it's not changed after that, I don't think we need to include its contents in invoke()
.
I'll go ahead and make the other changes, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I didn't forget what the code does, the only field that should go in each request/invoke (instead of using them to initialize the client) is max_tokens
, because for each paragraph we restrict the model to generate up to twice (or so) the number of tokens in the input paragraph. So that should go into each request, not the client (or update the client before each request?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, after I made the comment above I discovered that invoke()
does take max_tokens
as well as stop
; I've added it in my most recent commits. I assume we still don't need to change n
from 1, which AFAICT is the default for invoke()
as well, so I left that out of the call to invoke()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I left a few comments where I thought improvements could be made. I'm less familiar with how this might operate in the context of other code in the project so my comments might miss the mark. If there's a more specific area of focus I can give just let me know; happy to give things another look.
In addition to the individual comments I wondered: "how does pytest --runcost
work?" (mentioned in the PR description). Consider adding this to the documentation somewhere, perhaps in the readme or another spot.
libs/manubot_ai_editor/models.py
Outdated
@@ -253,6 +255,22 @@ def __init__( | |||
|
|||
self.several_spaces_pattern = re.compile(r"\s+") | |||
|
|||
if self.endpoint == "edits": | |||
# FIXME: what's the "edits" equivalent in langchain? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider moving this FIXME
to a GitHub issue (if it's not already) which is more actionable and may be less prone to being forgotten. This comment also applies to other locations where this pattern is found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, and fair that adding "FIXME"s at all runs the risk of them being introduced into merged code. My intent here was to get this FIXME figured out within the scope of this PR, which is why I didn't create an issue for it, but I'll think more on not adding FIXMEs and instead communicating questions some other way (review comments, perhaps?)
libs/manubot_ai_editor/models.py
Outdated
@@ -253,6 +255,22 @@ def __init__( | |||
|
|||
self.several_spaces_pattern = re.compile(r"\s+") | |||
|
|||
if self.endpoint == "edits": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider documenting class attributes in the docstring for the class to help identify what functionality they're associated with. As I read through this I wondered "what does self.endpoint
do; how might it matter later?" and couldn't find much human-readable form on this topic. It could be that I'm missing fundamental common knowledge about how this works - if so, please don't hesitate to link to the appropriate location.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for pointing this out; I've created an issue to address filling these gaps in the documentation, #68.
completions = openai.Completion.create(**params) | ||
# map the prompt to langchain's prompt types, based on what | ||
# kind of endpoint we're using | ||
if "messages" in params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit outside the PR scope but adding as this is a fresh read of the code and I'm less familiar with how params
are used. I noticed the docstring doesn't match the method parameters. Consider updating this when there's a chance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking we'll do a comprehensive review of the docstrings for the PR that addresses issue #68, but in this PR I've attempted to add some documentation to the GPT3CompletionModel.get_params()
method to address this gap.
libs/manubot_ai_editor/models.py
Outdated
# based on the 'role' field | ||
prompt = [ | ||
HumanMessage(content=msg["content"]) | ||
if msg["role"] == "user" else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might need formatting corrections applied via Black (I tested using the existing .pre-commit-config.yaml
).
This (work-in-progress) PR changes the
GPT3CompletionModel
from using theopenai
package directly for communicating with the OpenAI API to usinglangchain-openai
, which wraps theopenai
package.Tests have been updated and should work with LangChain. Executing
pytest --runcost
will actually query the OpenAI API, so that should be considered a good test that the changeover to the new package is working.There's still many things missing (e.g.. mapping all the
openai
params to LangChain equivalents), which is why this PR is a draft, but I thought I'd push it early and get comments as we incrementally make the change to LangChain.