forked from Stability-AI/generative-models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
imagenet-f8_cond.yaml
188 lines (168 loc) · 5.39 KB
/
imagenet-f8_cond.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
log_keys:
- cls
scheduler_config:
target: sgm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [10000]
cycle_lengths: [10000000000000]
f_start: [1.e-6]
f_max: [1.]
f_min: [1.]
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 256
attention_resolutions: [1, 2, 4]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
num_classes: sequential
adm_in_channels: 1024
use_spatial_transformer: true
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: True
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
add_sequence_dim: True # will be used through crossattn then
embed_dim: 1024
n_classes: 1000
# vector cond
- is_trainable: False
ucg_rate: 0.2
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
ckpt_path: CKPT_PATH
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 5.0
data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
# USER: adapt this path the root of your custom dataset
- "DATA_PATH"
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
decoders:
- "pil"
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height # USER: you might wanna adapt this for your custom dataset
w_key: width # USER: you might wanna adapt this for your custom dataset
loader:
batch_size: 64
num_workers: 6
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 8
n_rows: 2
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 1000