Skip to content

Commit

Permalink
Fix for train.controlnet.controlnet_v1_5_1node_100steps (NVIDIA#9678)
Browse files Browse the repository at this point in the history
* Fixed bug with controlnet where the model can be assumed to be saved
directly from the unet.

also added a checkpoint converter for SD that converts both SD and SDXL

Signed-off-by: Rohit Jena <[email protected]>

* delete unnecessary comments and messages

Signed-off-by: Rohit Jena <[email protected]>

* reinstate try-catch for loading missing modules

Signed-off-by: Rohit Jena <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

---------

Signed-off-by: Rohit Jena <[email protected]>
Signed-off-by: Victor49152 <[email protected]>
Co-authored-by: Rohit Jena <[email protected]>
Co-authored-by: Ming <[email protected]>
Co-authored-by: Victor49152 <[email protected]>
Signed-off-by: Vivian Chen <[email protected]>
  • Loading branch information
4 people authored and Vivian Chen committed Aug 1, 2024
1 parent 2678eb7 commit 465d7f2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,18 @@ def load_from_unet(self, from_pretrained_unet, from_NeMo=True):
print("Loading unet blocks from sd")

state_dict = torch.load(from_pretrained_unet, map_location='cpu')
state_dict = state_dict['state_dict']
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
model_state_dict = self.state_dict()
model_state_keys = model_state_dict.keys()

re_state_dict = {}
for key_, value_ in state_dict.items():
# check if key is a raw parameter
if key_ in model_state_keys:
re_state_dict[key_] = value_
continue
# prune from model prefix
if key_.startswith('model.model.diffusion_model'):
re_state_dict[key_.replace('model.model.diffusion_model.', '')] = value_
if key_.startswith('model.diffusion_model'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,8 @@ def __init__(
)
logging.info(f"Missing keys: {missing_key}")
logging.info(f"Unexpected keys: {unexpected_keys}")
else:
logging.info(f"There are no missing keys, model loaded properly!")

if unet_precision == "fp16-mixed": # AMP O2
self.convert_to_fp16()
Expand Down Expand Up @@ -1217,6 +1219,7 @@ def _state_key_mapping(self, state_dict: dict):
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False):
state_dict = self._strip_unet_key_prefix(state_dict)
if not from_NeMo:
logging.info("creating state key mapping from HF")
state_dict = self._state_key_mapping(state_dict)
state_dict = self._legacy_unet_ckpt_mapping(state_dict)

Expand Down
66 changes: 33 additions & 33 deletions scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

r"""
Conversion script to convert HuggingFace Starcoder2 checkpoints into nemo checkpoint.
Conversion script to convert HuggingFace StableDiffusion checkpoints into nemo checkpoint.
Example to run this conversion script:
python convert_hf_starcoder2_to_nemo.py \
--input_name_or_path <path_to_sc2_checkpoints_folder> \
--output_path <path_to_output_nemo_file>
--output_path <path_to_output_nemo_file> --model <unet|vae>
"""

import os
from argparse import ArgumentParser

import numpy as np
Expand All @@ -29,8 +30,6 @@

from nemo.utils import logging

intkey = lambda x: int(x)


def filter_keys(rule, dict):
keys = list(dict.keys())
Expand Down Expand Up @@ -95,7 +94,7 @@ def __getitem__(self, name: str):
return None
# either more than 1 match (error) or exactly 1 (success)
if np.sum(p_flag) > 1:
print(f"error: multiple matches of key {name} with {keys}")
logging.warning(f"warning: multiple matches of key {name} with {keys}")
else:
i = np.where(p_flag)[0][0]
n = numdots(keys[i])
Expand Down Expand Up @@ -130,14 +129,9 @@ def get_args():
return args


def make_tiny_config(config):
'''dial down the config file to make things tractable'''
# TODO
return config


def load_hf_ckpt(in_dir, args):
ckpt = {}
assert os.path.isdir(in_dir), "Currently supports only directories with a safetensor file in it."
with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f:
for k in f.keys():
ckpt[k] = f.get_tensor(k)
Expand All @@ -161,9 +155,9 @@ def sanity_check(hf_tree, hf_unet, nemo_unet):
# check if i'm introducing new keys
for hfk, nk in hf_to_nemo_mapping(hf_tree).items():
if nk not in nemo_unet.keys():
print(nk)
logging.info(nk)
if hfk not in hf_unet.keys():
print(hfk)
logging.info(hfk)


def convert_input_keys(hf_tree: SegTree):
Expand All @@ -174,7 +168,7 @@ def convert_input_keys(hf_tree: SegTree):
# start counting blocks from now on
nemo_inp_blk = 1
down_blocks = hf_tree['down_blocks']
down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=intkey)
down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=int)
for downblockid in down_blocks_keys:
block = down_blocks[str(downblockid)]
# compute number of resnets, attentions, downsamplers in this block
Expand All @@ -183,14 +177,14 @@ def convert_input_keys(hf_tree: SegTree):
downsamplers = block.nodes.get('downsamplers', SegTree())

if len(attentions) == 0: # no attentions, this is a DownBlock2d
for resid in sorted(list(resnets.nodes.keys()), key=intkey):
for resid in sorted(list(resnets.nodes.keys()), key=int):
resid = str(resid)
resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0"
map_resnet_block(resnets[resid])
nemo_inp_blk += 1
elif len(attentions) == len(resnets):
# there are attention blocks here -- each resnet+attention becomes a block
for resid in sorted(list(resnets.nodes.keys()), key=intkey):
for resid in sorted(list(resnets.nodes.keys()), key=int):
resid = str(resid)
resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0"
map_resnet_block(resnets[resid])
Expand All @@ -199,7 +193,6 @@ def convert_input_keys(hf_tree: SegTree):
nemo_inp_blk += 1
else:
logging.warning("number of attention blocks is not the same as resnets - whats going on?")

