Skip to content

Commit

Permalink
fix: Bug fixes for existing issues.(#18)
Browse files Browse the repository at this point in the history
Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored May 10, 2024
1 parent f7d43b2 commit 7cd6101
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 53 deletions.
2 changes: 1 addition & 1 deletion milvus_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import dense, hybrid, sparse
from .dense.sentence_transformer import SentenceTransformerEmbeddingFunction

__all__ = ["DefaultEmbeddingFunction", "dense", "sparse", "hybrid"]
__all__ = ["DefaultEmbeddingFunction", "dense", "sparse", "hybrid", "reranker"]

DefaultEmbeddingFunction = SentenceTransformerEmbeddingFunction
1 change: 0 additions & 1 deletion milvus_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class BaseEmbeddingFunction:

@abstractmethod
def __call__(self, texts: List[str]):
""" """
Expand Down
2 changes: 1 addition & 1 deletion milvus_model/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .jinaai import JinaEmbeddingFunction
from .openai import OpenAIEmbeddingFunction
from .sentence_transformer import SentenceTransformerEmbeddingFunction
from .voyageai import VoyageEmbeddingFunction
from .jinaai import JinaEmbeddingFunction

__all__ = [
"OpenAIEmbeddingFunction",
Expand Down
30 changes: 17 additions & 13 deletions milvus_model/dense/jinaai.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
import os
import requests
from typing import List, Optional

import numpy as np
import requests

from milvus_model.base import BaseEmbeddingFunction

API_URL = "https://api.jina.ai/v1/embeddings"


class JinaEmbeddingFunction(BaseEmbeddingFunction):
def __init__(self, model_name: str = "jina-embeddings-v2-base-en", api_key: Optional[str] = None, **kwargs):
def __init__(
self,
model_name: str = "jina-embeddings-v2-base-en",
api_key: Optional[str] = None,
**kwargs,
):
if api_key is None:
if 'JINAAI_API_KEY' in os.environ and os.environ['JINAAI_API_KEY']:
self.api_key = os.environ['JINAAI_API_KEY']
if "JINAAI_API_KEY" in os.environ and os.environ["JINAAI_API_KEY"]:
self.api_key = os.environ["JINAAI_API_KEY"]
else:
raise ValueError(
f"Did not find api_key, please add an environment variable"
f" `JINAAI_API_KEY` which contains it, or pass"
f" `api_key` as a named parameter."
error_message = (
"Did not find api_key, please add an environment variable"
" `JINAAI_API_KEY` which contains it, or pass"
" `api_key` as a named parameter."
)
raise ValueError(error_message)
else:
self.api_key = api_key
self.model_name = model_name
Expand All @@ -46,16 +52,14 @@ def __call__(self, texts: List[str]) -> List[np.array]:
return self._call_jina_api(texts)

def _call_jina_api(self, texts: List[str]):
resp = self._session.post( # type: ignore
API_URL,
json={"input": texts, "model": self.model_name},
resp = self._session.post( # type: ignore[assignment]
API_URL, json={"input": texts, "model": self.model_name},
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])

embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore

sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore[no-any-return]
return [np.array(result["embedding"]) for result in sorted_embeddings]
5 changes: 1 addition & 4 deletions milvus_model/dense/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def __call__(self, texts: List[str]) -> List[np.array]:

def _encode(self, texts: List[str]) -> List[np.array]:
embs = self.model.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=False,
convert_to_numpy=True,
texts, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True,
)
return list(embs)

Expand Down
2 changes: 1 addition & 1 deletion milvus_model/hybrid/bge_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _encode(self, texts: List[str]) -> Dict:
row_indices = [0] * len(indices)
csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim))
results["sparse"].append(csr)
results["sparse"] = vstack(results["sparse"])
results["sparse"] = vstack(results["sparse"]).tocsr()
if self._encode_config["return_colbert_vecs"] is True:
results["colbert_vecs"] = output["colbert_vecs"]
return results
Expand Down
2 changes: 1 addition & 1 deletion milvus_model/reranker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .bgereranker import BGERerankFunction
from .cohere import CohereRerankFunction
from .cross_encoder import CrossEncoderRerankFunction
from .voyageai import VoyageRerankFunction
from .jinaai import JinaRerankFunction
from .voyageai import VoyageRerankFunction

