Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp8 support for SD/Update notebook paths #8489

Merged
merged 3 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ model:
use_flash_attention: True
enable_amp_o2_fp16: False
resblock_gn_groups: 32
use_te_fp8: False

first_stage_config:
_target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL
Expand Down Expand Up @@ -191,7 +192,7 @@ model:
synthetic_data_length: 10000
train:
dataset_path:
- /datasets/coyo/test.pkl
- /datasets/coyo/wdinfo/coyo-700m/wdinfo-selene.pkl
augmentations:
resize_smallest_side: 512
center_crop_h_w: 512, 512
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def model_cfg_modifier(model_cfg):
model_cfg.unet_config.use_flash_attention = False
model_cfg.unet_config.from_pretrained = None
model_cfg.first_stage_config.from_pretrained = None
model_cfg.first_stage_config._target_ = (
'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL'
)

torch.backends.cuda.matmul.allow_tf32 = True
trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference(
Expand Down
47 changes: 35 additions & 12 deletions nemo/collections/multimodal/modules/stable_diffusion/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from inspect import isfunction

import torch
Expand All @@ -21,6 +22,13 @@
from torch import einsum, nn
from torch._dynamo import disable

if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1":
from nemo.gn_native import GroupNormNormlization as GroupNorm
else:
from apex.contrib.group_norm import GroupNorm

from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP

from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import (
AdapterName,
Expand Down Expand Up @@ -95,13 +103,19 @@ def forward(self, x):


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=False):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)

self.net = nn.Sequential(project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))
if use_te:
activation = 'gelu' if not glu else 'geglu'
# TODO: more parameters to be confirmed, dropout, seq_length
self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,)
else:
norm = nn.LayerNorm(dim)
project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(norm, project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))

