Skip to content

Commit

Permalink
refactor: address redhog feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Oct 13, 2024
1 parent 658b59e commit eecbd41
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 36 deletions.
17 changes: 6 additions & 11 deletions docetl/operations/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,6 @@ def syntax_check(self) -> None:
if "center" in self.config.get("method_kwargs", {}):
if not isinstance(self.config.get("method_kwargs", {})["center"], dict):
raise TypeError("'center' must be a dictionary")
for key, value in self.config.get("method_kwargs", {})["center"].items():
if not isinstance(value, (int, float)):
raise TypeError(
f"Values in 'center' must be numbers, got {type(value)} for key '{key}'"
)

def execute(
self, input_data: List[Dict], is_build: bool = False
Expand Down Expand Up @@ -133,13 +128,13 @@ def execute(
cost += embedding_cost
embeddings = np.array(embeddings)

if "center" in self.config:
center = np.array(
[
outliers_config["center"][key]
for key in outliers_config["embedding_keys"]
]
if "center" in outliers_config:
center_embeddings, cost2 = get_embeddings_for_clustering(
[outliers_config["center"]], outliers_config, self.runner.api
)
cost += cost2
center = np.array(center_embeddings[0])

else:
center = embeddings.mean(axis=0)

Expand Down
109 changes: 84 additions & 25 deletions docs/operators/sample.md
Original file line number Diff line number Diff line change
@@ -1,64 +1,123 @@
# Sample operation

The Sample operation in DocETL samples items from the input. It is
meant mostly as a debugging tool:
The Sample operation in DocETL samples items from the input. It is meant mostly as a debugging tool:

Insert it before the last operation, the one you're currently trying
to tack on to the end of a working pipeline, to limit the amount of
data it will be fed, so that the run time is small enough to
comfortably debug its prompt. Once it seems to be working, you can
remove the sample operation. You can then repeat this for each
operation you add while developing your pipeline!
Insert it before the last operation, the one you're currently trying to add to the end of a working pipeline, to limit the amount of data it will be fed, so that the run time is small enough to comfortably debug its prompt. Once it seems to be working, you can remove the sample operation. You can then repeat this for each operation you add while developing your pipeline!

## 🚀 Example:

```yaml
- name: cluster_concepts
type: sample
method: stratify
samples: 0.1
method_kwargs:
stratify_key: category
random_state: 42
stratify: category
```
This sample operation will return a pseudo-randomly selected 10% of
the samples (samples: 0.1). The random selection will be seeded with
a constant (42), meaning the same selection will be returned if you
rerun the pipeline (If no random state is given, a different sample
will be returned every time). Additionally, the random sampling will
sample each value of the category key equally.
This sample operation will return a pseudo-randomly selected 10% of the samples (samples: 0.1). The random selection will be seeded with a constant (42), meaning the same sample will be returned if you rerun the pipeline (If no random state is given, a different sample will be returned every time). Additionally, the random sampling will sample each value of the category key equally.
## Required Parameters
- name: A unique name for the operation.
- type: Must be set to "sample".
- method: The sampling method to use. Can be "uniform", "stratify", "outliers", or "custom".
- samples: Either a list of key-value pairs representing document ids and values, an integer count of samples, or a float fraction of samples.
## Optional Parameters
| Parameter | Description | Default |
| ------------ | -------------------------------------------- | ----------------------------------- |
| random_state | An integer to seed the random generator with | Use the (numpy) global random state |
| stratify | The key to stratify by | |
| Parameter | Description | Default |
| ------------- | -------------------------------------------- | ----------------------------------- |
| random_state | An integer to seed the random generator with | Use the (numpy) global random state |
| method_kwargs | Additional parameters for the chosen method | {} |
## Outliers
## Sampling Methods
The Sample operation can also be used to sample outliers. To do this, instead of specifying "samples", specify an "outliers" object with the following parameters:
### Uniform Sampling
For uniform sampling, no additional parameters are required in method_kwargs.
### Stratified Sampling
For stratified sampling, specify the following in method_kwargs:
- stratify_key: The key to stratify by
### Outlier Sampling
For outlier sampling, specify the following in method_kwargs:
- embedding_keys: A list of keys to use for creating embeddings.
- std: The number of standard deviations to use as the cutoff for outliers.
- samples: The number or fraction of samples to consider as outliers.
- keep: Whether to keep (true) or remove (false) the outliers. Defaults to false.
- center: (Optional) A dictionary specifying the center point for distance calculations. It should look like a document, with all the keys present in the embedding_keys list.
You must specify either "std" or "samples" in the outliers configuration, but not both.
Example:
### Custom Sampling
For custom sampling, provide a list of documents to sample in the "samples" parameter. Each document in the list should be a dictionary containing keys that match the keys in your input data.
## Examples:
Uniform sampling:
```yaml
- name: uniform_sample
type: sample
method: uniform
samples: 100
```
Stratified sampling:
```yaml
- name: stratified_sample
type: sample
method: stratify
samples: 0.2
method_kwargs:
stratify_key: category
```
Outlier sampling:
```yaml
- name: remove_outliers
type: sample
method: outliers
method_kwargs:
embedding_keys:
- concept
- description
std: 2
keep: false
```
Custom sampling:
```yaml
- name: custom_sample
type: sample
method: custom
samples:
- id: 1
- id: 5
```
Outlier sampling with a center:
```yaml
- name: remove-worst-10
- name: remove_outliers
type: sample
outliers:
method: outliers
method_kwargs:
embedding_keys:
- concept
- description
samples: 0.9
center:
concept: Tree house
description: A small house built among the branches of a tree for children to play in.
```
21 changes: 21 additions & 0 deletions tests/basic/test_cluster_and_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,24 @@ def test_sample_operation_empty_input(

assert len(results) == 0
assert cost == 0


def test_sample_operation_with_outliers_and_center(
sample_config, sample_data, api_wrapper, default_model, max_threads
):
sample_config["method"] = "outliers"
sample_config["method_kwargs"] = {
"std": 2,
"embedding_keys": ["concept", "description"],
"keep": True,
"center": {
"concept": "Tree house",
"description": "A small house built among the branches of a tree for children to play in.",
},
}
operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads)
results, cost = operation.execute(sample_data)

assert len(results) < len(sample_data)
assert cost > 0
assert all(item in sample_data for item in results)

0 comments on commit eecbd41

Please sign in to comment.