diff --git a/Snakefile b/Snakefile index d2a9c1b..e8a0256 100644 --- a/Snakefile +++ b/Snakefile @@ -59,7 +59,7 @@ rule prep_io_data: # shell: # """ # module load analytics cuda10.1/toolkit/10.1.105 -# run_training -e /home/jsadler/.conda/envs/rgcn --no-node-list "python {code_dir}/train_model.py -o {params.run_dir} -i {input[0]} -p {params.pt_epochs} -f {params.ft_epochs} --lambdas {params.lamb} --loss_func multitask_rmse --model rgcn -s 135" +# run_training -e /home/jsadler/.conda/envs/rgcn --no-node-list "python {code_dir}/train_model_cli.py -o {params.run_dir} -i {input[0]} -p {params.pt_epochs} -f {params.ft_epochs} --lambdas {params.lamb} --loss_func multitask_rmse --model rgcn -s 135" # """ diff --git a/river_dl/train_model.py b/river_dl/train_model_cli.py similarity index 90% rename from river_dl/train_model.py rename to river_dl/train_model_cli.py index fd8c86c..ba630aa 100644 --- a/river_dl/train_model.py +++ b/river_dl/train_model_cli.py @@ -1,3 +1,9 @@ +""" +This file provides a commandline interface (CLI) for the `train.train_model` +function. The commandline interface was originally provided to allow a command +to be sent to a slurm scheduler which was necessary to train the model using +GPUs. This has been tested on USGS's Tallgrass supercomputer. +""" import argparse from river_dl.train import train_model import river_dl.loss_functions as lf