Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Community: LlamaCppEmbeddings embed_documents and embed_query #28827

Merged
merged 12 commits into from
Dec 23, 2024
50 changes: 32 additions & 18 deletions libs/community/langchain_community/embeddings/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
"""

client: Any = None #: :meta private:
model_path: str
model_path: str = Field(default="")

n_ctx: int = Field(512, alias="n_ctx")
"""Token context window."""
Expand Down Expand Up @@ -88,21 +88,22 @@ def validate_environment(self) -> Self:
if self.n_gpu_layers is not None:
model_params["n_gpu_layers"] = self.n_gpu_layers

try:
from llama_cpp import Llama

self.client = Llama(model_path, embedding=True, **model_params)
except ImportError:
raise ImportError(
"Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python"
)
except Exception as e:
raise ValueError(
f"Could not load Llama model from path: {model_path}. "
f"Received error {e}"
)
if not self.client:
try:
from llama_cpp import Llama

self.client = Llama(model_path, embedding=True, **model_params)
except ImportError:
raise ImportError(
"Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python"
)
except Exception as e:
raise ValueError(
f"Could not load Llama model from path: {model_path}. "
f"Received error {e}"
)

return self

Expand All @@ -116,7 +117,17 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
List of embeddings, one for each text.
"""
embeddings = self.client.create_embedding(texts)
return [list(map(float, e["embedding"])) for e in embeddings["data"]]
final_embeddings = []
for e in embeddings["data"]:
try:
if isinstance(e["embedding"][0], list):
for data in e["embedding"]:
final_embeddings.append(list(map(float, data)))
else:
final_embeddings.append(list(map(float, e["embedding"])))
except (IndexError, TypeError):
final_embeddings.append(list(map(float, e["embedding"])))
return final_embeddings

def embed_query(self, text: str) -> List[float]:
"""Embed a query using the Llama model.
Expand All @@ -128,4 +139,7 @@ def embed_query(self, text: str) -> List[float]:
Embeddings for the text.
"""
embedding = self.client.embed(text)
return list(map(float, embedding))
if not isinstance(embedding, list):
return list(map(float, embedding))
else:
return list(map(float, embedding[0]))
40 changes: 40 additions & 0 deletions libs/community/tests/unit_tests/embeddings/test_llamacpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Generator
from unittest.mock import MagicMock, patch

import pytest

from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings


@pytest.fixture
def mock_llama_client() -> Generator[MagicMock, None, None]:
with patch(
"langchain_community.embeddings.llamacpp.LlamaCppEmbeddings"
) as MockLlama:
mock_client = MagicMock()
MockLlama.return_value = mock_client
yield mock_client


def test_initialization(mock_llama_client: MagicMock) -> None:
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
assert embeddings.client is not None


def test_embed_documents(mock_llama_client: MagicMock) -> None:
mock_llama_client.create_embedding.return_value = {
"data": [{"embedding": [[0.1, 0.2, 0.3]]}, {"embedding": [[0.4, 0.5, 0.6]]}]
}
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
texts = ["Hello world", "Test document"]
result = embeddings.embed_documents(texts)
expected = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert result == expected


def test_embed_query(mock_llama_client: MagicMock) -> None:
mock_llama_client.embed.return_value = [[0.1, 0.2, 0.3]]
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
result = embeddings.embed_query("Sample query")
expected = [0.1, 0.2, 0.3]
assert result == expected
Loading