From b344e3a8f6c8fdaa2066ef691cc128618aa5f6bc Mon Sep 17 00:00:00 2001 From: Egil Date: Wed, 2 Oct 2024 12:04:42 +0200 Subject: [PATCH] Bugfix for sqlite3 operation error in cache --- docetl/operations/utils.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index 88038b2c..e7fbd917 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -33,7 +33,7 @@ CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "cache") LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm_cache") cache = Cache(LLM_CACHE_DIR) - +cache.close() def freezeargs(func): """ @@ -96,18 +96,19 @@ def gen_embedding(model: str, input: List[str]) -> List[float]: key = hashlib.md5(f"{model}_{input}".encode()).hexdigest() input = json.loads(input) - # Try to get the result from cache - result = cache.get(key) - if result is None: - # If not in cache, compute the embedding - if not isinstance(input[0], str): - input = [json.dumps(item) for item in input] + with cache as c: + # Try to get the result from cache + result = c.get(key) + if result is None: + # If not in cache, compute the embedding + if not isinstance(input[0], str): + input = [json.dumps(item) for item in input] - input = [item if item else "None" for item in input] + input = [item if item else "None" for item in input] - result = embedding(model=model, input=input) - # Cache the result - cache.set(key, result) + result = embedding(model=model, input=input) + # Cache the result + c.set(key, result) return result @@ -265,12 +266,13 @@ def cached_call_llm( Returns: str: The result from call_llm_with_cache. """ - result = cache.get(cache_key) - if result is None: - result = call_llm_with_cache( - model, op_type, messages, output_schema, tools, scratchpad - ) - cache.set(cache_key, result) + with cache as c: + result = c.get(cache_key) + if result is None: + result = call_llm_with_cache( + model, op_type, messages, output_schema, tools, scratchpad + ) + c.set(cache_key, result) return result @@ -411,7 +413,6 @@ def call_llm( # TODO: HITL return {} - class InvalidOutputError(Exception): """ Custom exception raised when the LLM output is invalid or cannot be parsed. @@ -454,7 +455,6 @@ def target(): return decorator - def call_llm_with_cache( model: str, op_type: str,