PyTorch implementation of our framework for Drivable Area Segmentation and Parameterization using the ENet Architecture applied to the *Berkeley Deep Drive Dataset containing 100K drivable area maps.
Run [main.py
], the script used to run the training/testing of the ENet model
python main.py [-h] [--num-epochs NUM_EPOCHS] [--learning-rate LEARNING_RATE]
[--lr-decay LR_DECAY] [--lr-decay-epochs LR_DECAY_EPOCHS]
[--weight-decay WEIGHT_DECAY] [--epochs EPOCHS]
[--run-cuda CUDA] [--batch-size BATCH_SIZE]
[--data-dir DATA_DIR] [--data-list DATA_LIST]
[--input-size INPUT_SIZE] [--gpu-select GPU_SELECT]
[--run-name RUN_NAME] [--mode MODE] [--load LOAD]
For help on the optional arguments run: python main.py -h
Take a look at [src/arguments.py
] to check for the default arguments
python main.py --gpu-select 0 --batch-size 10 --run-name train_example
(OR) Equivalently,
python main.py -g 0 -b 10 -rn train_example
Note: This example uses a model checkpoint stored in [saved_models/run4/
]. If the model is elsewhere, provide the path to the model instead.
python main.py --gpu-select 0 --batch-size 5 --run-name test_example --mode test --load saved_models/run4/checkpoint_20.h5
(OR) Equivalently,
python main.py -g 0 -b 5 -rn test_example -m test -l saved_models/run4/checkpoint_20.h5
-
[
dataset
] contains the dataloader codes, the dataset (not on the repository, download it yourself with the instructions given below) and list of labels that are used by the dataloader -
[
src
] contains the training and testing codes, the metrics that are used for evaluation in [src/metrics
], the arguments and the helper functions used by the other codes. -
[
models
] contains the ENet model architecture code definitions -
[
saved_models
] contains the saved model checkpoints of our imple- mentation inside a folder corresponding to the run_name provided at run time.
- [
src/arguments.py
] Contains all the parsable command-line options and their defaults. - [
dataset/bdd_dataset.py
] Contains the DataLoader classes for each of train,valid and test datasets - [
src/train.py
] Defines theTrainNetwork
class used to train the model - [
src/test.py
] Defines theTestNetwork
class used to test the train model