Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for train.controlnet.controlnet_v1_5_1node_100steps #9678

Merged
merged 6 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,18 @@ def load_from_unet(self, from_pretrained_unet, from_NeMo=True):
print("Loading unet blocks from sd")

state_dict = torch.load(from_pretrained_unet, map_location='cpu')
state_dict = state_dict['state_dict']
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
model_state_dict = self.state_dict()
model_state_keys = model_state_dict.keys()

re_state_dict = {}
for key_, value_ in state_dict.items():
# check if key is a raw parameter
if key_ in model_state_keys:
re_state_dict[key_] = value_
continue
# prune from model prefix
if key_.startswith('model.model.diffusion_model'):
re_state_dict[key_.replace('model.model.diffusion_model.', '')] = value_
if key_.startswith('model.diffusion_model'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,8 @@ def __init__(
)
logging.info(f"Missing keys: {missing_key}")
logging.info(f"Unexpected keys: {unexpected_keys}")
else:
logging.info(f"There are no missing keys, model loaded properly!")

if unet_precision == "fp16-mixed": # AMP O2
self.convert_to_fp16()
Expand Down Expand Up @@ -1217,6 +1219,7 @@ def _state_key_mapping(self, state_dict: dict):
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False):
state_dict = self._strip_unet_key_prefix(state_dict)
if not from_NeMo:
logging.info("creating state key mapping from HF")
state_dict = self._state_key_mapping(state_dict)
state_dict = self._legacy_unet_ckpt_mapping(state_dict)

Expand Down
66 changes: 33 additions & 33 deletions scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

r"""
Conversion script to convert HuggingFace Starcoder2 checkpoints into nemo checkpoint.
Conversion script to convert HuggingFace StableDiffusion checkpoints into nemo checkpoint.
Example to run this conversion script:
python convert_hf_starcoder2_to_nemo.py \
--input_name_or_path <path_to_sc2_checkpoints_folder> \
--output_path <path_to_output_nemo_file>
--output_path <path_to_output_nemo_file> --model <unet|vae>
"""

import os
from argparse import ArgumentParser

import numpy as np
Expand All @@ -29,8 +30,6 @@

from nemo.utils import logging

intkey = lambda x: int(x)


def filter_keys(rule, dict):
keys = list(dict.keys())
Expand Down Expand Up @@ -95,7 +94,7 @@ def __getitem__(self, name: str):
return None
# either more than 1 match (error) or exactly 1 (success)
if np.sum(p_flag) > 1:
print(f"error: multiple matches of key {name} with {keys}")
logging.warning(f"warning: multiple matches of key {name} with {keys}")
else:
i = np.where(p_flag)[0][0]
n = numdots(keys[i])
Expand Down Expand Up @@ -130,14 +129,9 @@ def get_args():
return args


def make_tiny_config(config):
'''dial down the config file to make things tractable'''
# TODO
return config


def load_hf_ckpt(in_dir, args):
ckpt = {}
assert os.path.isdir(in_dir), "Currently supports only directories with a safetensor file in it."
with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f:
for k in f.keys():
ckpt[k] = f.get_tensor(k)
Expand All @@ -161,9 +155,9 @@ def sanity_check(hf_tree, hf_unet, nemo_unet):
# check if i'm introducing new keys
for hfk, nk in hf_to_nemo_mapping(hf_tree).items():
if nk not in nemo_unet.keys():
print(nk)
logging.info(nk)
if hfk not in hf_unet.keys():
print(hfk)
logging.info(hfk)


def convert_input_keys(hf_tree: SegTree):
Expand All @@ -174,7 +168,7 @@ def convert_input_keys(hf_tree: SegTree):
# start counting blocks from now on
nemo_inp_blk = 1
down_blocks = hf_tree['down_blocks']
down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=intkey)
down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=int)
for downblockid in down_blocks_keys:
block = down_blocks[str(downblockid)]
# compute number of resnets, attentions, downsamplers in this block
Expand All @@ -183,14 +177,14 @@ def convert_input_keys(hf_tree: SegTree):
downsamplers = block.nodes.get('downsamplers', SegTree())

if len(attentions) == 0: # no attentions, this is a DownBlock2d
for resid in sorted(list(resnets.nodes.keys()), key=intkey):
for resid in sorted(list(resnets.nodes.keys()), key=int):
resid = str(resid)
resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0"
map_resnet_block(resnets[resid])
nemo_inp_blk += 1
elif len(attentions) == len(resnets):
# there are attention blocks here -- each resnet+attention becomes a block
for resid in sorted(list(resnets.nodes.keys()), key=intkey):
for resid in sorted(list(resnets.nodes.keys()), key=int):
resid = str(resid)
resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0"
map_resnet_block(resnets[resid])
Expand All @@ -199,7 +193,6 @@ def convert_input_keys(hf_tree: SegTree):
nemo_inp_blk += 1
else:
logging.warning("number of attention blocks is not the same as resnets - whats going on?")

# if there is a downsampler, then also append it
if len(downsamplers) > 0:
for k in downsamplers.nodes.keys():
Expand All @@ -217,10 +210,9 @@ def clean_convert_names(tree):
def map_attention_block(att_tree: SegTree):
'''this HF tree can either be an AttentionBlock or a DualAttention block
currently assumed AttentionBlock

'''

# TODO (rohit): Add check for dual attention block
# TODO(@rohitrango): Add check for dual attention block, but this works for both SD and SDXL
def check_att_type(tree):
return "att_block"

Expand All @@ -237,7 +229,7 @@ def check_att_type(tree):
dup_convert_name_recursive(tblock['norm1'], 'attn1.norm')
dup_convert_name_recursive(tblock['norm2'], 'attn2.norm')
dup_convert_name_recursive(tblock['norm3'], 'ff.net.0')
# map ff module
# map ff
tblock['ff'].convert_name = "ff"
tblock['ff.net'].convert_name = 'net'
dup_convert_name_recursive(tblock['ff.net.0'], '1')
Expand Down Expand Up @@ -272,12 +264,16 @@ def hf_to_nemo_mapping(tree: SegTree):

def convert_cond_keys(tree: SegTree):
# map all conditioning keys
tree['add_embedding'].convert_name = 'label_emb.0'
dup_convert_name_recursive(tree['add_embedding.linear_1'], '0')
dup_convert_name_recursive(tree['add_embedding.linear_2'], '2')
tree['time_embedding'].convert_name = 'time_embed'
dup_convert_name_recursive(tree['time_embedding.linear_1'], '0')
dup_convert_name_recursive(tree['time_embedding.linear_2'], '2')
if tree.nodes.get("add_embedding"):
logging.info("Add embedding found...")
tree['add_embedding'].convert_name = 'label_emb.0'
dup_convert_name_recursive(tree['add_embedding.linear_1'], '0')
dup_convert_name_recursive(tree['add_embedding.linear_2'], '2')
if tree.nodes.get("time_embedding"):
logging.info("Time embedding found...")
tree['time_embedding'].convert_name = 'time_embed'
dup_convert_name_recursive(tree['time_embedding.linear_1'], '0')
dup_convert_name_recursive(tree['time_embedding.linear_2'], '2')


def convert_middle_keys(tree: SegTree):
Expand All @@ -298,7 +294,7 @@ def convert_output_keys(hf_tree: SegTree):
'''output keys is similar to input keys'''
nemo_inp_blk = 0
up_blocks = hf_tree['up_blocks']
up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=intkey)
up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=int)

for downblockid in up_blocks_keys:
block = up_blocks[str(downblockid)]
Expand All @@ -307,16 +303,16 @@ def convert_output_keys(hf_tree: SegTree):
attentions = block.nodes.get('attentions', SegTree())
upsamplers = block.nodes.get('upsamplers', SegTree())

if len(attentions) == 0: # no attentions, this is a DownBlock2d
for resid in sorted(list(resnets.nodes.keys()), key=intkey):
if len(attentions) == 0: # no attentions, this is a UpBlock2D
for resid in sorted(list(resnets.nodes.keys()), key=int):
resid = str(resid)
resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0"
map_resnet_block(resnets[resid])
nemo_inp_blk += 1

elif len(attentions) == len(resnets):
# there are attention blocks here -- each resnet+attention becomes a block
for resid in sorted(list(resnets.nodes.keys()), key=intkey):
for resid in sorted(list(resnets.nodes.keys()), key=int):
resid = str(resid)
resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0"
map_resnet_block(resnets[resid])
Expand All @@ -326,11 +322,13 @@ def convert_output_keys(hf_tree: SegTree):
else:
logging.warning("number of attention blocks is not the same as resnets - whats going on?")

# if there is a downsampler, then also append it
# if there is a upsampler, then also append it
if len(upsamplers) > 0:
# for k in upsamplers.nodes.keys():
nemo_inp_blk -= 1
upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.2"
upsamplenum = (
1 if len(attentions) == 0 else 2
) # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD)
upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.{upsamplenum}"
dup_convert_name_recursive(upsamplers['0.conv'], 'conv')
nemo_inp_blk += 1

Expand Down Expand Up @@ -387,6 +385,7 @@ def convert_decoder(hf_tree: SegTree):
decoder['mid_block'].convert_name = 'mid'
dup_convert_name_recursive(decoder[f'mid_block.resnets.0'], 'block_1')
dup_convert_name_recursive(decoder[f'mid_block.resnets.1'], 'block_2')
# attention blocks
att = decoder['mid_block.attentions.0']
att.convert_name = 'attn_1'
dup_convert_name_recursive(att['group_norm'], 'norm')
Expand Down Expand Up @@ -443,6 +442,7 @@ def convert(args):

for hf_key, nemo_key in mapping.items():
nemo_ckpt[nemo_key] = hf_ckpt[hf_key]
# save this
torch.save(nemo_ckpt, args.output_path)
logging.info(f"Saved nemo file to {args.output_path}")

Expand Down
Loading