Skip to content

Commit

Permalink
Revert FP8 integration (#8520)
Browse files Browse the repository at this point in the history
* Revert FP8 integration

Signed-off-by: Mingyuan Ma <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Mingyuan Ma <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Victor49152 and pre-commit-ci[bot] authored Feb 27, 2024
1 parent e6b7354 commit ae9a2aa
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ model:
use_flash_attention: True
enable_amp_o2_fp16: False
resblock_gn_groups: 32
use_te_fp8: False

first_stage_config:
_target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL
Expand Down Expand Up @@ -192,7 +191,7 @@ model:
synthetic_data_length: 10000
train:
dataset_path:
- /datasets/coyo/wdinfo/coyo-700m/wdinfo-selene.pkl
- /datasets/coyo/test.pkl
augmentations:
resize_smallest_side: 512
center_crop_h_w: 512, 512
Expand Down
47 changes: 12 additions & 35 deletions nemo/collections/multimodal/modules/stable_diffusion/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from inspect import isfunction

import torch
Expand All @@ -22,13 +21,6 @@
from torch import einsum, nn
from torch._dynamo import disable

if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1":
from nemo.gn_native import GroupNormNormlization as GroupNorm
else:
from apex.contrib.group_norm import GroupNorm

from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP

from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import (
AdapterName,
Expand Down Expand Up @@ -103,19 +95,13 @@ def forward(self, x):


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=False):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)

if use_te:
activation = 'gelu' if not glu else 'geglu'
# TODO: more parameters to be confirmed, dropout, seq_length
self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,)
else:
norm = nn.LayerNorm(dim)
project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(norm, project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))
self.net = nn.Sequential(project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))

