diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py index 7a25840e8a65..7e151699b38c 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py @@ -74,7 +74,11 @@ def main(cfg) -> None: n, c, h = cfg.model.micro_batch_size, cfg.model.channels, cfg.model.image_size x = torch.randn((n, c, h, h), dtype=torch.float32, device="cuda") t = torch.randint(77, (n,), device="cuda") - cc = torch.randn((n, 77, cfg.model.unet_config.context_dim), dtype=torch.float32, device="cuda",) + cc = torch.randn( + (n, 77, cfg.model.unet_config.context_dim), + dtype=torch.float32, + device="cuda", + ) if cfg.model.precision in [16, '16']: x = x.type(torch.float16) cc = cc.type(torch.float16) diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py index 722abb693c24..44412aee0d14 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py @@ -41,7 +41,10 @@ def _training_strategy(self) -> NLPDDPStrategy: _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) if _IS_INTERACTIVE and self.cfg.trainer.devices == 1: logging.info("Detected interactive environment, using NLPDDPStrategyNotebook") - return NLPDDPStrategyNotebook(no_ddp_communication_hook=True, find_unused_parameters=False,) + return NLPDDPStrategyNotebook( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) if self.cfg.model.get('fsdp', False): assert ( diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py index edd8ad9cff96..755588202ef0 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py @@ -119,7 +119,9 @@ def __init__(self, cfg, model_parallel_config): self._init_first_stage(first_stage_config) self.model_type = None - self.rng = torch.Generator(device=torch.cuda.current_device(),) + self.rng = torch.Generator( + device=torch.cuda.current_device(), + ) self.use_ema = False # TODO use_ema need to switch to NeMo style if self.use_ema: @@ -192,7 +194,12 @@ def training_step(self, batch, batch_idx): self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False) self.log( - "global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, ) if self.scheduler_config is not None: @@ -238,7 +245,11 @@ def configure_optimizers(self): scheduler = DiffusionEngine.from_config_dict(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ - {"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1,} + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } ] return [opt], scheduler return opt @@ -298,7 +309,14 @@ def set_input_tensor(self, input_tensor): pass @torch.no_grad() - def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: List[str] = None, **kwargs,) -> Dict: + def log_images( + self, + batch: Dict, + N: int = 8, + sample: bool = True, + ucg_keys: List[str] = None, + **kwargs, + ) -> Dict: conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] if ucg_keys: assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( @@ -312,7 +330,8 @@ def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: Lis x = self.get_input(batch) c, uc = self.conditioner.get_unconditional_conditioning( - batch, force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], + batch, + force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], ) sampling_kwargs = {} @@ -407,7 +426,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # handle asynchronous grad reduction no_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) # pipeline schedules will get these from self.model.config for module in self.get_module_list(): @@ -445,12 +467,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ self._optimizer.zero_grad() @@ -498,20 +520,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -524,12 +546,13 @@ def _append_sequence_parallel_module_grads(self, module, grads): def get_forward_output_and_loss_func(self): def process_batch(batch): - """ Prepares the global batch for apex fwd/bwd functions. - Global batch is a list of micro batches. + """Prepares the global batch for apex fwd/bwd functions. + Global batch is a list of micro batches. """ # SD has more dedicated structure for encoding, so we enable autocasting here as well with torch.cuda.amp.autocast( - self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): if self.model.precache_mode == 'both': x = batch[self.model.input_key].to(torch.cuda.current_device()) @@ -572,7 +595,7 @@ def validation_step(self, dataloader_iter, batch_idx): return loss def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -685,20 +708,23 @@ def setup_test_data(self, cfg): f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = torch.utils.data.DataLoader( - self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + self._test_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, ) def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py index 13d1196a156a..d79d85c2e026 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py @@ -358,6 +358,7 @@ def __init__( def _state_key_mapping(self, state_dict: dict): import re + res_dict = {} key_list = state_dict.keys() key_str = " ".join(key_list) @@ -397,7 +398,7 @@ def _state_key_mapping(self, state_dict: dict): res_dict[key_] = val_ return res_dict - def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo = False): + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False): if not from_NeMo: state_dict = self._state_key_mapping(state_dict) model_state_dict = self.state_dict() @@ -408,7 +409,10 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from unexpected_keys = list(set(loaded_keys) - set(expected_keys)) def _find_mismatched_keys( - state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: @@ -443,7 +447,10 @@ def _find_mismatched_keys( if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( - state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, ) error_msgs = self._load_state_dict_into_model(state_dict) return missing_keys, unexpected_keys, mismatched_keys, error_msgs diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 35eab9df25b3..eb449c5406b9 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -958,6 +958,7 @@ def __init__( logging.info(f"Attempting to load pretrained unet from {from_pretrained}") if from_pretrained.endswith('safetensors'): from safetensors.torch import load_file as load_safetensors + state_dict = load_safetensors(from_pretrained) else: state_dict = torch.load(from_pretrained, map_location='cpu') diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py index 1afec7a3b78b..bfae8790eeb2 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py @@ -47,7 +47,12 @@ def __init__( ): self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) - self.guider = instantiate_from_config(default(guider_config, DEFAULT_GUIDER,)) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) self.verbose = verbose self.device = device @@ -103,22 +108,22 @@ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0, sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(x) * self.s_noise - x = x + eps * append_dims(sigma_hat ** 2 - sigma ** 2, x.ndim) ** 0.5 + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) # this is the noise (e_t) d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) - euler_step = self.euler_step(x, d, dt) # this is x_{t-\delta{t}} + euler_step = self.euler_step(x, d, dt) # this is x_{t-\delta{t}} x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) if return_noise: return x, d return x - + def get_gamma(self, sigmas, num_sigmas, index): gamma = ( - min(self.s_churn / (num_sigmas - 1), 2 ** 0.5 - 1) if self.s_tmin <= sigmas[index] <= self.s_tmax else 0.0 + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[index] <= self.s_tmax else 0.0 ) return gamma @@ -128,9 +133,17 @@ def __call__(self, denoiser, x, cond, uc=None, num_steps=None): for i in self.get_sigma_gen(num_sigmas): gamma = self.get_gamma(sigmas, num_sigmas, i) - x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, gamma,) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) return x - + class AncestralSampler(SingleStepDiffusionSampler): def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): @@ -158,14 +171,24 @@ def __call__(self, denoiser, x, cond, uc=None, num_steps=None): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): - x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc,) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) return x class LinearMultistepSampler(BaseDiffusionSampler): def __init__( - self, order=4, *args, **kwargs, + self, + order=4, + *args, + **kwargs, ): super().__init__(*args, **kwargs) @@ -283,7 +306,15 @@ def get_mult(self, h, r, t, t_next, previous_sigma): return mult1, mult2 def sampler_step( - self, old_denoised, previous_sigma, sigma, next_sigma, denoiser, x, cond, uc=None, + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, ): denoised = self.denoise(x, denoiser, sigma, cond, uc) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py index 0299094cbfe1..24e2124e6f83 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py @@ -39,5 +39,9 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch. x = torch.cat((x, c.get("concat")), dim=1) return self.diffusion_model( - x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), **kwargs, + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, ) diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 21c37c73fbcb..5a01e8702a9e 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -23,11 +23,11 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import TorchElasticEnvironment from transformers import CLIPImageProcessor, SiglipImageProcessor -from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform +from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform from nemo.collections.multimodal.data.neva.neva_dataset import process_image from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel -from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector, NLPFSDPStrategy +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPFSDPStrategy, NLPSaveRestoreConnector from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import AppState, logging @@ -285,13 +285,13 @@ def setup_trainer_and_model_for_inference( else: logging.info("Using FSDP strategy.") strategy = NLPFSDPStrategy( - limit_all_gathers=cfg.model.get('fsdp_limit_all_gathers', True), - sharding_strategy=cfg.model.get('fsdp_sharding_strategy', 'full'), - cpu_offload=cfg.model.get('fsdp_cpu_offload', True), - grad_reduce_dtype=cfg.model.get('fsdp_grad_reduce_dtype', 32), - precision=cfg.trainer.precision, - # use_orig_params=cfg.model.inductor, - set_buffer_dtype=cfg.get('fsdp_set_buffer_dtype', None), + limit_all_gathers=cfg.model.get('fsdp_limit_all_gathers', True), + sharding_strategy=cfg.model.get('fsdp_sharding_strategy', 'full'), + cpu_offload=cfg.model.get('fsdp_cpu_offload', True), + grad_reduce_dtype=cfg.model.get('fsdp_grad_reduce_dtype', 32), + precision=cfg.trainer.precision, + # use_orig_params=cfg.model.inductor, + set_buffer_dtype=cfg.get('fsdp_set_buffer_dtype', None), ) # Set up the trainer with the specified plugins and strategy. diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index f2ecc516ba37..7b5d02c86bf7 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -355,7 +355,7 @@ def get_enabled_adapters(self) -> List[str]: # Skip the global adapter config if name == self.adapter_global_cfg_key: continue - + # If name is in the current available modules, and it is enabled in the config if name in available_module_names and self.adapter_cfg[name]['enabled']: # Check if type is supported (if available) and is an enabled adapter @@ -390,16 +390,15 @@ def get_adapter_module(self, name: str): if hasattr(self, "adapter_layer"): return self.adapter_layer[name] if name in self.adapter_layer else None return None - + def get_adapter_cfg(self, name: str): - """ Same logic as `get_adapter_module` but to get the config """ + """Same logic as `get_adapter_module` but to get the config""" _, name = self.resolve_adapter_module_name_(name) if hasattr(self, "adapter_cfg"): return self.adapter_cfg[name] if name in self.adapter_cfg else None return None - def set_accepted_adapter_types(self, adapter_types: List[Union[type, str]]) -> None: """ The module with this mixin can define a list of adapter names that it will accept. diff --git a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py index 0c49215f1ebb..67bc975708d0 100644 --- a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py @@ -20,37 +20,43 @@ --output_path """ -import torch -import numpy as np -import safetensors from argparse import ArgumentParser +import numpy as np +import safetensors import torch import torch.nn + from nemo.utils import logging intkey = lambda x: int(x) + def filter_keys(rule, dict): keys = list(dict.keys()) nd = {k: dict[k] for k in keys if rule(k)} return nd + def map_keys(rule, dict): new = {rule(k): v for k, v in dict.items()} return new + def split_name(name, dots=0): l = name.split(".") - return ".".join(l[:dots+1]), ".".join(l[dots+1:]) + return ".".join(l[: dots + 1]), ".".join(l[dots + 1 :]) + def is_prefix(shortstr, longstr): # is the first string a prefix of the second one return longstr == shortstr or longstr.startswith(shortstr + ".") + def numdots(str): return str.count(".") + class SegTree: def __init__(self): self.nodes = dict() @@ -60,10 +66,10 @@ def __init__(self): def __len__(self): return len(self.nodes) - + def is_leaf(self): return len(self.nodes) == 0 - + def add(self, name, val=0): prefix, subname = split_name(name) if subname == '': @@ -73,10 +79,10 @@ def add(self, name, val=0): if self.nodes.get(prefix) is None: self.nodes[prefix] = SegTree() self.nodes[prefix].add(subname, val) - + def change(self, name, val): self.add(name, val) - + def __getitem__(self, name: str): if hasattr(self, name): return getattr(self, name) @@ -97,6 +103,7 @@ def __getitem__(self, name: str): return self.nodes[prefix][substr] return val + def model_to_tree(model): keys = list(model.keys()) tree = SegTree() @@ -104,6 +111,7 @@ def model_to_tree(model): tree.add(k, "leaf") return tree + def get_args(): parser = ArgumentParser() parser.add_argument( @@ -117,24 +125,27 @@ def get_args(): parser.add_argument("--precision", type=str, default="32", help="Model precision") parser.add_argument("--model", type=str, default="unet", required=True, choices=['unet', 'vae']) parser.add_argument("--debug", action='store_true', help="Useful for debugging purposes.") - + args = parser.parse_args() return args + def make_tiny_config(config): - ''' dial down the config file to make things tractable ''' + '''dial down the config file to make things tractable''' # TODO return config + def load_hf_ckpt(in_dir, args): ckpt = {} with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f: for k in f.keys(): ckpt[k] = f.get_tensor(k) - return args, ckpt + return args, ckpt + def dup_convert_name_recursive(tree: SegTree, convert_name=None): - ''' inside this tree, convert all nodes recursively + '''inside this tree, convert all nodes recursively optionally, convert the name of the root as given by name (if not None) ''' if tree is None: @@ -145,6 +156,7 @@ def dup_convert_name_recursive(tree: SegTree, convert_name=None): for k, v in tree.nodes.items(): dup_convert_name_recursive(v, k) + def sanity_check(hf_tree, hf_unet, nemo_unet): # check if i'm introducing new keys for hfk, nk in hf_to_nemo_mapping(hf_tree).items(): @@ -153,8 +165,9 @@ def sanity_check(hf_tree, hf_unet, nemo_unet): if hfk not in hf_unet.keys(): print(hfk) + def convert_input_keys(hf_tree: SegTree): - ''' map the input blocks of huggingface model ''' + '''map the input blocks of huggingface model''' # map `conv_in` to first input block dup_convert_name_recursive(hf_tree['conv_in'], 'input_blocks.0.0') @@ -169,7 +182,7 @@ def convert_input_keys(hf_tree: SegTree): attentions = block.nodes.get('attentions', SegTree()) downsamplers = block.nodes.get('downsamplers', SegTree()) - if len(attentions) == 0: # no attentions, this is a DownBlock2d + if len(attentions) == 0: # no attentions, this is a DownBlock2d for resid in sorted(list(resnets.nodes.keys()), key=intkey): resid = str(resid) resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" @@ -194,16 +207,19 @@ def convert_input_keys(hf_tree: SegTree): dup_convert_name_recursive(downsamplers[k]['conv'], 'op') nemo_inp_blk += 1 + def clean_convert_names(tree): tree.convert_name = None for k, v in tree.nodes.items(): clean_convert_names(v) + def map_attention_block(att_tree: SegTree): - ''' this HF tree can either be an AttentionBlock or a DualAttention block + '''this HF tree can either be an AttentionBlock or a DualAttention block currently assumed AttentionBlock ''' + # TODO (rohit): Add check for dual attention block def check_att_type(tree): return "att_block" @@ -229,8 +245,9 @@ def check_att_type(tree): else: logging.warning("failed to identify type of attention block here.") + def map_resnet_block(resnet_tree: SegTree): - ''' this HF tree is supposed to have all the keys for a resnet ''' + '''this HF tree is supposed to have all the keys for a resnet''' dup_convert_name_recursive(resnet_tree.nodes.get('time_emb_proj'), 'emb_layers.1') dup_convert_name_recursive(resnet_tree['norm1'], 'in_layers.0') dup_convert_name_recursive(resnet_tree['conv1'], 'in_layers.1') @@ -238,6 +255,7 @@ def map_resnet_block(resnet_tree: SegTree): dup_convert_name_recursive(resnet_tree['conv2'], 'out_layers.2') dup_convert_name_recursive(resnet_tree.nodes.get('conv_shortcut'), 'skip_connection') + def hf_to_nemo_mapping(tree: SegTree): mapping = {} for nodename, subtree in tree.nodes.items(): @@ -251,6 +269,7 @@ def hf_to_nemo_mapping(tree: SegTree): mapping[nodename + "." + k] = convert_name + v return mapping + def convert_cond_keys(tree: SegTree): # map all conditioning keys tree['add_embedding'].convert_name = 'label_emb.0' @@ -260,8 +279,9 @@ def convert_cond_keys(tree: SegTree): dup_convert_name_recursive(tree['time_embedding.linear_1'], '0') dup_convert_name_recursive(tree['time_embedding.linear_2'], '2') + def convert_middle_keys(tree: SegTree): - ''' middle block is fixed (resnet -> attention -> resnet) ''' + '''middle block is fixed (resnet -> attention -> resnet)''' mid = tree['mid_block'] resnets = mid['resnets'] attns = mid['attentions'] @@ -273,8 +293,9 @@ def convert_middle_keys(tree: SegTree): map_resnet_block(resnets['1']) map_attention_block(attns['0']) + def convert_output_keys(hf_tree: SegTree): - ''' output keys is similar to input keys ''' + '''output keys is similar to input keys''' nemo_inp_blk = 0 up_blocks = hf_tree['up_blocks'] up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=intkey) @@ -286,7 +307,7 @@ def convert_output_keys(hf_tree: SegTree): attentions = block.nodes.get('attentions', SegTree()) upsamplers = block.nodes.get('upsamplers', SegTree()) - if len(attentions) == 0: # no attentions, this is a DownBlock2d + if len(attentions) == 0: # no attentions, this is a DownBlock2d for resid in sorted(list(resnets.nodes.keys()), key=intkey): resid = str(resid) resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" @@ -313,10 +334,12 @@ def convert_output_keys(hf_tree: SegTree): dup_convert_name_recursive(upsamplers['0.conv'], 'conv') nemo_inp_blk += 1 + def convert_finalout_keys(hf_tree: SegTree): dup_convert_name_recursive(hf_tree['conv_norm_out'], "out.0") dup_convert_name_recursive(hf_tree['conv_out'], "out.1") + def convert_encoder(hf_tree: SegTree): encoder = hf_tree['encoder'] encoder.convert_name = 'encoder' @@ -372,7 +395,7 @@ def convert_decoder(hf_tree: SegTree): dup_convert_name_recursive(att['to_v'], 'v') dup_convert_name_recursive(att['to_out.0'], 'proj_out') - # up blocks contain resnets and upsamplers + # up blocks contain resnets and upsamplers decoder['up_blocks'].convert_name = 'up' num_up_blocks = len(decoder['up_blocks']) for upid, upblock in decoder['up_blocks'].nodes.items(): @@ -411,7 +434,7 @@ def convert(args): else: logging.error("incorrect model specification.") return - + # check mapping mapping = hf_to_nemo_mapping(hf_tree) if len(mapping) != len(hf_ckpt.keys()): @@ -423,6 +446,7 @@ def convert(args): torch.save(nemo_ckpt, args.output_path) logging.info(f"Saved nemo file to {args.output_path}") + if __name__ == '__main__': args = get_args() - convert(args) \ No newline at end of file + convert(args)