Skip to content

Commit

Permalink
added exp setup parameters (goeckslab#15)
Browse files Browse the repository at this point in the history
replicate the brach from exp-parameter-setup
  • Loading branch information
qchiujunhao authored Jul 11, 2024
1 parent 0a9a284 commit 5696cf1
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 30 deletions.
57 changes: 53 additions & 4 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
LOG = logging.getLogger(__name__)

class BaseModelTrainer:
def __init__(self, input_file, target_col, output_dir):
def __init__(self, input_file, target_col, output_dir, **kwargs):
self.exp = None # This will be set in the subclass
self.input_file = input_file
self.target_col = target_col
Expand All @@ -18,6 +18,11 @@ def __init__(self, input_file, target_col, output_dir):
self.best_model = None
self.results = None
self.plots = {}
for key, value in kwargs.items():
setattr(self, key, value)
self.setup_params = {}

LOG.info(f"Model kwargs: {self.__dict__}")

def load_data(self):
LOG.info(f"Loading data from {self.input_file}")
Expand All @@ -29,9 +34,44 @@ def load_data(self):

def setup_pycaret(self):
LOG.info("Initializing PyCaret")
self.exp.setup(self.data, target=self.target,
session_id=123, html=True,
log_experiment=False, system_log=False)
self.setup_params = {
'target': self.target,
'session_id': 123,
'html': True,
'log_experiment': False,
'system_log': False
}

if hasattr(self, 'train_size') and self.train_size is not None:
self.setup_params['train_size'] = self.train_size

if hasattr(self, 'normalize') and self.normalize is not None:
self.setup_params['normalize'] = self.normalize

if hasattr(self, 'feature_selection') and self.feature_selection is not None:
self.setup_params['feature_selection'] = self.feature_selection

if hasattr(self, 'cross_validation') and self.cross_validation is not None and self.cross_validation == False:
self.setup_params['cross_validation'] = self.cross_validation

if hasattr(self, 'cross_validation') and self.cross_validation is not None:
if hasattr(self, 'cross_validation_folds'):
self.setup_params['fold'] = self.cross_validation_folds

if hasattr(self, 'remove_outliers') and self.remove_outliers is not None:
self.setup_params['remove_outliers'] = self.remove_outliers

if hasattr(self, 'remove_multicollinearity') and self.remove_multicollinearity is not None:
self.setup_params['remove_multicollinearity'] = self.remove_multicollinearity

if hasattr(self, 'polynomial_features') and self.polynomial_features is not None:
self.setup_params['polynomial_features'] = self.polynomial_features

if hasattr(self, 'fix_imbalance') and self.fix_imbalance is not None:
self.setup_params['fix_imbalance'] = self.fix_imbalance

LOG.info(self.setup_params)
self.exp.setup(self.data, **self.setup_params)

def train_model(self):
LOG.info("Training and selecting the best model")
Expand All @@ -54,6 +94,10 @@ def save_html_report(self):

model_name = type(self.best_model).__name__

excluded_params = ['html', 'log_experiment', 'system_log']
filtered_setup_params = {k: v for k, v in self.setup_params.items() if k not in excluded_params}
setup_params_table = pd.DataFrame(list(filtered_setup_params.items()), columns=['Parameter', 'Value'])

# Save model summary
best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value'])
best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False)
Expand Down Expand Up @@ -132,6 +176,11 @@ def save_html_report(self):
<body>
<div class="container">
<h1>PyCaret Model Training Report</h1>
<h2>Setup Parameters</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{setup_params_table.to_html(index=False, header=False, classes='table')}
</table>
<h2>Best Model: {model_name}</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
Expand Down
12 changes: 8 additions & 4 deletions tools/pycaret_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
LOG = logging.getLogger(__name__)

class ClassificationModelTrainer(BaseModelTrainer):
def __init__(self, input_file, target_col, output_dir):
super().__init__(input_file, target_col, output_dir)
def __init__(self, input_file, target_col, output_dir, **kwargs):
super().__init__(input_file, target_col, output_dir, **kwargs)
self.exp = ClassificationExperiment()