__all__ = [
"CohereRerankFunction",
Expand Down
20 changes: 11 additions & 9 deletions milvus_model/reranker/jinaai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import requests
from typing import List, Optional

import requests

from milvus_model.base import BaseRerankFunction, RerankResult

API_URL = "https://api.jina.ai/v1/rerank"
Expand All @@ -10,14 +11,15 @@
class JinaRerankFunction(BaseRerankFunction):
def __init__(self, model_name: str = "jina-reranker-v1-base-en", api_key: Optional[str] = None):
if api_key is None:
if 'JINAAI_API_KEY' in os.environ and os.environ['JINAAI_API_KEY']:
self.api_key = os.environ['JINAAI_API_KEY']
if "JINAAI_API_KEY" in os.environ and os.environ["JINAAI_API_KEY"]:
self.api_key = os.environ["JINAAI_API_KEY"]
else:
raise ValueError(
f"Did not find api_key, please add an environment variable"
f" `JINAAI_API_KEY` which contains it, or pass"
f" `api_key` as a named parameter."
error_message = (
"Did not find api_key, please add an environment variable"
" `JINAAI_API_KEY` which contains it, or pass"
" `api_key` as a named parameter."
)
raise ValueError(error_message)
else:
self.api_key = api_key
self.model_name = model_name
Expand All @@ -28,7 +30,7 @@ def __init__(self, model_name: str = "jina-reranker-v1-base-en", api_key: Option
self.model_name = model_name

def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
resp = self._session.post( # type: ignore
resp = self._session.post( # type: ignore[assignment]
API_URL,
json={
"query": query,
Expand All @@ -44,7 +46,7 @@ def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[Rer
for res in resp["results"]:
results.append(
RerankResult(
text=res['document']['text'], score=res['relevance_score'], index=res['index']
text=res["document"]["text"], score=res["relevance_score"], index=res["index"]
)
)
return results
4 changes: 2 additions & 2 deletions milvus_model/sparse/bm25/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ def _encode_document(self, doc: str) -> csr_array:

def encode_queries(self, queries: List[str]) -> csr_array:
sparse_embs = [self._encode_query(query) for query in queries]
return vstack(sparse_embs)
return vstack(sparse_embs).tocsr()

def __call__(self, texts: List[str]) -> csr_array:
error_message = "Unsupported function called, please check the documentation of 'BM25EmbeddingFunction'."
raise ValueError(error_message)

def encode_documents(self, documents: List[str]) -> csr_array:
sparse_embs = [self._encode_document(document) for document in documents]
return vstack(sparse_embs)
return vstack(sparse_embs).tocsr()

def save(self, path: str):
bm25_params = {}
Expand Down
8 changes: 7 additions & 1 deletion milvus_model/sparse/bm25/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, Dict, List, Match, Optional, Type

import nltk
import yaml
from nltk import word_tokenize
from nltk.corpus import stopwords
Expand Down Expand Up @@ -73,6 +74,11 @@ def apply(self, tokens: List[str]):
@register_class("StopwordFilter")
class StopwordFilter(TextFilter):
def __init__(self, language: str = "english", stopword_list: Optional[List[str]] = None):
try:
nltk.corpus.stopwords.words(language)
except LookupError:
nltk.download("stopwords")

if stopword_list is None:
stopword_list = []
self.stopwords = set(stopwords.words(language) + stopword_list)
Expand Down Expand Up @@ -176,7 +182,7 @@ def build_default_analyzer(language: str = "en"):


def build_analyer_from_yaml(filepath: str, name: str):
with Path(filepath).open() as file:
with Path(filepath).open(encoding="utf-8") as file:
config = yaml.safe_load(file)

lang_config = config.get(name)
Expand Down
34 changes: 16 additions & 18 deletions milvus_model/sparse/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,15 @@ def __call__(self, texts: List[str]) -> csr_array:

def encode_documents(self, documents: List[str]) -> csr_array:
return self._encode(
[self.doc_instruction + document for document in documents],
self.k_tokens_document,
[self.doc_instruction + document for document in documents], self.k_tokens_document,
)

def _encode(self, texts: List[str], k_tokens: int) -> csr_array:
return self.model.forward(texts, k_tokens=k_tokens)

def encode_queries(self, queries: List[str]) -> csr_array:
return self._encode(
[self.query_instruction + query for query in queries],
self.k_tokens_query,
[self.query_instruction + query for query in queries], self.k_tokens_query,
)

@property
Expand Down Expand Up @@ -130,26 +128,26 @@ def _encode(self, texts: List[str]):
padding=True,
)
encoded_input = {key: val.to(self.device) for key, val in encoded_input.items()}
with torch.no_grad():
output = self.model(**encoded_input)
output = self.model(**encoded_input)
return output.logits

def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]

def forward(self, texts: List[str], k_tokens: int) -> csr_array:
batched_texts = self._batchify(texts, self.batch_size)
sparse_embs = []
for batch_texts in batched_texts:
logits = self._encode(texts=batch_texts)
activations = self._get_activation(logits=logits)
if k_tokens is None:
nonzero_indices = torch.nonzero(activations["sparse_activations"])
activations["activations"] = nonzero_indices
else:
activations = self._update_activations(**activations, k_tokens=k_tokens)
batch_csr = self._convert_to_csr_array(activations)
sparse_embs.extend(batch_csr)
with torch.no_grad():
batched_texts = self._batchify(texts, self.batch_size)
sparse_embs = []
for batch_texts in batched_texts:
logits = self._encode(texts=batch_texts)
activations = self._get_activation(logits=logits)
if k_tokens is None:
nonzero_indices = torch.nonzero(activations["sparse_activations"])
activations["activations"] = nonzero_indices
else:
activations = self._update_activations(**activations, k_tokens=k_tokens)
batch_csr = self._convert_to_csr_array(activations)
sparse_embs.extend(batch_csr)

return vstack(sparse_embs).tocsr()

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"konlpy",
"mecab-python3",
"scipy >= 1.10.0",
"protobuf==3.20.0",
"protobuf==3.20.2",
"unidic-lite",
"cohere",
"voyageai >= 0.2.0",
Expand Down

0 comments on commit 7cd6101

Please sign in to comment.