Skip to content

Commit

Permalink
Add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
dqii committed Sep 3, 2024
1 parent 2eff4f1 commit 6d6c72d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
38 changes: 38 additions & 0 deletions scripts/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,44 @@ def test_vector_search_with_filter(primary, source_table):
row[2] == filter_val
), f"Expected all results to have random_bool == {filter_val}"

@pytest.mark.parametrize("distance_metric", ["", "l2sq", "cos"])
def test_weighted_vector_search(primary, distance_metric):
primary.execute("testdb", "CREATE TABLE IF NOT EXISTS small_world (id VARCHAR(3), b BOOLEAN, v VECTOR(3), s SPARSEVEC(3));")
primary.execute("testdb", """
INSERT INTO small_world VALUES
('000', TRUE, '[0,0,0]', '{}/3'),
('001', TRUE, '[0,0,1]', '{3:1}/3'),
('010', FALSE, '[0,1,0]' , '{2:1}/3'),
('011', TRUE, '[0,1,1]', '{2:1,3:1}/3'),
('100', FALSE, '[1,0,0]', '{1:1}/3'),
('101', FALSE, '[1,0,1]', '{1:1,3:1}/3'),
('110', FALSE, '[1,1,0]', '{1:1,2:1}/3'),
('111', TRUE, '[1,1,1]', '{1:1,2:1,3:1}/3');
""")
operator = op = { 'l2sq': '<->', 'cos': '<=>', 'hamming': '<+>' }[distance_metric or 'l2sq']
query_s = "{1:0.4,2:0.3,3:0.2}/3"
query_v = "[-0.5,-0.1,-0.3]"
function = f'weighted_vector_search_{distance_metric}' if distance_metric else 'weighted_vector_search'
query = f"""
SELECT
id,
round(cast(0.9 * (s {operator} :'{query_s}'::sparsevec) + 0.1 * (v {operator} :'{query_v}'::vector) as numeric), 2) as dist
FROM lantern.{function}(CAST(NULL as "small_world"), operator=>'{operator}',
w1=> 0.9, col1=>'s'::text, vec1=>:'{query_s}'::sparsevec,
w2=> 0.1, col2=>'v'::text, vec2=>:'{query_v}'::vector
);
LIMIT 3;
"""
res = primary.execute("testdb", query)

expected_results_cos = [('111', 0.22), ('110', 0.24), ('101', 0.39)]
expected_results_l2sq = [('000', 0.54), ('100', 0.78), ('010', 0.87)]
if distance_metric == 'cos':
assert res == expected_results_cos
else:
assert res == expected_results_l2sq


# fixture to handle external index server setup
@pytest.fixture
def external_index(request):
Expand Down
4 changes: 2 additions & 2 deletions sql/updates/0.3.2--0.3.3.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ DECLARE
-- function suffix, function default operator
utility_functions text[2][] := ARRAY[
ARRAY['', '<->'],
ARRAY['_cos', '<->'],
ARRAY['_l2sq', '<=>']
ARRAY['_cos', '<=>'],
ARRAY['_l2sq', '<->']
];
BEGIN
-- Check if the vector type from pgvector exists
Expand Down

0 comments on commit 6d6c72d

Please sign in to comment.