Skip to content

Commit

Permalink
community[patch], langchain[minor]: Add retriever self_query and scor…
Browse files Browse the repository at this point in the history
…e_threshold in DingoDB (#18106)
  • Loading branch information
HeChangHaoGary authored Mar 5, 2024
1 parent d039dcb commit 6a08134
Show file tree
Hide file tree
Showing 6 changed files with 656 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/api_reference/guide_imports.json

Large diffs are not rendered by default.

496 changes: 496 additions & 0 deletions docs/docs/integrations/retrievers/self_query/dingo.ipynb

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions libs/community/langchain_community/vectorstores/dingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def similarity_search(
List of Documents most similar to the query and score for each
"""
docs_and_scores = self.similarity_search_with_score(
query, k=k, search_params=search_params
query, k=k, search_params=search_params, **kwargs
)
return [doc for doc, _ in docs_and_scores]

Expand Down Expand Up @@ -177,9 +177,15 @@ def similarity_search_with_score(
return []

for res in results[0]["vectorWithDistances"]:
score = res["distance"]
if (
"score_threshold" in kwargs
and kwargs.get("score_threshold") is not None
):
if score > kwargs.get("score_threshold"):
continue
metadatas = res["scalarData"]
id = res["id"]
score = res["distance"]
text = metadatas[self._text_key]["fields"][0]["data"]
metadata = {"id": id, "text": text, "score": score}
for meta_key in metadatas.keys():
Expand Down
3 changes: 3 additions & 0 deletions libs/langchain/langchain/retrievers/self_query/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Chroma,
DashVector,
DeepLake,
Dingo,
ElasticsearchStore,
Milvus,
MongoDBAtlasVectorSearch,
Expand Down Expand Up @@ -39,6 +40,7 @@
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain.retrievers.self_query.dingo import DingoDBTranslator
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
from langchain.retrievers.self_query.milvus import MilvusTranslator
from langchain.retrievers.self_query.mongodb_atlas import MongoDBAtlasTranslator
Expand All @@ -65,6 +67,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
Pinecone: PineconeTranslator,
Chroma: ChromaTranslator,
DashVector: DashvectorTranslator,
Dingo: DingoDBTranslator,
Weaviate: WeaviateTranslator,
Vectara: VectaraTranslator,
Qdrant: QdrantTranslator,
Expand Down
49 changes: 49 additions & 0 deletions libs/langchain/langchain/retrievers/self_query/dingo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Tuple, Union

from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)


class DingoDBTranslator(Visitor):
"""Translate `DingoDB` internal query language elements to valid filters."""

allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
)
"""Subset of allowed logical comparators."""
allowed_operators = (Operator.AND, Operator.OR)
"""Subset of allowed logical operators."""

def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"

def visit_operation(self, operation: Operation) -> Operation:
return operation

def visit_comparison(self, comparison: Comparison) -> Comparison:
return comparison

def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {
"search_params": {
"langchain_expr": structured_query.filter.accept(self)
}
}
return structured_query.query, kwargs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Dict, Tuple

from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
)
from langchain.retrievers.self_query.dingo import DingoDBTranslator

DEFAULT_TRANSLATOR = DingoDBTranslator()


def test_visit_comparison() -> None:
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
expected = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual


def test_visit_operation() -> None:
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
expected = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual


def test_visit_structured_query() -> None:
query = "What is the capital of France?"

structured_query = StructuredQuery(
query=query,
filter=None,
)
expected: Tuple[str, Dict] = (query, {})
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual

comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
structured_query = StructuredQuery(
query=query,
filter=comp,
)
expected = (
query,
{
"search_params": {
"langchain_expr": Comparison(
comparator=Comparator.LT, attribute="foo", value=["1", "2"]
)
}
},
)
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual

op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
structured_query = StructuredQuery(
query=query,
filter=op,
)
expected = (
query,
{
"search_params": {
"langchain_expr": Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(
comparator=Comparator.EQ, attribute="bar", value="baz"
),
],
)
}
},
)

actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual

0 comments on commit 6a08134

Please sign in to comment.