From 2c5e9ecc559f4a337e2a1e501c6d5cfd88629945 Mon Sep 17 00:00:00 2001 From: Rohit Jena Date: Thu, 27 Jun 2024 21:06:36 -0700 Subject: [PATCH 1/4] Fixed bug with controlnet where the model can be assumed to be saved directly from the unet. also added a checkpoint converter for SD that converts both SD and SDXL Signed-off-by: Rohit Jena --- .../text_to_image/controlnet/controlnet.py | 11 +- .../diffusionmodules/openaimodel.py | 14 +- .../convert_stablediffusion_hf_to_nemo.py | 124 +++++++----------- 3 files changed, 66 insertions(+), 83 deletions(-) diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py index fc661d91ab61..cd2c2ca1c661 100644 --- a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py @@ -550,11 +550,18 @@ def load_from_unet(self, from_pretrained_unet, from_NeMo=True): print("Loading unet blocks from sd") state_dict = torch.load(from_pretrained_unet, map_location='cpu') - state_dict = state_dict['state_dict'] + if 'state_dict' in state_dict.keys(): + state_dict = state_dict['state_dict'] model_state_dict = self.state_dict() + model_state_keys = model_state_dict.keys() re_state_dict = {} for key_, value_ in state_dict.items(): + # check if key is a raw parameter + if key_ in model_state_keys: + re_state_dict[key_] = value_ + continue + # prune from model prefix if key_.startswith('model.model.diffusion_model'): re_state_dict[key_.replace('model.model.diffusion_model.', '')] = value_ if key_.startswith('model.diffusion_model'): @@ -593,6 +600,8 @@ def load_from_unet(self, from_pretrained_unet, from_NeMo=True): ) print(f'There is {len(missing_keys)} total missing keys') print("Missing:", missing_keys) + for key in sorted(missing_keys): + print(key) print("Unexpected:", unexpected_keys) else: print("sd blocks loaded successfully") diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index eb449c5406b9..3fb3a8cee3a2 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -971,6 +971,8 @@ def __init__( ) logging.info(f"Missing keys: {missing_key}") logging.info(f"Unexpected keys: {unexpected_keys}") + else: + logging.info(f"There are no missing keys, model loaded properly!") if unet_precision == "fp16-mixed": # AMP O2 self.convert_to_fp16() @@ -1217,6 +1219,7 @@ def _state_key_mapping(self, state_dict: dict): def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False): state_dict = self._strip_unet_key_prefix(state_dict) if not from_NeMo: + print("creating state key mapping from HF") state_dict = self._state_key_mapping(state_dict) state_dict = self._legacy_unet_ckpt_mapping(state_dict) @@ -1242,13 +1245,10 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from ): # GroupNormOpt fuses activation function to one layer, thus the indexing of weights are shifted for following for key_ in missing_keys: - try: - s = key_.split('.') - idx = int(s[-2]) - new_key_ = ".".join(s[:-2] + [str(int(idx + 1))] + [s[-1]]) - state_dict[key_] = state_dict[new_key_] - except: - continue + s = key_.split('.') + idx = int(s[-2]) + new_key_ = ".".join(s[:-2] + [str(int(idx + 1))] + [s[-1]]) + state_dict[key_] = state_dict[new_key_] loaded_keys = list(state_dict.keys()) missing_keys = list(set(expected_keys) - set(loaded_keys)) diff --git a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py index 67bc975708d0..c0a11eca6100 100644 --- a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py @@ -13,50 +13,42 @@ # limitations under the License. r""" -Conversion script to convert HuggingFace Starcoder2 checkpoints into nemo checkpoint. +Conversion script to convert HuggingFace StableDiffusion checkpoints into nemo checkpoint. Example to run this conversion script: python convert_hf_starcoder2_to_nemo.py \ --input_name_or_path \ - --output_path + --output_path --model """ -from argparse import ArgumentParser - +import torch import numpy as np import safetensors +from argparse import ArgumentParser import torch import torch.nn - from nemo.utils import logging - -intkey = lambda x: int(x) - +import os def filter_keys(rule, dict): keys = list(dict.keys()) nd = {k: dict[k] for k in keys if rule(k)} return nd - def map_keys(rule, dict): new = {rule(k): v for k, v in dict.items()} return new - def split_name(name, dots=0): l = name.split(".") - return ".".join(l[: dots + 1]), ".".join(l[dots + 1 :]) - + return ".".join(l[:dots+1]), ".".join(l[dots+1:]) def is_prefix(shortstr, longstr): # is the first string a prefix of the second one return longstr == shortstr or longstr.startswith(shortstr + ".") - def numdots(str): return str.count(".") - class SegTree: def __init__(self): self.nodes = dict() @@ -66,10 +58,10 @@ def __init__(self): def __len__(self): return len(self.nodes) - + def is_leaf(self): return len(self.nodes) == 0 - + def add(self, name, val=0): prefix, subname = split_name(name) if subname == '': @@ -79,10 +71,10 @@ def add(self, name, val=0): if self.nodes.get(prefix) is None: self.nodes[prefix] = SegTree() self.nodes[prefix].add(subname, val) - + def change(self, name, val): self.add(name, val) - + def __getitem__(self, name: str): if hasattr(self, name): return getattr(self, name) @@ -95,7 +87,7 @@ def __getitem__(self, name: str): return None # either more than 1 match (error) or exactly 1 (success) if np.sum(p_flag) > 1: - print(f"error: multiple matches of key {name} with {keys}") + logging.warning(f"warning: multiple matches of key {name} with {keys}") else: i = np.where(p_flag)[0][0] n = numdots(keys[i]) @@ -103,7 +95,6 @@ def __getitem__(self, name: str): return self.nodes[prefix][substr] return val - def model_to_tree(model): keys = list(model.keys()) tree = SegTree() @@ -111,7 +102,6 @@ def model_to_tree(model): tree.add(k, "leaf") return tree - def get_args(): parser = ArgumentParser() parser.add_argument( @@ -125,27 +115,20 @@ def get_args(): parser.add_argument("--precision", type=str, default="32", help="Model precision") parser.add_argument("--model", type=str, default="unet", required=True, choices=['unet', 'vae']) parser.add_argument("--debug", action='store_true', help="Useful for debugging purposes.") - + args = parser.parse_args() return args - -def make_tiny_config(config): - '''dial down the config file to make things tractable''' - # TODO - return config - - def load_hf_ckpt(in_dir, args): ckpt = {} + assert os.path.isdir(in_dir), "Currently supports only directories with a safetensor file in it." with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f: for k in f.keys(): ckpt[k] = f.get_tensor(k) - return args, ckpt - + return args, ckpt def dup_convert_name_recursive(tree: SegTree, convert_name=None): - '''inside this tree, convert all nodes recursively + ''' inside this tree, convert all nodes recursively optionally, convert the name of the root as given by name (if not None) ''' if tree is None: @@ -156,25 +139,23 @@ def dup_convert_name_recursive(tree: SegTree, convert_name=None): for k, v in tree.nodes.items(): dup_convert_name_recursive(v, k) - def sanity_check(hf_tree, hf_unet, nemo_unet): # check if i'm introducing new keys for hfk, nk in hf_to_nemo_mapping(hf_tree).items(): if nk not in nemo_unet.keys(): - print(nk) + logging.info(nk) if hfk not in hf_unet.keys(): - print(hfk) - + logging.info(hfk) def convert_input_keys(hf_tree: SegTree): - '''map the input blocks of huggingface model''' + ''' map the input blocks of huggingface model ''' # map `conv_in` to first input block dup_convert_name_recursive(hf_tree['conv_in'], 'input_blocks.0.0') # start counting blocks from now on nemo_inp_blk = 1 down_blocks = hf_tree['down_blocks'] - down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=intkey) + down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=int) for downblockid in down_blocks_keys: block = down_blocks[str(downblockid)] # compute number of resnets, attentions, downsamplers in this block @@ -182,15 +163,15 @@ def convert_input_keys(hf_tree: SegTree): attentions = block.nodes.get('attentions', SegTree()) downsamplers = block.nodes.get('downsamplers', SegTree()) - if len(attentions) == 0: # no attentions, this is a DownBlock2d - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + if len(attentions) == 0: # no attentions, this is a DownBlock2d + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) nemo_inp_blk += 1 elif len(attentions) == len(resnets): # there are attention blocks here -- each resnet+attention becomes a block - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) @@ -199,7 +180,6 @@ def convert_input_keys(hf_tree: SegTree): nemo_inp_blk += 1 else: logging.warning("number of attention blocks is not the same as resnets - whats going on?") - # if there is a downsampler, then also append it if len(downsamplers) > 0: for k in downsamplers.nodes.keys(): @@ -207,20 +187,16 @@ def convert_input_keys(hf_tree: SegTree): dup_convert_name_recursive(downsamplers[k]['conv'], 'op') nemo_inp_blk += 1 - def clean_convert_names(tree): tree.convert_name = None for k, v in tree.nodes.items(): clean_convert_names(v) - def map_attention_block(att_tree: SegTree): - '''this HF tree can either be an AttentionBlock or a DualAttention block + ''' this HF tree can either be an AttentionBlock or a DualAttention block currently assumed AttentionBlock - ''' - - # TODO (rohit): Add check for dual attention block + # TODO(@rohitrango): Add check for dual attention block, but this works for both SD and SDXL def check_att_type(tree): return "att_block" @@ -237,7 +213,7 @@ def check_att_type(tree): dup_convert_name_recursive(tblock['norm1'], 'attn1.norm') dup_convert_name_recursive(tblock['norm2'], 'attn2.norm') dup_convert_name_recursive(tblock['norm3'], 'ff.net.0') - # map ff module + # map ff tblock['ff'].convert_name = "ff" tblock['ff.net'].convert_name = 'net' dup_convert_name_recursive(tblock['ff.net.0'], '1') @@ -245,9 +221,8 @@ def check_att_type(tree): else: logging.warning("failed to identify type of attention block here.") - def map_resnet_block(resnet_tree: SegTree): - '''this HF tree is supposed to have all the keys for a resnet''' + ''' this HF tree is supposed to have all the keys for a resnet ''' dup_convert_name_recursive(resnet_tree.nodes.get('time_emb_proj'), 'emb_layers.1') dup_convert_name_recursive(resnet_tree['norm1'], 'in_layers.0') dup_convert_name_recursive(resnet_tree['conv1'], 'in_layers.1') @@ -255,7 +230,6 @@ def map_resnet_block(resnet_tree: SegTree): dup_convert_name_recursive(resnet_tree['conv2'], 'out_layers.2') dup_convert_name_recursive(resnet_tree.nodes.get('conv_shortcut'), 'skip_connection') - def hf_to_nemo_mapping(tree: SegTree): mapping = {} for nodename, subtree in tree.nodes.items(): @@ -269,19 +243,21 @@ def hf_to_nemo_mapping(tree: SegTree): mapping[nodename + "." + k] = convert_name + v return mapping - def convert_cond_keys(tree: SegTree): # map all conditioning keys - tree['add_embedding'].convert_name = 'label_emb.0' - dup_convert_name_recursive(tree['add_embedding.linear_1'], '0') - dup_convert_name_recursive(tree['add_embedding.linear_2'], '2') - tree['time_embedding'].convert_name = 'time_embed' - dup_convert_name_recursive(tree['time_embedding.linear_1'], '0') - dup_convert_name_recursive(tree['time_embedding.linear_2'], '2') - + if tree.nodes.get("add_embedding"): + logging.info("Add embedding found...") + tree['add_embedding'].convert_name = 'label_emb.0' + dup_convert_name_recursive(tree['add_embedding.linear_1'], '0') + dup_convert_name_recursive(tree['add_embedding.linear_2'], '2') + if tree.nodes.get("time_embedding"): + logging.info("Time embedding found...") + tree['time_embedding'].convert_name = 'time_embed' + dup_convert_name_recursive(tree['time_embedding.linear_1'], '0') + dup_convert_name_recursive(tree['time_embedding.linear_2'], '2') def convert_middle_keys(tree: SegTree): - '''middle block is fixed (resnet -> attention -> resnet)''' + ''' middle block is fixed (resnet -> attention -> resnet) ''' mid = tree['mid_block'] resnets = mid['resnets'] attns = mid['attentions'] @@ -293,12 +269,11 @@ def convert_middle_keys(tree: SegTree): map_resnet_block(resnets['1']) map_attention_block(attns['0']) - def convert_output_keys(hf_tree: SegTree): - '''output keys is similar to input keys''' + ''' output keys is similar to input keys ''' nemo_inp_blk = 0 up_blocks = hf_tree['up_blocks'] - up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=intkey) + up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=int) for downblockid in up_blocks_keys: block = up_blocks[str(downblockid)] @@ -307,8 +282,8 @@ def convert_output_keys(hf_tree: SegTree): attentions = block.nodes.get('attentions', SegTree()) upsamplers = block.nodes.get('upsamplers', SegTree()) - if len(attentions) == 0: # no attentions, this is a DownBlock2d - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + if len(attentions) == 0: # no attentions, this is a UpBlock2D + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) @@ -316,7 +291,7 @@ def convert_output_keys(hf_tree: SegTree): elif len(attentions) == len(resnets): # there are attention blocks here -- each resnet+attention becomes a block - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) @@ -326,20 +301,18 @@ def convert_output_keys(hf_tree: SegTree): else: logging.warning("number of attention blocks is not the same as resnets - whats going on?") - # if there is a downsampler, then also append it + # if there is a upsampler, then also append it if len(upsamplers) > 0: - # for k in upsamplers.nodes.keys(): nemo_inp_blk -= 1 - upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.2" + upsamplenum = 1 if len(attentions) == 0 else 2 # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD) + upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.{upsamplenum}" dup_convert_name_recursive(upsamplers['0.conv'], 'conv') nemo_inp_blk += 1 - def convert_finalout_keys(hf_tree: SegTree): dup_convert_name_recursive(hf_tree['conv_norm_out'], "out.0") dup_convert_name_recursive(hf_tree['conv_out'], "out.1") - def convert_encoder(hf_tree: SegTree): encoder = hf_tree['encoder'] encoder.convert_name = 'encoder' @@ -387,6 +360,7 @@ def convert_decoder(hf_tree: SegTree): decoder['mid_block'].convert_name = 'mid' dup_convert_name_recursive(decoder[f'mid_block.resnets.0'], 'block_1') dup_convert_name_recursive(decoder[f'mid_block.resnets.1'], 'block_2') + # attention blocks att = decoder['mid_block.attentions.0'] att.convert_name = 'attn_1' dup_convert_name_recursive(att['group_norm'], 'norm') @@ -395,7 +369,7 @@ def convert_decoder(hf_tree: SegTree): dup_convert_name_recursive(att['to_v'], 'v') dup_convert_name_recursive(att['to_out.0'], 'proj_out') - # up blocks contain resnets and upsamplers + # up blocks contain resnets and upsamplers decoder['up_blocks'].convert_name = 'up' num_up_blocks = len(decoder['up_blocks']) for upid, upblock in decoder['up_blocks'].nodes.items(): @@ -434,7 +408,7 @@ def convert(args): else: logging.error("incorrect model specification.") return - + # check mapping mapping = hf_to_nemo_mapping(hf_tree) if len(mapping) != len(hf_ckpt.keys()): @@ -443,10 +417,10 @@ def convert(args): for hf_key, nemo_key in mapping.items(): nemo_ckpt[nemo_key] = hf_ckpt[hf_key] + # save this torch.save(nemo_ckpt, args.output_path) logging.info(f"Saved nemo file to {args.output_path}") - if __name__ == '__main__': args = get_args() convert(args) From 95a8a981f53b135516d2248d410f3a918a9ced41 Mon Sep 17 00:00:00 2001 From: Rohit Jena Date: Wed, 10 Jul 2024 14:36:19 -0700 Subject: [PATCH 2/4] delete unnecessary comments and messages Signed-off-by: Rohit Jena --- .../multimodal/models/text_to_image/controlnet/controlnet.py | 2 -- .../modules/stable_diffusion/diffusionmodules/openaimodel.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py index cd2c2ca1c661..65e31b5343de 100644 --- a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py @@ -600,8 +600,6 @@ def load_from_unet(self, from_pretrained_unet, from_NeMo=True): ) print(f'There is {len(missing_keys)} total missing keys') print("Missing:", missing_keys) - for key in sorted(missing_keys): - print(key) print("Unexpected:", unexpected_keys) else: print("sd blocks loaded successfully") diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 3fb3a8cee3a2..7c845db6c066 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -1219,7 +1219,7 @@ def _state_key_mapping(self, state_dict: dict): def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False): state_dict = self._strip_unet_key_prefix(state_dict) if not from_NeMo: - print("creating state key mapping from HF") + logging.info("creating state key mapping from HF") state_dict = self._state_key_mapping(state_dict) state_dict = self._legacy_unet_ckpt_mapping(state_dict) From 1bd4a73d84eda23feb31eec10193ff69f9c38501 Mon Sep 17 00:00:00 2001 From: Rohit Jena Date: Tue, 30 Jul 2024 10:29:17 -0700 Subject: [PATCH 3/4] reinstate try-catch for loading missing modules Signed-off-by: Rohit Jena --- .../stable_diffusion/diffusionmodules/openaimodel.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 7c845db6c066..b94624b33ba2 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -1245,10 +1245,13 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from ): # GroupNormOpt fuses activation function to one layer, thus the indexing of weights are shifted for following for key_ in missing_keys: - s = key_.split('.') - idx = int(s[-2]) - new_key_ = ".".join(s[:-2] + [str(int(idx + 1))] + [s[-1]]) - state_dict[key_] = state_dict[new_key_] + try: + s = key_.split('.') + idx = int(s[-2]) + new_key_ = ".".join(s[:-2] + [str(int(idx + 1))] + [s[-1]]) + state_dict[key_] = state_dict[new_key_] + except: + continue loaded_keys = list(state_dict.keys()) missing_keys = list(set(expected_keys) - set(loaded_keys)) From feb713d40e770349b5172c46953f118cad6d4003 Mon Sep 17 00:00:00 2001 From: Victor49152 Date: Tue, 30 Jul 2024 18:40:15 +0000 Subject: [PATCH 4/4] Apply isort and black reformatting Signed-off-by: Victor49152 --- .../convert_stablediffusion_hf_to_nemo.py | 68 +++++++++++++------ 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py index c0a11eca6100..ff10dab4bc90 100644 --- a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py @@ -20,35 +20,42 @@ --output_path --model """ -import torch +import os +from argparse import ArgumentParser + import numpy as np import safetensors -from argparse import ArgumentParser import torch import torch.nn + from nemo.utils import logging -import os + def filter_keys(rule, dict): keys = list(dict.keys()) nd = {k: dict[k] for k in keys if rule(k)} return nd + def map_keys(rule, dict): new = {rule(k): v for k, v in dict.items()} return new + def split_name(name, dots=0): l = name.split(".") - return ".".join(l[:dots+1]), ".".join(l[dots+1:]) + return ".".join(l[: dots + 1]), ".".join(l[dots + 1 :]) + def is_prefix(shortstr, longstr): # is the first string a prefix of the second one return longstr == shortstr or longstr.startswith(shortstr + ".") + def numdots(str): return str.count(".") + class SegTree: def __init__(self): self.nodes = dict() @@ -58,10 +65,10 @@ def __init__(self): def __len__(self): return len(self.nodes) - + def is_leaf(self): return len(self.nodes) == 0 - + def add(self, name, val=0): prefix, subname = split_name(name) if subname == '': @@ -71,10 +78,10 @@ def add(self, name, val=0): if self.nodes.get(prefix) is None: self.nodes[prefix] = SegTree() self.nodes[prefix].add(subname, val) - + def change(self, name, val): self.add(name, val) - + def __getitem__(self, name: str): if hasattr(self, name): return getattr(self, name) @@ -95,6 +102,7 @@ def __getitem__(self, name: str): return self.nodes[prefix][substr] return val + def model_to_tree(model): keys = list(model.keys()) tree = SegTree() @@ -102,6 +110,7 @@ def model_to_tree(model): tree.add(k, "leaf") return tree + def get_args(): parser = ArgumentParser() parser.add_argument( @@ -115,20 +124,22 @@ def get_args(): parser.add_argument("--precision", type=str, default="32", help="Model precision") parser.add_argument("--model", type=str, default="unet", required=True, choices=['unet', 'vae']) parser.add_argument("--debug", action='store_true', help="Useful for debugging purposes.") - + args = parser.parse_args() return args + def load_hf_ckpt(in_dir, args): ckpt = {} assert os.path.isdir(in_dir), "Currently supports only directories with a safetensor file in it." with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f: for k in f.keys(): ckpt[k] = f.get_tensor(k) - return args, ckpt + return args, ckpt + def dup_convert_name_recursive(tree: SegTree, convert_name=None): - ''' inside this tree, convert all nodes recursively + '''inside this tree, convert all nodes recursively optionally, convert the name of the root as given by name (if not None) ''' if tree is None: @@ -139,6 +150,7 @@ def dup_convert_name_recursive(tree: SegTree, convert_name=None): for k, v in tree.nodes.items(): dup_convert_name_recursive(v, k) + def sanity_check(hf_tree, hf_unet, nemo_unet): # check if i'm introducing new keys for hfk, nk in hf_to_nemo_mapping(hf_tree).items(): @@ -147,8 +159,9 @@ def sanity_check(hf_tree, hf_unet, nemo_unet): if hfk not in hf_unet.keys(): logging.info(hfk) + def convert_input_keys(hf_tree: SegTree): - ''' map the input blocks of huggingface model ''' + '''map the input blocks of huggingface model''' # map `conv_in` to first input block dup_convert_name_recursive(hf_tree['conv_in'], 'input_blocks.0.0') @@ -163,7 +176,7 @@ def convert_input_keys(hf_tree: SegTree): attentions = block.nodes.get('attentions', SegTree()) downsamplers = block.nodes.get('downsamplers', SegTree()) - if len(attentions) == 0: # no attentions, this is a DownBlock2d + if len(attentions) == 0: # no attentions, this is a DownBlock2d for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" @@ -187,15 +200,18 @@ def convert_input_keys(hf_tree: SegTree): dup_convert_name_recursive(downsamplers[k]['conv'], 'op') nemo_inp_blk += 1 + def clean_convert_names(tree): tree.convert_name = None for k, v in tree.nodes.items(): clean_convert_names(v) + def map_attention_block(att_tree: SegTree): - ''' this HF tree can either be an AttentionBlock or a DualAttention block + '''this HF tree can either be an AttentionBlock or a DualAttention block currently assumed AttentionBlock ''' + # TODO(@rohitrango): Add check for dual attention block, but this works for both SD and SDXL def check_att_type(tree): return "att_block" @@ -221,8 +237,9 @@ def check_att_type(tree): else: logging.warning("failed to identify type of attention block here.") + def map_resnet_block(resnet_tree: SegTree): - ''' this HF tree is supposed to have all the keys for a resnet ''' + '''this HF tree is supposed to have all the keys for a resnet''' dup_convert_name_recursive(resnet_tree.nodes.get('time_emb_proj'), 'emb_layers.1') dup_convert_name_recursive(resnet_tree['norm1'], 'in_layers.0') dup_convert_name_recursive(resnet_tree['conv1'], 'in_layers.1') @@ -230,6 +247,7 @@ def map_resnet_block(resnet_tree: SegTree): dup_convert_name_recursive(resnet_tree['conv2'], 'out_layers.2') dup_convert_name_recursive(resnet_tree.nodes.get('conv_shortcut'), 'skip_connection') + def hf_to_nemo_mapping(tree: SegTree): mapping = {} for nodename, subtree in tree.nodes.items(): @@ -243,6 +261,7 @@ def hf_to_nemo_mapping(tree: SegTree): mapping[nodename + "." + k] = convert_name + v return mapping + def convert_cond_keys(tree: SegTree): # map all conditioning keys if tree.nodes.get("add_embedding"): @@ -256,8 +275,9 @@ def convert_cond_keys(tree: SegTree): dup_convert_name_recursive(tree['time_embedding.linear_1'], '0') dup_convert_name_recursive(tree['time_embedding.linear_2'], '2') + def convert_middle_keys(tree: SegTree): - ''' middle block is fixed (resnet -> attention -> resnet) ''' + '''middle block is fixed (resnet -> attention -> resnet)''' mid = tree['mid_block'] resnets = mid['resnets'] attns = mid['attentions'] @@ -269,8 +289,9 @@ def convert_middle_keys(tree: SegTree): map_resnet_block(resnets['1']) map_attention_block(attns['0']) + def convert_output_keys(hf_tree: SegTree): - ''' output keys is similar to input keys ''' + '''output keys is similar to input keys''' nemo_inp_blk = 0 up_blocks = hf_tree['up_blocks'] up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=int) @@ -282,7 +303,7 @@ def convert_output_keys(hf_tree: SegTree): attentions = block.nodes.get('attentions', SegTree()) upsamplers = block.nodes.get('upsamplers', SegTree()) - if len(attentions) == 0: # no attentions, this is a UpBlock2D + if len(attentions) == 0: # no attentions, this is a UpBlock2D for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" @@ -304,15 +325,19 @@ def convert_output_keys(hf_tree: SegTree): # if there is a upsampler, then also append it if len(upsamplers) > 0: nemo_inp_blk -= 1 - upsamplenum = 1 if len(attentions) == 0 else 2 # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD) + upsamplenum = ( + 1 if len(attentions) == 0 else 2 + ) # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD) upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.{upsamplenum}" dup_convert_name_recursive(upsamplers['0.conv'], 'conv') nemo_inp_blk += 1 + def convert_finalout_keys(hf_tree: SegTree): dup_convert_name_recursive(hf_tree['conv_norm_out'], "out.0") dup_convert_name_recursive(hf_tree['conv_out'], "out.1") + def convert_encoder(hf_tree: SegTree): encoder = hf_tree['encoder'] encoder.convert_name = 'encoder' @@ -369,7 +394,7 @@ def convert_decoder(hf_tree: SegTree): dup_convert_name_recursive(att['to_v'], 'v') dup_convert_name_recursive(att['to_out.0'], 'proj_out') - # up blocks contain resnets and upsamplers + # up blocks contain resnets and upsamplers decoder['up_blocks'].convert_name = 'up' num_up_blocks = len(decoder['up_blocks']) for upid, upblock in decoder['up_blocks'].nodes.items(): @@ -408,7 +433,7 @@ def convert(args): else: logging.error("incorrect model specification.") return - + # check mapping mapping = hf_to_nemo_mapping(hf_tree) if len(mapping) != len(hf_ckpt.keys()): @@ -421,6 +446,7 @@ def convert(args): torch.save(nemo_ckpt, args.output_path) logging.info(f"Saved nemo file to {args.output_path}") + if __name__ == '__main__': args = get_args() convert(args)