# if there is a downsampler, then also append it
if len(downsamplers) > 0:
for k in downsamplers.nodes.keys():
Expand All @@ -217,10 +210,9 @@ def clean_convert_names(tree):
def map_attention_block(att_tree: SegTree):
'''this HF tree can either be an AttentionBlock or a DualAttention block
currently assumed AttentionBlock
'''

# TODO (rohit): Add check for dual attention block
# TODO(@rohitrango): Add check for dual attention block, but this works for both SD and SDXL
def check_att_type(tree):
return "att_block"

Expand All @@ -237,7 +229,7 @@ def check_att_type(tree):
dup_convert_name_recursive(tblock['norm1'], 'attn1.norm')
dup_convert_name_recursive(tblock['norm2'], 'attn2.norm')
dup_convert_name_recursive(tblock['norm3'], 'ff.net.0')
# map ff module
# map ff
tblock['ff'].convert_name = "ff"
tblock['ff.net'].convert_name = 'net'
dup_convert_name_recursive(tblock['ff.net.0'], '1')
Expand Down Expand Up @@ -272,12 +264,16 @@ def hf_to_nemo_mapping(tree: SegTree):

def convert_cond_keys(tree: SegTree):
# map all conditioning keys
tree['add_embedding'].convert_name = 'label_emb.0'
dup_convert_name_recursive(tree['add_embedding.linear_1'], '0')
dup_convert_name_recursive(tree['add_embedding.linear_2'], '2')
tree['time_embedding'].convert_name = 'time_embed'
dup_convert_name_recursive(tree['time_embedding.linear_1'], '0')
dup_convert_name_recursive(tree['time_embedding.linear_2'], '2')
if tree.nodes.get("add_embedding"):
logging.info("Add embedding found...")
tree['add_embedding'].convert_name = 'label_emb.0'
dup_convert_name_recursive(tree['add_embedding.linear_1'], '0')
dup_convert_name_recursive(tree['add_embedding.linear_2'], '2')
if tree.nodes.get("time_embedding"):
logging.info("Time embedding found...")
tree['time_embedding'].convert_name = 'time_embed'
dup_convert_name_recursive(tree['time_embedding.linear_1'], '0')
dup_convert_name_recursive(tree['time_embedding.linear_2'], '2')


def convert_middle_keys(tree: SegTree):
Expand All @@ -298,7 +294,7 @@ def convert_output_keys(hf_tree: SegTree):
'''output keys is similar to input keys'''
nemo_inp_blk = 0
up_blocks = hf_tree['up_blocks']
up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=intkey)
up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=int)

for downblockid in up_blocks_keys:
block = up_blocks[str(downblockid)]
Expand All @@ -307,16 +303,16 @@ def convert_output_keys(hf_tree: SegTree):
attentions = block.nodes.get('attentions', SegTree())
upsamplers = block.nodes.get('upsamplers', SegTree())

if len(attentions) == 0: # no attentions, this is a DownBlock2d
for resid in sorted(list(resnets.nodes.keys()), key=intkey):
if len(attentions) == 0: # no attentions, this is a UpBlock2D
for resid in sorted(list(resnets.nodes.keys()), key=int):
resid = str(resid)
resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0"
map_resnet_block(resnets[resid])
nemo_inp_blk += 1

elif len(attentions) == len(resnets):
# there are attention blocks here -- each resnet+attention becomes a block
for resid in sorted(list(resnets.nodes.keys()), key=intkey):
for resid in sorted(list(resnets.nodes.keys()), key=int):
resid = str(resid)
resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0"
map_resnet_block(resnets[resid])
Expand All @@ -326,11 +322,13 @@ def convert_output_keys(hf_tree: SegTree):
else:
logging.warning("number of attention blocks is not the same as resnets - whats going on?")

# if there is a downsampler, then also append it
# if there is a upsampler, then also append it
if len(upsamplers) > 0:
# for k in upsamplers.nodes.keys():
nemo_inp_blk -= 1
upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.2"
upsamplenum = (
1 if len(attentions) == 0 else 2
) # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD)
upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.{upsamplenum}"
dup_convert_name_recursive(upsamplers['0.conv'], 'conv')
nemo_inp_blk += 1

Expand Down Expand Up @@ -387,6 +385,7 @@ def convert_decoder(hf_tree: SegTree):
decoder['mid_block'].convert_name = 'mid'
dup_convert_name_recursive(decoder[f'mid_block.resnets.0'], 'block_1')
dup_convert_name_recursive(decoder[f'mid_block.resnets.1'], 'block_2')
# attention blocks
att = decoder['mid_block.attentions.0']
att.convert_name = 'attn_1'
dup_convert_name_recursive(att['group_norm'], 'norm')
Expand Down Expand Up @@ -443,6 +442,7 @@ def convert(args):

for hf_key, nemo_key in mapping.items():
nemo_ckpt[nemo_key] = hf_ckpt[hf_key]
# save this
torch.save(nemo_ckpt, args.output_path)
logging.info(f"Saved nemo file to {args.output_path}")

Expand Down

0 comments on commit 465d7f2

Please sign in to comment.