Skip to content

Commit

Permalink
Bugfix for sqlite3 operation error in cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Egil committed Oct 2, 2024
1 parent b79c889 commit b344e3a
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions docetl/operations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "cache")
LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm_cache")
cache = Cache(LLM_CACHE_DIR)

cache.close()

def freezeargs(func):
"""
Expand Down Expand Up @@ -96,18 +96,19 @@ def gen_embedding(model: str, input: List[str]) -> List[float]:
key = hashlib.md5(f"{model}_{input}".encode()).hexdigest()
input = json.loads(input)

# Try to get the result from cache
result = cache.get(key)
if result is None:
# If not in cache, compute the embedding
if not isinstance(input[0], str):
input = [json.dumps(item) for item in input]
with cache as c:
# Try to get the result from cache
result = c.get(key)
if result is None:
# If not in cache, compute the embedding
if not isinstance(input[0], str):
input = [json.dumps(item) for item in input]

input = [item if item else "None" for item in input]
input = [item if item else "None" for item in input]

result = embedding(model=model, input=input)
# Cache the result
cache.set(key, result)
result = embedding(model=model, input=input)
# Cache the result
c.set(key, result)

return result

Expand Down Expand Up @@ -265,12 +266,13 @@ def cached_call_llm(
Returns:
str: The result from call_llm_with_cache.
"""
result = cache.get(cache_key)
if result is None:
result = call_llm_with_cache(
model, op_type, messages, output_schema, tools, scratchpad
)
cache.set(cache_key, result)
with cache as c:
result = c.get(cache_key)
if result is None:
result = call_llm_with_cache(
model, op_type, messages, output_schema, tools, scratchpad
)
c.set(cache_key, result)
return result


Expand Down Expand Up @@ -411,7 +413,6 @@ def call_llm(
# TODO: HITL
return {}


class InvalidOutputError(Exception):
"""
Custom exception raised when the LLM output is invalid or cannot be parsed.
Expand Down Expand Up @@ -454,7 +455,6 @@ def target():

return decorator


def call_llm_with_cache(
model: str,
op_type: str,
Expand Down

0 comments on commit b344e3a

Please sign in to comment.