Skip to content

Commit

Permalink
support for py3.8, use List for typing instead of list
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Apr 17, 2024
1 parent da4824c commit 6662c10
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
6 changes: 3 additions & 3 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, **kwargs: Any):
self._client = _NVIDIAClient(model=self.model)

@property
def available_models(self) -> list[Model]:
def available_models(self) -> List[Model]:
"""
Get a list of available models that work with ChatNVIDIA.
"""
Expand Down Expand Up @@ -115,7 +115,7 @@ def mode(
return self

# todo: batching when len(documents) > endpoint's max batch size
def _rank(self, documents: list[str], query: str) -> List[Ranking]:
def _rank(self, documents: List[str], query: str) -> List[Ranking]:
response = self._client.client.get_req(
model_name=self.model,
payload={
Expand Down Expand Up @@ -152,7 +152,7 @@ def compress_documents(
if len(documents) == 0 or self.top_n < 1:
return []

def batch(ls: list, size: int) -> Generator[list[Document], None, None]:
def batch(ls: list, size: int) -> Generator[List[Document], None, None]:
for i in range(0, len(ls), size):
yield ls[i : i + size]

Expand Down
23 changes: 12 additions & 11 deletions libs/ai-endpoints/tests/integration_tests/test_ranking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List

import faker
import pytest
Expand All @@ -12,7 +13,7 @@ class CharacterTextSplitter:
def __init__(self, chunk_size: int):
self.chunk_size = chunk_size

def create_documents(self, text: str) -> list[Document]:
def create_documents(self, text: str) -> List[Document]:
words = text.split()
chunks = []
for i in range(0, len(words), self.chunk_size):
Expand All @@ -39,12 +40,12 @@ def splitter() -> CharacterTextSplitter:


@pytest.fixture
def documents(text: str, splitter: CharacterTextSplitter) -> list[Document]:
def documents(text: str, splitter: CharacterTextSplitter) -> List[Document]:
return splitter.create_documents(text)


def test_langchain_reranker_direct(
query: str, documents: list[Document], rerank_model: str, mode: dict
query: str, documents: List[Document], rerank_model: str, mode: dict
) -> None:
ranker = NVIDIARerank(model=rerank_model).mode(**mode)
result_docs = ranker.compress_documents(documents=documents, query=query)
Expand All @@ -64,7 +65,7 @@ def test_langchain_reranker_direct_empty_docs(


def test_langchain_reranker_direct_top_n_negative(
query: str, documents: list[Document], rerank_model: str, mode: dict
query: str, documents: List[Document], rerank_model: str, mode: dict
) -> None:
orig = NVIDIARerank.Config.validate_assignment
NVIDIARerank.Config.validate_assignment = False
Expand All @@ -76,7 +77,7 @@ def test_langchain_reranker_direct_top_n_negative(


def test_langchain_reranker_direct_top_n_zero(
query: str, documents: list[Document], rerank_model: str, mode: dict
query: str, documents: List[Document], rerank_model: str, mode: dict
) -> None:
ranker = NVIDIARerank(model=rerank_model).mode(**mode)
ranker.top_n = 0
Expand All @@ -85,7 +86,7 @@ def test_langchain_reranker_direct_top_n_zero(


def test_langchain_reranker_direct_top_n_one(
query: str, documents: list[Document], rerank_model: str, mode: dict
query: str, documents: List[Document], rerank_model: str, mode: dict
) -> None:
ranker = NVIDIARerank(model=rerank_model).mode(**mode)
ranker.top_n = 1
Expand All @@ -94,7 +95,7 @@ def test_langchain_reranker_direct_top_n_one(


def test_langchain_reranker_direct_top_n_equal_len_docs(
query: str, documents: list[Document], rerank_model: str, mode: dict
query: str, documents: List[Document], rerank_model: str, mode: dict
) -> None:
ranker = NVIDIARerank(model=rerank_model).mode(**mode)
ranker.top_n = len(documents)
Expand All @@ -103,7 +104,7 @@ def test_langchain_reranker_direct_top_n_equal_len_docs(


def test_langchain_reranker_direct_top_n_greater_len_docs(
query: str, documents: list[Document], rerank_model: str, mode: dict
query: str, documents: List[Document], rerank_model: str, mode: dict
) -> None:
ranker = NVIDIARerank(model=rerank_model).mode(**mode)
ranker.top_n = len(documents) * 2
Expand Down Expand Up @@ -139,7 +140,7 @@ def test_rerank_invalid_top_n(rerank_model: str, mode: dict) -> None:
)
def test_rerank_batching(
query: str,
documents: list[Document],
documents: List[Document],
rerank_model: str,
mode: dict,
batch_size: int,
Expand Down Expand Up @@ -182,15 +183,15 @@ def test_rerank_batching(


def test_langchain_reranker_direct_endpoint_bogus(
query: str, documents: list[Document]
query: str, documents: List[Document]
) -> None:
ranker = NVIDIARerank().mode(mode="nim", base_url="bogus")
with pytest.raises(MissingSchema):
ranker.compress_documents(documents=documents, query=query)


def test_langchain_reranker_direct_endpoint_unavailable(
query: str, documents: list[Document]
query: str, documents: List[Document]
) -> None:
ranker = NVIDIARerank().mode(mode="nim", base_url="http://localhost:12321")
with pytest.raises(ConnectionError):
Expand Down

0 comments on commit 6662c10

Please sign in to comment.