Skip to content

Commit

Permalink
added a user input of random seed (goeckslab#27)
Browse files Browse the repository at this point in the history
* added input of random seed

* flake8 linting

* added random seed for testing

* passed the tests

* planemo lint
  • Loading branch information
qchiujunhao authored Sep 27, 2024
1 parent 9a3b9e1 commit 82e718c
Show file tree
Hide file tree
Showing 11 changed files with 406 additions and 454 deletions.
4 changes: 3 additions & 1 deletion tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ def __init__(
target_col,
output_dir,
task_type,
random_seed,
**kwargs
):
self.exp = None # This will be set in the subclass
self.input_file = input_file
self.target_col = target_col
self.output_dir = output_dir
self.task_type = task_type
self.random_seed = random_seed
self.data = None
self.target = None
self.best_model = None
Expand Down Expand Up @@ -72,7 +74,7 @@ def setup_pycaret(self):
LOG.info("Initializing PyCaret")
self.setup_params = {
'target': self.target,
'session_id': 123,
'session_id': self.random_seed,
'html': True,
'log_experiment': False,
'system_log': False
Expand Down
8 changes: 7 additions & 1 deletion tools/pycaret_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ def __init__(
target_col,
output_dir,
task_type,
random_seed,
**kwargs):
super().__init__(
input_file, target_col, output_dir, task_type, **kwargs)
input_file,
target_col,
output_dir,
task_type,
random_seed,
**kwargs)
self.exp = ClassificationExperiment()

def save_dashboard(self):
Expand Down
8 changes: 7 additions & 1 deletion tools/pycaret_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ def __init__(
target_col,
output_dir,
task_type,
random_seed,
**kwargs):
super().__init__(
input_file, target_col, output_dir, task_type, **kwargs)
input_file,
target_col,
output_dir,
task_type,
random_seed,
**kwargs)
self.exp = RegressionExperiment()

def save_dashboard(self):
Expand Down
5 changes: 5 additions & 0 deletions tools/pycaret_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def main():
parser.add_argument("--models", nargs='+',
default=None,
help="Selected models for training")
parser.add_argument("--random_seed", type=int,
default=42,
help="Random seed for PyCaret setup")

args = parser.parse_args()

Expand Down Expand Up @@ -87,6 +90,7 @@ def main():
args.target_col,
args.output_dir,
args.model_type,
args.random_seed,
**model_kwargs)
elif args.model_type == "regression":
if "fix_imbalance" in model_kwargs:
Expand All @@ -96,6 +100,7 @@ def main():
args.target_col,
args.output_dir,
args.model_type,
args.random_seed,
**model_kwargs)
else:
LOG.error("Invalid model type. Please choose \
Expand Down
7 changes: 5 additions & 2 deletions tools/pycaret_train.xml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
<tool id="pycaret_tool" name="PyCaret Model Comparison" version="@VERSION@" profile="@PROFILE@">
<description>Compare different machine learning models on a dataset using PyCaret. Do feature analysis using LR, Random Forest and LightGBM. </description>
<description>Compare different machine learning models on a dataset using PyCaret. Do feature analyses using Random Forest and LightGBM. </description>
<macros>
<import>pycaret_macros.xml</import>
</macros>
<expand macro="python_requirements" />
<command>
<![CDATA[
python $__tool_directory__/pycaret_train.py --input_file $input_file --target_col $target_feature --output_dir "`pwd`"
python $__tool_directory__/pycaret_train.py --input_file $input_file --target_col $target_feature --output_dir "`pwd`" --random_seed $random_seed
#if $model_type == "classification"
#if $classification_models
--models $classification_models
Expand Down Expand Up @@ -112,6 +112,7 @@
</param>
</when>
</conditional>
<param name="random_seed" type="integer" value="42" label="Random Seed" help="Random seed for reproducibility." />
<conditional name="advanced_settings">
<param name="customize_defaults" type="select" label="Customize Default Settings?" help="Select yes if you want to customize the default settings of the experiment.">
<option value="false" selected="true">No</option>
Expand Down Expand Up @@ -152,6 +153,7 @@
<param name="input_file" value="pcr.tsv"/>
<param name="target_feature" value="11"/>
<param name="model_type" value="classification"/>
<param name="random_seed" value="42"/>
<output name="model" file="expected_model_classification" compare="sim_size"/>
<output name="comparison_result" file="expected_comparison_result_classification.html" compare="sim_size">
<extra_files type="file" name="best_model.csv" value="expected_best_model_classification.csv" />
Expand All @@ -161,6 +163,7 @@
<param name="input_file" value="auto-mpg.tsv"/>
<param name="target_feature" value="1"/>
<param name="model_type" value="regression"/>
<param name="random_seed" value="42"/>
<!-- <param name="customize_defaults" value="true"/>
<param name="train_size" value="0.8"/>
<param name="normalize" value="true"/>
Expand Down
2 changes: 1 addition & 1 deletion tools/test-data/expected_best_model_classification.csv
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ n_estimators,100
n_jobs,-1
num_leaves,31
objective,
random_state,123
random_state,42
reg_alpha,0.0
reg_lambda,0.0
subsample,1.0
Expand Down
24 changes: 5 additions & 19 deletions tools/test-data/expected_best_model_regression.csv
Original file line number Diff line number Diff line change
@@ -1,20 +1,6 @@
Parameter,Value
boosting_type,gbdt
class_weight,
colsample_bytree,1.0
importance_type,split
learning_rate,0.1
max_depth,-1
min_child_samples,20
min_child_weight,0.001
min_split_gain,0.0
n_estimators,100
n_jobs,-1
num_leaves,31
objective,
random_state,123
reg_alpha,0.0
reg_lambda,0.0
subsample,1.0
subsample_for_bin,200000
subsample_freq,0
loss_function,RMSE
border_count,254
verbose,False
task_type,CPU
random_state,42
Loading

0 comments on commit 82e718c

Please sign in to comment.