-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
docs: add samples to migrate pinecone to alloy db #292
base: main
Are you sure you want to change the base?
Changes from 4 commits
ec9a0b5
7d6f68d
0ee4df7
ec7abce
3c38878
7644c9b
b17a017
6336516
4daacc5
ab89d03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
#!/usr/bin/env python | ||
|
||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from typing import Any, Optional | ||
|
||
# TODO(dev): Replace the values below | ||
project_id = os.environ["PROJECT_ID"] | ||
region = os.environ["REGION"] | ||
cluster = os.environ["CLUSTER_ID"] | ||
instance = os.environ["INSTANCE_ID"] | ||
db_name = os.environ["DATABASE_ID"] | ||
|
||
# TODO(dev): (optional values) Replace the values below | ||
db_user = os.environ.get("DB_USER", "") | ||
db_pwd = os.environ.get("DB_PASSWORD", "") | ||
table_name = os.environ.get("TABLE_NAME", "alloy_db_migration_table") | ||
vector_size = int(os.environ.get("VECTOR_SIZE", "768")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See note on variables not env vars |
||
|
||
|
||
# [START langchain_alloydb_migration_get_client] | ||
vishwarajanand marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from langchain_google_alloydb_pg import AlloyDBEngine | ||
|
||
|
||
async def aget_alloydb_client( | ||
project_id: str = project_id, | ||
region: str = region, | ||
cluster: str = cluster, | ||
instance: str = instance, | ||
database: str = db_name, | ||
user: Optional[str] = db_user, | ||
password: Optional[str] = db_pwd, | ||
) -> AlloyDBEngine: | ||
engine = await AlloyDBEngine.afrom_instance( | ||
project_id=project_id, | ||
region=region, | ||
cluster=cluster, | ||
instance=instance, | ||
database=database, | ||
user=user, | ||
password=password, | ||
) | ||
|
||
print("Langchain AlloyDB client initiated.") | ||
return engine | ||
|
||
|
||
# [END langchain_alloydb_migration_get_client] | ||
|
||
# [START langchain_alloydb_migration_fake_embedding_service] | ||
from langchain_core.embeddings import FakeEmbeddings | ||
|
||
|
||
def get_embeddings_service(size: int = vector_size) -> FakeEmbeddings: | ||
embeddings_service = FakeEmbeddings(size=size) | ||
|
||
print("Langchain FakeEmbeddings service initiated.") | ||
return embeddings_service | ||
|
||
|
||
# [END langchain_alloydb_migration_fake_embedding_service] | ||
|
||
|
||
# [START langchain_create_alloydb_migration_vector_store_table] | ||
async def ainit_vector_store( | ||
engine: AlloyDBEngine, | ||
table_name: str = table_name, | ||
vector_size: int = vector_size, | ||
**kwargs: Any, | ||
) -> None: | ||
await engine.ainit_vectorstore_table( | ||
table_name=table_name, | ||
vector_size=vector_size, | ||
**kwargs, | ||
) | ||
|
||
print("Langchain AlloyDB vector store table initialized.") | ||
|
||
|
||
# [END langchain_create_alloydb_migration_vector_store_table] | ||
|
||
|
||
# [START langchain_get_alloydb_migration_vector_store] | ||
from langchain_core.embeddings import Embeddings | ||
|
||
from langchain_google_alloydb_pg import AlloyDBVectorStore | ||
|
||
|
||
async def aget_vector_store( | ||
engine: AlloyDBEngine, | ||
embeddings_service: Embeddings, | ||
table_name: str = table_name, | ||
**kwargs: Any, | ||
) -> AlloyDBVectorStore: | ||
vector_store = await AlloyDBVectorStore.create( | ||
engine=engine, | ||
embedding_service=embeddings_service, | ||
table_name=table_name, | ||
**kwargs, | ||
) | ||
|
||
print("Langchain AlloyDB vector store instantiated.") | ||
return vector_store | ||
|
||
|
||
# [END langchain_get_alloydb_migration_vector_store] | ||
|
||
|
||
# [START langchain_alloydb_migration_vector_store_insert_data] | ||
async def ainsert_data( | ||
vector_store: AlloyDBVectorStore, | ||
texts: list[str], | ||
embeddings: list[list[float]], | ||
metadatas: list[dict[str, Any]], | ||
ids: list[str], | ||
) -> list[str]: | ||
inserted_ids = await vector_store.aadd_embeddings( | ||
texts=texts, | ||
embeddings=embeddings, | ||
metadatas=metadatas, | ||
ids=ids, | ||
) | ||
|
||
print("AlloyDB client inserted the provided data.") | ||
return inserted_ids | ||
|
||
|
||
# [END langchain_alloydb_migration_vector_store_insert_data] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
#!/usr/bin/env python | ||
|
||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import asyncio | ||
import os | ||
from typing import Any, Iterator | ||
|
||
"""Migrate Pinecone to Langchain AlloyDBVectorStore. | ||
|
||
Given a pinecone index, the following code fetches the data from pinecone | ||
in batches and uploads to an AlloyDBVectorStore. | ||
""" | ||
|
||
# TODO(dev): Replace the values below | ||
pinecone_api_key = os.environ["PINECONE_API_KEY"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We discussed that these would be variables to be set like https://github.com/GoogleCloudPlatform/python-docs-samples/blob/140b9dae356a8ffb4aa587571c4ee1eb1ae99e39/automl/snippets/get_model.py#L21, not environment variables. I would prefer that this is updated to use variables so there is not additional time and friction to understand and validate the environment variable values. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Limited the use of env vars to only the tests |
||
|
||
# TODO(dev): (optional values) Replace the values below | ||
pinecone_index_name = os.environ.get("PINECONE_INDEX_NAME", "sample-movies") | ||
pinecone_namespace = os.environ.get("PINECONE_NAMESPACE", "") | ||
pinecone_serverless_cloud = os.environ.get("PINECONE_SERVERLESS_CLOUD", "aws") | ||
pinecone_serverless_region = os.environ.get("PINECONE_SERVERLESS_REGION", "us-east-1") | ||
vishwarajanand marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pinecone_migration_table = os.environ.get( | ||
vishwarajanand marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"PINECONE_MIGRATION_TABLE", "pinecone_migration" | ||
) | ||
pinecone_batch_size = int(os.environ.get("PINECONE_BATCH_SIZE", "100")) | ||
pinecone_vector_size = int(os.environ.get("PINECONE_VECTOR_SIZE", "1024")) | ||
vishwarajanand marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# [START pinecone_get_ids_batch] | ||
from pinecone import Index # type: ignore | ||
|
||
|
||
def get_ids_batch( | ||
pinecone_index: Index, namespace: str = "", batch_size: int = 100 | ||
) -> Iterator[list[str]]: | ||
results = pinecone_index.list_paginated( | ||
prefix="", namespace=namespace, limit=batch_size | ||
) | ||
ids = [v.id for v in results.vectors] | ||
yield ids | ||
|
||
while results.pagination is not None: | ||
pagination_token = results.pagination.next | ||
results = pinecone_index.list_paginated( | ||
prefix="", pagination_token=pagination_token, limit=batch_size | ||
) | ||
|
||
# Extract and yield the next batch of IDs | ||
ids = [v.id for v in results.vectors] | ||
yield ids | ||
print("Pinecone client fetched all ids from index.") | ||
|
||
|
||
# [END pinecone_get_ids_batch] | ||
|
||
|
||
# [START pinecone_get_data_batch] | ||
from pinecone import Index # type: ignore | ||
|
||
|
||
def get_data_batch( | ||
pinecone_index: Index, namespace: str = "", batch_size: int = 100 | ||
) -> Iterator[tuple[list[str], list[Any], list[str], list[Any]]]: | ||
|
||
id_iterator = get_ids_batch(pinecone_index, namespace, batch_size) | ||
# Iterate through the batches of IDs and process them | ||
for ids in id_iterator: | ||
|
||
# Fetch vectors for the current batch of IDs | ||
all_data = pinecone_index.fetch(ids=ids) | ||
ids = [] | ||
embeddings = [] | ||
contents = [] | ||
metadatas = [] | ||
|
||
# Process each vector in the current batch | ||
for doc in all_data["vectors"].values(): | ||
ids.append(doc["id"]) | ||
embeddings.append(doc["values"]) | ||
contents.append(str(doc["metadata"])) | ||
metadata = doc["metadata"] | ||
metadatas.append(metadata) | ||
|
||
# Yield the current batch of results | ||
yield ids, embeddings, contents, metadatas | ||
print("Pinecone client fetched all data from index.") | ||
|
||
|
||
# [END pinecone_get_data_batch] | ||
|
||
|
||
async def main() -> None: | ||
# [START pinecone_get_client] | ||
from pinecone import Pinecone, ServerlessSpec # type: ignore | ||
|
||
pinecone_client = Pinecone( | ||
api_key=pinecone_api_key, | ||
spec=ServerlessSpec( | ||
cloud=pinecone_serverless_cloud, region=pinecone_serverless_region | ||
), | ||
) | ||
print("Pinecone client initiated.") | ||
# [END pinecone_get_client] | ||
|
||
# [START pinecone_get_index] | ||
pinecone_index = pinecone_client.Index(pinecone_index_name) | ||
print("Pinecone index reference initiated.") | ||
# [END pinecone_get_index] | ||
|
||
from alloydb_snippets import aget_alloydb_client | ||
|
||
alloydb_engine = await aget_alloydb_client() | ||
|
||
# [START pinecone_alloydb_migration_get_alloydb_vectorstore] | ||
from alloydb_snippets import aget_vector_store, get_embeddings_service | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't want the region tag to include the new methods. Please update this so it's clean only using the langchain methods. |
||
|
||
from langchain_google_alloydb_pg import Column | ||
|
||
# Note that the vector size and id_column name/type are configurable. | ||
# We need to customize the vector store table because the sample data has | ||
# 1024 vectors and integer like id values (not UUIDs). | ||
await alloydb_engine.ainit_vectorstore_table( | ||
table_name=pinecone_migration_table, | ||
vector_size=pinecone_vector_size, | ||
id_column=Column("langchain_id", "text", nullable=False), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pinecone defaults to uuids if the ids are not included. We need to figure out the decision here and also note that this in the instructions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I created a pinecone vector store with uuids and the id_column customization is not required anymore. |
||
overwrite_existing=True, | ||
) | ||
print("Pinecone migration AlloyDBVectorStore table created.") | ||
|
||
embeddings_service = get_embeddings_service(pinecone_vector_size) | ||
vs = await aget_vector_store( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. region tags should include the new wrapper methods |
||
engine=alloydb_engine, | ||
embeddings_service=embeddings_service, | ||
table_name=pinecone_migration_table, | ||
) | ||
# [END pinecone_alloydb_migration_get_alloydb_vectorstore] | ||
|
||
# [START pinecone_alloydb_migration_insert_data_batch] | ||
for ids, embeddings, contents, metadatas in get_data_batch( | ||
pinecone_index=pinecone_index, | ||
namespace=pinecone_namespace, | ||
batch_size=pinecone_batch_size, | ||
): | ||
inserted_ids = await vs.aadd_embeddings( | ||
texts=contents, | ||
embeddings=embeddings, | ||
metadatas=metadatas, | ||
ids=ids, | ||
) | ||
|
||
print("Migration completed, inserted all the batches of data to AlloyDB.") | ||
# [END pinecone_alloydb_migration_insert_data_batch] | ||
vishwarajanand marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I have been adding all sample reqs to https://github.com/googleapis/langchain-google-alloydb-pg-python/blob/main/samples/requirements.txt so this file doesn't need to be updated. I am also ok with this pattern of adding the new req file to the workflow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to follow this snippet in the current version of of snippets