Skip to content

Commit

Permalink
fix: missing quote from table name (#118)
Browse files Browse the repository at this point in the history
* fix: missing quote from table name

* fix tests

* Update alloydb_vectorstore.py

* Update alloydb_vectorstore.py

* fix
  • Loading branch information
averikitsch authored Apr 30, 2024
1 parent 4712bae commit b1d80f2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/langchain_google_alloydb_pg/alloydb_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ async def __query_collection(
search_function = self.distance_strategy.search_function

filter = f"WHERE {filter}" if filter else ""
stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM {self.table_name} {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};"
stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};"
if self.index_query_options:
await self.engine._aexecute(
f"SET LOCAL {self.index_query_options.to_string()};"
Expand Down Expand Up @@ -742,7 +742,7 @@ async def aapply_vector_index(
params = "WITH " + index.index_options()
function = index.distance_strategy.index_function
name = name or index.name
stmt = f"CREATE INDEX {'CONCURRENTLY' if concurrently else ''} {name} ON {self.table_name} USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};"
stmt = f"CREATE INDEX {'CONCURRENTLY' if concurrently else ''} {name} ON \"{self.table_name}\" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};"
if concurrently:
await self.engine._aexecute_outside_tx(stmt)
else:
Expand Down
36 changes: 18 additions & 18 deletions tests/test_alloydb_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column

DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE = "test_table" + str(uuid.uuid4())
DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4())
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4())
VECTOR_SIZE = 768

Expand All @@ -41,7 +41,7 @@
def get_env_var(key: str, desc: str) -> str:
v = os.environ.get(key)
if v is None:
raise ValueError(f"Must set env var {key} to: {desc}")
raise ValueError(f'Must set env var "{key} to: "{desc}"')
return v


Expand Down Expand Up @@ -99,7 +99,7 @@ def vs_sync(self, engine_sync):
table_name=DEFAULT_TABLE_SYNC,
)
yield vs
engine_sync._execute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}")
engine_sync._execute(f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"')

engine_sync._engine.dispose()

Expand All @@ -112,7 +112,7 @@ async def vs(self, engine):
table_name=DEFAULT_TABLE,
)
yield vs
await engine._aexecute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE}")
await engine._aexecute(f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"')
await engine._engine.dispose()

@pytest_asyncio.fixture(scope="class")
Expand Down Expand Up @@ -155,45 +155,45 @@ async def test_post_init(self, engine):
async def test_aadd_texts(self, engine, vs):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_texts(texts, ids=ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3

ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_texts(texts, metadatas, ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 6
await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}")
await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"')

async def test_aadd_texts_edge_cases(self, engine, vs):
texts = ["Taylor's", '"Swift"', "best-friend"]
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_texts(texts, ids=ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3
await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}")
await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"')

async def test_aadd_docs(self, engine, vs):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_documents(docs, ids=ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3
await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}")
await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"')

async def test_aadd_embedding(self, engine, vs):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs._aadd_embeddings(texts, embeddings, metadatas, ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3
await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}")
await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"')

async def test_adelete(self, engine, vs):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_texts(texts, ids=ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3
# delete an ID
await vs.adelete([ids[0]])
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 2

async def test_aadd_texts_custom(self, engine, vs_custom):
Expand Down Expand Up @@ -256,11 +256,11 @@ async def test_aadd_embedding_custom(self, engine, vs_custom):
def test_add_docs(self, engine_sync, vs_sync):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
vs_sync.add_documents(docs, ids=ids)
results = engine_sync._fetch(f"SELECT * FROM {DEFAULT_TABLE_SYNC}")
results = engine_sync._fetch(f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"')
assert len(results) == 3

def test_add_texts(self, engine_sync, vs_sync):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
vs_sync.add_texts(texts, ids=ids)
results = engine_sync._fetch(f"SELECT * FROM {DEFAULT_TABLE_SYNC}")
results = engine_sync._fetch(f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"')
assert len(results) == 6

0 comments on commit b1d80f2

Please sign in to comment.