diff --git a/docetl/operations/sample.py b/docetl/operations/sample.py new file mode 100644 index 00000000..669b4de4 --- /dev/null +++ b/docetl/operations/sample.py @@ -0,0 +1,188 @@ +from typing import Any, Dict, List, Optional, Tuple +import numpy as np +from docetl.operations.base import BaseOperation +from docetl.operations.clustering_utils import get_embeddings_for_clustering + + +class SampleOperation(BaseOperation): + """ + Params: + - method: "uniform", "stratify", "outliers", "custom" + - samples: int, float, or list + - method_kwargs: dict, optional + - embedding_model: str, optional + - embedding_keys: list, optional + - center: dict, optional + """ + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def syntax_check(self) -> None: + """ + Checks the configuration of the SampleOperation for required keys and valid structure. + + Raises: + ValueError: If required keys are missing or invalid in the configuration. + TypeError: If configuration values have incorrect types. + """ + if "method" not in self.config: + raise ValueError("Must specify 'method' in SampleOperation configuration") + + valid_methods = ["uniform", "stratify", "outliers", "custom"] + if self.config["method"] not in valid_methods: + raise ValueError(f"'method' must be one of {valid_methods}") + + if self.config["method"] == "custom": + # Samples must be a list + if not isinstance(self.config["samples"], list): + raise TypeError("'samples' must be a list for custom sampling") + + if self.config["method"] in ["random", "stratify"]: + if "samples" not in self.config: + raise ValueError( + "Must specify 'samples' for random or stratify sampling" + ) + if not isinstance(self.config["samples"], (int, float, list)) or ( + isinstance(self.config["samples"], (int, float)) + and self.config["samples"] <= 0 + ): + raise TypeError("'samples' must be a positive integer, float, or list") + + if self.config["method"] == "stratify": + if "stratify_key" not in self.config.get("method_kwargs", {}): + raise ValueError("Must specify 'stratify_key' for stratify sampling") + if not isinstance( + self.config.get("method_kwargs", {})["stratify_key"], str + ): + raise TypeError("'stratify_key' must be a string") + + if self.config["method"] == "outliers": + outliers_config = self.config.get("method_kwargs", {}) + if "std" not in outliers_config and "samples" not in outliers_config: + raise ValueError( + "Must specify either 'std' or 'samples' in outliers configuration" + ) + + if "std" in outliers_config: + if ( + not isinstance(outliers_config["std"], (int, float)) + or outliers_config["std"] <= 0 + ): + raise TypeError("'std' in outliers must be a positive number") + + if "samples" in outliers_config: + if ( + not isinstance(outliers_config["samples"], (int, float)) + or outliers_config["samples"] <= 0 + ): + raise TypeError( + "'samples' in outliers must be a positive integer or float" + ) + + if "embedding_keys" not in outliers_config: + raise ValueError( + "'embedding_keys' must be specified in outliers configuration" + ) + + if not isinstance(outliers_config["embedding_keys"], list) or not all( + isinstance(key, str) for key in outliers_config["embedding_keys"] + ): + raise TypeError( + "'embedding_keys' in outliers must be a list of strings" + ) + + if "center" in self.config.get("method_kwargs", {}): + if not isinstance(self.config.get("method_kwargs", {})["center"], dict): + raise TypeError("'center' must be a dictionary") + + def execute( + self, input_data: List[Dict], is_build: bool = False + ) -> Tuple[List[Dict], float]: + """ + Executes the sample operation on the input data. + + Args: + input_data (List[Dict]): A list of dictionaries to process. + is_build (bool): Whether the operation is being executed + in the build phase. Defaults to False. + + Returns: + Tuple[List[Dict], float]: A tuple containing the filtered + list of dictionaries and the total cost of the operation. + """ + cost = 0 + if not input_data: + return [], cost + + if self.config["method"] == "outliers": + # Outlier functionality + outliers_config = self.config.get("method_kwargs", {}) + embeddings, embedding_cost = get_embeddings_for_clustering( + input_data, outliers_config, self.runner.api + ) + cost += embedding_cost + embeddings = np.array(embeddings) + + if "center" in outliers_config: + center_embeddings, cost2 = get_embeddings_for_clustering( + [outliers_config["center"]], outliers_config, self.runner.api + ) + cost += cost2 + center = np.array(center_embeddings[0]) + + else: + center = embeddings.mean(axis=0) + + distances = np.sqrt(((embeddings - center) ** 2).sum(axis=1)) + + if "std" in outliers_config: + cutoff = ( + np.sqrt((embeddings.std(axis=0) ** 2).sum()) + * outliers_config["std"] + ) + else: # "samples" in config + distance_distribution = np.sort(distances) + samples = self.config["samples"] + if isinstance(samples, float): + samples = int(samples * (len(distance_distribution) - 1)) + cutoff = distance_distribution[samples] + + keep = outliers_config.get("keep", False) + include = distances > cutoff if keep else distances <= cutoff + + output_data = [item for idx, item in enumerate(input_data) if include[idx]] + else: + samples = self.config["samples"] + if self.config["method"] == "custom": + keys = list(samples[0].keys()) + key_to_doc = { + tuple([doc[key] for key in keys]): doc for doc in input_data + } + + output_data = [ + key_to_doc[tuple([sample[key] for key in keys])] + for sample in samples + ] + else: + stratify = None + if self.config["method"] == "stratify": + stratify = [ + data[self.config.get("method_kwargs", {})["stratify_key"]] + for data in input_data + ] + + import sklearn.model_selection + + output_data, _ = sklearn.model_selection.train_test_split( + input_data, + train_size=samples, + random_state=self.config.get("random_state", None), + stratify=stratify, + ) + + return output_data, cost diff --git a/docs/operators/sample.md b/docs/operators/sample.md new file mode 100644 index 00000000..6d4fbc89 --- /dev/null +++ b/docs/operators/sample.md @@ -0,0 +1,123 @@ +# Sample operation + +The Sample operation in DocETL samples items from the input. It is meant mostly as a debugging tool: + +Insert it before the last operation, the one you're currently trying to add to the end of a working pipeline, to limit the amount of data it will be fed, so that the run time is small enough to comfortably debug its prompt. Once it seems to be working, you can remove the sample operation. You can then repeat this for each operation you add while developing your pipeline! + +## 🚀 Example: + +```yaml +- name: cluster_concepts + type: sample + method: stratify + samples: 0.1 + method_kwargs: + stratify_key: category + random_state: 42 +``` + +This sample operation will return a pseudo-randomly selected 10% of the samples (samples: 0.1). The random selection will be seeded with a constant (42), meaning the same sample will be returned if you rerun the pipeline (If no random state is given, a different sample will be returned every time). Additionally, the random sampling will sample each value of the category key equally. + +## Required Parameters + +- name: A unique name for the operation. +- type: Must be set to "sample". +- method: The sampling method to use. Can be "uniform", "stratify", "outliers", or "custom". +- samples: Either a list of key-value pairs representing document ids and values, an integer count of samples, or a float fraction of samples. + +## Optional Parameters + +| Parameter | Description | Default | +| ------------- | -------------------------------------------- | ----------------------------------- | +| random_state | An integer to seed the random generator with | Use the (numpy) global random state | +| method_kwargs | Additional parameters for the chosen method | {} | + +## Sampling Methods + +### Uniform Sampling + +For uniform sampling, no additional parameters are required in method_kwargs. + +### Stratified Sampling + +For stratified sampling, specify the following in method_kwargs: + +- stratify_key: The key to stratify by + +### Outlier Sampling + +For outlier sampling, specify the following in method_kwargs: + +- embedding_keys: A list of keys to use for creating embeddings. +- std: The number of standard deviations to use as the cutoff for outliers. +- samples: The number or fraction of samples to consider as outliers. +- keep: Whether to keep (true) or remove (false) the outliers. Defaults to false. +- center: (Optional) A dictionary specifying the center point for distance calculations. It should look like a document, with all the keys present in the embedding_keys list. + +You must specify either "std" or "samples" in the outliers configuration, but not both. + +### Custom Sampling + +For custom sampling, provide a list of documents to sample in the "samples" parameter. Each document in the list should be a dictionary containing keys that match the keys in your input data. + +## Examples: + +Uniform sampling: + +```yaml +- name: uniform_sample + type: sample + method: uniform + samples: 100 +``` + +Stratified sampling: + +```yaml +- name: stratified_sample + type: sample + method: stratify + samples: 0.2 + method_kwargs: + stratify_key: category +``` + +Outlier sampling: + +```yaml +- name: remove_outliers + type: sample + method: outliers + method_kwargs: + embedding_keys: + - concept + - description + std: 2 + keep: false +``` + +Custom sampling: + +```yaml +- name: custom_sample + type: sample + method: custom + samples: + - id: 1 + - id: 5 +``` + +Outlier sampling with a center: + +```yaml +- name: remove_outliers + type: sample + method: outliers + method_kwargs: + embedding_keys: + - concept + - description + center: + concept: Tree house + description: A small house built among the branches of a tree for children to play in. +``` diff --git a/mkdocs.yml b/mkdocs.yml index e995670e..988c272f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -35,6 +35,7 @@ nav: - Split: operators/split.md - Gather: operators/gather.md - Unnest: operators/unnest.md + - Sample: operators/sample.md - Optimization: - Overview: optimization/overview.md - Example: optimization/example.md diff --git a/pyproject.toml b/pyproject.toml index e909ca87..54640b3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ reduce = "docetl.operations.reduce:ReduceOperation" resolve = "docetl.operations.resolve:ResolveOperation" gather = "docetl.operations.gather:GatherOperation" cluster = "docetl.operations.cluster:ClusterOperation" +sample = "docetl.operations.sample:SampleOperation" [tool.poetry.plugins."docetl.parser"] llama_index_simple_directory_reader = "docetl.parsing_tools:llama_index_simple_directory_reader" diff --git a/tests/basic/test_cluster.py b/tests/basic/test_cluster.py deleted file mode 100644 index 0bcaa717..00000000 --- a/tests/basic/test_cluster.py +++ /dev/null @@ -1,117 +0,0 @@ -import pytest -from docetl.operations.cluster import ClusterOperation -from tests.conftest import api_wrapper, default_model, max_threads - - -@pytest.fixture -def cluster_config(): - return { - "name": "test_cluster", - "type": "cluster", - "embedding_keys": ["concept", "description"], - "output_key": "categories", - "summary_schema": {"concept": "string", "description": "string"}, - "summary_prompt": """ - The following describes two related concepts. What concept - encompasses both? Try not to be too broad; it might be that one of - these two concepts already encompasses the other; in that case, - you should just use that concept. - - {% for input in inputs %} - {{input.concept}}: - {{input.description}} - {% endfor %} - - Provide the title of the super-concept, and a description. - """, - "model": "gpt-4o-mini", - } - - -@pytest.fixture -def sample_data(): - return [ - { - "concept": "Shed", - "description": "A simple, single-story roofed structure, often used for storage or as a workshop.", - }, - { - "concept": "Barn", - "description": "A large agricultural building used for storing farm products and sheltering livestock.", - }, - { - "concept": "Tree house", - "description": "A small house built among the branches of a tree for children to play in.", - }, - { - "concept": "Skyscraper", - "description": "A very tall building of many stories, typically found in urban areas.", - }, - { - "concept": "Castle", - "description": "A large fortified building or set of buildings from the medieval period.", - }, - { - "concept": "Igloo", - "description": "A dome-shaped dwelling made of blocks of solid snow, traditionally built by Inuit people.", - }, - { - "concept": "Lighthouse", - "description": "A tower with a bright light at the top, used to warn or guide ships at sea.", - }, - { - "concept": "Windmill", - "description": "A building with sails or vanes that turn in the wind and generate power to grind grain into flour.", - }, - ] - - -def test_cluster_operation( - cluster_config, sample_data, api_wrapper, default_model, max_threads -): - cluster_config["bypass_cache"] = True - operation = ClusterOperation( - api_wrapper, cluster_config, default_model, max_threads - ) - results, cost = operation.execute(sample_data) - - assert len(results) == len(sample_data) - assert cost > 0 - - for result in results: - assert "categories" in result - assert isinstance(result["categories"], tuple) - assert len(result["categories"]) > 0 - - for category in result["categories"]: - assert "concept" in category - assert "description" in category - - -def test_cluster_operation_empty_input( - cluster_config, api_wrapper, default_model, max_threads -): - operation = ClusterOperation( - api_wrapper, cluster_config, default_model, max_threads - ) - results, cost = operation.execute([]) - - assert len(results) == 0 - assert cost == 0 - - -def test_cluster_operation_single_item( - cluster_config, api_wrapper, default_model, max_threads -): - single_item = [ - {"concept": "House", "description": "A building for human habitation."} - ] - operation = ClusterOperation( - api_wrapper, cluster_config, default_model, max_threads - ) - results, cost = operation.execute(single_item) - - assert len(results) == 1 - assert cost == 0 - assert "categories" in results[0] - assert isinstance(results[0]["categories"], tuple) diff --git a/tests/basic/test_cluster_and_sample.py b/tests/basic/test_cluster_and_sample.py new file mode 100644 index 00000000..f45ef06f --- /dev/null +++ b/tests/basic/test_cluster_and_sample.py @@ -0,0 +1,248 @@ +import pytest +from docetl.operations.cluster import ClusterOperation +from docetl.operations.sample import SampleOperation +from tests.conftest import api_wrapper, default_model, max_threads + + +@pytest.fixture +def cluster_config(): + return { + "name": "test_cluster", + "type": "cluster", + "embedding_keys": ["concept", "description"], + "output_key": "categories", + "summary_schema": {"concept": "string", "description": "string"}, + "summary_prompt": """ + The following describes two related concepts. What concept + encompasses both? Try not to be too broad; it might be that one of + these two concepts already encompasses the other; in that case, + you should just use that concept. + + {% for input in inputs %} + {{input.concept}}: + {{input.description}} + {% endfor %} + + Provide the title of the super-concept, and a description. + """, + "model": "gpt-4o-mini", + } + + +@pytest.fixture +def sample_data(): + return [ + { + "id": 1, + "concept": "Shed", + "description": "A simple, single-story roofed structure, often used for storage or as a workshop.", + "group": "A", + }, + { + "id": 2, + "concept": "Barn", + "description": "A large agricultural building used for storing farm products and sheltering livestock.", + "group": "B", + }, + { + "id": 3, + "concept": "Tree house", + "description": "A small house built among the branches of a tree for children to play in.", + "group": "A", + }, + { + "id": 4, + "concept": "Skyscraper", + "description": "A very tall building of many stories, typically found in urban areas.", + "group": "B", + }, + { + "id": 5, + "concept": "Castle", + "description": "A large fortified building or set of buildings from the medieval period.", + "group": "A", + }, + { + "id": 6, + "concept": "Igloo", + "description": "A dome-shaped dwelling made of blocks of solid snow, traditionally built by Inuit people.", + "group": "B", + }, + { + "id": 7, + "concept": "Lighthouse", + "description": "A tower with a bright light at the top, used to warn or guide ships at sea.", + "group": "A", + }, + { + "id": 8, + "concept": "Windmill", + "description": "A building with sails or vanes that turn in the wind and generate power to grind grain into flour.", + "group": "B", + }, + ] + + +def test_cluster_operation( + cluster_config, sample_data, api_wrapper, default_model, max_threads +): + cluster_config["bypass_cache"] = True + operation = ClusterOperation( + api_wrapper, cluster_config, default_model, max_threads + ) + results, cost = operation.execute(sample_data) + + assert len(results) == len(sample_data) + assert cost > 0 + + for result in results: + assert "categories" in result + assert isinstance(result["categories"], tuple) + assert len(result["categories"]) > 0 + + for category in result["categories"]: + assert "concept" in category + assert "description" in category + + +def test_cluster_operation_empty_input( + cluster_config, api_wrapper, default_model, max_threads +): + operation = ClusterOperation( + api_wrapper, cluster_config, default_model, max_threads + ) + results, cost = operation.execute([]) + + assert len(results) == 0 + assert cost == 0 + + +def test_cluster_operation_single_item( + cluster_config, api_wrapper, default_model, max_threads +): + single_item = [ + {"concept": "House", "description": "A building for human habitation."} + ] + operation = ClusterOperation( + api_wrapper, cluster_config, default_model, max_threads + ) + results, cost = operation.execute(single_item) + + assert len(results) == 1 + assert cost == 0 + assert "categories" in results[0] + assert isinstance(results[0]["categories"], tuple) + + +@pytest.fixture +def sample_config(): + return { + "name": "sample_operation", + "type": "sample", + "random_state": 42, # For reproducibility + } + + +def test_sample_operation_with_count( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_config["samples"] = 5 + sample_config["method"] = "uniform" + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) == 5 + assert cost == 0 + assert all(item in sample_data for item in results) + + +def test_sample_operation_with_fraction( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_config["samples"] = 0.5 + sample_config["method"] = "uniform" + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) == len(sample_data) // 2 + assert cost == 0 + assert all(item in sample_data for item in results) + + +def test_sample_operation_with_list( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_list = [{"id": 1}, {"id": 3}, {"id": 5}] + sample_config["samples"] = sample_list + sample_config["method"] = "custom" + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) == len(sample_list) + assert cost == 0 + assert all(item["id"] in [1, 3, 5] for item in results) + + +def test_sample_operation_with_stratify( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_config["samples"] = 5 + sample_config["method"] = "stratify" + sample_config["method_kwargs"] = {"stratify_key": "group"} + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) == 5 + assert cost == 0 + assert all(item in sample_data for item in results) + assert len(set(item["group"] for item in results)) > 1 + + +def test_sample_operation_with_outliers( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_config["method"] = "outliers" + sample_config["method_kwargs"] = { + "std": 2, + "embedding_keys": ["concept", "description"], + "keep": True, + } + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) < len(sample_data) + assert cost > 0 + assert all(item in sample_data for item in results) + + +def test_sample_operation_empty_input( + sample_config, api_wrapper, default_model, max_threads +): + sample_config["samples"] = 3 + sample_config["method"] = "uniform" + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute([]) + + assert len(results) == 0 + assert cost == 0 + + +def test_sample_operation_with_outliers_and_center( + sample_config, sample_data, api_wrapper, default_model, max_threads +): + sample_config["method"] = "outliers" + sample_config["method_kwargs"] = { + "std": 2, + "embedding_keys": ["concept", "description"], + "keep": True, + "center": { + "concept": "Tree house", + "description": "A small house built among the branches of a tree for children to play in.", + }, + } + operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads) + results, cost = operation.execute(sample_data) + + assert len(results) < len(sample_data) + assert cost > 0 + assert all(item in sample_data for item in results)