Skip to content
This repository has been archived by the owner on Apr 16, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3 from S-aiueo32/feature/pl-1.5
Browse files Browse the repository at this point in the history
v0.2.0rc1
  • Loading branch information
S-aiueo32 authored Mar 5, 2022
2 parents c4a3b03 + fedd913 commit 9218343
Show file tree
Hide file tree
Showing 18 changed files with 2,401 additions and 1,590 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ env:
jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 5
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@v2
Expand All @@ -33,7 +33,7 @@ jobs:

test:
runs-on: ubuntu-latest
timeout-minutes: 5
timeout-minutes: 10
needs: build
steps:
- name: Checkout
Expand Down
131 changes: 84 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ $ pip install -U hiraishin

# Basic workflow
## 1. Model initialization with type annotations
Define a model class that has training components of PyTorch as instance variables.
Define a model class that has training components with type annotations.

```python
import torch.nn as nn
Expand All @@ -32,73 +32,95 @@ class ToyModel(BaseModel):
scheduler: optim.lr_schedulers.ExponentialLR

def __init__(self, config: DictConfig) -> None:
self.initialize(config) # call `initialize()` instead of `super()__init__()`
super().__init__(config)
```

We recommend that they have the following prefix to indicate their role.
Modules with the following prefixes are instantiated by their own role-specific logic.

- `net` for networks. It must be a subclass of `nn.Module` to initialize and load weights.
- `criterion` for loss functions.
- `optimizer` for optimizers. It must be subclass of `Optimizer`.
- `scheduler` for schedulers. It must be subclass of `_LRScheduler` and the suffix must match to the corresponding optimizer.
- `net`
- `criterion`
- `optimizer`
- `scheduler`

If you need to define modules besides the above components (e.g. tokenizers), feel free to define them. The modules will be defined with the names you specify.
The same notation can be used to define components other than the learning components listed above (e.g., tokenizers). It is also possible to define built-in type constants that are YAML serializable.

## 2. Generating configuration file
Hiraishin has the functionality to generate configuration files on the command line.
If the above class was written in `model.py` at the same level as the current directory, you can generate it with the following command.
```python
class ToyModel(BaseModel):

net: nn.Linear
criterion: nn.CrossEntropyLoss
optimizer: optim.Adam
scheduler: optim.lr_schedulers.ExponentialLR

# additional components and constants
tokenizer: MyTokenizer
n_classes: int

def __init__(self, config: DictConfig) -> None:
super().__init__(config)
```

## 2. Configuration file generation
Hiraishin provides a CLI command that automatically generates a configuration file based on type annotations.

For example, if `ToyModel` is defined in `models.py` (i.e., `from models import ToyModel` can be executed in the code), then the following command will generate the configuration file automatically.

```shell
$ hiraishin configen model.ToyModel
The config has been generated! --> config/model/toy.yaml
$ hiraishin generate model.ToyModel --output_dir config/model
The config has been generated! --> config/model/ToyModel.yaml
```

Let's take a look at the generated file.
The positional arguments are filled with `???` that indicates mandatory parameters in Hydra.
We recommend overwriting them with the default value, otherwise, you must give them through command-line arguments for every run.

```yaml
_target_: model.ToyModel
_target_: models.ToyModel
_recursive_: false
config:

networks:
- name: net
args:
_target_: torch.nn.Linear
_recursive_: true
in_features: ??? # -> 1
out_features: ??? # -> 1
init:
weight_path: null
init_type: null
init_gain: null
net:
args:
_target_: torch.nn.Linear
out_features: ???
in_features: ???
weights:
initializer: null
path: null

losses:
- name: criterion
args:
_target_: torch.nn.CrossEntropyLoss
_recursive_: true
weight: 1.0
criterion:
args:
_target_: torch.nn.CrossEntropyLoss
weight: 1.0

optimizers:
- name: optimizer
args:
_target_: torch.optim.Adam
_recursive_: true
params:
- ??? # -> net
scheduler:
optimizer:
args:
_target_: torch.optim.lr_scheduler.ExponentialLR
_recursive_: true
gamma: ??? # -> 1
interval: epoch
frequency: 1
strict: true
monitor: null
modules: null
_target_: torch.optim.Adam
params:
- ???
scheduler:
args:
_target_: torch.optim.lr_scheduler.ExponentialLR
gamma: ???
interval: epoch
frequency: 1
strict: true
monitor: null

tokenizer:
_target_: MyTokenizer
n_classes: ???

```

First of all, it is compliant with the instantiation by `hydra.utils.instantiation`.

The positional arguments are filled with `???` that indicates mandatory parameters. They should be overridden by the values you want to set.

## 3. Training routines definition
The rest of model definition is only defining your training routine along with the style of PyTorch Lightning.

```python
class ToyModel(BaseModel):

Expand All @@ -116,7 +138,8 @@ class ToyModel(BaseModel):
```

## 4. Model Instantiation
The defined model can be instantiated from configuration file. Try to train and test models!
The defined model can be instantiated from configuration file. Let's train your models!

```python
from hydra.utils import inatantiate
from omegeconf import OmegaConf
Expand All @@ -137,5 +160,19 @@ def app():
trainer.fit(model, ...)
```

## 5. Model loading
You can easily load trained models by using the checkpoints generated by PyTorch Lightning's standard features. Let's test your models!

```python
from hiraishin.utils import load_from_checkpoint

model = load_from_checkpoint('path/to/model.ckpt')
print(model)
# ToyModel(
# (net): Linear(in_features=1, out_features=1, bias=True)
# (criterion): CrossEntropyLoss()
# )
```

# License
Hiraishin is licensed under the Apache License, Version 2.0. See [LICENSE](LICENSE) for the full license text.
Loading

0 comments on commit 9218343

Please sign in to comment.