Skip to content

Commit

Permalink
Fix bugs in pipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Oct 21, 2023
1 parent 4dee9da commit 8199358
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 23 deletions.
31 changes: 9 additions & 22 deletions scripts/render_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
also stored in mlflow. The id of the mlflow experiment is given by the schema
`experiment_name.model_name`.
"""
import math as m
import os
import os.path
from datetime import datetime
Expand All @@ -22,9 +21,15 @@
import mlflow
import numpy as np
from dotenv import load_dotenv
from matplotlib import pyplot as plt

from re_classwise_shapley.io import Accessor
from re_classwise_shapley.log import log_datasets, log_figure, setup_logger
from re_classwise_shapley.log import (
get_or_create_mlflow_experiment,
log_datasets,
log_figure,
setup_logger,
)
from re_classwise_shapley.plotting import (
plot_curves,
plot_histogram,
Expand All @@ -41,25 +46,6 @@
logger = setup_logger("render_plots")


def get_or_create_mlflow_experiment(experiment_name: str) -> str:
"""
Get or create a mlflow experiment. If the experiment does not exist, it will be
created.
Args:
experiment_name: Name of the experiment.
Returns:
Identifier of the experiment.
"""
experiment = mlflow.get_experiment_by_name(experiment_name)
if not experiment:
experiment_id = mlflow.create_experiment(experiment_name)
else:
experiment_id = experiment.experiment_id
return experiment_id


@click.command()
@click.option("--experiment-name", type=str, required=True)
@click.option("--model-name", type=str, required=True)
Expand All @@ -77,7 +63,6 @@ def render_plots(experiment_name: str, model_name: str):


def _render_plots(experiment_name: str, model_name: str):
load_dotenv()
logger.info("Starting plotting of data valuation experiment")
output_folder = Accessor.PLOT_PATH / experiment_name / model_name
mlflow_id = f"{experiment_name}.{model_name}"
Expand Down Expand Up @@ -110,6 +95,7 @@ def _render_plots(experiment_name: str, model_name: str):
)
)

plt.switch_backend("agg")
valuation_results = Accessor.valuation_results(
experiment_name,
model_name,
Expand Down Expand Up @@ -169,4 +155,5 @@ def _render_plots(experiment_name: str, model_name: str):


if __name__ == "__main__":
load_dotenv()
render_plots()
19 changes: 19 additions & 0 deletions src/re_classwise_shapley/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,22 @@ def dataset_to_dataframe(dataset: Dataset) -> pd.DataFrame:
columns=dataset.feature_names + dataset.target_names,
)
return df


def get_or_create_mlflow_experiment(experiment_name: str) -> str:
"""
Get or create a mlflow experiment. If the experiment does not exist, it will be
created.
Args:
experiment_name: Name of the experiment.
Returns:
Identifier of the experiment.
"""
experiment = mlflow.get_experiment_by_name(experiment_name)
if not experiment:
experiment_id = mlflow.create_experiment(experiment_name)
else:
experiment_id = experiment.experiment_id
return experiment_id
8 changes: 7 additions & 1 deletion src/re_classwise_shapley/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@

logger = setup_logger()

__all__ = ["flatten_dict", "pipeline_seed", "load_params_fast", "n_threaded"]
__all__ = [
"flatten_dict",
"pipeline_seed",
"load_params_fast",
"n_threaded",
"linear_dataframe_to_table",
]


def pipeline_seed(initial_seed: Seed, pipeline_step: int) -> int:
Expand Down

0 comments on commit 8199358

Please sign in to comment.