Skip to content

Commit

Permalink
Fix bug where AzureOpenAIEmbeddings inherits from OpenAIEmbeddings (#203
Browse files Browse the repository at this point in the history
)

* Fix bug where AzureOpenAIEmbeddings inherits from OpenAIEmbeddings

* Cast embedding to list[float]

* Refactored embed_query method to be in base class
  • Loading branch information
willtai authored Oct 25, 2024
1 parent bc8540e commit cb96815
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 32 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
- Removed support for neo4j.AsyncDriver in the KG creation pipeline, affecting Neo4jWriter and related components.
- Updated examples and unit tests to reflect the removal of async driver support.

### Fixed
- Resolved issue with `AzureOpenAIEmbeddings` incorrectly inheriting from `OpenAIEmbeddings`, now inherits from `BaseOpenAIEmbeddings`.

## 1.1.0

Expand Down
4 changes: 2 additions & 2 deletions examples/customize/embeddings/azure_openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from neo4j_graphrag.embeddings import AzureOpenAIEmbeddings

embeder = AzureOpenAIEmbeddings(
embedder = AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
azure_endpoint="https://my-endpoint.openai.azure.com/",
api_key="<my key>",
api_version="<update version>",
)
res = embeder.embed_query("my question")
res = embedder.embed_query("my question")
print(res[:10])
67 changes: 49 additions & 18 deletions src/neo4j_graphrag/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@

from __future__ import annotations

from typing import Any
import abc
from typing import TYPE_CHECKING, Any

from neo4j_graphrag.embeddings.base import Embedder

if TYPE_CHECKING:
import openai

class OpenAIEmbeddings(Embedder):
"""
OpenAI embeddings class.
This class uses the OpenAI python client to generate embeddings for text data.

Args:
model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
kwargs: All other parameters will be passed to the openai.OpenAI init.
class BaseOpenAIEmbeddings(Embedder, abc.ABC):
"""
Abstract base class for OpenAI embeddings.
"""

client: openai.OpenAI

def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
try:
import openai
Expand All @@ -39,23 +41,52 @@ def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None
)
self.openai = openai
self.model = model
self.openai_client = self.openai.OpenAI(**kwargs)
self.client = self._initialize_client(**kwargs)

@abc.abstractmethod
def _initialize_client(self, **kwargs: Any) -> Any:
"""
Initialize the OpenAI client.
Must be implemented by subclasses.
"""
pass

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using a OpenAI text embedding model.
Generate embeddings for a given query using an OpenAI text embedding model.
Args:
text (str): The text to generate an embedding for.
**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
"""
response = self.openai_client.embeddings.create(
input=text, model=self.model, **kwargs
)
return response.data[0].embedding
response = self.client.embeddings.create(input=text, model=self.model, **kwargs)
embedding: list[float] = response.data[0].embedding
return embedding


class AzureOpenAIEmbeddings(OpenAIEmbeddings):
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
super().__init__(model, **kwargs)
self.openai_client = self.openai.AzureOpenAI(**kwargs)
class OpenAIEmbeddings(BaseOpenAIEmbeddings):
"""
OpenAI embeddings class.
This class uses the OpenAI python client to generate embeddings for text data.
Args:
model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""

def _initialize_client(self, **kwargs: Any) -> Any:
return self.openai.OpenAI(**kwargs)


class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
"""
Azure OpenAI embeddings class.
This class uses the Azure OpenAI python client to generate embeddings for text data.
Args:
model (str): The name of the Azure OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
kwargs: All other parameters will be passed to the openai.AzureOpenAI init.
"""

def _initialize_client(self, **kwargs: Any) -> Any:
return self.openai.AzureOpenAI(**kwargs)
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import abc
from typing import Any, Optional, TYPE_CHECKING, Iterable
from typing import TYPE_CHECKING, Any, Iterable, Optional

from ..exceptions import LLMGenerationError
from .base import LLMInterface
Expand Down
23 changes: 22 additions & 1 deletion tests/unit/embeddings/test_openai_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# limitations under the License.
from unittest.mock import MagicMock, Mock, patch

import openai
import pytest
from neo4j_graphrag.embeddings.openai import (
AzureOpenAIEmbeddings,
OpenAIEmbeddings,
)
import openai


def get_mock_openai() -> MagicMock:
Expand Down Expand Up @@ -71,3 +71,24 @@ def test_azure_openai_embedder_happy_path(mock_import: Mock) -> None:
res = embedder.embed_query("my text")
assert isinstance(res, list)
assert res == [1.0, 2.0]


def test_azure_openai_embedder_does_not_call_openai_client() -> None:
from unittest.mock import patch

mock_openai = get_mock_openai()

with patch.dict("sys.modules", {"openai": mock_openai}):
AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
azure_endpoint="https://test.openai.azure.com/",
api_key="my_key",
api_version="2023-05-15",
)

mock_openai.OpenAI.assert_not_called()
mock_openai.AzureOpenAI.assert_called_once_with(
azure_endpoint="https://test.openai.azure.com/",
api_key="my_key",
api_version="2023-05-15",
)
4 changes: 2 additions & 2 deletions tests/unit/embeddings/test_sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from unittest.mock import MagicMock, patch, Mock
from unittest.mock import MagicMock, Mock, patch

import numpy as np
import pytest
import torch
from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.embeddings.sentence_transformers import (
SentenceTransformerEmbeddings,
)
import torch


def get_mock_sentence_transformers() -> MagicMock:
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/llm/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.
from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, patch, Mock
import sys
from typing import Generator
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import anthropic
import pytest
from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM
import sys
import anthropic
from typing import Generator


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/llm/test_cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Generator
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import cohere.core
import pytest
from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.llm.cohere_llm import CohereLLM
import sys
from typing import Generator


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/llm/test_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch, Mock
from unittest.mock import MagicMock, Mock, patch

import openai
import pytest
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM
import openai


def get_mock_openai() -> MagicMock:
Expand Down

0 comments on commit cb96815

Please sign in to comment.