Skip to content

Commit

Permalink
mark all AI Foundation Models as deprecated, use recommended alternat…
Browse files Browse the repository at this point in the history
…ives instead
  • Loading branch information
mattf committed Apr 15, 2024
1 parent 9b81a30 commit b73bb87
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 13 deletions.
72 changes: 60 additions & 12 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,80 @@ class Model(BaseModel):
MODEL_SPECS = {
"playground_smaug_72b": {"model_type": "chat", "api_type": "aifm"},
"playground_kosmos_2": {"model_type": "image_in", "api_type": "aifm"},
"playground_llama2_70b": {"model_type": "chat", "api_type": "aifm"},
"playground_llama2_70b": {
"model_type": "chat",
"api_type": "aifm",
"alternative": "ai-llama2-70b",
},
"playground_nvolveqa_40k": {"model_type": "embedding", "api_type": "aifm"},
"playground_nemotron_qa_8b": {"model_type": "qa", "api_type": "aifm"},
"playground_gemma_7b": {"model_type": "chat", "api_type": "aifm"},
"playground_mistral_7b": {"model_type": "chat", "api_type": "aifm"},
"playground_gemma_7b": {
"model_type": "chat",
"api_type": "aifm",
"alternative": "ai-gemma-7b",
},
"playground_mistral_7b": {
"model_type": "chat",
"api_type": "aifm",
"alternative": "ai-mistral-7b-instruct-v2",
},
"playground_mamba_chat": {"model_type": "chat", "api_type": "aifm"},
"playground_phi2": {"model_type": "chat", "api_type": "aifm"},
"playground_sdxl": {"model_type": "image_out", "api_type": "aifm"},
"playground_nv_llama2_rlhf_70b": {"model_type": "chat", "api_type": "aifm"},
"playground_neva_22b": {"model_type": "image_in", "api_type": "aifm"},
"playground_neva_22b": {
"model_type": "image_in",
"api_type": "aifm",
"alternative": "ai-neva-22b",
},
"playground_yi_34b": {"model_type": "chat", "api_type": "aifm"},
"playground_nemotron_steerlm_8b": {"model_type": "chat", "api_type": "aifm"},
"playground_cuopt": {"model_type": "cuopt", "api_type": "aifm"},
"playground_llama_guard": {"model_type": "classifier", "api_type": "aifm"},
"playground_starcoder2_15b": {"model_type": "completion", "api_type": "aifm"},
"playground_deplot": {"model_type": "image_in", "api_type": "aifm"},
"playground_llama2_code_70b": {"model_type": "chat", "api_type": "aifm"},
"playground_gemma_2b": {"model_type": "chat", "api_type": "aifm"},
"playground_deplot": {
"model_type": "image_in",
"api_type": "aifm",
"alternative": "ai-google-deplot",
},
"playground_llama2_code_70b": {
"model_type": "chat",
"api_type": "aifm",
"alternative": "ai-codellama-70b",
},
"playground_gemma_2b": {
"model_type": "chat",
"api_type": "aifm",
"alternative": "ai-gemma-2b",
},
"playground_seamless": {"model_type": "translation", "api_type": "aifm"},
"playground_mixtral_8x7b": {"model_type": "chat", "api_type": "aifm"},
"playground_fuyu_8b": {"model_type": "image_in", "api_type": "aifm"},
"playground_llama2_code_34b": {"model_type": "chat", "api_type": "aifm"},
"playground_llama2_code_13b": {"model_type": "chat", "api_type": "aifm"},
"playground_mixtral_8x7b": {
"model_type": "chat",
"api_type": "aifm",
"alternative": "ai-mixtral-8x7b-instruct",
},
"playground_fuyu_8b": {
"model_type": "image_in",
"api_type": "aifm",
"alternative": "ai-fuyu-8b",
},
"playground_llama2_code_34b": {
"model_type": "chat",
"api_type": "aifm",
"alternative": "ai-codellama-70b",
},
"playground_llama2_code_13b": {
"model_type": "chat",
"api_type": "aifm",
"alternative": "ai-codellama-70b",
},
"playground_steerlm_llama_70b": {"model_type": "chat", "api_type": "aifm"},
"playground_clip": {"model_type": "similarity", "api_type": "aifm"},
"playground_llama2_13b": {"model_type": "chat", "api_type": "aifm"},
"playground_llama2_13b": {
"model_type": "chat",
"api_type": "aifm",
"alternative": "ai-llama2-70b",
},
}

MODEL_SPECS.update(
Expand Down
18 changes: 17 additions & 1 deletion libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import sys
import urllib.parse
import warnings
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -39,12 +40,13 @@
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool

from langchain_nvidia_ai_endpoints import _common as nvidia_ai_endpoints
from langchain_nvidia_ai_endpoints._statics import MODEL_SPECS

_CallbackManager = Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
_DictOrPydanticClass = Union[Dict[str, Any], Type[BaseModel]]
Expand Down Expand Up @@ -140,6 +142,20 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, BaseChatModel):
labels: Optional[Dict[str, float]] = Field(description="Steering parameters")
streaming: bool = Field(True)

@validator("model")
def aifm_deprecated(cls, value: str) -> str:
"""All AI Foundataion Models are deprecate, use API Catalog models instead."""
for model in [value, f"playground_{value}"]:
if model in MODEL_SPECS and MODEL_SPECS[model].get("api_type") == "aifm":
alternative = MODEL_SPECS[model].get(
"alternative", ChatNVIDIA._default_model
)
warnings.warn(
f"{value} is deprecated. Try {alternative} instead.",
DeprecationWarning,
)
return value

@property
def _llm_type(self) -> str:
"""Return type of NVIDIA AI Foundation Model Interface."""
Expand Down
25 changes: 25 additions & 0 deletions libs/ai-endpoints/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Test chat model integration."""


import warnings

import pytest

from langchain_nvidia_ai_endpoints._statics import MODEL_SPECS
from langchain_nvidia_ai_endpoints.chat_models import ChatNVIDIA


Expand All @@ -14,3 +19,23 @@ def test_integration_initialization() -> None:
max_tokens=50,
)
ChatNVIDIA(model="mistral", nvidia_api_key="nvapi-...")


@pytest.mark.parametrize(
"model",
[
name
for pair in [
(model, model.replace("playground_", ""))
for model, config in MODEL_SPECS.items()
if "api_type" in config and config["api_type"] == "aifm"
]
for name in pair
],
)
def test_aifm_deprecated(model: str) -> None:
with warnings.catch_warnings():
warnings.simplefilter("error")
ChatNVIDIA()
with pytest.deprecated_call():
ChatNVIDIA(model=model)

0 comments on commit b73bb87

Please sign in to comment.