diff --git a/milvus_model/__init__.py b/milvus_model/__init__.py index ca6d9c1..4f811c4 100644 --- a/milvus_model/__init__.py +++ b/milvus_model/__init__.py @@ -1,6 +1,6 @@ from . import dense, hybrid, sparse from .dense.sentence_transformer import SentenceTransformerEmbeddingFunction -__all__ = ["DefaultEmbeddingFunction", "dense", "sparse", "hybrid"] +__all__ = ["DefaultEmbeddingFunction", "dense", "sparse", "hybrid", "reranker"] DefaultEmbeddingFunction = SentenceTransformerEmbeddingFunction diff --git a/milvus_model/base.py b/milvus_model/base.py index 52e3c9f..0d6b689 100644 --- a/milvus_model/base.py +++ b/milvus_model/base.py @@ -3,7 +3,6 @@ class BaseEmbeddingFunction: - @abstractmethod def __call__(self, texts: List[str]): """ """ diff --git a/milvus_model/dense/__init__.py b/milvus_model/dense/__init__.py index 35106a4..c586b7e 100644 --- a/milvus_model/dense/__init__.py +++ b/milvus_model/dense/__init__.py @@ -1,7 +1,7 @@ +from .jinaai import JinaEmbeddingFunction from .openai import OpenAIEmbeddingFunction from .sentence_transformer import SentenceTransformerEmbeddingFunction from .voyageai import VoyageEmbeddingFunction -from .jinaai import JinaEmbeddingFunction __all__ = [ "OpenAIEmbeddingFunction", diff --git a/milvus_model/dense/jinaai.py b/milvus_model/dense/jinaai.py index 4e3315f..c458cee 100644 --- a/milvus_model/dense/jinaai.py +++ b/milvus_model/dense/jinaai.py @@ -1,8 +1,8 @@ import os -import requests from typing import List, Optional import numpy as np +import requests from milvus_model.base import BaseEmbeddingFunction @@ -10,16 +10,22 @@ class JinaEmbeddingFunction(BaseEmbeddingFunction): - def __init__(self, model_name: str = "jina-embeddings-v2-base-en", api_key: Optional[str] = None, **kwargs): + def __init__( + self, + model_name: str = "jina-embeddings-v2-base-en", + api_key: Optional[str] = None, + **kwargs, + ): if api_key is None: - if 'JINAAI_API_KEY' in os.environ and os.environ['JINAAI_API_KEY']: - self.api_key = os.environ['JINAAI_API_KEY'] + if "JINAAI_API_KEY" in os.environ and os.environ["JINAAI_API_KEY"]: + self.api_key = os.environ["JINAAI_API_KEY"] else: - raise ValueError( - f"Did not find api_key, please add an environment variable" - f" `JINAAI_API_KEY` which contains it, or pass" - f" `api_key` as a named parameter." + error_message = ( + "Did not find api_key, please add an environment variable" + " `JINAAI_API_KEY` which contains it, or pass" + " `api_key` as a named parameter." ) + raise ValueError(error_message) else: self.api_key = api_key self.model_name = model_name @@ -46,9 +52,8 @@ def __call__(self, texts: List[str]) -> List[np.array]: return self._call_jina_api(texts) def _call_jina_api(self, texts: List[str]): - resp = self._session.post( # type: ignore - API_URL, - json={"input": texts, "model": self.model_name}, + resp = self._session.post( # type: ignore[assignment] + API_URL, json={"input": texts, "model": self.model_name}, ).json() if "data" not in resp: raise RuntimeError(resp["detail"]) @@ -56,6 +61,5 @@ def _call_jina_api(self, texts: List[str]): embeddings = resp["data"] # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore - + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore[no-any-return] return [np.array(result["embedding"]) for result in sorted_embeddings] diff --git a/milvus_model/dense/sentence_transformer.py b/milvus_model/dense/sentence_transformer.py index 1804f8b..ffb9cee 100644 --- a/milvus_model/dense/sentence_transformer.py +++ b/milvus_model/dense/sentence_transformer.py @@ -35,10 +35,7 @@ def __call__(self, texts: List[str]) -> List[np.array]: def _encode(self, texts: List[str]) -> List[np.array]: embs = self.model.encode( - texts, - batch_size=self.batch_size, - show_progress_bar=False, - convert_to_numpy=True, + texts, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True, ) return list(embs) diff --git a/milvus_model/hybrid/bge_m3.py b/milvus_model/hybrid/bge_m3.py index 1cc63e7..3c91328 100644 --- a/milvus_model/hybrid/bge_m3.py +++ b/milvus_model/hybrid/bge_m3.py @@ -83,7 +83,7 @@ def _encode(self, texts: List[str]) -> Dict: row_indices = [0] * len(indices) csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim)) results["sparse"].append(csr) - results["sparse"] = vstack(results["sparse"]) + results["sparse"] = vstack(results["sparse"]).tocsr() if self._encode_config["return_colbert_vecs"] is True: results["colbert_vecs"] = output["colbert_vecs"] return results diff --git a/milvus_model/reranker/__init__.py b/milvus_model/reranker/__init__.py index 854f771..857e2f1 100644 --- a/milvus_model/reranker/__init__.py +++ b/milvus_model/reranker/__init__.py @@ -1,8 +1,8 @@ from .bgereranker import BGERerankFunction from .cohere import CohereRerankFunction from .cross_encoder import CrossEncoderRerankFunction -from .voyageai import VoyageRerankFunction from .jinaai import JinaRerankFunction +from .voyageai import VoyageRerankFunction __all__ = [ "CohereRerankFunction", diff --git a/milvus_model/reranker/jinaai.py b/milvus_model/reranker/jinaai.py index 28d6d3e..4c918e6 100644 --- a/milvus_model/reranker/jinaai.py +++ b/milvus_model/reranker/jinaai.py @@ -1,7 +1,8 @@ import os -import requests from typing import List, Optional +import requests + from milvus_model.base import BaseRerankFunction, RerankResult API_URL = "https://api.jina.ai/v1/rerank" @@ -10,14 +11,15 @@ class JinaRerankFunction(BaseRerankFunction): def __init__(self, model_name: str = "jina-reranker-v1-base-en", api_key: Optional[str] = None): if api_key is None: - if 'JINAAI_API_KEY' in os.environ and os.environ['JINAAI_API_KEY']: - self.api_key = os.environ['JINAAI_API_KEY'] + if "JINAAI_API_KEY" in os.environ and os.environ["JINAAI_API_KEY"]: + self.api_key = os.environ["JINAAI_API_KEY"] else: - raise ValueError( - f"Did not find api_key, please add an environment variable" - f" `JINAAI_API_KEY` which contains it, or pass" - f" `api_key` as a named parameter." + error_message = ( + "Did not find api_key, please add an environment variable" + " `JINAAI_API_KEY` which contains it, or pass" + " `api_key` as a named parameter." ) + raise ValueError(error_message) else: self.api_key = api_key self.model_name = model_name @@ -28,7 +30,7 @@ def __init__(self, model_name: str = "jina-reranker-v1-base-en", api_key: Option self.model_name = model_name def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: - resp = self._session.post( # type: ignore + resp = self._session.post( # type: ignore[assignment] API_URL, json={ "query": query, @@ -44,7 +46,7 @@ def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[Rer for res in resp["results"]: results.append( RerankResult( - text=res['document']['text'], score=res['relevance_score'], index=res['index'] + text=res["document"]["text"], score=res["relevance_score"], index=res["index"] ) ) return results diff --git a/milvus_model/sparse/bm25/bm25.py b/milvus_model/sparse/bm25/bm25.py index 0779e13..3823b51 100644 --- a/milvus_model/sparse/bm25/bm25.py +++ b/milvus_model/sparse/bm25/bm25.py @@ -160,7 +160,7 @@ def _encode_document(self, doc: str) -> csr_array: def encode_queries(self, queries: List[str]) -> csr_array: sparse_embs = [self._encode_query(query) for query in queries] - return vstack(sparse_embs) + return vstack(sparse_embs).tocsr() def __call__(self, texts: List[str]) -> csr_array: error_message = "Unsupported function called, please check the documentation of 'BM25EmbeddingFunction'." @@ -168,7 +168,7 @@ def __call__(self, texts: List[str]) -> csr_array: def encode_documents(self, documents: List[str]) -> csr_array: sparse_embs = [self._encode_document(document) for document in documents] - return vstack(sparse_embs) + return vstack(sparse_embs).tocsr() def save(self, path: str): bm25_params = {} diff --git a/milvus_model/sparse/bm25/tokenizers.py b/milvus_model/sparse/bm25/tokenizers.py index 9ffc4a1..fcf3f8d 100644 --- a/milvus_model/sparse/bm25/tokenizers.py +++ b/milvus_model/sparse/bm25/tokenizers.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Dict, List, Match, Optional, Type +import nltk import yaml from nltk import word_tokenize from nltk.corpus import stopwords @@ -73,6 +74,11 @@ def apply(self, tokens: List[str]): @register_class("StopwordFilter") class StopwordFilter(TextFilter): def __init__(self, language: str = "english", stopword_list: Optional[List[str]] = None): + try: + nltk.corpus.stopwords.words(language) + except LookupError: + nltk.download("stopwords") + if stopword_list is None: stopword_list = [] self.stopwords = set(stopwords.words(language) + stopword_list) @@ -176,7 +182,7 @@ def build_default_analyzer(language: str = "en"): def build_analyer_from_yaml(filepath: str, name: str): - with Path(filepath).open() as file: + with Path(filepath).open(encoding="utf-8") as file: config = yaml.safe_load(file) lang_config = config.get(name) diff --git a/milvus_model/sparse/splade.py b/milvus_model/sparse/splade.py index 69f0d3d..8371510 100644 --- a/milvus_model/sparse/splade.py +++ b/milvus_model/sparse/splade.py @@ -71,8 +71,7 @@ def __call__(self, texts: List[str]) -> csr_array: def encode_documents(self, documents: List[str]) -> csr_array: return self._encode( - [self.doc_instruction + document for document in documents], - self.k_tokens_document, + [self.doc_instruction + document for document in documents], self.k_tokens_document, ) def _encode(self, texts: List[str], k_tokens: int) -> csr_array: @@ -80,8 +79,7 @@ def _encode(self, texts: List[str], k_tokens: int) -> csr_array: def encode_queries(self, queries: List[str]) -> csr_array: return self._encode( - [self.query_instruction + query for query in queries], - self.k_tokens_query, + [self.query_instruction + query for query in queries], self.k_tokens_query, ) @property @@ -130,26 +128,26 @@ def _encode(self, texts: List[str]): padding=True, ) encoded_input = {key: val.to(self.device) for key, val in encoded_input.items()} - with torch.no_grad(): - output = self.model(**encoded_input) + output = self.model(**encoded_input) return output.logits def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]: return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)] def forward(self, texts: List[str], k_tokens: int) -> csr_array: - batched_texts = self._batchify(texts, self.batch_size) - sparse_embs = [] - for batch_texts in batched_texts: - logits = self._encode(texts=batch_texts) - activations = self._get_activation(logits=logits) - if k_tokens is None: - nonzero_indices = torch.nonzero(activations["sparse_activations"]) - activations["activations"] = nonzero_indices - else: - activations = self._update_activations(**activations, k_tokens=k_tokens) - batch_csr = self._convert_to_csr_array(activations) - sparse_embs.extend(batch_csr) + with torch.no_grad(): + batched_texts = self._batchify(texts, self.batch_size) + sparse_embs = [] + for batch_texts in batched_texts: + logits = self._encode(texts=batch_texts) + activations = self._get_activation(logits=logits) + if k_tokens is None: + nonzero_indices = torch.nonzero(activations["sparse_activations"]) + activations["activations"] = nonzero_indices + else: + activations = self._update_activations(**activations, k_tokens=k_tokens) + batch_csr = self._convert_to_csr_array(activations) + sparse_embs.extend(batch_csr) return vstack(sparse_embs).tocsr() diff --git a/pyproject.toml b/pyproject.toml index 916d5f9..f7417ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "konlpy", "mecab-python3", "scipy >= 1.10.0", - "protobuf==3.20.0", + "protobuf==3.20.2", "unidic-lite", "cohere", "voyageai >= 0.2.0",