diff --git a/docs/docs/integrations/llms/llm_caching.ipynb b/docs/docs/integrations/llms/llm_caching.ipynb index 3e1907691f324..bcff9775a03e9 100644 --- a/docs/docs/integrations/llms/llm_caching.ipynb +++ b/docs/docs/integrations/llms/llm_caching.ipynb @@ -912,7 +912,7 @@ "source": [ "## `Cassandra` caches\n", "\n", - "You can use Cassandra / Astra DB for caching LLM responses, choosing from the exact-match `CassandraCache` or the (vector-similarity-based) `CassandraSemanticCache`.\n", + "You can use Cassandra / Astra DB through CQL for caching LLM responses, choosing from the exact-match `CassandraCache` or the (vector-similarity-based) `CassandraSemanticCache`.\n", "\n", "Let's see both in action in the following cells." ] @@ -924,7 +924,7 @@ "source": [ "#### Connect to the DB\n", "\n", - "First you need to establish a `Session` to the DB and to specify a _keyspace_ for the cache table(s). The following gets you started with an Astra DB instance (see e.g. [here](https://cassio.org/start_here/#vector-database) for more backends and connection options)." + "First you need to establish a `Session` to the DB and to specify a _keyspace_ for the cache table(s). The following gets you connected to Astra DB through CQL (see e.g. [here](https://cassio.org/start_here/#vector-database) for more backends and connection options)." ] }, { @@ -1132,6 +1132,214 @@ "print(llm(\"How come we always see one face of the moon?\"))" ] }, + { + "cell_type": "markdown", + "id": "8712f8fc-bb89-4164-beb9-c672778bbd91", + "metadata": {}, + "source": [ + "## `Astra DB` Caches" + ] + }, + { + "cell_type": "markdown", + "id": "173041d9-e4af-4f68-8461-d302bfc7e1bd", + "metadata": {}, + "source": [ + "You can easily use [Astra DB](https://docs.datastax.com/en/astra/home/astra.html) as an LLM cache, with either the \"exact\" or the \"semantic-based\" cache.\n", + "\n", + "Make sure you have a running database (it must be a Vector-enabled database to use the Semantic cache) and get the required credentials on your Astra dashboard:\n", + "\n", + "- the API Endpoint looks like `https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com`\n", + "- the Token looks like `AstraCS:6gBhNmsk135....`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "feb510b6-99a3-4228-8e11-563051f8178e", + "metadata": {}, + "outputs": [ + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ASTRA_DB_API_ENDPOINT = https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com\n", + "ASTRA_DB_APPLICATION_TOKEN = ········\n" + ] + } + ], + "source": [ + "import getpass\n", + "\n", + "ASTRA_DB_API_ENDPOINT = input(\"ASTRA_DB_API_ENDPOINT = \")\n", + "ASTRA_DB_APPLICATION_TOKEN = getpass.getpass(\"ASTRA_DB_APPLICATION_TOKEN = \")" + ] + }, + { + "cell_type": "markdown", + "id": "ee6d587f-4b7c-43f4-9e90-5129c842a143", + "metadata": {}, + "source": [ + "### Astra DB exact LLM cache\n", + "\n", + "This will avoid invoking the LLM when the supplied prompt is _exactly_ the same as one encountered already:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ad63c146-ee41-4896-90ee-29fcc39f0ed5", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.cache import AstraDBCache\n", + "from langchain.globals import set_llm_cache\n", + "\n", + "set_llm_cache(\n", + " AstraDBCache(\n", + " api_endpoint=ASTRA_DB_API_ENDPOINT,\n", + " token=ASTRA_DB_APPLICATION_TOKEN,\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "83e0fb02-e8eb-4483-9eb1-55b5e14c4487", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "There is no definitive answer to this question as it depends on the interpretation of the terms \"true fakery\" and \"fake truth\". However, one possible interpretation is that a true fakery is a counterfeit or imitation that is intended to deceive, whereas a fake truth is a false statement that is presented as if it were true.\n", + "CPU times: user 70.8 ms, sys: 4.13 ms, total: 74.9 ms\n", + "Wall time: 2.06 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "print(llm(\"Is a true fakery the same as a fake truth?\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4d20d498-fe28-4e26-8531-2b31c52ee687", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "There is no definitive answer to this question as it depends on the interpretation of the terms \"true fakery\" and \"fake truth\". However, one possible interpretation is that a true fakery is a counterfeit or imitation that is intended to deceive, whereas a fake truth is a false statement that is presented as if it were true.\n", + "CPU times: user 15.1 ms, sys: 3.7 ms, total: 18.8 ms\n", + "Wall time: 531 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "print(llm(\"Is a true fakery the same as a fake truth?\"))" + ] + }, + { + "cell_type": "markdown", + "id": "524b94fa-6162-4880-884d-d008749d14e2", + "metadata": {}, + "source": [ + "### Astra DB Semantic cache\n", + "\n", + "This cache will do a semantic similarity search and return a hit if it finds a cached entry that is similar enough, For this, you need to provide an `Embeddings` instance of your choice." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "dc329c55-1cc4-4b74-94f9-61f8990fb214", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings import OpenAIEmbeddings\n", + "\n", + "embedding = OpenAIEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "83952a90-ab14-4e59-87c0-d2bdc1d43e43", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.cache import AstraDBSemanticCache\n", + "\n", + "set_llm_cache(\n", + " AstraDBSemanticCache(\n", + " api_endpoint=ASTRA_DB_API_ENDPOINT,\n", + " token=ASTRA_DB_APPLICATION_TOKEN,\n", + " embedding=embedding,\n", + " collection_name=\"demo_semantic_cache\",\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d74b249a-94d5-42d0-af74-f7565a994dea", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "There is no definitive answer to this question since it presupposes a great deal about the nature of truth itself, which is a matter of considerable philosophical debate. It is possible, however, to construct scenarios in which something could be considered true despite being false, such as if someone sincerely believes something to be true even though it is not.\n", + "CPU times: user 65.6 ms, sys: 15.3 ms, total: 80.9 ms\n", + "Wall time: 2.72 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "print(llm(\"Are there truths that are false?\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "11973d73-d2f4-46bd-b229-1c589df9b788", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "There is no definitive answer to this question since it presupposes a great deal about the nature of truth itself, which is a matter of considerable philosophical debate. It is possible, however, to construct scenarios in which something could be considered true despite being false, such as if someone sincerely believes something to be true even though it is not.\n", + "CPU times: user 29.3 ms, sys: 6.21 ms, total: 35.5 ms\n", + "Wall time: 1.03 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "print(llm(\"Is is possible that something false can be also true?\"))" + ] + }, { "cell_type": "markdown", "id": "0c69d84d", diff --git a/docs/docs/integrations/providers/astradb.mdx b/docs/docs/integrations/providers/astradb.mdx index 6fcb1fc605023..fc093aad3f2cc 100644 --- a/docs/docs/integrations/providers/astradb.mdx +++ b/docs/docs/integrations/providers/astradb.mdx @@ -29,8 +29,35 @@ vector_store = AstraDB( Learn more in the [example notebook](/docs/integrations/vectorstores/astradb). +### LLM Cache -### Memory +```python +from langchain.globals import set_llm_cache +from langchain.cache import AstraDBCache +set_llm_cache(AstraDBCache( + api_endpoint="...", + token="...", +)) +``` + +Learn more in the [example notebook](/docs/integrations/llms/llm_caching) (scroll to the Astra DB section). + + +### Semantic LLM Cache + +```python +from langchain.globals import set_llm_cache +from langchain.cache import AstraDBSemanticCache +set_llm_cache(AstraDBSemanticCache( + embedding=my_embedding, + api_endpoint="...", + token="...", +)) +``` + +Learn more in the [example notebook](/docs/integrations/llms/llm_caching) (scroll to the appropriate section). + +### Chat message history ```python from langchain.memory import AstraDBChatMessageHistory diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index 33bea55247d2c..ba15a24a21f19 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -1239,3 +1239,318 @@ def clear(self, **kwargs: Any) -> None: @staticmethod def get_md5(input_string: str) -> str: return hashlib.md5(input_string.encode()).hexdigest() + + +ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache" + + +class AstraDBCache(BaseCache): + """ + Cache that uses Astra DB as a backend. + + It uses a single collection as a kv store + The lookup keys, combined in the _id of the documents, are: + - prompt, a string + - llm_string, a deterministic str representation of the model parameters. + (needed to prevent same-prompt-different-model collisions) + """ + + def __init__( + self, + *, + collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + namespace: Optional[str] = None, + ): + """ + Create an AstraDB cache using a collection for storage. + + Args (only keyword-arguments accepted): + collection_name (str): name of the Astra DB collection to create/use. + token (Optional[str]): API token for Astra DB usage. + api_endpoint (Optional[str]): full URL to the API endpoint, + such as "https://-us-east1.apps.astra.datastax.com". + astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + namespace (Optional[str]): namespace (aka keyspace) where the + collection is created. Defaults to the database's "default namespace". + """ + try: + from astrapy.db import ( + AstraDB as LibAstraDB, + ) + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import a recent astrapy python package. " + "Please install it with `pip install --upgrade astrapy`." + ) + # Conflicting-arg checks: + if astra_db_client is not None: + if token is not None or api_endpoint is not None: + raise ValueError( + "You cannot pass 'astra_db_client' to AstraDB if passing " + "'token' and 'api_endpoint'." + ) + + self.collection_name = collection_name + self.token = token + self.api_endpoint = api_endpoint + self.namespace = namespace + + if astra_db_client is not None: + self.astra_db = astra_db_client + else: + self.astra_db = LibAstraDB( + token=self.token, + api_endpoint=self.api_endpoint, + namespace=self.namespace, + ) + self.collection = self.astra_db.create_collection( + collection_name=self.collection_name, + ) + + @staticmethod + def _make_id(prompt: str, llm_string: str) -> str: + return f"{_hash(prompt)}#{_hash(llm_string)}" + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + doc_id = self._make_id(prompt, llm_string) + item = self.collection.find_one( + filter={ + "_id": doc_id, + }, + projection={ + "body_blob": 1, + }, + )["data"]["document"] + if item is not None: + generations = _loads_generations(item["body_blob"]) + # this protects against malformed cached items: + if generations is not None: + return generations + else: + return None + else: + return None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + doc_id = self._make_id(prompt, llm_string) + blob = _dumps_generations(return_val) + self.collection.upsert( + { + "_id": doc_id, + "body_blob": blob, + }, + ) + + def delete_through_llm( + self, prompt: str, llm: LLM, stop: Optional[List[str]] = None + ) -> None: + """ + A wrapper around `delete` with the LLM being passed. + In case the llm(prompt) calls have a `stop` param, you should pass it here + """ + llm_string = get_prompts( + {**llm.dict(), **{"stop": stop}}, + [], + )[1] + return self.delete(prompt, llm_string=llm_string) + + def delete(self, prompt: str, llm_string: str) -> None: + """Evict from cache if there's an entry.""" + doc_id = self._make_id(prompt, llm_string) + return self.collection.delete_one(doc_id) + + def clear(self, **kwargs: Any) -> None: + """Clear cache. This is for all LLMs at once.""" + self.astra_db.truncate_collection(self.collection_name) + + +ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85 +ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache" +ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 + + +class AstraDBSemanticCache(BaseCache): + """ + Cache that uses Astra DB as a vector-store backend for semantic + (i.e. similarity-based) lookup. + + It uses a single (vector) collection and can store + cached values from several LLMs, so the LLM's 'llm_string' is stored + in the document metadata. + + You can choose the preferred similarity (or use the API default) -- + remember the threshold might require metric-dependend tuning. + """ + + def __init__( + self, + *, + collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + namespace: Optional[str] = None, + embedding: Embeddings, + metric: Optional[str] = None, + similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, + ): + """ + Initialize the cache with all relevant parameters. + Args: + + collection_name (str): name of the Astra DB collection to create/use. + token (Optional[str]): API token for Astra DB usage. + api_endpoint (Optional[str]): full URL to the API endpoint, + such as "https://-us-east1.apps.astra.datastax.com". + astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + namespace (Optional[str]): namespace (aka keyspace) where the + collection is created. Defaults to the database's "default namespace". + embedding (Embedding): Embedding provider for semantic + encoding and search. + metric: the function to use for evaluating similarity of text embeddings. + Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product') + similarity_threshold (float, optional): the minimum similarity + for accepting a (semantic-search) match. + + The default score threshold is tuned to the default metric. + Tune it carefully yourself if switching to another distance metric. + """ + try: + from astrapy.db import ( + AstraDB as LibAstraDB, + ) + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import a recent astrapy python package. " + "Please install it with `pip install --upgrade astrapy`." + ) + # Conflicting-arg checks: + if astra_db_client is not None: + if token is not None or api_endpoint is not None: + raise ValueError( + "You cannot pass 'astra_db_client' to AstraDB if passing " + "'token' and 'api_endpoint'." + ) + + self.embedding = embedding + self.metric = metric + self.similarity_threshold = similarity_threshold + + # The contract for this class has separate lookup and update: + # in order to spare some embedding calculations we cache them between + # the two calls. + # Note: each instance of this class has its own `_get_embedding` with + # its own lru. + @lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) + def _cache_embedding(text: str) -> List[float]: + return self.embedding.embed_query(text=text) + + self._get_embedding = _cache_embedding + self.embedding_dimension = self._get_embedding_dimension() + + self.collection_name = collection_name + self.token = token + self.api_endpoint = api_endpoint + self.namespace = namespace + + if astra_db_client is not None: + self.astra_db = astra_db_client + else: + self.astra_db = LibAstraDB( + token=self.token, + api_endpoint=self.api_endpoint, + namespace=self.namespace, + ) + self.collection = self.astra_db.create_collection( + collection_name=self.collection_name, + dimension=self.embedding_dimension, + metric=self.metric, + ) + + def _get_embedding_dimension(self) -> int: + return len(self._get_embedding(text="This is a sample sentence.")) + + @staticmethod + def _make_id(prompt: str, llm_string: str) -> str: + return f"{_hash(prompt)}#{_hash(llm_string)}" + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + doc_id = self._make_id(prompt, llm_string) + llm_string_hash = _hash(llm_string) + embedding_vector = self._get_embedding(text=prompt) + body = _dumps_generations(return_val) + # + self.collection.upsert( + { + "_id": doc_id, + "body_blob": body, + "llm_string_hash": llm_string_hash, + "$vector": embedding_vector, + } + ) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + hit_with_id = self.lookup_with_id(prompt, llm_string) + if hit_with_id is not None: + return hit_with_id[1] + else: + return None + + def lookup_with_id( + self, prompt: str, llm_string: str + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + """ + Look up based on prompt and llm_string. + If there are hits, return (document_id, cached_entry) for the top hit + """ + prompt_embedding: List[float] = self._get_embedding(text=prompt) + llm_string_hash = _hash(llm_string) + + hit = self.collection.vector_find_one( + vector=prompt_embedding, + filter={ + "llm_string_hash": llm_string_hash, + }, + fields=["body_blob", "_id"], + include_similarity=True, + ) + + if hit is None or hit["$similarity"] < self.similarity_threshold: + return None + else: + generations = _loads_generations(hit["body_blob"]) + if generations is not None: + # this protects against malformed cached items: + return (hit["_id"], generations) + else: + return None + + def lookup_with_id_through_llm( + self, prompt: str, llm: LLM, stop: Optional[List[str]] = None + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + llm_string = get_prompts( + {**llm.dict(), **{"stop": stop}}, + [], + )[1] + return self.lookup_with_id(prompt, llm_string=llm_string) + + def delete_by_document_id(self, document_id: str) -> None: + """ + Given this is a "similarity search" cache, an invalidation pattern + that makes sense is first a lookup to get an ID, and then deleting + with that ID. This is for the second step. + """ + self.collection.delete_one(document_id) + + def clear(self, **kwargs: Any) -> None: + """Clear the *whole* semantic cache.""" + self.astra_db.truncate_collection(self.collection_name) diff --git a/libs/langchain/tests/integration_tests/cache/test_astradb.py b/libs/langchain/tests/integration_tests/cache/test_astradb.py new file mode 100644 index 0000000000000..d17a631020ee8 --- /dev/null +++ b/libs/langchain/tests/integration_tests/cache/test_astradb.py @@ -0,0 +1,99 @@ +""" +Test AstraDB caches. Requires an Astra DB vector instance. + +Required to run this test: + - a recent `astrapy` Python package available + - an Astra DB instance; + - the two environment variables set: + export ASTRA_DB_API_ENDPOINT="https://-us-east1.apps.astra.datastax.com" + export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........." + - optionally this as well (otherwise defaults are used): + export ASTRA_DB_KEYSPACE="my_keyspace" +""" +import os +from typing import Iterator + +import pytest +from langchain_core.outputs import Generation, LLMResult + +from langchain.cache import AstraDBCache, AstraDBSemanticCache +from langchain.globals import get_llm_cache, set_llm_cache +from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def _has_env_vars() -> bool: + return all( + [ + "ASTRA_DB_APPLICATION_TOKEN" in os.environ, + "ASTRA_DB_API_ENDPOINT" in os.environ, + ] + ) + + +@pytest.fixture(scope="module") +def astradb_cache() -> Iterator[AstraDBCache]: + cache = AstraDBCache( + collection_name="lc_integration_test_cache", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + yield cache + cache.astra_db.delete_collection("lc_integration_test_cache") + + +@pytest.fixture(scope="module") +def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]: + fake_embe = FakeEmbeddings() + sem_cache = AstraDBSemanticCache( + collection_name="lc_integration_test_sem_cache", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + embedding=fake_embe, + ) + yield sem_cache + sem_cache.astra_db.delete_collection("lc_integration_test_cache") + + +@pytest.mark.requires("astrapy") +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +class TestAstraDBCaches: + def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None: + set_llm_cache(astradb_cache) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate(["foo"]) + print(output) + expected_output = LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + print(expected_output) + assert output == expected_output + astradb_cache.clear() + + def test_cassandra_semantic_cache( + self, astradb_semantic_cache: AstraDBSemanticCache + ) -> None: + set_llm_cache(astradb_semantic_cache) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate(["bar"]) # same embedding as 'foo' + expected_output = LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + assert output == expected_output + # clear the cache + astradb_semantic_cache.clear() + output = llm.generate(["bar"]) # 'fizz' is erased away now + assert output != expected_output + astradb_semantic_cache.clear()