Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added model selection and tested #20

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def setup_pycaret(self):

def train_model(self):
LOG.info("Training and selecting the best model")
self.best_model = self.exp.compare_models()
if hasattr(self, 'models') and self.models is not None:
self.best_model = self.exp.compare_models(
include=self.models)
else:
self.best_model = self.exp.compare_models()
self.results = self.exp.pull()

def save_model(self):
Expand Down
16 changes: 13 additions & 3 deletions tools/pycaret_train.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
<command>
<![CDATA[
python $__tool_directory__/pycaret_train.py --input_file $input_file --target_col $target_feature --output_dir "`pwd`" --model_type $model_type
#if $model_type == "classification"
#if $classification_models
--models $classification_models
#end if
#end if
#if $model_type == "regression"
#if $regression_models
--models $regression_models
#end if
#end if
#if $customize_defaults == "true"
#if $train_size
--train_size $train_size
Expand Down Expand Up @@ -42,12 +52,12 @@
<inputs>
<param name="input_file" type="data" format="csv,tabular" label="Input Dataset (CSV or TSV)" />
<param name="target_feature" multiple="false" type="data_column" use_header_names="true" data_ref="input_file" label="Select the target column:" />
<!-- <conditional name="model_selection"> -->
<conditional name="model_selection">
<param name="model_type" type="select" label="Task">
<option value="classification">classification</option>
<option value="regression">regression</option>
</param>
<!-- <when value="classification">
<when value="classification">
<param name="classification_models" type="select" multiple="true" label="Only Select Classification Models if you don't want to compare all models">
<option value="lr">Logistic Regression</option>
<option value="knn">K Neighbors Classifier</option>
Expand Down Expand Up @@ -98,7 +108,7 @@
<option value="catboost">CatBoost Regressor</option>
</param>
</when>
</conditional> -->
</conditional>
<conditional name="advanced_settings">
<param name="customize_defaults" type="select" label="Customize Default Settings?" help="Select yes if you want to customize the default settings of the experiment.">
<option value="false" selected="true">No</option>
Expand Down
156 changes: 67 additions & 89 deletions tools/test-data/expected_comparison_result_regression.html

Large diffs are not rendered by default.

Loading