diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py index fc661d91ab61..65e31b5343de 100644 --- a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py @@ -550,11 +550,18 @@ def load_from_unet(self, from_pretrained_unet, from_NeMo=True): print("Loading unet blocks from sd") state_dict = torch.load(from_pretrained_unet, map_location='cpu') - state_dict = state_dict['state_dict'] + if 'state_dict' in state_dict.keys(): + state_dict = state_dict['state_dict'] model_state_dict = self.state_dict() + model_state_keys = model_state_dict.keys() re_state_dict = {} for key_, value_ in state_dict.items(): + # check if key is a raw parameter + if key_ in model_state_keys: + re_state_dict[key_] = value_ + continue + # prune from model prefix if key_.startswith('model.model.diffusion_model'): re_state_dict[key_.replace('model.model.diffusion_model.', '')] = value_ if key_.startswith('model.diffusion_model'): diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index eb449c5406b9..b94624b33ba2 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -971,6 +971,8 @@ def __init__( ) logging.info(f"Missing keys: {missing_key}") logging.info(f"Unexpected keys: {unexpected_keys}") + else: + logging.info(f"There are no missing keys, model loaded properly!") if unet_precision == "fp16-mixed": # AMP O2 self.convert_to_fp16() @@ -1217,6 +1219,7 @@ def _state_key_mapping(self, state_dict: dict): def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False): state_dict = self._strip_unet_key_prefix(state_dict) if not from_NeMo: + logging.info("creating state key mapping from HF") state_dict = self._state_key_mapping(state_dict) state_dict = self._legacy_unet_ckpt_mapping(state_dict) diff --git a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py index 67bc975708d0..ff10dab4bc90 100644 --- a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py @@ -13,13 +13,14 @@ # limitations under the License. r""" -Conversion script to convert HuggingFace Starcoder2 checkpoints into nemo checkpoint. +Conversion script to convert HuggingFace StableDiffusion checkpoints into nemo checkpoint. Example to run this conversion script: python convert_hf_starcoder2_to_nemo.py \ --input_name_or_path \ - --output_path + --output_path --model """ +import os from argparse import ArgumentParser import numpy as np @@ -29,8 +30,6 @@ from nemo.utils import logging -intkey = lambda x: int(x) - def filter_keys(rule, dict): keys = list(dict.keys()) @@ -95,7 +94,7 @@ def __getitem__(self, name: str): return None # either more than 1 match (error) or exactly 1 (success) if np.sum(p_flag) > 1: - print(f"error: multiple matches of key {name} with {keys}") + logging.warning(f"warning: multiple matches of key {name} with {keys}") else: i = np.where(p_flag)[0][0] n = numdots(keys[i]) @@ -130,14 +129,9 @@ def get_args(): return args -def make_tiny_config(config): - '''dial down the config file to make things tractable''' - # TODO - return config - - def load_hf_ckpt(in_dir, args): ckpt = {} + assert os.path.isdir(in_dir), "Currently supports only directories with a safetensor file in it." with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f: for k in f.keys(): ckpt[k] = f.get_tensor(k) @@ -161,9 +155,9 @@ def sanity_check(hf_tree, hf_unet, nemo_unet): # check if i'm introducing new keys for hfk, nk in hf_to_nemo_mapping(hf_tree).items(): if nk not in nemo_unet.keys(): - print(nk) + logging.info(nk) if hfk not in hf_unet.keys(): - print(hfk) + logging.info(hfk) def convert_input_keys(hf_tree: SegTree): @@ -174,7 +168,7 @@ def convert_input_keys(hf_tree: SegTree): # start counting blocks from now on nemo_inp_blk = 1 down_blocks = hf_tree['down_blocks'] - down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=intkey) + down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=int) for downblockid in down_blocks_keys: block = down_blocks[str(downblockid)] # compute number of resnets, attentions, downsamplers in this block @@ -183,14 +177,14 @@ def convert_input_keys(hf_tree: SegTree): downsamplers = block.nodes.get('downsamplers', SegTree()) if len(attentions) == 0: # no attentions, this is a DownBlock2d - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) nemo_inp_blk += 1 elif len(attentions) == len(resnets): # there are attention blocks here -- each resnet+attention becomes a block - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) @@ -199,7 +193,6 @@ def convert_input_keys(hf_tree: SegTree): nemo_inp_blk += 1 else: logging.warning("number of attention blocks is not the same as resnets - whats going on?") - # if there is a downsampler, then also append it if len(downsamplers) > 0: for k in downsamplers.nodes.keys(): @@ -217,10 +210,9 @@ def clean_convert_names(tree): def map_attention_block(att_tree: SegTree): '''this HF tree can either be an AttentionBlock or a DualAttention block currently assumed AttentionBlock - ''' - # TODO (rohit): Add check for dual attention block + # TODO(@rohitrango): Add check for dual attention block, but this works for both SD and SDXL def check_att_type(tree): return "att_block" @@ -237,7 +229,7 @@ def check_att_type(tree): dup_convert_name_recursive(tblock['norm1'], 'attn1.norm') dup_convert_name_recursive(tblock['norm2'], 'attn2.norm') dup_convert_name_recursive(tblock['norm3'], 'ff.net.0') - # map ff module + # map ff tblock['ff'].convert_name = "ff" tblock['ff.net'].convert_name = 'net' dup_convert_name_recursive(tblock['ff.net.0'], '1') @@ -272,12 +264,16 @@ def hf_to_nemo_mapping(tree: SegTree): def convert_cond_keys(tree: SegTree): # map all conditioning keys - tree['add_embedding'].convert_name = 'label_emb.0' - dup_convert_name_recursive(tree['add_embedding.linear_1'], '0') - dup_convert_name_recursive(tree['add_embedding.linear_2'], '2') - tree['time_embedding'].convert_name = 'time_embed' - dup_convert_name_recursive(tree['time_embedding.linear_1'], '0') - dup_convert_name_recursive(tree['time_embedding.linear_2'], '2') + if tree.nodes.get("add_embedding"): + logging.info("Add embedding found...") + tree['add_embedding'].convert_name = 'label_emb.0' + dup_convert_name_recursive(tree['add_embedding.linear_1'], '0') + dup_convert_name_recursive(tree['add_embedding.linear_2'], '2') + if tree.nodes.get("time_embedding"): + logging.info("Time embedding found...") + tree['time_embedding'].convert_name = 'time_embed' + dup_convert_name_recursive(tree['time_embedding.linear_1'], '0') + dup_convert_name_recursive(tree['time_embedding.linear_2'], '2') def convert_middle_keys(tree: SegTree): @@ -298,7 +294,7 @@ def convert_output_keys(hf_tree: SegTree): '''output keys is similar to input keys''' nemo_inp_blk = 0 up_blocks = hf_tree['up_blocks'] - up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=intkey) + up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=int) for downblockid in up_blocks_keys: block = up_blocks[str(downblockid)] @@ -307,8 +303,8 @@ def convert_output_keys(hf_tree: SegTree): attentions = block.nodes.get('attentions', SegTree()) upsamplers = block.nodes.get('upsamplers', SegTree()) - if len(attentions) == 0: # no attentions, this is a DownBlock2d - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + if len(attentions) == 0: # no attentions, this is a UpBlock2D + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) @@ -316,7 +312,7 @@ def convert_output_keys(hf_tree: SegTree): elif len(attentions) == len(resnets): # there are attention blocks here -- each resnet+attention becomes a block - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) @@ -326,11 +322,13 @@ def convert_output_keys(hf_tree: SegTree): else: logging.warning("number of attention blocks is not the same as resnets - whats going on?") - # if there is a downsampler, then also append it + # if there is a upsampler, then also append it if len(upsamplers) > 0: - # for k in upsamplers.nodes.keys(): nemo_inp_blk -= 1 - upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.2" + upsamplenum = ( + 1 if len(attentions) == 0 else 2 + ) # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD) + upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.{upsamplenum}" dup_convert_name_recursive(upsamplers['0.conv'], 'conv') nemo_inp_blk += 1 @@ -387,6 +385,7 @@ def convert_decoder(hf_tree: SegTree): decoder['mid_block'].convert_name = 'mid' dup_convert_name_recursive(decoder[f'mid_block.resnets.0'], 'block_1') dup_convert_name_recursive(decoder[f'mid_block.resnets.1'], 'block_2') + # attention blocks att = decoder['mid_block.attentions.0'] att.convert_name = 'attn_1' dup_convert_name_recursive(att['group_norm'], 'norm') @@ -443,6 +442,7 @@ def convert(args): for hf_key, nemo_key in mapping.items(): nemo_ckpt[nemo_key] = hf_ckpt[hf_key] + # save this torch.save(nemo_ckpt, args.output_path) logging.info(f"Saved nemo file to {args.output_path}")