diff --git a/CHANGELOG.md b/CHANGELOG.md index 42aaa9eb..4b66adb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ - Removed support for neo4j.AsyncDriver in the KG creation pipeline, affecting Neo4jWriter and related components. - Updated examples and unit tests to reflect the removal of async driver support. +### Fixed +- Resolved issue with `AzureOpenAIEmbeddings` incorrectly inheriting from `OpenAIEmbeddings`, now inherits from `BaseOpenAIEmbeddings`. ## 1.1.0 diff --git a/examples/customize/embeddings/azure_openai_embeddings.py b/examples/customize/embeddings/azure_openai_embeddings.py index 62fe3ac8..932a2415 100644 --- a/examples/customize/embeddings/azure_openai_embeddings.py +++ b/examples/customize/embeddings/azure_openai_embeddings.py @@ -4,11 +4,11 @@ from neo4j_graphrag.embeddings import AzureOpenAIEmbeddings -embeder = AzureOpenAIEmbeddings( +embedder = AzureOpenAIEmbeddings( model="text-embedding-ada-002", azure_endpoint="https://my-endpoint.openai.azure.com/", api_key="", api_version="", ) -res = embeder.embed_query("my question") +res = embedder.embed_query("my question") print(res[:10]) diff --git a/src/neo4j_graphrag/embeddings/openai.py b/src/neo4j_graphrag/embeddings/openai.py index 96ce7142..6204acb3 100644 --- a/src/neo4j_graphrag/embeddings/openai.py +++ b/src/neo4j_graphrag/embeddings/openai.py @@ -15,20 +15,22 @@ from __future__ import annotations -from typing import Any +import abc +from typing import TYPE_CHECKING, Any + from neo4j_graphrag.embeddings.base import Embedder +if TYPE_CHECKING: + import openai -class OpenAIEmbeddings(Embedder): - """ - OpenAI embeddings class. - This class uses the OpenAI python client to generate embeddings for text data. - Args: - model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002". - kwargs: All other parameters will be passed to the openai.OpenAI init. +class BaseOpenAIEmbeddings(Embedder, abc.ABC): + """ + Abstract base class for OpenAI embeddings. """ + client: openai.OpenAI + def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None: try: import openai @@ -39,23 +41,52 @@ def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None ) self.openai = openai self.model = model - self.openai_client = self.openai.OpenAI(**kwargs) + self.client = self._initialize_client(**kwargs) + + @abc.abstractmethod + def _initialize_client(self, **kwargs: Any) -> Any: + """ + Initialize the OpenAI client. + Must be implemented by subclasses. + """ + pass def embed_query(self, text: str, **kwargs: Any) -> list[float]: """ - Generate embeddings for a given query using a OpenAI text embedding model. + Generate embeddings for a given query using an OpenAI text embedding model. Args: text (str): The text to generate an embedding for. **kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function. """ - response = self.openai_client.embeddings.create( - input=text, model=self.model, **kwargs - ) - return response.data[0].embedding + response = self.client.embeddings.create(input=text, model=self.model, **kwargs) + embedding: list[float] = response.data[0].embedding + return embedding -class AzureOpenAIEmbeddings(OpenAIEmbeddings): - def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None: - super().__init__(model, **kwargs) - self.openai_client = self.openai.AzureOpenAI(**kwargs) +class OpenAIEmbeddings(BaseOpenAIEmbeddings): + """ + OpenAI embeddings class. + This class uses the OpenAI python client to generate embeddings for text data. + + Args: + model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002". + kwargs: All other parameters will be passed to the openai.OpenAI init. + """ + + def _initialize_client(self, **kwargs: Any) -> Any: + return self.openai.OpenAI(**kwargs) + + +class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings): + """ + Azure OpenAI embeddings class. + This class uses the Azure OpenAI python client to generate embeddings for text data. + + Args: + model (str): The name of the Azure OpenAI embedding model to use. Defaults to "text-embedding-ada-002". + kwargs: All other parameters will be passed to the openai.AzureOpenAI init. + """ + + def _initialize_client(self, **kwargs: Any) -> Any: + return self.openai.AzureOpenAI(**kwargs) diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index aa242712..04f3f2bf 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -15,7 +15,7 @@ from __future__ import annotations import abc -from typing import Any, Optional, TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING, Any, Iterable, Optional from ..exceptions import LLMGenerationError from .base import LLMInterface diff --git a/tests/unit/embeddings/test_openai_embedder.py b/tests/unit/embeddings/test_openai_embedder.py index ec216dcf..a1b940f0 100644 --- a/tests/unit/embeddings/test_openai_embedder.py +++ b/tests/unit/embeddings/test_openai_embedder.py @@ -14,12 +14,12 @@ # limitations under the License. from unittest.mock import MagicMock, Mock, patch +import openai import pytest from neo4j_graphrag.embeddings.openai import ( AzureOpenAIEmbeddings, OpenAIEmbeddings, ) -import openai def get_mock_openai() -> MagicMock: @@ -71,3 +71,24 @@ def test_azure_openai_embedder_happy_path(mock_import: Mock) -> None: res = embedder.embed_query("my text") assert isinstance(res, list) assert res == [1.0, 2.0] + + +def test_azure_openai_embedder_does_not_call_openai_client() -> None: + from unittest.mock import patch + + mock_openai = get_mock_openai() + + with patch.dict("sys.modules", {"openai": mock_openai}): + AzureOpenAIEmbeddings( + model="text-embedding-ada-002", + azure_endpoint="https://test.openai.azure.com/", + api_key="my_key", + api_version="2023-05-15", + ) + + mock_openai.OpenAI.assert_not_called() + mock_openai.AzureOpenAI.assert_called_once_with( + azure_endpoint="https://test.openai.azure.com/", + api_key="my_key", + api_version="2023-05-15", + ) diff --git a/tests/unit/embeddings/test_sentence_transformers.py b/tests/unit/embeddings/test_sentence_transformers.py index 97a1172c..197095e4 100644 --- a/tests/unit/embeddings/test_sentence_transformers.py +++ b/tests/unit/embeddings/test_sentence_transformers.py @@ -1,12 +1,12 @@ -from unittest.mock import MagicMock, patch, Mock +from unittest.mock import MagicMock, Mock, patch import numpy as np import pytest +import torch from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, ) -import torch def get_mock_sentence_transformers() -> MagicMock: diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index 4c19b69b..c8d5f6f8 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -13,13 +13,13 @@ # limitations under the License. from __future__ import annotations -from unittest.mock import AsyncMock, MagicMock, patch, Mock +import sys +from typing import Generator +from unittest.mock import AsyncMock, MagicMock, Mock, patch +import anthropic import pytest from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM -import sys -import anthropic -from typing import Generator @pytest.fixture diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index f4062718..db1b9db8 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +from typing import Generator from unittest.mock import AsyncMock, MagicMock, Mock, patch import cohere.core @@ -19,8 +21,6 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.cohere_llm import CohereLLM -import sys -from typing import Generator @pytest.fixture diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 9c6a1ef7..546d4e39 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock, patch, Mock +from unittest.mock import MagicMock, Mock, patch +import openai import pytest from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM -import openai def get_mock_openai() -> MagicMock: