Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDXL improvements (and support for Draft+) #9654

Merged
merged 5 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*.pkl
#*.ipynb
output
output_2048
result
*.pt
tests/data/asr
Expand Down Expand Up @@ -179,3 +180,4 @@ examples/neural_graphs/*.yml
.hydra/
nemo_experiments/

slurm*.out
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ trainer:
enable_model_summary: True
limit_val_batches: 0


exp_manager:
exp_dir: null
name: ${name}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ model:
lossconfig:
target: torch.nn.Identity



conditioner_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner
emb_models:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ model:
target: torch.nn.Identity



conditioner_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner
emb_models:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ infer:
sampling:
base:
sampler: EulerEDMSampler
width: 256
height: 256
steps: 40
width: 512
height: 512
steps: 50
discretization: "LegacyDDPMDiscretization"
guider: "VanillaCFG"
thresholder: "None"
Expand All @@ -48,8 +48,8 @@ sampling:
s_noise: 1.0
eta: 1.0
order: 4
orig_width: 1024
orig_height: 1024
orig_width: 512
orig_height: 512
crop_coords_top: 0
crop_coords_left: 0
aesthetic_score: 5.0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
trainer:
devices: 1
num_nodes: 1
accelerator: gpu
precision: 32
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models
gradient_clip_val: 1.0
benchmark: False
enable_model_summary: True
limit_val_batches: 0


infer:
num_samples_per_batch: 1
num_samples: 4
prompt:
- "A professional photograph of an astronaut riding a pig"
- 'A photo of a Shiba Inu dog with a backpack riding a bike. It is wearing sunglasses and a beach hat.'
- 'A cute corgi lives in a house made out of sushi.'
- 'A high contrast portrait of a very happy fuzzy panda dressed as a chef in a high end kitchen making dough. There is a painting of flowers on the wall behind him.'
- 'A brain riding a rocketship heading towards the moon.'
negative_prompt: ""
seed: 123


sampling:
base:
sampler: EulerEDMSampler
width: 512
height: 512
steps: 50
discretization: "LegacyDDPMDiscretization"
guider: "VanillaCFG"
thresholder: "None"
scale: 5.0
img2img_strength: 1.0
sigma_min: 0.0292
sigma_max: 14.6146
rho: 3.0
s_churn: 0.0
s_tmin: 0.0
s_tmax: 999.0
s_noise: 1.0
eta: 1.0
order: 4
orig_width: 512
orig_height: 512
crop_coords_top: 0
crop_coords_left: 0
aesthetic_score: 5.0
negative_aesthetic_score: 5.0

# model:
# is_legacy: False

use_refiner: False
use_fp16: False # use fp16 model weights
out_path: ./output

base_model_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_base.yaml
refiner_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_refiner.yaml

model:
scale_factor: 0.13025
disable_first_stage_autocast: True
is_legacy: False
restore_from_path: ""

fsdp: False
fsdp_set_buffer_dtype: null
fsdp_sharding_strategy: 'full'
use_cpu_initialization: True
# hidden_size: 4
# pipeline_model_parallel_size: 4

optim:
name: fused_adam
lr: 1e-4
weight_decay: 0.0
betas:
- 0.9
- 0.999
sched:
name: WarmupHoldPolicy
warmup_steps: 10
hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant

denoiser_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser.DiscreteDenoiser
num_idx: 1000

weighting_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.discretizer.LegacyDDPMDiscretization

unet_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel
from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt
from_NeMo: True
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: False
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [ 1, 2, 10 ] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
image_size: 64 # unused
# spatial_transformer_attn_type: softmax #note: only default softmax is supported now
legacy: False
use_flash_attention: False

first_stage_config:
# _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper
_target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper
from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt
from_NeMo: True
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity

conditioner_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenOpenCLIPEmbedder2
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND
outdim: 256 # multiplied by two

Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def main(cfg) -> None:
n, c, h = cfg.model.micro_batch_size, cfg.model.channels, cfg.model.image_size
x = torch.randn((n, c, h, h), dtype=torch.float32, device="cuda")
t = torch.randint(77, (n,), device="cuda")
cc = torch.randn((n, 77, cfg.model.unet_config.context_dim), dtype=torch.float32, device="cuda",)
cc = torch.randn(
(n, 77, cfg.model.unet_config.context_dim),
dtype=torch.float32,
device="cuda",
)
if cfg.model.precision in [16, '16']:
x = x.type(torch.float16)
cc = cc.type(torch.float16)
Expand All @@ -93,9 +97,7 @@ def main(cfg) -> None:
model.zero_grad()

if cfg.model.get('peft', None):

peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme]

if cfg.model.peft.restore_from_path is not None:
# initialize peft weights from a checkpoint instead of randomly
# This is not the same as resume training because optimizer states are not restored.
Expand Down
44 changes: 28 additions & 16 deletions examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,44 @@ def model_cfg_modifier(model_cfg):
model_cfg.precision = cfg.trainer.precision
model_cfg.ckpt_path = None
model_cfg.inductor = False
model_cfg.unet_config.from_pretrained = None
model_cfg.first_stage_config.from_pretrained = None
model_cfg.unet_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt"
model_cfg.unet_config.from_NeMo = True
model_cfg.first_stage_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt"
model_cfg.first_stage_config.from_NeMo = True
model_cfg.first_stage_config._target_ = 'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper'
model_cfg.fsdp = False
# model_cfg.fsdp = True

torch.backends.cuda.matmul.allow_tf32 = True
trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference(
model_provider=MegatronDiffusionEngine, cfg=cfg, model_cfg_modifier=model_cfg_modifier
)

### Manually configure sharded model
# model = megatron_diffusion_model
# model = trainer.strategy._setup_model(model)
# model = model.cuda(torch.cuda.current_device())
# get the diffusion part only
model = megatron_diffusion_model.model
model.cuda().eval()

base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy)
use_refiner = cfg.get('use_refiner', False)
for i, prompt in enumerate(cfg.infer.prompt):
samples = base.text_to_image(
params=cfg.sampling.base,
prompt=[prompt],
negative_prompt=cfg.infer.negative_prompt,
samples=cfg.infer.num_samples,
return_latents=True if use_refiner else False,
seed=int(cfg.infer.seed + i * 100),
)

perform_save_locally(cfg.out_path, samples)
with torch.no_grad():
base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy)
use_refiner = cfg.get('use_refiner', False)
num_samples_per_batch = cfg.infer.get('num_samples_per_batch', cfg.infer.num_samples)
num_batches = cfg.infer.num_samples // num_samples_per_batch

for i, prompt in enumerate(cfg.infer.prompt):
for batchid in range(num_batches):
samples = base.text_to_image(
params=cfg.sampling.base,
prompt=[prompt],
negative_prompt=cfg.infer.negative_prompt,
samples=num_samples_per_batch,
return_latents=True if use_refiner else False,
seed=int(cfg.infer.seed + i * 100 + batchid * 200),
)
# samples=cfg.infer.num_samples,
perform_save_locally(cfg.out_path, samples)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def _training_strategy(self) -> NLPDDPStrategy:
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
if _IS_INTERACTIVE and self.cfg.trainer.devices == 1:
logging.info("Detected interactive environment, using NLPDDPStrategyNotebook")
return NLPDDPStrategyNotebook(no_ddp_communication_hook=True, find_unused_parameters=False,)
return NLPDDPStrategyNotebook(
no_ddp_communication_hook=True,
find_unused_parameters=False,
)

if self.cfg.model.get('fsdp', False):
assert (
Expand Down Expand Up @@ -81,9 +84,7 @@ def main(cfg) -> None:
model = MegatronDiffusionEngine(cfg.model, trainer)

if cfg.model.get('peft', None):

peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme]

if cfg.model.peft.restore_from_path is not None:
# initialize peft weights from a checkpoint instead of randomly
# This is not the same as resume training because optimizer states are not restored.
Expand Down
Loading
Loading