def forward(self, x):
return self.net(x)
Expand Down Expand Up @@ -238,7 +224,6 @@ def __init__(
dropout=0.0,
use_flash_attention=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()

Expand All @@ -252,16 +237,10 @@ def __init__(
self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_v = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)

if use_te:
self.norm_to_q = LayerNormLinear(query_dim, self.inner_dim, bias=False)
else:
norm = nn.LayerNorm(query_dim)
to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.norm_to_q = nn.Sequential(norm, to_q)

self.to_out = nn.Sequential(
LinearWrapper(self.inner_dim, query_dim, lora_network_alpha=lora_network_alpha), nn.Dropout(dropout)
)
Expand All @@ -276,7 +255,7 @@ def __init__(
def forward(self, x, context=None, mask=None):
h = self.heads

q = self.norm_to_q(x)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
Expand Down Expand Up @@ -356,7 +335,6 @@ def __init__(
use_flash_attention=False,
disable_self_attn=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
Expand All @@ -368,9 +346,8 @@ def __init__(
use_flash_attention=use_flash_attention,
context_dim=context_dim if self.disable_self_attn else None,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, use_te=use_te)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
Expand All @@ -379,8 +356,10 @@ def __init__(
dropout=dropout,
use_flash_attention=use_flash_attention,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint

def forward(self, x, context=None):
Expand All @@ -390,9 +369,9 @@ def forward(self, x, context=None):
return self._forward(x, context)

def _forward(self, x, context=None):
x = self.attn1(x, context=context if self.disable_self_attn else None) + x
x = self.attn2(x, context=context) + x
x = self.ff(x) + x
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x


Expand All @@ -418,7 +397,6 @@ def __init__(
use_checkpoint=False,
use_flash_attention=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
Expand All @@ -444,7 +422,6 @@ def __init__(
use_flash_attention=use_flash_attention,
disable_self_attn=disable_self_attn,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
)
for d in range(depth)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from abc import abstractmethod
from contextlib import nullcontext

import numpy as np
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F

# FP8 related import
import transformer_engine

from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer
from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import (
avg_pool_nd,
Expand All @@ -50,39 +45,6 @@ def convert_module_to_fp16(module):
convert_module_to_dtype(module, torch.float16)


def convert_module_to_fp32(module):
convert_module_to_dtype(module, torch.float32)


def convert_module_to_fp8(model):
def _set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
setattr(cur_mod, tokens[-1], module)

import copy

from transformer_engine.pytorch.module import Linear as te_Linear

for n, v in model.named_modules():
if isinstance(v, torch.nn.Linear):
# if n in ['class_embed', 'bbox_embed.layers.0', 'bbox_embed.layers.1', 'bbox_embed.layers.2']: continue
logging.info(f'[INFO] Replace Linear: {n}, weight: {v.weight.shape}')
if v.bias is None:
is_bias = False
else:
is_bias = True
newlinear = te_Linear(v.in_features, v.out_features, bias=is_bias)
newlinear.weight = copy.deepcopy(v.weight)
if v.bias is not None:
newlinear.bias = copy.deepcopy(v.bias)
_set_module(model, n, newlinear)


## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
Expand Down Expand Up @@ -509,7 +471,6 @@ def __init__(
use_flash_attention: bool = False,
enable_amp_o2_fp16: bool = False,
lora_network_alpha=None,
use_te_fp8: bool = False,
):
super().__init__()
if use_spatial_transformer:
Expand Down Expand Up @@ -565,7 +526,6 @@ def __init__(
input_block_chans = [model_channels]
ch = model_channels
ds = 1
self.use_te_fp8 = use_te_fp8
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
Expand Down Expand Up @@ -608,7 +568,6 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
)
)
Expand Down Expand Up @@ -674,7 +633,6 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
),
ResBlock(
Expand Down Expand Up @@ -702,7 +660,6 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
resblock_gn_groups=resblock_gn_groups,
)
]
ch = model_channels * mult
Expand Down Expand Up @@ -733,7 +690,6 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
)
)
Expand Down Expand Up @@ -790,32 +746,6 @@ def __init__(
if enable_amp_o2_fp16:
self.convert_to_fp16()

elif self.use_te_fp8:
assert enable_amp_o2_fp16 is False, "fp8 training can't work with fp16 O2 amp recipe"
convert_module_to_fp8(self)

fp8_margin = int(os.getenv("FP8_MARGIN", '0'))
fp8_interval = int(os.getenv("FP8_INTERVAL", '1'))
fp8_format = os.getenv("FP8_FORMAT", "hybrid")
fp8_amax_history_len = int(os.getenv("FP8_HISTORY_LEN", '1024'))
fp8_amax_compute_algo = os.getenv("FP8_COMPUTE_ALGO", 'max')
fp8_wgrad = os.getenv("FP8_WGRAD", '1') == '1'

fp8_format_dict = {
'hybrid': transformer_engine.common.recipe.Format.HYBRID,
'e4m3': transformer_engine.common.recipe.Format.E4M3,
}
fp8_format = fp8_format_dict[fp8_format]

self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=fp8_margin,
interval=fp8_interval,
fp8_format=fp8_format,
amax_history_len=fp8_amax_history_len,
amax_compute_algo=fp8_amax_compute_algo,
override_linear_precision=(False, False, not fp8_wgrad),
)

def _input_blocks_mapping(self, input_dict):
res_dict = {}
for key_, value_ in input_dict.items():
Expand Down Expand Up @@ -1030,7 +960,7 @@ def convert_to_fp16(self):
"""
self.apply(convert_module_to_fp16)

def _forward(self, x, timesteps=None, context=None, y=None, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
Expand Down Expand Up @@ -1069,13 +999,6 @@ def _forward(self, x, timesteps=None, context=None, y=None, **kwargs):
else:
return self.out(h)

def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
with transformer_engine.pytorch.fp8_autocast(
enabled=self.use_te_fp8, fp8_recipe=self.fp8_recipe,
) if self.use_te_fp8 else nullcontext():
out = self._forward(x, timesteps, context, y, **kwargs)
return out


class EncoderUNetModel(nn.Module):
"""
Expand Down
24 changes: 0 additions & 24 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,30 +1003,6 @@ def should_process(key):
new_state_dict[key_] = state_dict[key_]
state_dict = new_state_dict

if conf.get('unet_config') and conf.get('unet_config').get('use_te_fp8') == False:
# remove _extra_state in fp8 if there is.
new_state_dict = {}
for key in state_dict.keys():
if 'extra_state' in key:
continue

### LayerNormLinear
# norm_to_q.layer_norm_{weight|bias} -> norm_to_q.0.{weight|bias}
# norm_to_q.weight -> norm_to_q.1.weight
new_key = key.replace('norm_to_q.layer_norm_', 'norm_to_q.0.')
new_key = new_key.replace('norm_to_q.weight', 'norm_to_q.1.weight')

### LayerNormMLP
# ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias}
# ff.net.fc1_{weight|bias} -> ff.net.1.proj.{weight|bias}
# ff.net.fc2_{weight|bias} -> ff.net.3.{weight|bias}
new_key = new_key.replace('ff.net.layer_norm_', 'ff.net.0.')
new_key = new_key.replace('ff.net.fc1_', 'ff.net.1.proj.')
new_key = new_key.replace('ff.net.fc2_', 'ff.net.3.')

new_state_dict[new_key] = state_dict[key]
state_dict = new_state_dict

return state_dict

def _load_state_dict_from_disk(self, model_weights, map_location=None):
Expand Down

0 comments on commit ae9a2aa

Please sign in to comment.