diff --git a/tools/base_model_trainer.py b/tools/base_model_trainer.py index 07be7ea..19d2642 100644 --- a/tools/base_model_trainer.py +++ b/tools/base_model_trainer.py @@ -22,6 +22,7 @@ def __init__( target_col, output_dir, task_type, + random_seed, **kwargs ): self.exp = None # This will be set in the subclass @@ -29,6 +30,7 @@ def __init__( self.target_col = target_col self.output_dir = output_dir self.task_type = task_type + self.random_seed = random_seed self.data = None self.target = None self.best_model = None @@ -72,7 +74,7 @@ def setup_pycaret(self): LOG.info("Initializing PyCaret") self.setup_params = { 'target': self.target, - 'session_id': 123, + 'session_id': self.random_seed, 'html': True, 'log_experiment': False, 'system_log': False diff --git a/tools/pycaret_classification.py b/tools/pycaret_classification.py index 9345b0a..7e5c95a 100644 --- a/tools/pycaret_classification.py +++ b/tools/pycaret_classification.py @@ -18,9 +18,15 @@ def __init__( target_col, output_dir, task_type, + random_seed, **kwargs): super().__init__( - input_file, target_col, output_dir, task_type, **kwargs) + input_file, + target_col, + output_dir, + task_type, + random_seed, + **kwargs) self.exp = ClassificationExperiment() def save_dashboard(self): diff --git a/tools/pycaret_regression.py b/tools/pycaret_regression.py index 70352cc..ef90ea9 100644 --- a/tools/pycaret_regression.py +++ b/tools/pycaret_regression.py @@ -18,9 +18,15 @@ def __init__( target_col, output_dir, task_type, + random_seed, **kwargs): super().__init__( - input_file, target_col, output_dir, task_type, **kwargs) + input_file, + target_col, + output_dir, + task_type, + random_seed, + **kwargs) self.exp = RegressionExperiment() def save_dashboard(self): diff --git a/tools/pycaret_train.py b/tools/pycaret_train.py index 534a997..95f64dc 100644 --- a/tools/pycaret_train.py +++ b/tools/pycaret_train.py @@ -55,6 +55,9 @@ def main(): parser.add_argument("--models", nargs='+', default=None, help="Selected models for training") + parser.add_argument("--random_seed", type=int, + default=42, + help="Random seed for PyCaret setup") args = parser.parse_args() @@ -87,6 +90,7 @@ def main(): args.target_col, args.output_dir, args.model_type, + args.random_seed, **model_kwargs) elif args.model_type == "regression": if "fix_imbalance" in model_kwargs: @@ -96,6 +100,7 @@ def main(): args.target_col, args.output_dir, args.model_type, + args.random_seed, **model_kwargs) else: LOG.error("Invalid model type. Please choose \ diff --git a/tools/pycaret_train.xml b/tools/pycaret_train.xml index 71c31ba..9f201e9 100644 --- a/tools/pycaret_train.xml +++ b/tools/pycaret_train.xml @@ -1,12 +1,12 @@ - Compare different machine learning models on a dataset using PyCaret. Do feature analysis using LR, Random Forest and LightGBM. + Compare different machine learning models on a dataset using PyCaret. Do feature analyses using Random Forest and LightGBM. pycaret_macros.xml + @@ -152,6 +153,7 @@ + @@ -161,6 +163,7 @@ +