Skip to content

Commit

Permalink
Final fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
kosmitive committed May 1, 2024
1 parent 8a4f2c8 commit 59790bc
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 380 deletions.
176 changes: 125 additions & 51 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Reproduction "CS-Shapley: Class-wise Shapley Values for Data Valuation in Classification"

Code for the submission to the ML Reproducibility Challenge 2023. The original paper can
be found [here](https://arxiv.org/abs/2211.06800).
Code for the submission to TMLR 2024. The original paper can be found [here](https://arxiv.org/abs/2211.06800).

# Getting started

Expand All @@ -18,16 +17,63 @@ poetry install

a new conda environment is created and all dependencies get installed.

## MLflow

## env file

In `sample.env` is an example for a env. This environment has to be copied to `.env`
file prior to starting the

### MLflow Configuration

#### Tracking Server URL
- `MLFLOW_TRACKING_URI`: The URL where the MLflow tracking server is hosted.
- **Example**: `http://localhost:5000`

#### Artifact Storage
- `MLFLOW_S3_ENDPOINT_URL`: The endpoint URL for S3 compatible artifact storage.
- **Example**: `http://localhost:9000`

### AWS Credentials (Used for MinIO in this setup)
- `AWS_ACCESS_KEY_ID`: The access key ID for AWS or AWS-compatible services.
- **Example**: `key`
- `AWS_SECRET_ACCESS_KEY`: The secret access key for AWS or AWS-compatible services.
- **Example**: `access_key`

### MinIO Configuration
- `MINIO_ROOT_USER`: The username for the MinIO root user.
- **Example**: `mlflow`
- `MINIO_ROOT_PASSWORD`: The password for the MinIO root user.
- **Example**: `password`

### MySQL Database Configuration

#### mlflow Database
- `MYSQL_DATABASE`: The database name for MLflow.
- **Example**: `mlflow`
- `MYSQL_USER`: The username for accessing the MLflow database.
- **Example**: `mlflow`
- `MYSQL_PASSWORD`: The password for accessing the MLflow database.
- **Example**: `password`

#### MySQL Root User
- `MYSQL_ROOT_PASSWORD`: The root password for the MySQL database.
- **Example**: `password`

### Security Note
Ensure that these credentials are secured and managed appropriately to prevent unauthorized access.


## Start mlflow using docker

For experiment tracking we use MLflow. To start MLflow

```shell
cd docker/MLflow
cp .env docker/mlflow/.env
cd docker/mlflow
docker compose up -d
```

and open `http://localhost:5000` in your browser. MLflow relies on a S3 bucket served by
and open `http://localhost:5000` in your browser. mlflow relies on a S3 bucket served by
a minio server. All plots and artifacts are logged to this bucket. If you want to stop
MLflow execute

Expand All @@ -53,14 +99,15 @@ the [Reproduction](#Reproduction) section.

The pipeline is defined in the `dvc.yaml` and consists of the following six stages:

| Stages | Description |
|---------------------|-----------------------------------------------------------------------------------|
| 1. Fetch data | Fetches a dataset with a specific ID from openml. |
| 2. Preprocess data | Applies filters and preprocessors to each dataset as defined in `params.yaml` |
| 3. Sample data | Use a seed to perform stratified sampling on the preprocessed data. |
| 4. Calculate values | Compute values for a sampled dataset from (1). |
| 5. Evaluate metrics | Calculates several metrics based on the values calculated in (4) |
| 6. Render plots | Renders plots, saves to disk and logs to MLflow. Used information from (1) to (5) |
| Stages | Description |
|---------------------|----------------------------------------------------------------------------------|
| 1. Fetch data | Fetches a dataset with a specific ID from openml. |
| 2. Preprocess data | Applies filters and preprocessors to each dataset as defined in `params.yaml` |
| 3. Sample data | Use a seed to perform stratified sampling on the preprocessed data. |
| 4. Calculate values | Compute values for a sampled dataset from (1). |
| 5. Evaluate curves | Calculates several curves based on the values calculated in (4) |
| 6. Evaluate metrics | Calculates several metrics based on the curves calculated in (5) |
| 7. Render plots | Renders plots, saves to disk and logs to MLflow. Used information from (1) to (5) |

Each stage requires inputs and outputs their result to a sub-folder of `output` folder.
In the following section we describe each stage in more detail.
Expand Down Expand Up @@ -109,20 +156,28 @@ applied method generates two files. The first file has the name
the name `valuation.<method_name>.stats.json` and contains meta information, e.g. the
execution time. Again the repetition id is used an initial seed.

### 5. Evaluate metrics
### 5. Evaluate curves

After the values are calculated, the curves need to be evaluated. In general there
can be multiple metrics for one curve. The metrics are defined in the
`curves` section of the experiment. Per metric a file is generated in
`output/curves/<experiment_name>/<model_name>/<dataset_name>/<repetition_id>/
<valuation_method_name>`. The first file contains the aggregated result (a single
number) with the file name `<metric_name>.csv`.

### 6. Evaluate metrics

After the values are calculated, the metrics need to be evaluated. In general there
can be multiple metrics for one valuation result. The metrics are defined in the
`metrics` section of the experiment. Per metric two files are generated in
`output/results/<experiment_name>/<model_name>/<dataset_name>/<repetition_id>/
After the curves are calculated, the metrics need to be evaluated. In general there
can be multiple metrics for one curve. The metrics are defined in the
`metrics` section of the experiment. Per metric a file is generated in
`output/metrics/<experiment_name>/<model_name>/<dataset_name>/<repetition_id>/
<valuation_method_name>`. The first file contains the aggregated result (a single
number) with the file name `<metric_name>.csv`. The second file contains a curve of
values, e.g. the accuracy over points removed or the precision-recall curve.
number) with the file name `<metric_name>.<curve_name>.csv`.

### 6. Render plots
### 7. Render plots

Last but not least, the plots are rendered and all relevant information is logged to
MLflow. The following plots are generated
mlflow. The following plots are generated

| Plot | Description |
|------------------|------------------------------------------------------------------------------------------|
Expand All @@ -134,10 +189,24 @@ MLflow. The following plots are generated

## Reproduction

In general there are two ways of running the experiments. The former way uses `dvc` to
In general there are two ways of running the experiments. The later way uses `dvc` to
execute the pipeline. However, writing and reading the `dvc.lock` file takes some time.
Hence, the latter way uses python directly. Both ways can be bridged by using
`dvc commit`.
Hence, the former way uses python directly. Both ways can be bridged by using
`dvc commit`.

### Manual

Sometimes `dvc` takes a lot of time inbetween stages. Hence, we integrated an option to
run the experiments without `dvc` and committing the results later on. Execute

```shell
python scripts/run_pipeline.py
dvc commit
```

to do a manual and faster run without `dvc`. Committing the results is optimal, but is
necessary if you want to switch back to `dvc` with the results of the run.


### Run with `dvc`

Expand Down Expand Up @@ -180,19 +249,6 @@ checks are skipped and thus the command runs faster than `dvc exp run`.
dvc repro [-s <stage>]
```

### Manual

Sometimes `dvc` takes a lot of time inbetween stages. Hence, we integrated an option to
run the experiments without `dvc` and committing the results later on. Execute

```shell
python scripts/run_pipeline.py
dvc commit
```

to do a manual and faster run without `dvc`. Committing the results is optimal, but is
necessary if you want to switch back to `dvc` with the results of the run.

# Development

Make sure to install the pre-commit hooks:
Expand Down Expand Up @@ -312,12 +368,9 @@ experiments:
noise_removal:
sampler: default
preprocessors:
flip_labels:
perc: 0.2
metrics:
roc_auc:
idx: precision_recall_roc_auc
flipped_labels: preprocessor.flip_labels.idx
[...]
metrics:
[...]
```

Note that each experiment has a unique name (in our case `noise_removal`). Furthermore,
Expand All @@ -333,26 +386,47 @@ samplers:
max_samples: 3000
```

### Register a new curve

To add a new metric, register your curve with `CurveRegistry` in
`re_classwise_shapley.curve`. See existing
curves for more details. After a curve is registered it can be used in the `params.yaml`
file as follows:

```yaml
experiments:
point_removal:
[...]
curves:
accuracy:
fn: accuracy
type: "mse"
```

A special role has the parameter. It defines how much of the curve
should be drawn in the plots. It is not passed to the metric itself, but used in the
last stage. All other parameters are passed as keyword arguments to the metric.

### Register a new metric

To add a new metric, register your metric with `MetricRegistry` in
`re_classwise_shapley.metric`. A metric can accept any subset of parameters
`data`, `values`, `info`, `n_jobs`, `config`, `progress` and `seed`. See existing
metrics for more details. After a metric is registered it can be used in the `dvc.yaml`
metrics for more details. After a metric is registered it can be used in the `params.yaml`
file as follows:

```yaml
experiments:
point_removal:
[...]
metrics:
accuracy_logistic_regression:
idx: weighted_metric_drop
metric: accuracy
eval_model: logistic_regression
len_curve_perc: 0.5
mean:
curve:
- accuracy
fn: mean_accuracy
arg: arg
```

A special role has the parameter `len_curve_perc`. It defines how much of the curve
A special role has the parameter. It defines how much of the curve
should be drawn in the plots. It is not passed to the metric itself, but used in the
last stage. All other parameters are passed as keyword arguments to the metric.
76 changes: 63 additions & 13 deletions dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,23 @@ stages:
- output/sampled/${item.experiment}/${item.dataset}:
persist: true

determine-in-out-of-cls-marginal-accuracies:
calculate-threshold-characteristics:
matrix:
experiment: ${active.experiments}
dataset: ${active.datasets}
repetition: ${active.repetitions}
cmd: >
python -m scripts.determine_in_out_of_cls_marginal_accuracies
python -m scripts.calculate_threshold_characteristics
--experiment-name ${item.experiment}
--dataset-name ${item.dataset}
--model-name logistic_regression
--repetition-id ${item.repetition}
params:
- settings.threshold_characteristics
deps:
- scripts/determine_in_out_of_cls_marginal_accuracies.py
- scripts/calculate_threshold_characteristics.py
- output/sampled/${item.experiment}/${item.dataset}
outs:
- output/info/${item.experiment}/${item.dataset}/in_out_of_cls_marginals.json:
- output/threshold_characteristics/${item.experiment}/${item.dataset}/${item.repetition}:
persist: true

calculate-values:
Expand All @@ -89,7 +92,56 @@ stages:
- src/re_classwise_shapley/valuation_methods.py
- output/sampled/${item.experiment}/${item.dataset}
outs:
- output/values/${item.experiment}/${item.model}/${item.dataset}/${item.repetition}/valuation.${item.method}.pkl:
- output/values/${item.experiment}/${item.model}/${item.dataset}/${item.repetition}:
persist: true

# TODO Make more concise once https://github.com/iterative/dvc/issues/9948 is closed.
evaluate-curves-point-removal:
matrix:
dataset: ${active.datasets}
model: ${active.models}
repetition: ${active.repetitions}
method: ${active.valuation_methods}
curve: ${experiments.point_removal.curves}
cmd: python -m scripts.evaluate_curves
--experiment-name point_removal
--dataset-name ${item.dataset}
--model-name ${item.model}
--valuation-method-name ${item.method}
--repetition-id ${item.repetition}
--curve-name ${item.curve}
params:
- experiments.point_removal.curves.${item.curve}
deps:
- scripts/evaluate_curves.py
- src/re_classwise_shapley/curve.py
- output/values/point_removal/${item.model}/${item.dataset}/${item.repetition}/valuation.${item.method}.pkl
outs:
- output/curves/point_removal/${item.model}/${item.dataset}/${item.repetition}/${item.method}/${item.curve}.csv:
persist: true

evaluate-curves-noise-removal:
matrix:
dataset: ${active.datasets}
model: ${active.models}
repetition: ${active.repetitions}
method: ${active.valuation_methods}
curve: ${experiments.noise_removal.curves}
cmd: python -m scripts.evaluate_curves
--experiment-name noise_removal
--dataset-name ${item.dataset}
--model-name ${item.model}
--valuation-method-name ${item.method}
--repetition-id ${item.repetition}
--curve-name ${item.curve}
params:
- experiments.noise_removal.curves.${item.curve}
deps:
- scripts/evaluate_curves.py
- src/re_classwise_shapley/curve.py
- output/values/noise_removal/${item.model}/${item.dataset}/${item.repetition}/valuation.${item.method}.pkl
outs:
- output/curves/noise_removal/${item.model}/${item.dataset}/${item.repetition}/${item.method}/${item.curve}.csv:
persist: true

# TODO Make more concise once https://github.com/iterative/dvc/issues/9948 is closed.
Expand All @@ -114,9 +166,7 @@ stages:
- src/re_classwise_shapley/metric.py
- output/values/point_removal/${item.model}/${item.dataset}/${item.repetition}/valuation.${item.method}.pkl
outs:
- output/results/point_removal/${item.model}/${item.dataset}/${item.repetition}/${item.method}/${item.metric}.csv:
persist: true
- output/results/point_removal/${item.model}/${item.dataset}/${item.repetition}/${item.method}/${item.metric}.curve.csv:
- output/metrics/point_removal/${item.model}/${item.dataset}/${item.repetition}/${item.method}/${item.metric}:
persist: true

evaluate-metrics-noise-removal:
Expand All @@ -140,9 +190,7 @@ stages:
- src/re_classwise_shapley/metric.py
- output/values/noise_removal/${item.model}/${item.dataset}/${item.repetition}/valuation.${item.method}.pkl
outs:
- output/results/noise_removal/${item.model}/${item.dataset}/${item.repetition}/${item.method}/${item.metric}.csv:
persist: true
- output/results/noise_removal/${item.model}/${item.dataset}/${item.repetition}/${item.method}/${item.metric}.curve.csv:
- output/metrics/noise_removal/${item.model}/${item.dataset}/${item.repetition}/${item.method}/${item.metric}:
persist: true

render-plots:
Expand All @@ -156,7 +204,9 @@ stages:
- experiments.${item.experiment}.metrics
deps:
- scripts/render_plots.py
- output/results/${item.experiment}/${item.model}
- output/curves/${item.experiment}/${item.model}
- output/metrics/${item.experiment}/${item.model}
- output/threshold_characteristics/${item.experiment}
outs:
- output/plots/${item.experiment}/${item.model}

Expand Down
Loading

0 comments on commit 59790bc

Please sign in to comment.