Skip to content

Commit

Permalink
Merge pull request #22 from ucbepic/shreyashankar/blockingerr
Browse files Browse the repository at this point in the history
fix: remove openai client call in `utils.py`
  • Loading branch information
shreyashankar authored Sep 29, 2024
2 parents f65d71b + 691a58e commit 46645a7
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 12 deletions.
3 changes: 2 additions & 1 deletion docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from jinja2 import Template
from litellm import embedding, model_cost
from docetl.utils import completion_cost
from sklearn.metrics.pairwise import cosine_similarity

from docetl.operations.base import BaseOperation
from docetl.operations.utils import (
Expand Down Expand Up @@ -279,6 +278,8 @@ def get_embeddings(
)

# Compute all cosine similarities in one call
from sklearn.metrics.pairwise import cosine_similarity

similarities = cosine_similarity(left_embeddings, right_embeddings)

# Additional blocking based on embeddings
Expand Down
6 changes: 4 additions & 2 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from jinja2 import Template
from docetl.utils import completion_cost
from litellm import embedding
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity

from docetl.operations.base import BaseOperation
from docetl.operations.utils import (
Expand Down Expand Up @@ -392,6 +390,8 @@ def _cluster_based_sampling(
) -> Tuple[List[Dict], float]:
embeddings, cost = self._get_embeddings(group_list, value_sampling)

from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=sample_size, random_state=42)
cluster_labels = kmeans.fit_predict(embeddings)

Expand Down Expand Up @@ -420,6 +420,8 @@ def _semantic_similarity_sampling(
query_embedding = query_response["data"][0]["embedding"]
cost += completion_cost(query_response)

from sklearn.metrics.pairwise import cosine_similarity

similarities = cosine_similarity([query_embedding], embeddings)[0]

top_k_indices = np.argsort(similarities)[-sample_size:]
Expand Down
3 changes: 2 additions & 1 deletion docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from jinja2 import Template
from docetl.utils import completion_cost
from litellm import embedding
from sklearn.metrics.pairwise import cosine_similarity

from docetl.operations.base import BaseOperation
from docetl.operations.utils import (
Expand Down Expand Up @@ -311,6 +310,8 @@ def meets_blocking_conditions(pair):
)
if remaining_comparisons > 0 and blocking_threshold is not None:
# Compute cosine similarity for all pairs efficiently
from sklearn.metrics.pairwise import cosine_similarity

similarity_matrix = cosine_similarity(embeddings)

cosine_pairs = []
Expand Down
13 changes: 5 additions & 8 deletions docetl/operations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import threading
from concurrent.futures import as_completed
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from openai import OpenAI

from dotenv import load_dotenv
from frozendict import frozendict
Expand All @@ -30,8 +29,6 @@
LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm_cache")
cache = Cache(LLM_CACHE_DIR)

client = OpenAI()


def freezeargs(func):
"""
Expand Down Expand Up @@ -792,16 +789,16 @@ def parse_llm_response(
if "tool_calls" in dir(response.choices[0].message):
# Default behavior for write_output function
tool_calls = response.choices[0].message.tool_calls

if not tool_calls:
raise ValueError("No tool calls found in response")

outputs = []
for tool_call in tool_calls:
if tool_call.function.name == "write_output":
try:
outputs.append(json.loads(tool_call.function.arguments))
except json.JSONDecodeError:
return [{}]
try:
outputs.append(json.loads(tool_call.function.arguments))
except json.JSONDecodeError:
return [{}]
return outputs

else:
Expand Down
110 changes: 110 additions & 0 deletions tests/test_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import shutil
import pytest
import json
import tempfile
import os
from docetl.api import (
Pipeline,
Dataset,
MapOp,
ReduceOp,
PipelineStep,
PipelineOutput,
)
from dotenv import load_dotenv

load_dotenv()


@pytest.fixture
def temp_input_file():
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tmp:
json.dump(
[
{"text": "This is a test", "group": "A"},
{"text": "Another test", "group": "B"},
],
tmp,
)
yield tmp.name
os.unlink(tmp.name)


@pytest.fixture
def temp_output_file():
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp:
pass
yield tmp.name
os.unlink(tmp.name)


@pytest.fixture
def temp_intermediate_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname


@pytest.fixture
def map_config():
return MapOp(
name="sentiment_analysis",
type="map",
prompt="Analyze the sentiment of the following text: '{{ input.text }}'. Classify it as either positive, negative, or neutral.",
output={"schema": {"sentiment": "string"}},
model="ollama_chat/llama3",
)


@pytest.fixture
def reduce_config():
return ReduceOp(
name="group_summary",
type="reduce",
reduce_key="group",
prompt="Summarize the following group of values: {{ inputs }} Provide a total and any other relevant statistics.",
output={"schema": {"total": "number", "avg": "number"}},
model="ollama_chat/llama3",
)


@pytest.fixture(autouse=True)
def remove_openai_api_key():
openai_api_key = os.environ.pop("OPENAI_API_KEY", None)
yield
if openai_api_key:
os.environ["OPENAI_API_KEY"] = openai_api_key


def test_ollama_map_reduce_pipeline(
map_config, reduce_config, temp_input_file, temp_output_file, temp_intermediate_dir
):
pipeline = Pipeline(
name="test_ollama_pipeline",
datasets={"test_input": Dataset(type="file", path=temp_input_file)},
operations=[map_config, reduce_config],
steps=[
PipelineStep(
name="pipeline",
input="test_input",
operations=["sentiment_analysis", "group_summary"],
),
],
output=PipelineOutput(
type="file", path=temp_output_file, intermediate_dir=temp_intermediate_dir
),
default_model="ollama_chat/llama3",
)

cost = pipeline.run()

assert isinstance(cost, float)
assert cost == 0

# Verify output file exists and contains data
assert os.path.exists(temp_output_file)
with open(temp_output_file, "r") as f:
output_data = json.load(f)
assert len(output_data) > 0

# Clean up
shutil.rmtree(temp_intermediate_dir)

0 comments on commit 46645a7

Please sign in to comment.