diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index ab9f3d1b..e6e4e9aa 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -15,7 +15,6 @@ from jinja2 import Template from litellm import embedding, model_cost from docetl.utils import completion_cost -from sklearn.metrics.pairwise import cosine_similarity from docetl.operations.base import BaseOperation from docetl.operations.utils import ( @@ -279,6 +278,8 @@ def get_embeddings( ) # Compute all cosine similarities in one call + from sklearn.metrics.pairwise import cosine_similarity + similarities = cosine_similarity(left_embeddings, right_embeddings) # Additional blocking based on embeddings diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index a9f89ddd..dda15dca 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -19,8 +19,6 @@ from jinja2 import Template from docetl.utils import completion_cost from litellm import embedding -from sklearn.cluster import KMeans -from sklearn.metrics.pairwise import cosine_similarity from docetl.operations.base import BaseOperation from docetl.operations.utils import ( @@ -392,6 +390,8 @@ def _cluster_based_sampling( ) -> Tuple[List[Dict], float]: embeddings, cost = self._get_embeddings(group_list, value_sampling) + from sklearn.cluster import KMeans + kmeans = KMeans(n_clusters=sample_size, random_state=42) cluster_labels = kmeans.fit_predict(embeddings) @@ -420,6 +420,8 @@ def _semantic_similarity_sampling( query_embedding = query_response["data"][0]["embedding"] cost += completion_cost(query_response) + from sklearn.metrics.pairwise import cosine_similarity + similarities = cosine_similarity([query_embedding], embeddings)[0] top_k_indices = np.argsort(similarities)[-sample_size:] diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 0daa78b7..27a71340 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -12,7 +12,6 @@ from jinja2 import Template from docetl.utils import completion_cost from litellm import embedding -from sklearn.metrics.pairwise import cosine_similarity from docetl.operations.base import BaseOperation from docetl.operations.utils import ( @@ -311,6 +310,8 @@ def meets_blocking_conditions(pair): ) if remaining_comparisons > 0 and blocking_threshold is not None: # Compute cosine similarity for all pairs efficiently + from sklearn.metrics.pairwise import cosine_similarity + similarity_matrix = cosine_similarity(embeddings) cosine_pairs = [] diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index f05046c0..9831e55f 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -6,7 +6,6 @@ import threading from concurrent.futures import as_completed from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -from openai import OpenAI from dotenv import load_dotenv from frozendict import frozendict @@ -30,8 +29,6 @@ LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm_cache") cache = Cache(LLM_CACHE_DIR) -client = OpenAI() - def freezeargs(func): """ @@ -792,16 +789,16 @@ def parse_llm_response( if "tool_calls" in dir(response.choices[0].message): # Default behavior for write_output function tool_calls = response.choices[0].message.tool_calls + if not tool_calls: raise ValueError("No tool calls found in response") outputs = [] for tool_call in tool_calls: - if tool_call.function.name == "write_output": - try: - outputs.append(json.loads(tool_call.function.arguments)) - except json.JSONDecodeError: - return [{}] + try: + outputs.append(json.loads(tool_call.function.arguments)) + except json.JSONDecodeError: + return [{}] return outputs else: diff --git a/tests/test_ollama.py b/tests/test_ollama.py new file mode 100644 index 00000000..2f9d9ecb --- /dev/null +++ b/tests/test_ollama.py @@ -0,0 +1,110 @@ +import shutil +import pytest +import json +import tempfile +import os +from docetl.api import ( + Pipeline, + Dataset, + MapOp, + ReduceOp, + PipelineStep, + PipelineOutput, +) +from dotenv import load_dotenv + +load_dotenv() + + +@pytest.fixture +def temp_input_file(): + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tmp: + json.dump( + [ + {"text": "This is a test", "group": "A"}, + {"text": "Another test", "group": "B"}, + ], + tmp, + ) + yield tmp.name + os.unlink(tmp.name) + + +@pytest.fixture +def temp_output_file(): + with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp: + pass + yield tmp.name + os.unlink(tmp.name) + + +@pytest.fixture +def temp_intermediate_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield tmpdirname + + +@pytest.fixture +def map_config(): + return MapOp( + name="sentiment_analysis", + type="map", + prompt="Analyze the sentiment of the following text: '{{ input.text }}'. Classify it as either positive, negative, or neutral.", + output={"schema": {"sentiment": "string"}}, + model="ollama_chat/llama3", + ) + + +@pytest.fixture +def reduce_config(): + return ReduceOp( + name="group_summary", + type="reduce", + reduce_key="group", + prompt="Summarize the following group of values: {{ inputs }} Provide a total and any other relevant statistics.", + output={"schema": {"total": "number", "avg": "number"}}, + model="ollama_chat/llama3", + ) + + +@pytest.fixture(autouse=True) +def remove_openai_api_key(): + openai_api_key = os.environ.pop("OPENAI_API_KEY", None) + yield + if openai_api_key: + os.environ["OPENAI_API_KEY"] = openai_api_key + + +def test_ollama_map_reduce_pipeline( + map_config, reduce_config, temp_input_file, temp_output_file, temp_intermediate_dir +): + pipeline = Pipeline( + name="test_ollama_pipeline", + datasets={"test_input": Dataset(type="file", path=temp_input_file)}, + operations=[map_config, reduce_config], + steps=[ + PipelineStep( + name="pipeline", + input="test_input", + operations=["sentiment_analysis", "group_summary"], + ), + ], + output=PipelineOutput( + type="file", path=temp_output_file, intermediate_dir=temp_intermediate_dir + ), + default_model="ollama_chat/llama3", + ) + + cost = pipeline.run() + + assert isinstance(cost, float) + assert cost == 0 + + # Verify output file exists and contains data + assert os.path.exists(temp_output_file) + with open(temp_output_file, "r") as f: + output_data = json.load(f) + assert len(output_data) > 0 + + # Clean up + shutil.rmtree(temp_intermediate_dir)