def save_dashboard(self):
Expand All @@ -19,5 +19,9 @@ def generate_plots(self):
LOG.info("Generating and saving plots")
plots = ['auc', 'confusion_matrix', 'threshold', 'pr', 'error', 'class_report', 'learning', 'calibration', 'vc', 'dimension', 'manifold', 'rfe', 'feature', 'feature_all']
for plot_name in plots:
plot_path = self.exp.plot_model(self.best_model, plot=plot_name, save=True)
self.plots[plot_name] = plot_path
try:
plot_path = self.exp.plot_model(self.best_model, plot=plot_name, save=True)
self.plots[plot_name] = plot_path
except Exception as e:
LOG.error(f"Error generating plot {plot_name}: {e}")
continue
2 changes: 1 addition & 1 deletion tools/pycaret_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def save_dashboard(self):

def generate_plots(self):
LOG.info("Generating and saving plots")
plots = ['residuals', 'error', 'cooks', 'learning', 'vc', 'manifold', 'rfe', 'feature']
plots = ['residuals', 'error', 'cooks', 'learning', 'vc', 'manifold', 'rfe', 'feature', 'feature_all']
for plot_name in plots:
plot_path = self.exp.plot_model(self.best_model, plot=plot_name, save=True)
self.plots[plot_name] = plot_path
65 changes: 55 additions & 10 deletions tools/pycaret_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import argparse
import logging

from pycaret_classification import ClassificationModelTrainer
Expand All @@ -7,14 +7,59 @@
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)

input_file = sys.argv[1]
target_col = sys.argv[2]
output_dir = sys.argv[3]
model_type = sys.argv[4]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", help="Path to the input file")
parser.add_argument("--target_col", help="Column number of the target")
parser.add_argument("--output_dir", help="Path to the output directory")
parser.add_argument("--model_type", choices=["classification", "regression"], help="Type of the model")
parser.add_argument("--train_size", type=float, default=None, help="Train size for PyCaret setup")
parser.add_argument("--normalize", action="store_true", default=None, help="Normalize data for PyCaret setup")
parser.add_argument("--feature_selection", action="store_true", default=None, help="Perform feature selection for PyCaret setup")
parser.add_argument("--cross_validation", action="store_true", default=None, help="Perform cross-validation for PyCaret setup")
parser.add_argument("--cross_validation_folds", type=int, default=None, help="Number of cross-validation folds for PyCaret setup")
parser.add_argument("--remove_outliers", action="store_true", default=None, help="Remove outliers for PyCaret setup")
parser.add_argument("--remove_multicollinearity", action="store_true", default=None, help="Remove multicollinearity for PyCaret setup")
parser.add_argument("--polynomial_features", action="store_true", default=None, help="Generate polynomial features for PyCaret setup")
parser.add_argument("--feature_interaction", action="store_true", default=None, help="Generate feature interactions for PyCaret setup")
parser.add_argument("--feature_ratio", action="store_true", default=None, help="Generate feature ratios for PyCaret setup")
parser.add_argument("--fix_imbalance", action="store_true", default=None, help="Fix class imbalance for PyCaret setup")
parser.add_argument("--models", nargs='+', default=None, help="Selected models for training")

if model_type == "classification":
trainer = ClassificationModelTrainer(input_file, target_col, output_dir)
args = parser.parse_args()

model_kwargs = {
"train_size": args.train_size,
"normalize": args.normalize,
"feature_selection": args.feature_selection,
"cross_validation": args.cross_validation,
"cross_validation_folds": args.cross_validation_folds,
"remove_outliers": args.remove_outliers,
"remove_multicollinearity": args.remove_multicollinearity,
"polynomial_features": args.polynomial_features,
"feature_interaction": args.feature_interaction,
"feature_ratio": args.feature_ratio,
"fix_imbalance": args.fix_imbalance,
}
LOG.info(f"Model kwargs: {model_kwargs}")

# Remove None values from model_kwargs

