diff --git a/milvus_model/dense/jinaai.py b/milvus_model/dense/jinaai.py index c458cee..3ca1d65 100644 --- a/milvus_model/dense/jinaai.py +++ b/milvus_model/dense/jinaai.py @@ -12,8 +12,10 @@ class JinaEmbeddingFunction(BaseEmbeddingFunction): def __init__( self, - model_name: str = "jina-embeddings-v2-base-en", + model_name: str = "jina-embeddings-v3", api_key: Optional[str] = None, + task: str = 'retrieval.passage', + dimensions: Optional[int] = None, **kwargs, ): if api_key is None: @@ -33,8 +35,8 @@ def __init__( self._session.headers.update( {"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"} ) - self.model_name = model_name - self._dim = None + self.task = task + self._dim = dimensions @property def dim(self): @@ -43,17 +45,25 @@ def dim(self): return self._dim def encode_queries(self, queries: List[str]) -> List[np.array]: - return self._call_jina_api(queries) + return self._call_jina_api(queries, task='retrieval.query') def encode_documents(self, documents: List[str]) -> List[np.array]: - return self._call_jina_api(documents) + return self._call_jina_api(documents, task='retrieval.passage') def __call__(self, texts: List[str]) -> List[np.array]: - return self._call_jina_api(texts) + return self._call_jina_api(texts, task=self.task) - def _call_jina_api(self, texts: List[str]): + def _call_jina_api(self, texts: List[str], task: Optional[str] = None): + data = { + "input": texts, + "model": self.model_name, + "task": task, + } + if self._dim is not None: + data["dimensions"] = self._dim resp = self._session.post( # type: ignore[assignment] - API_URL, json={"input": texts, "model": self.model_name}, + API_URL, + json=data, ).json() if "data" not in resp: raise RuntimeError(resp["detail"])