Skip to content

Commit

Permalink
langchain[patch]: infer mistral provider in init_chat_model (#26557)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Sep 17, 2024
1 parent 31f61d4 commit d8952b8
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def init_chat_model(
- google_vertexai (langchain-google-vertexai)
- google_genai (langchain-google-genai)
- bedrock (langchain-aws)
- bedrock_converse (langchain-aws)
- cohere (langchain-cohere)
- fireworks (langchain-fireworks)
- together (langchain-together)
Expand All @@ -120,12 +121,13 @@ def init_chat_model(
Will attempt to infer model_provider from model if not specified. The
following providers will be inferred based on these model prefixes:
- gpt-3... or gpt-4... -> openai
- gpt-3..., gpt-4..., or o1... -> openai
- claude... -> anthropic
- amazon.... -> bedrock
- gemini... -> google_vertexai
- command... -> cohere
- accounts/fireworks... -> fireworks
- mistral... -> mistralai
configurable_fields: Which model parameters are
configurable:
Expand Down Expand Up @@ -276,8 +278,13 @@ class GetPopulation(BaseModel):
.. versionchanged:: 0.2.12
Support for Ollama via langchain-ollama package added. Previously
langchain-community version of Ollama (now deprecated) was installed by default.
Support for ChatOllama via langchain-ollama package added
(langchain_ollama.ChatOllama). Previously,
the now-deprecated langchain-community version of Ollama was imported
(langchain_community.chat_models.ChatOllama).
Support for langchain_aws.ChatBedrockConverse added
(model_provider="bedrock_converse").
""" # noqa: E501
if not model and not configurable_fields:
Expand Down Expand Up @@ -424,7 +431,7 @@ def _init_chat_model_helper(


def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
if model_name.startswith("gpt-3") or model_name.startswith("gpt-4"):
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1")):
return "openai"
elif model_name.startswith("claude"):
return "anthropic"
Expand All @@ -436,6 +443,8 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
return "google_vertexai"
elif model_name.startswith("amazon."):
return "bedrock"
elif model_name.startswith("mistral"):
return "mistralai"
else:
return None

Expand Down

0 comments on commit d8952b8

Please sign in to comment.