Skip to content

Commit

Permalink
add with_structured_output support for Pydantic models, dicts and Enu…
Browse files Browse the repository at this point in the history
…ms (only include_raw=False)
  • Loading branch information
mattf committed Jul 24, 2024
1 parent c3745e0 commit cacc739
Show file tree
Hide file tree
Showing 8 changed files with 716 additions and 13 deletions.
113 changes: 113 additions & 0 deletions libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,119 @@
"source": [
"See [How to use chat models to call tools](https://python.langchain.com/v0.2/docs/how_to/tool_calling/) for additional examples."
]
},
{
"cell_type": "markdown",
"id": "8d249662",
"metadata": {},
"source": [
"## Structured output\n",
"\n",
"Starting in v0.2.1, `ChatNVIDIA` supports [with_structured_output](https://api.python.langchain.com/en/latest/language_models/langchain_core.language_models.chat_models.BaseChatModel.html#langchain_core.language_models.chat_models.BaseChatModel.with_structured_output).\n",
"\n",
"`ChatNVIDIA` provides integration with the variety of models on [build.nvidia.com](https://build.nvidia.com) as well as local NIMs. Not all these model endpoints implement the structured output features. Be sure to select a model that does have structured output features for your experimention and applications.\n",
"\n",
"Note: `include_raw` is not supported. You can get raw output from your LLM and use a [PydanticOutputParser](https://python.langchain.com/v0.2/docs/how_to/structured_output/#using-pydanticoutputparser) or [JsonOutputParser](https://python.langchain.com/v0.2/docs/how_to/output_parser_json/#without-pydantic)."
]
},
{
"cell_type": "markdown",
"id": "a94e0e69",
"metadata": {},
"source": [
"You can get a list of models that are known to support structured output with,"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0515f558",
"metadata": {},
"outputs": [],
"source": [
"from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
"structured_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_structured_output]\n",
"structured_models"
]
},
{
"cell_type": "markdown",
"id": "21e56187",
"metadata": {},
"source": [
"### Pydantic style"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "482c37e8",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"class Person(BaseModel):\n",
" first_name: str = Field(..., description=\"The person's first name.\")\n",
" last_name: str = Field(..., description=\"The person's last name.\")\n",
"\n",
"llm = ChatNVIDIA(model=structured_models[0].id).with_structured_output(Person)\n",
"response = llm.invoke(\"Who is Michael Jeffrey Jordon?\")\n",
"response"
]
},
{
"cell_type": "markdown",
"id": "a25ce43f",
"metadata": {},
"source": [
"### Enum style"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f802912",
"metadata": {},
"outputs": [],
"source": [
"from enum import Enum\n",
"\n",
"class Choices(Enum):\n",
" A = \"A\"\n",
" B = \"B\"\n",
" C = \"C\"\n",
"\n",
"llm = ChatNVIDIA(model=structured_models[2].id).with_structured_output(Choices)\n",
"response = llm.invoke(\"\"\"\n",
" What does 1+1 equal?\n",
" A. -100\n",
" B. 2\n",
" C. doorstop\n",
" \"\"\"\n",
")\n",
"response"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02b7ef29",
"metadata": {},
"outputs": [],
"source": [
"model = structured_models[3].id\n",
"llm = ChatNVIDIA(model=model).with_structured_output(Choices)\n",
"print(model)\n",
"response = llm.invoke(\"\"\"\n",
" What does 1+1 equal?\n",
" A. -100\n",
" B. 2\n",
" C. doorstop\n",
" \"\"\"\n",
")\n",
"response"
]
}
],
"metadata": {
Expand Down
6 changes: 6 additions & 0 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Model(BaseModel):
endpoint: custom endpoint for the model
aliases: list of aliases for the model
supports_tools: whether the model supports tool calling
supports_structured_output: whether the model supports structured output
All aliases are deprecated and will trigger a warning when used.
"""
Expand All @@ -28,6 +29,7 @@ class Model(BaseModel):
endpoint: Optional[str] = None
aliases: Optional[list] = None
supports_tools: Optional[bool] = False
supports_structured_output: Optional[bool] = False
base_model: Optional[str] = None

def __hash__(self) -> int:
Expand Down Expand Up @@ -284,24 +286,28 @@ def validate_client(cls, client: str, values: dict) -> str:
id="nv-mistralai/mistral-nemo-12b-instruct",
model_type="chat",
client="ChatNVIDIA",
supports_structured_output=True,
),
"meta/llama-3.1-8b-instruct": Model(
id="meta/llama-3.1-8b-instruct",
model_type="chat",
client="ChatNVIDIA",
supports_tools=True,
supports_structured_output=True,
),
"meta/llama-3.1-70b-instruct": Model(
id="meta/llama-3.1-70b-instruct",
model_type="chat",
client="ChatNVIDIA",
supports_tools=True,
supports_structured_output=True,
),
"meta/llama-3.1-405b-instruct": Model(
id="meta/llama-3.1-405b-instruct",
model_type="chat",
client="ChatNVIDIA",
supports_tools=True,
supports_structured_output=True,
),
}

Expand Down
Loading

0 comments on commit cacc739

Please sign in to comment.