Skip to content

Commit

Permalink
fix: Fixed lazy-import bugs that occurred during refactoring. (#20)
Browse files Browse the repository at this point in the history
Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored May 17, 2024
1 parent 9c78710 commit c8a11db
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 31 deletions.
3 changes: 1 addition & 2 deletions milvus_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
__all__ = ["DefaultEmbeddingFunction", "dense", "sparse", "hybrid", "reranker", "utils"]

from . import dense, hybrid, sparse, reranker, utils
from .dense import OnnxEmebeddingFunction

DefaultEmbeddingFunction = OnnxEmebeddingFunction
DefaultEmbeddingFunction = dense.onnx.OnnxEmbeddingFunction
4 changes: 2 additions & 2 deletions milvus_model/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def JinaEmbeddingFunction(*args, **kwargs):
return jinaai.JinaEmbeddingFunction(*args, **kwargs)

def OpenAIEmbeddingFunction(*args, **kwargs):
def OpenAIEmbeddingFunction(*args, **kwargs):
return openai.OpenAIEmbeddingFunction(*args, **kwargs)

def SentenceTransformerEmbeddingFunction(*args, **kwargs):
Expand All @@ -27,4 +27,4 @@ def VoyageEmbeddingFunction(*args, **kwargs):
return voyageai.VoyageEmbeddingFunction(*args, **kwargs)

def OnnxEmbeddingFunction(*args, **kwargs):
return onnx.OnnxEmbeddingFunction(*args, **kwargs)
return onnx.OnnxEmbeddingFunction(*args, **kwargs)
35 changes: 23 additions & 12 deletions milvus_model/dense/onnx.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@

import onnxruntime

from transformers import AutoTokenizer, AutoConfig
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
import numpy as np
from typing import List

from milvus_model.base import BaseEmbeddingFunction

class Onnx(BaseEmbeddingFunction):
def __init__(self, model_name = "GPTCache/paraphrase-albert-onnx", tokenizer_name = "GPTCache/paraphrase-albert-small-v2"):
class OnnxEmbeddingFunction(BaseEmbeddingFunction):
def __init__(self, model_name: str = "GPTCache/paraphrase-albert-onnx", tokenizer_name: str = "GPTCache/paraphrase-albert-small-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.model_name = model_name
onnx_model_path = hf_hub_download(repo_id=model_name, filename="model.onnx")
self.ort_session = onnxruntime.InferenceSession(onnx_model_path)
config = AutoConfig.from_pretrained(
tokenizer_name
tokenizer_name
)
self.__dimension = config.hidden_size

def to_embeddings(self, data, **_):
def __call__(self, texts: List[str]) -> List[np.array]:
return self._encode(texts)

def encode_queries(self, queries: List[str]) -> List[np.array]:
return self._encode(queries)

def encode_documents(self, documents: List[str]) -> List[np.array]:
return self._encode(documents)

def _encode(self, texts: List[str]) -> List[np.array]:
return [self._to_embedding(text) for text in texts]

def _to_embedding(self, data: str, **_):
encoded_text = self.tokenizer.encode_plus(data, padding="max_length")

ort_inputs = {
Expand All @@ -28,10 +42,10 @@ def to_embeddings(self, data, **_):

ort_outputs = self.ort_session.run(None, ort_inputs)
ort_feat = ort_outputs[0]
emb = self.post_proc(ort_feat, ort_inputs["attention_mask"])
emb = self._post_proc(ort_feat, ort_inputs["attention_mask"])
return emb.flatten()

def post_proc(self, token_embeddings, attention_mask):
def _post_proc(self, token_embeddings, attention_mask):
input_mask_expanded = (
np.expand_dims(attention_mask, -1)
.repeat(token_embeddings.shape[-1], -1)
Expand All @@ -43,9 +57,6 @@ def post_proc(self, token_embeddings, attention_mask):
return sentence_embs

@property
def dimension(self):
"""Embedding dimension.
def dim(self):
return self.__dimension

:return: embedding dimension
"""
return self.__dimension
16 changes: 13 additions & 3 deletions milvus_model/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from .bm25 import BM25EmbeddingFunction
from .splade import SpladeEmbeddingFunction

__all__ = ["SpladeEmbeddingFunction", "BM25EmbeddingFunction"]


from milvus_model.utils.lazy_import import LazyImport

bm25 = LazyImport("bm25", globals(), "milvus_model.sparse.bm25")
splade = LazyImport("openai", globals(), "milvus_model.sparse.splade")

def BM25EmbeddingFunction(*args, **kwargs):
return bm25.BM25EmbeddingFunction(*args, **kwargs)

def SpladeEmbeddingFunction(*args, **kwargs):
return splade.SpladeEmbeddingFunction(*args, **kwargs)

6 changes: 2 additions & 4 deletions milvus_model/sparse/bm25/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from .bm25 import BM25EmbeddingFunction
from .tokenizers import Analyzer, build_analyzer_from_yaml, build_default_analyzer

__all__ = [
"BM25EmbeddingFunction",
"Analyzer",
Expand All @@ -10,7 +7,7 @@

from milvus_model.utils.lazy_import import LazyImport

bm25 = LazyImport("bm25", globals(), "milvus_model.sparse.bm25")
bm25 = LazyImport("bm25", globals(), "milvus_model.sparse.bm25.bm25")
tokenizers = LazyImport("tokenizers", globals(), "milvus_model.sparse.bm25.tokenizers")

def BM25EmbeddingFunction(*args, **kwargs):
Expand All @@ -24,3 +21,4 @@ def build_analyzer_from_yaml(*args, **kwargs):

def build_default_analyzer(*args, **kwargs):
return tokenizers.build_default_analyzer(*args, **kwargs)

2 changes: 0 additions & 2 deletions milvus_model/sparse/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
import_torch()
import_scipy()
import_transformers()
import torch
from scipy.sparse import csr_array, vstack
from transformers import AutoModelForMaskedLM, AutoTokenizer

logger = logging.getLogger(__name__)
Expand Down
10 changes: 7 additions & 3 deletions milvus_model/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = [
"import_openai",
"import_sentence_transformers",
"import_sentence_transformers",
"import_FlagEmbedding",
"import_nltk",
"import_transformers",
Expand All @@ -12,7 +12,8 @@
"import_unidic_lite",
"import_cohere",
"import_voyageai"
"import_torch"
"import_torch",
"import_huggingface_hub"
]

import importlib.util
Expand Down Expand Up @@ -62,10 +63,13 @@ def import_voyageai():
def import_torch():
_check_library("torch", "torch")

def import_huggingface_hub():
_check_library("huggingface_hub", package="huggingface-hub")

def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None):
is_avail = False
if importlib.util.find_spec(libname):
is_avail = True
if not is_avail and prompt:
prompt_install(package if package else libname)
return is_avail
return is_avail
3 changes: 1 addition & 2 deletions milvus_model/utils/dependency_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ def prompt_install(package: str, warn: bool = False): # pragma: no cover
print(f"start to install package: {package}")
subprocess.check_call(cmd, shell=True)
print(f"successfully installed package: {package}")
gptcache_log.info("%s installed successfully!", package)
except subprocess.CalledProcessError as e:
raise PipInstallError(package) from e
raise ValueError(f"install error {e}")
2 changes: 1 addition & 1 deletion milvus_model/utils/lazy_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ def __getattr__(self, item):

def __dir__(self):
module = self._load()
return dir(module)
return dir(module)

0 comments on commit c8a11db

Please sign in to comment.