diff --git a/tools/base_model_trainer.py b/tools/base_model_trainer.py index 2f3a689..98f4009 100644 --- a/tools/base_model_trainer.py +++ b/tools/base_model_trainer.py @@ -86,7 +86,11 @@ def setup_pycaret(self): def train_model(self): LOG.info("Training and selecting the best model") - self.best_model = self.exp.compare_models() + if hasattr(self, 'models') and self.models is not None: + self.best_model = self.exp.compare_models( + include=self.models) + else: + self.best_model = self.exp.compare_models() self.results = self.exp.pull() def save_model(self): diff --git a/tools/pycaret_train.xml b/tools/pycaret_train.xml index 4f7d32a..2a06c11 100644 --- a/tools/pycaret_train.xml +++ b/tools/pycaret_train.xml @@ -7,6 +7,16 @@ - + - + diff --git a/tools/test-data/expected_comparison_result_regression.html b/tools/test-data/expected_comparison_result_regression.html index 149bac3..7e78624 100644 --- a/tools/test-data/expected_comparison_result_regression.html +++ b/tools/test-data/expected_comparison_result_regression.html @@ -3,7 +3,8 @@ - + PyCaret Model Training Report