Skip to content

Commit

Permalink
mark nvolveqa_40k as deprecated, use ai-embed-qa-4 instead
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Apr 15, 2024
1 parent 7682885 commit 9b81a30
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
14 changes: 13 additions & 1 deletion libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Embeddings Components Derived from NVEModel/Embeddings"""

import warnings
from typing import List, Literal, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.pydantic_v1 import Field
from langchain_core.pydantic_v1 import Field, validator

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints.callbacks import usage_callback_var
Expand All @@ -25,6 +26,17 @@ class NVIDIAEmbeddings(_NVIDIAClient, Embeddings):
None, description="The type of text to be embedded."
)

# todo: fix _NVIDIAClient.validate_client and enable Config.validate_assignment
@validator("model")
def deprecated_nvolveqa_40k(cls, value: str) -> str:
"""Deprecate the nvolveqa_40k model."""
if value == "nvolveqa_40k" or value == "playground_nvolveqa_40k":
warnings.warn(
"nvolveqa_40k is deprecated. Use ai-embed-qa-4 instead.",
DeprecationWarning,
)
return value

def _embed(
self, texts: List[str], model_type: Literal["passage", "query"]
) -> List[List[float]]:
Expand Down
11 changes: 11 additions & 0 deletions libs/ai-endpoints/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Generator

import pytest
Expand Down Expand Up @@ -82,5 +83,15 @@ def test_embed_documents_negative_input_list_mixed(embedding: NVIDIAEmbeddings)
embedding.embed_documents(documents) # type: ignore


def test_embed_deprecated_nvolvqa_40k() -> None:
with warnings.catch_warnings():
warnings.simplefilter("error")
NVIDIAEmbeddings()
with pytest.deprecated_call():
NVIDIAEmbeddings(model="nvolveqa_40k")
with pytest.deprecated_call():
NVIDIAEmbeddings(model="playground_nvolveqa_40k")


# todo: test max_length (-100, 0, 100)
# todo: test max_batch_size (-50, 0, 1, 50)

0 comments on commit 9b81a30

Please sign in to comment.