Skip to content

Commit

Permalink
Add loop test
Browse files Browse the repository at this point in the history
  • Loading branch information
dqii committed Sep 3, 2024
1 parent 2eff4f1 commit 17bc1f5
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 4 deletions.
1 change: 1 addition & 0 deletions scripts/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def test_vector_search_with_filter(primary, source_table):
row[2] == filter_val
), f"Expected all results to have random_bool == {filter_val}"


# fixture to handle external index server setup
@pytest.fixture
def external_index(request):
Expand Down
67 changes: 67 additions & 0 deletions scripts/test_weighted_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import psycopg2

# Database connection parameters
db_params = {
'database': 'postgres',
'user': 'postgres', # Update with your username if different
'password': '', # Update with your password if required
'host': 'localhost',
'port': '5432'
}

# Connect to the database
conn = psycopg2.connect(**db_params)
conn.autocommit = True
cur = conn.cursor()

# Execute the SQL commands
cur.execute("""
DROP EXTENSION IF EXISTS lantern;
CREATE EXTENSION IF NOT EXISTS vector;
CREATE EXTENSION IF NOT EXISTS lantern;
CREATE TABLE IF NOT EXISTS small_world_weighted_search (
id VARCHAR(3) PRIMARY KEY,
b BOOLEAN,
v VECTOR(3),
s SPARSEVEC(3)
);
INSERT INTO small_world_weighted_search (id, b, v, s) VALUES
('000', TRUE, '[0,0,0]', '{}/3'),
('001', TRUE, '[0,0,1]', '{3:1}/3'),
('010', FALSE, '[0,1,0]' , '{2:1}/3'),
('011', TRUE, '[0,1,1]', '{2:1,3:1}/3'),
('100', FALSE, '[1,0,0]', '{1:1}/3'),
('101', FALSE, '[1,0,1]', '{1:1,3:1}/3'),
('110', FALSE, '[1,1,0]', '{1:1,2:1}/3'),
('111', TRUE, '[1,1,1]', '{1:1,2:1,3:1}/3')
ON CONFLICT DO NOTHING;
""")

distance_metrics = ["", "cos", "l2sq"]
for distance_metric in distance_metrics:
operator = op = { 'l2sq': '<->', 'cos': '<=>', 'hamming': '<+>' }[distance_metric or 'l2sq']
query_s = "{1:0.4,2:0.3,3:0.2}/3"
query_v = "[-0.5,-0.1,-0.3]"
function = f'weighted_vector_search_{distance_metric}' if distance_metric else 'weighted_vector_search'
query = f"""
SELECT
id,
round(cast(0.9 * (s {operator} '{query_s}'::sparsevec) + 0.1 * (v {operator} '{query_v}'::vector) as numeric), 2) as dist
FROM lantern.{function}(CAST(NULL as "small_world_weighted_search"), distance_operator=>'{operator}',
w1=> 0.9, col1=>'s'::text, vec1=>'{query_s}'::sparsevec,
w2=> 0.1, col2=>'v'::text, vec2=>'{query_v}'::vector
)
LIMIT 3;
"""
cur.execute(query)
res = cur.fetchall()
res = [(key, float(value)) for key, value in res]

expected_results_cos = [('111', 0.22), ('110', 0.24), ('101', 0.39)]
expected_results_l2sq = [('000', 0.54), ('100', 0.78), ('010', 0.87)]
if distance_metric == 'cos':
assert res == expected_results_cos
else:
assert res == expected_results_l2sq
4 changes: 2 additions & 2 deletions sql/lantern.sql
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,8 @@ DECLARE
-- function suffix, function default operator
utility_functions text[2][] := ARRAY[
ARRAY['', '<->'],
ARRAY['_cos', '<->'],
ARRAY['_l2sq', '<=>']
ARRAY['_cos', '<=>'],
ARRAY['_l2sq', '<->']
];
BEGIN
-- Check if the vector type from pgvector exists
Expand Down
4 changes: 2 additions & 2 deletions sql/updates/0.3.2--0.3.3.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ DECLARE
-- function suffix, function default operator
utility_functions text[2][] := ARRAY[
ARRAY['', '<->'],
ARRAY['_cos', '<->'],
ARRAY['_l2sq', '<=>']
ARRAY['_cos', '<=>'],
ARRAY['_l2sq', '<->']
];
BEGIN
-- Check if the vector type from pgvector exists
Expand Down

0 comments on commit 17bc1f5

Please sign in to comment.