-
Notifications
You must be signed in to change notification settings - Fork 15.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[patch], langchain[minor]: Add retriever self_query and scor…
…e_threshold in DingoDB (#18106)
- Loading branch information
1 parent
d039dcb
commit 6a08134
Showing
6 changed files
with
656 additions
and
3 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
496 changes: 496 additions & 0 deletions
496
docs/docs/integrations/retrievers/self_query/dingo.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from typing import Tuple, Union | ||
|
||
from langchain.chains.query_constructor.ir import ( | ||
Comparator, | ||
Comparison, | ||
Operation, | ||
Operator, | ||
StructuredQuery, | ||
Visitor, | ||
) | ||
|
||
|
||
class DingoDBTranslator(Visitor): | ||
"""Translate `DingoDB` internal query language elements to valid filters.""" | ||
|
||
allowed_comparators = ( | ||
Comparator.EQ, | ||
Comparator.NE, | ||
Comparator.LT, | ||
Comparator.LTE, | ||
Comparator.GT, | ||
Comparator.GTE, | ||
) | ||
"""Subset of allowed logical comparators.""" | ||
allowed_operators = (Operator.AND, Operator.OR) | ||
"""Subset of allowed logical operators.""" | ||
|
||
def _format_func(self, func: Union[Operator, Comparator]) -> str: | ||
self._validate_func(func) | ||
return f"${func.value}" | ||
|
||
def visit_operation(self, operation: Operation) -> Operation: | ||
return operation | ||
|
||
def visit_comparison(self, comparison: Comparison) -> Comparison: | ||
return comparison | ||
|
||
def visit_structured_query( | ||
self, structured_query: StructuredQuery | ||
) -> Tuple[str, dict]: | ||
if structured_query.filter is None: | ||
kwargs = {} | ||
else: | ||
kwargs = { | ||
"search_params": { | ||
"langchain_expr": structured_query.filter.accept(self) | ||
} | ||
} | ||
return structured_query.query, kwargs |
99 changes: 99 additions & 0 deletions
99
libs/langchain/tests/unit_tests/retrievers/self_query/test_dingo.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from typing import Dict, Tuple | ||
|
||
from langchain.chains.query_constructor.ir import ( | ||
Comparator, | ||
Comparison, | ||
Operation, | ||
Operator, | ||
StructuredQuery, | ||
) | ||
from langchain.retrievers.self_query.dingo import DingoDBTranslator | ||
|
||
DEFAULT_TRANSLATOR = DingoDBTranslator() | ||
|
||
|
||
def test_visit_comparison() -> None: | ||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"]) | ||
expected = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"]) | ||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp) | ||
assert expected == actual | ||
|
||
|
||
def test_visit_operation() -> None: | ||
op = Operation( | ||
operator=Operator.AND, | ||
arguments=[ | ||
Comparison(comparator=Comparator.LT, attribute="foo", value=2), | ||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), | ||
], | ||
) | ||
expected = Operation( | ||
operator=Operator.AND, | ||
arguments=[ | ||
Comparison(comparator=Comparator.LT, attribute="foo", value=2), | ||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), | ||
], | ||
) | ||
actual = DEFAULT_TRANSLATOR.visit_operation(op) | ||
assert expected == actual | ||
|
||
|
||
def test_visit_structured_query() -> None: | ||
query = "What is the capital of France?" | ||
|
||
structured_query = StructuredQuery( | ||
query=query, | ||
filter=None, | ||
) | ||
expected: Tuple[str, Dict] = (query, {}) | ||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) | ||
assert expected == actual | ||
|
||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"]) | ||
structured_query = StructuredQuery( | ||
query=query, | ||
filter=comp, | ||
) | ||
expected = ( | ||
query, | ||
{ | ||
"search_params": { | ||
"langchain_expr": Comparison( | ||
comparator=Comparator.LT, attribute="foo", value=["1", "2"] | ||
) | ||
} | ||
}, | ||
) | ||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) | ||
assert expected == actual | ||
|
||
op = Operation( | ||
operator=Operator.AND, | ||
arguments=[ | ||
Comparison(comparator=Comparator.LT, attribute="foo", value=2), | ||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), | ||
], | ||
) | ||
structured_query = StructuredQuery( | ||
query=query, | ||
filter=op, | ||
) | ||
expected = ( | ||
query, | ||
{ | ||
"search_params": { | ||
"langchain_expr": Operation( | ||
operator=Operator.AND, | ||
arguments=[ | ||
Comparison(comparator=Comparator.LT, attribute="foo", value=2), | ||
Comparison( | ||
comparator=Comparator.EQ, attribute="bar", value="baz" | ||
), | ||
], | ||
) | ||
} | ||
}, | ||
) | ||
|
||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) | ||
assert expected == actual |