LOG.info(f"Model kwargs 2: {model_kwargs}")
if args.models:
model_kwargs["models"] = args.models[0].split(",")

model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}

if args.model_type == "classification":
trainer = ClassificationModelTrainer(args.input_file, args.target_col, args.output_dir, **model_kwargs)
elif args.model_type == "regression":
trainer = RegressionModelTrainer(args.input_file, args.target_col, args.output_dir, **model_kwargs)
else:
LOG.error("Invalid model type. Please choose 'classification' or 'regression'.")
return

trainer.run()
elif model_type == "regression":
trainer = RegressionModelTrainer(input_file, target_col, output_dir)
trainer.run()

if __name__ == "__main__":
main()
131 changes: 120 additions & 11 deletions tools/pycaret_train.xml
Original file line number Diff line number Diff line change
@@ -1,30 +1,139 @@
<tool id="pycaret_tool" name="PyCaret Model Training" version="@VERSION@" profile="@PROFILE@">
<description>Train and evaluate machine learning models using PyCaret.</description>
<tool id="pycaret_tool" name="PyCaret Model Comparison" version="@VERSION@" profile="@PROFILE@">
<description>Compare different machine learning models on a dataset using PyCaret.</description>
<macros>
<import>pycaret_macros.xml</import>
</macros>
<expand macro="python_requirements" />
<command>
<![CDATA[
python '$__tool_directory__/pycaret_train.py' $input_file $target_feature "`pwd`" $model_type
python $__tool_directory__/pycaret_train.py --input_file $input_file --target_col $target_feature --output_dir "`pwd`" --model_type $model_type
#if $customize_defaults == "true"
#if $train_size
--train_size $train_size
#end if
#if $normalize
--normalize
#end if
#if $feature_selection
--feature_selection
#end if
#if $enable_cross_validation == "true"
--cross_validation
#end if
#if $cross_validation_folds
--cross_validation_folds $cross_validation_folds
#end if
#if $remove_outliers
--remove_outliers
#end if
#if $remove_multicollinearity
--remove_multicollinearity
#end if
#if $polynomial_features
--polynomial_features
#end if
#if $fix_imbalance
--fix_imbalance
#end if
#end if
]]>
</command>
<inputs>
<param name="model_type" type="select" label="Task">
<option value="classification" >classification</option>
<option value="regression" >regression</option>
</param>
<param name="input_file" type="data" format="csv,tabular" label="Input Dataset (CSV or TSV)"/>
<param name="target_feature" multiple="false" type="data_column" use_header_names="true" data_ref="input_file" label="Select the target column:" />
<param name="input_file" type="data" format="csv,tabular" label="Input Dataset (CSV or TSV)" />
<param name="target_feature" multiple="false" type="data_column" use_header_names="true" data_ref="input_file" label="Select the target column:" />
<!-- <conditional name="model_selection"> -->
<param name="model_type" type="select" label="Task">
<option value="classification">classification</option>
<option value="regression">regression</option>
</param>
<!-- <when value="classification">
<param name="classification_models" type="select" multiple="true" label="Only Select Classification Models if you don't want to compare all models">
<option value="lr">Logistic Regression</option>
<option value="knn">K Neighbors Classifier</option>
<option value="nb">Naive Bayes</option>
<option value="dt">Decision Tree Classifier</option>
<option value="svm">SVM - Linear Kernel</option>
<option value="rbfsvm">SVM - Radial Kernel</option>
<option value="gpc">Gaussian Process Classifier</option>
<option value="mlp">MLP Classifier</option>
<option value="ridge">Ridge Classifier</option>
<option value="rf">Random Forest Classifier</option>
<option value="qda">Quadratic Discriminant Analysis</option>
<option value="ada">Ada Boost Classifier</option>
<option value="gbc">Gradient Boosting Classifier</option>
<option value="lda">Linear Discriminant Analysis</option>
<option value="et">Extra Trees Classifier</option>
<option value="xgboost">Extreme Gradient Boosting</option>
<option value="lightgbm">Light Gradient Boosting Machine</option>
<option value="catboost">CatBoost Classifier</option>
</param>
</when>
<when value="regression">
<param name="regression_models" type="select" multiple="true" label="Only Select Regression Models if you don't want to compare all models">
<option value="lr">Linear Regression</option>
<option value="lasso">Lasso Regression</option>
<option value="ridge">Ridge Regression</option>
<option value="en">Elastic Net</option>
<option value="lar">Least Angle Regression</option>
<option value="llar">Lasso Least Angle Regression</option>
<option value="omp">Orthogonal Matching Pursuit</option>
<option value="br">Bayesian Ridge</option>
<option value="ard">Automatic Relevance Determination</option>
<option value="par">Passive Aggressive Regressor</option>
<option value="ransac">Random Sample Consensus</option>
<option value="tr">TheilSen Regressor</option>
<option value="huber">Huber Regressor</option>
<option value="kr">Kernel Ridge</option>
<option value="svm">Support Vector Regression</option>
<option value="knn">K Neighbors Regressor</option>
<option value="dt">Decision Tree Regressor</option>
<option value="rf">Random Forest Regressor</option>
<option value="et">Extra Trees Regressor</option>
<option value="ada">AdaBoost Regressor</option>
<option value="gbr">Gradient Boosting Regressor</option>
<option value="mlp">MLP Regressor</option>
<option value="xgboost">Extreme Gradient Boosting</option>
<option value="lightgbm">Light Gradient Boosting Machine</option>
<option value="catboost">CatBoost Regressor</option>
</param>
</when>
</conditional> -->
<conditional name="advanced_settings">
<param name="customize_defaults" type="select" label="Customize Default Settings?" help="Select yes if you want to customize the default settings of the experiment.">
<option value="false" selected="true">No</option>
<option value="true">Yes</option>
</param>
<when value="true">
<param name="train_size" type="float" value="0.7" min="0.1" max="0.9" label="Train Size" help="Proportion of the dataset to include in the train split." />
<param name="normalize" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Normalize Data" help="Whether to normalize data before training." />
<param name="feature_selection" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Feature Selection" help="Whether to perform feature selection." />
<conditional name="cross_validation">
<param name="enable_cross_validation" type="boolean" truevalue="true" falsevalue="false" checked="true" label="Enable Cross Validation?" help="Select whether to enable cross-validation. Default: Yes" />
<when value="true">
<param name="cross_validation_folds" type="integer" value="10" min="2" max="20" label="Cross Validation Folds" help="Number of folds to use for cross-validation. Default: 10" />
</when>
<when value="false">
<!-- No additional parameters to show if the user selects 'No' -->
</when>
</conditional>
<param name="remove_outliers" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Remove Outliers" help="Whether to remove outliers from the dataset before training. Default: False" />
<param name="remove_multicollinearity" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Remove Multicollinearity" help="Whether to remove multicollinear features before training. Default: False" />
<param name="polynomial_features" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Polynomial Features" help="Whether to create polynomial features before training. Default: False" />
<param name="fix_imbalance" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Fix Imbalance" help="Whether to use SMOTE or similar methods to fix imbalance in the dataset. Default: False" />
</when>
<when value="false">
<!-- No additional parameters to show if the user selects 'No' -->
</when>
</conditional>
</inputs>
<outputs>
<data name="model" format="data" from_work_dir="model.pkl" label="${tool.name} best model on ${on_string}"/>
<data name="model" format="data" from_work_dir="model.pkl" label="${tool.name} best model on ${on_string}" />
<data name="dashboard" format="html" from_work_dir="dashboard.html" label="${tool.name} Dashboard on ${on_string}"/>
<data name="comparison_result" format="html" from_work_dir="comparison_result.html" label="${tool.name} Comparison result on ${on_string}"/>
</outputs>
<help>
This tool uses PyCaret to train and evaluate machine learning models.
Ensure that the Conda environment specified in the requirements is correctly set up.
</help>
</tool>
</tool>

0 comments on commit 5696cf1

Please sign in to comment.