def forward(self, x):
return self.net(x)
Expand Down Expand Up @@ -224,6 +238,7 @@ def __init__(
dropout=0.0,
use_flash_attention=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()

Expand All @@ -237,10 +252,16 @@ def __init__(
self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_v = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)

if use_te:
self.norm_to_q = LayerNormLinear(query_dim, self.inner_dim, bias=False)
else:
norm = nn.LayerNorm(query_dim)
to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.norm_to_q = nn.Sequential(norm, to_q)

self.to_out = nn.Sequential(
LinearWrapper(self.inner_dim, query_dim, lora_network_alpha=lora_network_alpha), nn.Dropout(dropout)
)
Expand All @@ -255,7 +276,7 @@ def __init__(
def forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
q = self.norm_to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
Expand Down Expand Up @@ -335,6 +356,7 @@ def __init__(
use_flash_attention=False,
disable_self_attn=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
Expand All @@ -346,8 +368,9 @@ def __init__(
use_flash_attention=use_flash_attention,
context_dim=context_dim if self.disable_self_attn else None,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, use_te=use_te)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
Expand All @@ -356,10 +379,8 @@ def __init__(
dropout=dropout,
use_flash_attention=use_flash_attention,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint

def forward(self, x, context=None):
Expand All @@ -369,9 +390,9 @@ def forward(self, x, context=None):
return self._forward(x, context)

def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
x = self.attn1(x, context=context if self.disable_self_attn else None) + x
x = self.attn2(x, context=context) + x
x = self.ff(x) + x
return x


Expand All @@ -397,6 +418,7 @@ def __init__(
use_checkpoint=False,
use_flash_attention=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
Expand All @@ -422,6 +444,7 @@ def __init__(
use_flash_attention=use_flash_attention,
disable_self_attn=disable_self_attn,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
)
for d in range(depth)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from abc import abstractmethod
from contextlib import nullcontext

import numpy as np
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F

# FP8 related import
import transformer_engine

from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer
from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import (
avg_pool_nd,
Expand All @@ -45,6 +50,39 @@ def convert_module_to_fp16(module):
convert_module_to_dtype(module, torch.float16)


def convert_module_to_fp32(module):
convert_module_to_dtype(module, torch.float32)


def convert_module_to_fp8(model):
def _set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
setattr(cur_mod, tokens[-1], module)

import copy

from transformer_engine.pytorch.module import Linear as te_Linear

for n, v in model.named_modules():
if isinstance(v, torch.nn.Linear):
# if n in ['class_embed', 'bbox_embed.layers.0', 'bbox_embed.layers.1', 'bbox_embed.layers.2']: continue
logging.info(f'[INFO] Replace Linear: {n}, weight: {v.weight.shape}')
if v.bias is None:
is_bias = False
else:
is_bias = True
newlinear = te_Linear(v.in_features, v.out_features, bias=is_bias)
newlinear.weight = copy.deepcopy(v.weight)
if v.bias is not None:
newlinear.bias = copy.deepcopy(v.bias)
_set_module(model, n, newlinear)


## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
Expand Down Expand Up @@ -471,6 +509,7 @@ def __init__(
use_flash_attention: bool = False,
enable_amp_o2_fp16: bool = False,
lora_network_alpha=None,
use_te_fp8: bool = False,
):
super().__init__()
if use_spatial_transformer:
Expand Down Expand Up @@ -526,6 +565,7 @@ def __init__(
input_block_chans = [model_channels]
ch = model_channels
ds = 1
self.use_te_fp8 = use_te_fp8
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
Expand Down Expand Up @@ -568,6 +608,7 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
)
)
Expand Down Expand Up @@ -633,6 +674,7 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
),
ResBlock(
Expand Down Expand Up @@ -660,6 +702,7 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
resblock_gn_groups=resblock_gn_groups,
)
]
ch = model_channels * mult
Expand Down Expand Up @@ -690,6 +733,7 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
)
)
Expand Down Expand Up @@ -746,6 +790,32 @@ def __init__(
if enable_amp_o2_fp16:
self.convert_to_fp16()

elif self.use_te_fp8:
assert enable_amp_o2_fp16 is False, "fp8 training can't work with fp16 O2 amp recipe"
convert_module_to_fp8(self)

fp8_margin = int(os.getenv("FP8_MARGIN", '0'))
fp8_interval = int(os.getenv("FP8_INTERVAL", '1'))
fp8_format = os.getenv("FP8_FORMAT", "hybrid")
fp8_amax_history_len = int(os.getenv("FP8_HISTORY_LEN", '1024'))
fp8_amax_compute_algo = os.getenv("FP8_COMPUTE_ALGO", 'max')
fp8_wgrad = os.getenv("FP8_WGRAD", '1') == '1'

fp8_format_dict = {
'hybrid': transformer_engine.common.recipe.Format.HYBRID,
'e4m3': transformer_engine.common.recipe.Format.E4M3,
}
fp8_format = fp8_format_dict[fp8_format]

self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=fp8_margin,
interval=fp8_interval,
fp8_format=fp8_format,
amax_history_len=fp8_amax_history_len,
amax_compute_algo=fp8_amax_compute_algo,
override_linear_precision=(False, False, not fp8_wgrad),
)

def _input_blocks_mapping(self, input_dict):
res_dict = {}
for key_, value_ in input_dict.items():
Expand Down Expand Up @@ -960,7 +1030,7 @@ def convert_to_fp16(self):
"""
self.apply(convert_module_to_fp16)

def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
def _forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.

Expand Down Expand Up @@ -999,6 +1069,13 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
else:
return self.out(h)

def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
with transformer_engine.pytorch.fp8_autocast(
enabled=self.use_te_fp8, fp8_recipe=self.fp8_recipe,
) if self.use_te_fp8 else nullcontext():
out = self._forward(x, timesteps, context, y, **kwargs)
return out


class EncoderUNetModel(nn.Module):
"""
Expand Down
25 changes: 25 additions & 0 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@

try:
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam

HAVE_APEX = True
Expand Down Expand Up @@ -1002,6 +1003,30 @@ def should_process(key):
new_state_dict[key_] = state_dict[key_]
state_dict = new_state_dict

if conf.get('unet_config') and conf.get('unet_config').get('use_te_fp8') == False:
# remove _extra_state in fp8 if there is.
new_state_dict = {}
for key in state_dict.keys():
if 'extra_state' in key:
continue

### LayerNormLinear
# norm_to_q.layer_norm_{weight|bias} -> norm_to_q.0.{weight|bias}
# norm_to_q.weight -> norm_to_q.1.weight
new_key = key.replace('norm_to_q.layer_norm_', 'norm_to_q.0.')
new_key = new_key.replace('norm_to_q.weight', 'norm_to_q.1.weight')

### LayerNormMLP
# ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias}
# ff.net.fc1_{weight|bias} -> ff.net.1.proj.{weight|bias}
# ff.net.fc2_{weight|bias} -> ff.net.3.{weight|bias}
new_key = new_key.replace('ff.net.layer_norm_', 'ff.net.0.')
new_key = new_key.replace('ff.net.fc1_', 'ff.net.1.proj.')
new_key = new_key.replace('ff.net.fc2_', 'ff.net.3.')

new_state_dict[new_key] = state_dict[key]
state_dict = new_state_dict

return state_dict

def _load_state_dict_from_disk(self, model_weights, map_location=None):
Expand Down
Loading
Loading