From 94911ae50345357493118e9144eb9ad1a721a30e Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Mon, 8 Jan 2024 11:33:36 -0800 Subject: [PATCH] community[patch]: Support different Pinecone initializations depending on the version (#15717) Co-authored-by: DosticJelena --- .../vectorstores/pinecone.py | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/pinecone.py b/libs/community/langchain_community/vectorstores/pinecone.py index 0ecfb875bf48f..8e22004a01f0e 100644 --- a/libs/community/langchain_community/vectorstores/pinecone.py +++ b/libs/community/langchain_community/vectorstores/pinecone.py @@ -1,15 +1,26 @@ from __future__ import annotations import logging +import os import uuid import warnings -from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + List, + Optional, + Tuple, + Union, +) import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils.iter import batch_iterate from langchain_core.vectorstores import VectorStore +from packaging import version from langchain_community.vectorstores.utils import ( DistanceStrategy, @@ -63,10 +74,9 @@ def __init__( "Passing in `embedding` as a Callable is deprecated. Please pass in an" " Embeddings object instead." ) - if not isinstance(index, pinecone.index.Index): + if not isinstance(index, pinecone.Index): raise ValueError( - f"client should be an instance of pinecone.index.Index, " - f"got {type(index)}" + f"client should be an instance of pinecone.Index, " f"got {type(index)}" ) self._index = index self._embedding = embedding @@ -359,11 +369,24 @@ def get_pinecone_index( "Please install it with `pip install pinecone-client`." ) - indexes = pinecone.list_indexes() # checks if provided index exists + pinecone_client_version = pinecone.__version__ - if index_name in indexes: - index = pinecone.Index(index_name, pool_threads=pool_threads) - elif len(indexes) == 0: + if version.parse(pinecone_client_version) >= version.parse("3.0.0.dev"): + pinecone_instance = pinecone.Pinecone( + api_key=os.environ.get("PINECONE_API_KEY"), pool_threads=pool_threads + ) + indexes = pinecone_instance.list_indexes() + index_names = [i.name for i in indexes.index_list["indexes"]] + else: + index_names = pinecone.list_indexes() + + if index_name in index_names: + index = ( + pinecone_instance.Index(index_name) + if version.parse(pinecone_client_version) >= version.parse("3.0.0") + else pinecone.Index(index_name, pool_threads=pool_threads) + ) + elif len(index_names) == 0: raise ValueError( "No active indexes found in your Pinecone project, " "are you sure you're using the right Pinecone API key and Environment? " @@ -372,7 +395,7 @@ def get_pinecone_index( else: raise ValueError( f"Index '{index_name}' not found in your Pinecone project. " - f"Did you mean one of the following indexes: {', '.join(indexes)}" + f"Did you mean one of the following indexes: {', '.join(index_names)}" ) return index