diff --git a/tools/base_model_trainer.py b/tools/base_model_trainer.py
index 1a6b5ec..46bea14 100644
--- a/tools/base_model_trainer.py
+++ b/tools/base_model_trainer.py
@@ -8,7 +8,7 @@
LOG = logging.getLogger(__name__)
class BaseModelTrainer:
- def __init__(self, input_file, target_col, output_dir):
+ def __init__(self, input_file, target_col, output_dir, **kwargs):
self.exp = None # This will be set in the subclass
self.input_file = input_file
self.target_col = target_col
@@ -18,6 +18,11 @@ def __init__(self, input_file, target_col, output_dir):
self.best_model = None
self.results = None
self.plots = {}
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+ self.setup_params = {}
+
+ LOG.info(f"Model kwargs: {self.__dict__}")
def load_data(self):
LOG.info(f"Loading data from {self.input_file}")
@@ -29,9 +34,44 @@ def load_data(self):
def setup_pycaret(self):
LOG.info("Initializing PyCaret")
- self.exp.setup(self.data, target=self.target,
- session_id=123, html=True,
- log_experiment=False, system_log=False)
+ self.setup_params = {
+ 'target': self.target,
+ 'session_id': 123,
+ 'html': True,
+ 'log_experiment': False,
+ 'system_log': False
+ }
+
+ if hasattr(self, 'train_size') and self.train_size is not None:
+ self.setup_params['train_size'] = self.train_size
+
+ if hasattr(self, 'normalize') and self.normalize is not None:
+ self.setup_params['normalize'] = self.normalize
+
+ if hasattr(self, 'feature_selection') and self.feature_selection is not None:
+ self.setup_params['feature_selection'] = self.feature_selection
+
+ if hasattr(self, 'cross_validation') and self.cross_validation is not None and self.cross_validation == False:
+ self.setup_params['cross_validation'] = self.cross_validation
+
+ if hasattr(self, 'cross_validation') and self.cross_validation is not None:
+ if hasattr(self, 'cross_validation_folds'):
+ self.setup_params['fold'] = self.cross_validation_folds
+
+ if hasattr(self, 'remove_outliers') and self.remove_outliers is not None:
+ self.setup_params['remove_outliers'] = self.remove_outliers
+
+ if hasattr(self, 'remove_multicollinearity') and self.remove_multicollinearity is not None:
+ self.setup_params['remove_multicollinearity'] = self.remove_multicollinearity
+
+ if hasattr(self, 'polynomial_features') and self.polynomial_features is not None:
+ self.setup_params['polynomial_features'] = self.polynomial_features
+
+ if hasattr(self, 'fix_imbalance') and self.fix_imbalance is not None:
+ self.setup_params['fix_imbalance'] = self.fix_imbalance
+
+ LOG.info(self.setup_params)
+ self.exp.setup(self.data, **self.setup_params)
def train_model(self):
LOG.info("Training and selecting the best model")
@@ -54,6 +94,10 @@ def save_html_report(self):
model_name = type(self.best_model).__name__
+ excluded_params = ['html', 'log_experiment', 'system_log']
+ filtered_setup_params = {k: v for k, v in self.setup_params.items() if k not in excluded_params}
+ setup_params_table = pd.DataFrame(list(filtered_setup_params.items()), columns=['Parameter', 'Value'])
+
# Save model summary
best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value'])
best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False)
@@ -132,6 +176,11 @@ def save_html_report(self):
PyCaret Model Training Report
+
Setup Parameters
+
+ Parameter | Value |
+ {setup_params_table.to_html(index=False, header=False, classes='table')}
+
Best Model: {model_name}
Parameter | Value |
diff --git a/tools/pycaret_classification.py b/tools/pycaret_classification.py
index d382cd1..570fbd4 100644
--- a/tools/pycaret_classification.py
+++ b/tools/pycaret_classification.py
@@ -6,8 +6,8 @@
LOG = logging.getLogger(__name__)
class ClassificationModelTrainer(BaseModelTrainer):
- def __init__(self, input_file, target_col, output_dir):
- super().__init__(input_file, target_col, output_dir)
+ def __init__(self, input_file, target_col, output_dir, **kwargs):
+ super().__init__(input_file, target_col, output_dir, **kwargs)
self.exp = ClassificationExperiment()
def save_dashboard(self):
@@ -19,5 +19,9 @@ def generate_plots(self):
LOG.info("Generating and saving plots")
plots = ['auc', 'confusion_matrix', 'threshold', 'pr', 'error', 'class_report', 'learning', 'calibration', 'vc', 'dimension', 'manifold', 'rfe', 'feature', 'feature_all']
for plot_name in plots:
- plot_path = self.exp.plot_model(self.best_model, plot=plot_name, save=True)
- self.plots[plot_name] = plot_path
+ try:
+ plot_path = self.exp.plot_model(self.best_model, plot=plot_name, save=True)
+ self.plots[plot_name] = plot_path
+ except Exception as e:
+ LOG.error(f"Error generating plot {plot_name}: {e}")
+ continue
diff --git a/tools/pycaret_regression.py b/tools/pycaret_regression.py
index f04700b..a9799af 100644
--- a/tools/pycaret_regression.py
+++ b/tools/pycaret_regression.py
@@ -17,7 +17,7 @@ def save_dashboard(self):
def generate_plots(self):
LOG.info("Generating and saving plots")
- plots = ['residuals', 'error', 'cooks', 'learning', 'vc', 'manifold', 'rfe', 'feature']
+ plots = ['residuals', 'error', 'cooks', 'learning', 'vc', 'manifold', 'rfe', 'feature', 'feature_all']
for plot_name in plots:
plot_path = self.exp.plot_model(self.best_model, plot=plot_name, save=True)
self.plots[plot_name] = plot_path
diff --git a/tools/pycaret_train.py b/tools/pycaret_train.py
index acfcb46..4884fca 100644
--- a/tools/pycaret_train.py
+++ b/tools/pycaret_train.py
@@ -1,4 +1,4 @@
-import sys
+import argparse
import logging
from pycaret_classification import ClassificationModelTrainer
@@ -7,14 +7,59 @@
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)
-input_file = sys.argv[1]
-target_col = sys.argv[2]
-output_dir = sys.argv[3]
-model_type = sys.argv[4]
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input_file", help="Path to the input file")
+ parser.add_argument("--target_col", help="Column number of the target")
+ parser.add_argument("--output_dir", help="Path to the output directory")
+ parser.add_argument("--model_type", choices=["classification", "regression"], help="Type of the model")
+ parser.add_argument("--train_size", type=float, default=None, help="Train size for PyCaret setup")
+ parser.add_argument("--normalize", action="store_true", default=None, help="Normalize data for PyCaret setup")
+ parser.add_argument("--feature_selection", action="store_true", default=None, help="Perform feature selection for PyCaret setup")
+ parser.add_argument("--cross_validation", action="store_true", default=None, help="Perform cross-validation for PyCaret setup")
+ parser.add_argument("--cross_validation_folds", type=int, default=None, help="Number of cross-validation folds for PyCaret setup")
+ parser.add_argument("--remove_outliers", action="store_true", default=None, help="Remove outliers for PyCaret setup")
+ parser.add_argument("--remove_multicollinearity", action="store_true", default=None, help="Remove multicollinearity for PyCaret setup")
+ parser.add_argument("--polynomial_features", action="store_true", default=None, help="Generate polynomial features for PyCaret setup")
+ parser.add_argument("--feature_interaction", action="store_true", default=None, help="Generate feature interactions for PyCaret setup")
+ parser.add_argument("--feature_ratio", action="store_true", default=None, help="Generate feature ratios for PyCaret setup")
+ parser.add_argument("--fix_imbalance", action="store_true", default=None, help="Fix class imbalance for PyCaret setup")
+ parser.add_argument("--models", nargs='+', default=None, help="Selected models for training")
-if model_type == "classification":
- trainer = ClassificationModelTrainer(input_file, target_col, output_dir)
+ args = parser.parse_args()
+
+ model_kwargs = {
+ "train_size": args.train_size,
+ "normalize": args.normalize,
+ "feature_selection": args.feature_selection,
+ "cross_validation": args.cross_validation,
+ "cross_validation_folds": args.cross_validation_folds,
+ "remove_outliers": args.remove_outliers,
+ "remove_multicollinearity": args.remove_multicollinearity,
+ "polynomial_features": args.polynomial_features,
+ "feature_interaction": args.feature_interaction,
+ "feature_ratio": args.feature_ratio,
+ "fix_imbalance": args.fix_imbalance,
+ }
+ LOG.info(f"Model kwargs: {model_kwargs}")
+
+ # Remove None values from model_kwargs
+
+ LOG.info(f"Model kwargs 2: {model_kwargs}")
+ if args.models:
+ model_kwargs["models"] = args.models[0].split(",")
+
+ model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
+
+ if args.model_type == "classification":
+ trainer = ClassificationModelTrainer(args.input_file, args.target_col, args.output_dir, **model_kwargs)
+ elif args.model_type == "regression":
+ trainer = RegressionModelTrainer(args.input_file, args.target_col, args.output_dir, **model_kwargs)
+ else:
+ LOG.error("Invalid model type. Please choose 'classification' or 'regression'.")
+ return
+
trainer.run()
-elif model_type == "regression":
- trainer = RegressionModelTrainer(input_file, target_col, output_dir)
- trainer.run()
\ No newline at end of file
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/tools/pycaret_train.xml b/tools/pycaret_train.xml
index cf5b555..6c6f849 100644
--- a/tools/pycaret_train.xml
+++ b/tools/pycaret_train.xml
@@ -1,25 +1,134 @@
-
- Train and evaluate machine learning models using PyCaret.
+
+ Compare different machine learning models on a dataset using PyCaret.
pycaret_macros.xml
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
+
@@ -27,4 +136,4 @@
This tool uses PyCaret to train and evaluate machine learning models.
Ensure that the Conda environment specified in the requirements is correctly set up.
-
+
\ No newline at end of file