From 784f26511bb56f79ae3658f02114ebedb440430b Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Thu, 17 Oct 2024 13:16:17 -0700 Subject: [PATCH 1/2] fix: update python api --- README.md | 2 +- docetl/api.py | 16 ++++++++++------ docetl/schemas.py | 20 +++++++++++++++++++- docs/operators/cluster.md | 2 -- 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ac3e2c36..e71189aa 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # DocETL: Powering Complex Document Processing Pipelines -[Website (Includes Demo)](https://docetl.com) | [Documentation](https://ucbepic.github.io/docetl) | [Discord](https://discord.gg/fHp7B2X3xx) | [NotebookLM Podcast](https://notebooklm.google.com/notebook/ef73248b-5a43-49cd-9976-432d20f9fa4f/audio?pli=1) (thanks Shabie from our Discord community!) | Paper (coming soon!) +[Website (Includes Demo)](https://docetl.org) | [Documentation](https://ucbepic.github.io/docetl) | [Discord](https://discord.gg/fHp7B2X3xx) | [Paper](https://arxiv.org/abs/2410.12189) ![DocETL Figure](docs/assets/readmefig.png) diff --git a/docetl/api.py b/docetl/api.py index 893a4faf..f7f7f6de 100644 --- a/docetl/api.py +++ b/docetl/api.py @@ -64,17 +64,17 @@ FilterOp, GatherOp, MapOp, + ReduceOp, + ResolveOp, + SplitOp, + UnnestOp, + ClusterOp, + SampleOp, OpType, ParallelMapOp, ParsingTool, -) -from docetl.schemas import ( PipelineOutput, PipelineStep, - ReduceOp, - ResolveOp, - SplitOp, - UnnestOp, ) @@ -322,6 +322,10 @@ def _update_from_dict(self, config: Dict[str, Any]): self.operations.append(GatherOp(**op, type=op_type)) elif op_type == "unnest": self.operations.append(UnnestOp(**op, type=op_type)) + elif op_type == "cluster": + self.operations.append(ClusterOp(**op, type=op_type)) + elif op_type == "sample": + self.operations.append(SampleOp(**op, type=op_type)) self.steps = [PipelineStep(**step) for step in config["pipeline"]["steps"]] self.output = PipelineOutput(**config["pipeline"]["output"]) self.default_model = config.get("default_model") diff --git a/docetl/schemas.py b/docetl/schemas.py index ce89473e..90382fe0 100644 --- a/docetl/schemas.py +++ b/docetl/schemas.py @@ -78,7 +78,7 @@ class Dataset(BaseModel): parsing: Optional[List[Dict[str, str]]] = None -class BaseOp(BaseModel): +class BaseOp(BaseModel, extra="allow"): name: str type: str @@ -222,6 +222,22 @@ class UnnestOp(BaseOp): depth: Optional[int] = None +class ClusterOp(BaseOp): + type: str = "cluster" + embedding_keys: List[str] + summary_prompt: str + summary_schema: Dict[str, Any] + output_key: Optional[str] = "clusters" + + +class SampleOp(BaseOp): + type: str = "sample" + method: str + samples: Union[int, float, List[Dict[str, Any]]] + method_kwargs: Optional[Dict[str, Any]] = None + random_state: Optional[int] = None + + OpType = Union[ MapOp, ResolveOp, @@ -232,6 +248,8 @@ class UnnestOp(BaseOp): SplitOp, GatherOp, UnnestOp, + ClusterOp, + SampleOp, ] diff --git a/docs/operators/cluster.md b/docs/operators/cluster.md index 51ef82a1..e7abb65a 100644 --- a/docs/operators/cluster.md +++ b/docs/operators/cluster.md @@ -181,8 +181,6 @@ and a description, and groups them into a tree of categories. | `output_key` | The name of the output key where the cluster path will be inserted in the items. | "clusters" | | `model` | The language model to use | Falls back to `default_model` | | `embedding_model` | The embedding model to use | "text-embedding-3-small" | -| `tools` | List of tool definitions for LLM use | None | | `timeout` | Timeout for each LLM call in seconds | 120 | | `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | -| `validate` | List of Python expressions to validate the output | None | | `sample` | Number of items to sample for this operation | None | From ead12e8e282361937ae0a525becfa318b47b3ba0 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Thu, 17 Oct 2024 13:21:17 -0700 Subject: [PATCH 2/2] add flaky test --- tests/test_runner_caching.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_runner_caching.py b/tests/test_runner_caching.py index fbd973fa..42f02f2f 100644 --- a/tests/test_runner_caching.py +++ b/tests/test_runner_caching.py @@ -58,6 +58,7 @@ def create_pipeline(input_file, output_file, intermediate_dir, operation_prompt) ) +@pytest.mark.flaky(reruns=3) def test_pipeline_rerun_on_operation_change( temp_input_file, temp_output_file, temp_intermediate_dir ):