diff --git a/train/SDXL/config.yaml b/train/SDXL/config.yaml new file mode 100644 index 0000000..c4587bc --- /dev/null +++ b/train/SDXL/config.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false + diff --git a/train/SDXL/dataset.py b/train/SDXL/dataset.py new file mode 100644 index 0000000..94141dd --- /dev/null +++ b/train/SDXL/dataset.py @@ -0,0 +1,264 @@ +import torch +import itertools +import json +import math +from typing import Iterable, List, Optional, Union +from torch.utils.data.dataloader import _collate_fn_t, _worker_init_fn_t + +import torchvision.transforms.functional as TF +from torchvision.transforms import RandomHorizontalFlip +from torch.utils.data import Dataset, Sampler, default_collate +from torchvision import transforms + +import webdataset as wds +from braceexpand import braceexpand +from webdataset.tariterators import ( + base_plus_ext, + tar_file_expander, + url_opener, + valid_sample, +) +import os + +from torch.utils.data import DataLoader +# from diffusers.training_utils import resolve_interpolation_mode + +import chardet +import random + +# Adjust for your dataset +WDS_JSON_WIDTH = "width" # original_width for LAION +WDS_JSON_HEIGHT = "height" # original_height for LAION +MIN_SIZE = 512 # ~960 for LAION, ideal: 1024 if the dataset contains large images + + +# from torchvision.transforms import ToPILImage +from diffusers.utils import make_image_grid + +def resolve_interpolation_mode(interpolation_type: str): + """ + Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The + full list of supported enums is documented at + https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode. + + Args: + interpolation_type (`str`): + A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`, + `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes + in torchvision. + + Returns: + `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` + transform. + """ + # if not is_torchvision_available(): + # raise ImportError( + # "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function." + # ) + + if interpolation_type == "bilinear": + interpolation_mode = transforms.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = transforms.InterpolationMode.BICUBIC + elif interpolation_type == "box": + interpolation_mode = transforms.InterpolationMode.BOX + elif interpolation_type == "nearest": + interpolation_mode = transforms.InterpolationMode.NEAREST + elif interpolation_type == "nearest_exact": + interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT + elif interpolation_type == "hamming": + interpolation_mode = transforms.InterpolationMode.HAMMING + elif interpolation_type == "lanczos": + interpolation_mode = transforms.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ) + + return interpolation_mode + + +def filter_keys(key_set): + def _f(dictionary): + return {k: v for k, v in dictionary.items() if k in key_set} + + return _f + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to + lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = {"__key__": prefix, "__url__": filesample["__url__"]} + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + +def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + + +class WebdatasetFilter: + def __init__(self, min_size=MIN_SIZE, max_pwatermark=0.5): + self.min_size = min_size + self.max_pwatermark = max_pwatermark + + def __call__(self, x): + try: + if "json" in x: + x_json = json.loads(x["json"]) + filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get( + WDS_JSON_HEIGHT, 0 + ) >= self.min_size + filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark + return filter_size and filter_watermark + else: + return False + except Exception: + return False + + +class SDXLText2ImageDataset: + def __init__( + self, + train_shards_path_or_url: Union[str, List[str]], + num_train_examples: int, + per_gpu_batch_size: int, + global_batch_size: int, + num_workers: int, + resolution: int = 1024, + interpolation_type: str = "bilinear", + shuffle_buffer_size: int = 1000, + pin_memory: bool = False, + persistent_workers: bool = False, + use_fix_crop_and_size: bool = False, + random_flip: bool = False, + + ): + + # train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] + if os.path.isdir(train_shards_path_or_url): + tar_list = os.listdir(train_shards_path_or_url) + temp_train_shards_path_or_url = [os.path.join(train_shards_path_or_url, f) for f in tar_list if f.endswith('.tar')] #and f.startswith('cog_')] + train_shards_path_or_url = temp_train_shards_path_or_url + # flatten list using itertools + # train_shards_path_or_url = list(itertools.chain.from_iterable(temp_train_shards_path_or_url)) + + def get_orig_size(json): + if use_fix_crop_and_size: + return (resolution, resolution) + else: + return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + + interpolation_mode = resolve_interpolation_mode(interpolation_type) + + def transform(example): + # resize image + image = example["image"] + image = TF.resize(image, resolution, interpolation=interpolation_mode) + + # random flip + if random_flip and random.random()<0.5: + image = TF.hflip(image) + example["orig_size"] = (image.size[1], image.size[0]) + # get crop coordinates and crop image + c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) + image = TF.crop(image, c_top, c_left, resolution, resolution) + + image = TF.to_tensor(image) + image = TF.normalize(image, [0.5], [0.5]) + + example["image"] = image + example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0) + return example + + def decode(example): + try: + text = example["text"] + + text_encoding = chardet.detect(text)["encoding"] + + example["text"] = text.decode(text_encoding) + #print(example["text"]) + except: + example['text'] = "" + return example + + processing_pipeline = [ + wds.decode("pil", handler=wds.ignore_and_continue), + wds.rename( + image="jpg;png;jpeg;webp", text="text;txt;caption;prompt", handler=wds.warn_and_stop + ), + wds.map(filter_keys({"image", "text"})), + wds.map(transform), + wds.map(decode), + wds.to_tuple("image", "text", "orig_size", "crop_coords"), + ] + + # Create train dataset and loader + pipeline = [ + wds.ResampledShards(train_shards_path_or_url), + tarfile_to_samples_nothrow, + # wds.select(WebdatasetFilter(min_size=MIN_SIZE)), + wds.shuffle(shuffle_buffer_size), + *processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), + ] + + num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + # each worker is iterating over this + self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) + + + + self._train_dataloader = wds.WebLoader( + self._train_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + + # add meta-data to dataloader instance for convenience + + self._train_dataloader.num_batches = num_batches + self._train_dataloader.num_samples = num_samples + + + @property + def train_dataset(self): + return self._train_dataset + + @property + def train_dataloader(self): + return self._train_dataloader + + diff --git a/train/SDXL/pcm_discriminator_sdxl.py b/train/SDXL/pcm_discriminator_sdxl.py new file mode 100644 index 0000000..4cf12e5 --- /dev/null +++ b/train/SDXL/pcm_discriminator_sdxl.py @@ -0,0 +1,539 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from typing import Union, Optional, Dict, Any, Tuple +from diffusers.utils import ( + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) + + +def modified_forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, +): + r""" + The [`UNet2DConditionModel`] forward method. + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(sample.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + emb = emb + aug_emb if aug_emb is not None else emb + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + if ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_proj" + ): + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_image_proj" + ): + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj( + encoder_hidden_states, image_embeds + ) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "image_proj" + ): + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "ip_image_proj" + ): + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + # 2.5 GLIGEN position net + if ( + cross_attention_kwargs is not None + and cross_attention_kwargs.get("gligen", None) is not None + ): + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + # 3. down + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + down_block_res_samples = (sample,) + + output_features = [] + + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, scale=lora_scale + ) + + output_features.append(sample) + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if ( + hasattr(self.mid_block, "has_cross_attention") + and self.mid_block.has_cross_attention + ): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + output_features.append(sample) + + return output_features # do not use up blocks to save memory + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + output_features.append(sample) + # 6. post-process + + return output_features + + +class DiscriminatorHead(nn.Module): + def __init__(self, input_channel, output_channel=1): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(input_channel, input_channel, 1, 1, 0), # 1x1 to save memory + nn.GroupNorm(32, input_channel), + nn.LeakyReLU(inplace=True), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(input_channel, input_channel, 1, 1, 0), + nn.GroupNorm(32, input_channel), + nn.LeakyReLU(inplace=True), + ) + + self.conv_out = nn.Conv2d(input_channel, output_channel, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x + x = self.conv_out(x) + return x + + +class Discriminator(nn.Module): + + def __init__( + self, + unet, + num_h_per_head=1, + adapter_channel_dims=[ + 320, + 640, + 1280, + 1280, + # 1280, + # 640, + # 320 + # do not use up blocks to save memory + ], + ): + super().__init__() + self.unet = unet + self.num_h_per_head = num_h_per_head + self.head_num = len(adapter_channel_dims) + self.heads = nn.ModuleList( + [ + nn.ModuleList( + [ + DiscriminatorHead(adapter_channel) + for _ in range(self.num_h_per_head) + ] + ) + for adapter_channel in adapter_channel_dims + ] + ) + + def _forward(self, sample, timestep, encoder_hidden_states, added_cond_kwargs): + features = modified_forward( + self.unet, + sample, + timestep, + encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + ) + assert self.head_num == len(features) + outputs = [] + for feature, head in zip(features, self.heads): + for h in head: + outputs.append(h(feature)) + return outputs + + def forward(self, flag, *args): + if flag == "d_loss": + return self.d_loss(*args) + elif flag == "g_loss": + return self.g_loss(*args) + else: + assert 0, "not supported" + + def d_loss( + self, + sample_fake, + sample_real, + timestep, + encoder_hidden_states, + added_cond_kwargs, + weight, + ): + loss = 0.0 + fake_outputs = self._forward( + sample_fake.detach(), timestep, encoder_hidden_states, added_cond_kwargs + ) + real_outputs = self._forward( + sample_real.detach(), timestep, encoder_hidden_states, added_cond_kwargs + ) + for fake_output, real_output in zip(fake_outputs, real_outputs): + loss += ( + torch.mean(weight * torch.relu(fake_output.float() + 1)) + + torch.mean(weight * torch.relu(1 - real_output.float())) + ) / (self.head_num * self.num_h_per_head) + return loss + + def g_loss( + self, sample_fake, timestep, encoder_hidden_states, added_cond_kwargs, weight + ): + loss = 0.0 + fake_outputs = self._forward( + sample_fake, timestep, encoder_hidden_states, added_cond_kwargs + ) + for fake_output in fake_outputs: + loss += torch.mean(weight * torch.relu(1 - fake_output.float())) / ( + self.head_num * self.num_h_per_head + ) + return loss + + def match_loss( + self, + sample_fake, + sample_real, + timestep, + encoder_hidden_states, + added_cond_kwargs, + weight, + ): + loss = 0.0 + features_fake = self._forward( + sample_fake, timestep, encoder_hidden_states, added_cond_kwargs + ) + with torch.no_grad(): + features_real = self._forward( + sample_real.detach(), timestep, encoder_hidden_states, added_cond_kwargs + ) + for feature_fake, feature_real in zip(features_fake, features_real): + loss += torch.mean((feature_fake - feature_real.detach()) ** 2) / ( + self.head_num * self.num_h_per_head + ) + return loss + + def feature_loss( + self, + sample_fake, + sample_real, + timestep, + encoder_hidden_states, + added_cond_kwargs, + weight, + ): + loss = 0.0 + features_fake = modified_forward( + self.unet, + sample_fake, + timestep, + encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + ) + features_real = modified_forward( + self.unet, + sample_real.detach(), + timestep, + encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + ) + for feature_fake, feature_real in zip(features_fake, features_real): + loss += torch.mean((feature_fake - feature_real.detach()) ** 2) / ( + self.head_num + ) + return loss + + +if __name__ == "__main__": + teacher_unet = UNet2DConditionModel.from_pretrained( + "stable-diffusion-xl-base-1.0", + subfolder="unet", + ) + teacher_unet.cuda() + discriminator = Discriminator(teacher_unet).cuda() + sample = torch.randn((1, 4, 128, 128)).cuda() + timestep = torch.randn((1,)).long().cuda() + prompt_embeds = torch.randn((1, 77, 2048)).cuda() + add_text_embeds = torch.randn((1, 1280)).cuda() + add_time_ids = torch.randn((1, 6)).cuda() + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + features = modified_forward( + discriminator.unet, + sample, + timestep, + prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ) + for feature in features: + print(feature.shape) \ No newline at end of file diff --git a/train/SDXL/pcm_scheduling_ddpm_modified.py b/train/SDXL/pcm_scheduling_ddpm_modified.py new file mode 100644 index 0000000..94052d7 --- /dev/null +++ b/train/SDXL/pcm_scheduling_ddpm_modified.py @@ -0,0 +1,663 @@ +# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, +) + + +@dataclass +class DDPMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDPMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`. + variance_type (`str`, defaults to `"fixed_small"`): + Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, + `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, + rescale_betas_zero_snr: int = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace( + beta_start, beta_end, num_train_timesteps, dtype=torch.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}" + ) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.custom_timesteps = False + self.num_inference_steps = None + self.timesteps = torch.from_numpy( + np.arange(0, num_train_timesteps)[::-1].copy() + ) + + self.variance_type = variance_type + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Optional[int] = None + ) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, + `num_inference_steps` must be `None`. + + """ + if num_inference_steps is not None and timesteps is not None: + raise ValueError( + "Can only pass one of `num_inference_steps` or `custom_timesteps`." + ) + + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + self.custom_timesteps = False + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace( + 0, self.config.num_train_timesteps - 1, num_inference_steps + ) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps) * step_ratio) + .round()[::-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round( + np.arange(self.config.num_train_timesteps, 0, -step_ratio) + ).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t, predicted_variance=None, variance_type=None): + prev_t = self.previous_timestep(t) + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t + + # we always take the log of variance, so clamp it to ensure it's not 0 + variance = torch.clamp(variance, min=1e-20) + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = variance + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = torch.log(variance) + variance = torch.exp(0.5 * variance) + elif variance_type == "fixed_large": + variance = current_beta_t + elif variance_type == "fixed_large_log": + # Glide max_log + variance = torch.log(current_beta_t) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = torch.log(variance) + max_log = torch.log(current_beta_t) + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[DDPMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + t = timestep + + prev_t = self.previous_timestep(t) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: + model_output, predicted_variance = torch.split( + model_output, sample.shape[1], dim=1 + ) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = ( + sample - beta_prod_t ** (0.5) * model_output + ) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - ( + beta_prod_t**0.5 + ) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = ( + alpha_prod_t_prev ** (0.5) * current_beta_t + ) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = ( + pred_original_sample_coeff * pred_original_sample + + current_sample_coeff * sample + ) + + # 6. Add noise + variance = 0 + if t > 0: + device = model_output.device + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=device, + dtype=model_output.dtype, + ) + if self.variance_type == "fixed_small_log": + variance = ( + self._get_variance(t, predicted_variance=predicted_variance) + * variance_noise + ) + elif self.variance_type == "learned_range": + variance = self._get_variance(t, predicted_variance=predicted_variance) + variance = torch.exp(0.5 * variance) * variance_noise + else: + variance = ( + self._get_variance(t, predicted_variance=predicted_variance) ** 0.5 + ) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return DDPMSchedulerOutput( + prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample + ) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = ( + sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + ) + return noisy_samples + + def noise_travel( + self, + current_samples: torch.FloatTensor, + noise: torch.FloatTensor, + current_timesteps: torch.IntTensor, + target_timesteps: torch.IntTensor, + ): + # assert current_timesteps < target_timesteps + alphas_cumprod = self.alphas_cumprod.to( + device=current_samples.device, dtype=current_samples.dtype + ) + target_timesteps = target_timesteps.to(current_samples.device) + current_timesteps = current_timesteps.to(current_samples.device) + alpha_prod_target = alphas_cumprod[target_timesteps] + alpha_prod_target = alpha_prod_target.flatten() + alpha_prod_current = alphas_cumprod[current_timesteps] + alpha_prod_current = alpha_prod_current.flatten() + + alpha_prod = alpha_prod_target / alpha_prod_current + + sqrt_alpha_prod = alpha_prod**0.5 + sqrt_one_minus_alpha_prod = (1 - alpha_prod) ** 0.5 + + while len(sqrt_alpha_prod.shape) < len(current_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + while len(sqrt_one_minus_alpha_prod.shape) < len(current_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = ( + sqrt_alpha_prod * current_samples + sqrt_one_minus_alpha_prod * noise + ) + + return noisy_samples + + def get_velocity( + self, + sample: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps + if self.num_inference_steps + else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t \ No newline at end of file diff --git a/train/SDXL/scheduler.py b/train/SDXL/scheduler.py new file mode 100644 index 0000000..3daea56 --- /dev/null +++ b/train/SDXL/scheduler.py @@ -0,0 +1,825 @@ +from diffusers import TCDScheduler, DPMSolverSinglestepScheduler +from diffusers.schedulers.scheduling_tcd import * +from diffusers.schedulers.scheduling_dpmsolver_singlestep import * + +class TDDScheduler(TCDScheduler): + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.3, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[TCDSchedulerOutput, Tuple]: + + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + assert 0 <= eta <= 1.0, "gamma must be less than or equal to 1.0" + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = torch.tensor(0) + + timestep_s = torch.floor((1 - eta) * prev_timestep).to(dtype=torch.long) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t_prev = 1 - alpha_prod_t_prev + + alpha_prod_s = self.alphas_cumprod[timestep_s] + beta_prod_s = 1 - alpha_prod_s + + # 3. Compute the predicted noised sample x_s based on the model parameterization + # xx + + + if self.step_index == 0: + self.buffer = [None] * 2 + if self.config.prediction_type == "epsilon": # noise-prediction + pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + pred_epsilon = model_output + pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon + elif self.config.prediction_type == "ode_dpmsolver_1": + lambda_t = torch.log(alpha_prod_t.sqrt() / beta_prod_t.sqrt()) + lambda_s = torch.log(alpha_prod_s.sqrt() / beta_prod_s.sqrt()) + h_t = lambda_s -lambda_t + pred_noised_sample = alpha_prod_s.sqrt() / alpha_prod_t.sqrt() * sample - beta_prod_s.sqrt() * (torch.expm1(h_t)) * model_output + elif self.config.prediction_type == "ode_dpmsolver_2M": + print("222") + self.buffer[0] = self.buffer[1] + self.buffer[-1] = model_output + # self.buffer.append(model_output) + if self.step_index > 0: + alpha_prod_t0 = self.alphas_cumprod[self.timesteps[self.step_index - 1]] + beta_prod_t0 = 1 - alpha_prod_t0 + + lambda_t_prev = torch.log(alpha_prod_t_prev.sqrt() / beta_prod_t_prev.sqrt()) + lambda_t = torch.log(alpha_prod_t.sqrt() / beta_prod_t.sqrt()) + lambda_t0 = torch.log(alpha_prod_t0.sqrt() / beta_prod_t0.sqrt()) + + else: + lambda_t_prev = torch.log(alpha_prod_t_prev.sqrt() / beta_prod_t_prev.sqrt()) + lambda_t = torch.log(alpha_prod_t.sqrt() / beta_prod_t.sqrt()) + lambda_t0 = torch.log(alpha_prod_t0.sqrt() / beta_prod_t0.sqrt()) + + mt, mt_prev = self.buffer[-1], self.buffer[-2] + h, h_0 = lambda_t_prev - lambda_t0, lambda_t - lambda_t0 + r0 = h_0 / h + D0, D1 = mt_prev, (1.0 / r0) * (mt - mt_prev) + pred_noised_sample = ( + alpha_prod_t_prev.sqrt() / alpha_prod_t.sqrt() * sample + - (alpha_prod_t_prev.sqrt() * (torch.expm1(h)) * D0) + - 0.5 * (alpha_prod_t_prev.sqrt() * torch.expm1(h)) * D1 + ) + + elif self.config.prediction_type == "ode_dpmsolver++_1": + lambda_t = torch.log(alpha_prod_t.sqrt() / beta_prod_t.sqrt()) + lambda_s = torch.log(alpha_prod_s.sqrt() / beta_prod_s.sqrt()) + h_t = lambda_s -lambda_t + pred_noised_sample = beta_prod_s.sqrt() / beta_prod_t.sqrt() * sample - alpha_prod_s.sqrt() * (torch.expm1(-h_t)) * model_output + + elif self.config.prediction_type == "sample": # x-prediction + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon + elif self.config.prediction_type == "v_prediction": # v-prediction + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `TCDScheduler`." + ) + + # 4. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + # Eta (referred to as "gamma" in the paper) was introduced to control the stochasticity in every step. + # When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling. + if eta > 0: + if self.step_index != self.num_inference_steps - 1: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=pred_noised_sample.dtype + ) + prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + ( + 1 - alpha_prod_t_prev / alpha_prod_s + ).sqrt() * noise + else: + prev_sample = pred_noised_sample + else: + prev_sample = pred_noised_sample + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample, pred_noised_sample) + + return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample) + +class TDDSchedulerPlus(DPMSolverSinglestepScheduler): + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + tdd_train_step: int = 50, + ): + self.tdd_train_step = tdd_train_step + if algorithm_type == "dpmsolver": + deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.sample = None + self.order_list = self.get_order_list(num_train_timesteps) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + self.num_inference_steps = num_inference_steps + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + #original_steps = self.config.original_inference_steps + if True: + original_steps=self.tdd_train_step + k = 1000 / original_steps + tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1 + else: + tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps)))) + # TCD Inference Steps Schedule + tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from tcd_origin_timesteps. + inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = tcd_origin_timesteps[inference_indices] + + # clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + # timesteps = ( + # np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) + # .round()[::-1][:-1] + # .copy() + # .astype(np.int64) + # ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + print("WWWwWWWWWwWW") + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas).to(device=device) + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + self.model_outputs = [None] * self.config.solver_order + self.sample = None + + if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: + logger.warning( + "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`." + ) + self.register_to_config(lower_order_final=True) + + if not self.config.lower_order_final and self.config.final_sigmas_type == "zero": + logger.warning( + " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." + ) + self.register_to_config(lower_order_final=True) + + self.order_list = self.get_order_list(num_inference_steps) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def set_timesteps_s(self, eta: float = 0.0): + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + num_inference_steps = self.num_inference_steps + device = self.timesteps.device + if True: + original_steps=self.tdd_train_step + k = 1000 / original_steps + tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1 + else: + tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps)))) + # TCD Inference Steps Schedule + tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from tcd_origin_timesteps. + inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = tcd_origin_timesteps[inference_indices] + #clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + # timesteps = ( + # np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) + # .round()[::-1][:-1] + # .copy() + # .astype(np.int64) + # ) + timesteps_s = np.floor((1 - eta) * timesteps).astype(np.int64) + + sigmas_s = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + print("have not write") + pass + else: + sigmas_s = np.interp(timesteps_s, np.arange(0, len(sigmas_s)), sigmas_s) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}" + ) + + sigmas_s = np.concatenate([sigmas_s, [sigma_last]]).astype(np.float32) + self.sigmas_s = torch.from_numpy(sigmas_s).to(device=device) + self.timesteps_s = torch.from_numpy(timesteps_s).to(device=device, dtype=torch.int64) + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + if self.step_index == 0: + self.set_timesteps_s(eta) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + order = self.order_list[self.step_index] + + # For img2img denoising might start with order>1 which is not possible + # In this case make sure that the first two steps are both order=1 + while self.model_outputs[-order] is None: + order -= 1 + + # For single-step solvers, we use the initial value at each time with order = 1. + if order == 1: + self.sample = sample + + prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order) + + if eta > 0: + if self.step_index != self.num_inference_steps - 1: + + alpha_prod_s = self.alphas_cumprod[self.timesteps_s[self.step_index + 1]] + alpha_prod_t_prev = self.alphas_cumprod[self.timesteps[self.step_index + 1]] + + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=prev_sample.dtype + ) + prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * prev_sample + ( + 1 - alpha_prod_t_prev / alpha_prod_s + ).sqrt() * noise + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s = self.sigmas_s[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + return x_t + + def singlestep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas_s[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m1, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s1) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s1) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + return x_t + + def singlestep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the + time `timestep_list[-3]`. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): + The current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m2 + D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) + D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s2) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s2) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + + def singlestep_dpm_solver_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + order: int = None, + **kwargs, + ) -> torch.FloatTensor: + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing `order` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if order == 1: + return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) + elif order == 2: + return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample) + elif order == 3: + return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) + else: + raise ValueError(f"Order must be 1, 2, 3, got {order}") + + def convert_model_output( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type == "dpmsolver++": + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) + + if self.step_index == 0: + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] + return model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) + + +class TDDSchedulerTest(TCDScheduler): + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.3, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[TCDSchedulerOutput, Tuple]: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + assert 0 <= eta <= 1.0, "gamma must be less than or equal to 1.0" + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = torch.tensor(0) + + timestep_s = torch.floor((1 - eta) * prev_timestep).to(dtype=torch.long) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + alpha_prod_s = self.alphas_cumprod[timestep_s] + beta_prod_s = 1 - alpha_prod_s + + # 3. Compute the predicted noised sample x_s based on the model parameterization + if self.config.prediction_type == "epsilon": # noise-prediction + pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + pred_epsilon = model_output + pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon + elif self.config.prediction_type == "sample": # x-prediction + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon + elif self.config.prediction_type == "v_prediction": # v-prediction + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `TCDScheduler`." + ) + + # 4. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + # Eta (referred to as "gamma" in the paper) was introduced to control the stochasticity in every step. + # When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling. + if eta > 0: + if self.step_index != self.num_inference_steps - 1: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=pred_noised_sample.dtype + ) + prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + ( + 1 - alpha_prod_t_prev / alpha_prod_s + ).sqrt() * noise + else: + prev_sample = pred_noised_sample + else: + prev_sample = pred_noised_sample + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample, pred_noised_sample) + + return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample) diff --git a/train/SDXL/train_tdd.py b/train/SDXL/train_tdd.py new file mode 100644 index 0000000..4ce05bd --- /dev/null +++ b/train/SDXL/train_tdd.py @@ -0,0 +1,1457 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import functools +import gc +import itertools +import json +import logging +import math +import os +import random +import shutil +from pathlib import Path +from PIL import Image + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo +from packaging import version +from peft import LoraConfig, get_peft_model, get_peft_model_state_dict +from torch.utils.data import default_collate +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig +from webdataset.tariterators import ( + base_plus_ext, + tar_file_expander, + url_opener, + valid_sample, +) + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DDIMScheduler, + TCDScheduler, + HeunDiscreteScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from dataset import SDXLText2ImageDataset + + +MAX_SEQ_LENGTH = 77 + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__) + + +def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"): + kohya_ss_state_dict = {} + for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items(): + kohya_key = peft_key.replace("base_model.model", prefix) + kohya_key = kohya_key.replace("lora_A", "lora_down") + kohya_key = kohya_key.replace("lora_B", "lora_up") + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) + kohya_ss_state_dict[kohya_key] = weight.to(dtype) + + # Set alpha parameter + if "lora_down" in kohya_key: + alpha_key = f'{kohya_key.split(".")[0]}.alpha' + kohya_ss_state_dict[alpha_key] = torch.tensor( + module.peft_config[adapter_name].lora_alpha + ).to(dtype) + + return kohya_ss_state_dict + + +@torch.no_grad() +def log_validation( + vae, unet, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, args, accelerator, weight_dtype, step, cfg, num_inference_step +): + logger.info("Running validation... ") + + unet = accelerator.unwrap_model(unet) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_teacher_model, + vae=vae, + text_encoder = text_encoder_one, + text_encoder_2 = text_encoder_two, + tokenizer = tokenizer_one, + tokenizer_2 = tokenizer_two, + scheduler=TCDScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + print("Loading LORA weights... ") + lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype) + pipeline.load_lora_weights(lora_state_dict) + pipeline.fuse_lora() + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + validation_prompts = [ + "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", + "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", + "A young woman stands at the edge of a dense forest, her long hair flowing in the breeze. In the distance, sunlight filters through the leaves, casting dappled shadows on the moss-covered ground.", + "On a grassy plain, a shepherd wearing a worn cloak gazes towards distant mountains. Beside him, his flock of sheep grazes leisurely, while a few white clouds drift across the distant sky.", + "In front of a small cabin by the lake, an elderly fisherman prepares his boat. The surface of the lake shimmers, and a white swan glides across, leaving behind a beautiful ripple.", + "Along a rugged mountain path, a young hiker with a backpack strides towards the summit. Surrounding him are dense pine forests and distant peaks.", + "In the courtyard of an ancient castle, a young couple strolls hand in hand. Vines cling to the ancient stone pillars, and fluffy white clouds drift across the sky.", + "In a golden wheat field, a farmer swings a sickle to harvest the crops. A small yellow sparrow perches on his shoulder, as if engaged in conversation.", + "By the river in a small town, a group of children frolics in the clear water. Willow trees sway gently on the bank, while a light breeze carries the scent of freshness.", + "Deep in the heart of a tropical rainforest, a cheetah prowls silently among the dense foliage. Sunlight filters through the leaves, creating dappled patterns on the forest floor.", + "Amidst ancient ruins, an archaeologist excavates buried treasures. Carved into the ancient pillars are mysterious symbols, seeming to recount tales of history.", + "At the summit of a snow-capped mountain, a young mountaineer stands, admiring the breathtaking vista. Snowflakes dance in the air, and the distant valley echoes with the chirping of birds.", + ] + + image_logs = [] + + for idx, prompt in enumerate(validation_prompts): + images = [] + with torch.autocast("cuda", dtype=weight_dtype): + images = pipeline( + prompt=prompt, + num_inference_steps=num_inference_step, + num_images_per_prompt=1, + height=args.resolution, + width=args.resolution, + generator=generator, + guidance_scale=cfg, + eta=0.2, + ).images + image_logs.append({"validation_prompt": prompt, "images": images, "idx": idx}) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + idx = int(log["idx"]) + images = log["images"] + validation_prompt = log["validation_prompt"] + formatted_images = [] + for image in images: + formatted_images.append(image) + + total_width = sum(img.width for img in formatted_images) + max_height = max(img.height for img in formatted_images) + + combined_image = Image.new('RGB', (total_width, max_height)) + + x_offset = 0 + for img in formatted_images: + combined_image.paste(img, (x_offset, 0)) + x_offset += img.width + + target_output_dir = os.path.join(args.output_dir, "samples") + os.makedirs(target_output_dir, exist_ok=True) + + combined_image.save(os.path.join(target_output_dir, f"{cfg}_{idx:05d}_{step:06d}.png")) + combined_np_img = np.array(combined_image) + combined_np_img = np.expand_dims(combined_np_img, axis=0) # Add batch dimension + tracker.writer.add_images( + f"{cfg}: " + validation_prompt, combined_np_img, step, dataformats="NHWC" + ) + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({f"validation-{cfg}": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): + c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) + c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + +def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): + if prediction_type == "epsilon": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "v_prediction": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError(f"Prediction type {prediction_type} currently not supported.") + + return pred_x_0 + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +@torch.no_grad() +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +# From LatentConsistencyModel.get_guidance_scale_embedding +def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +class DDIMSolver: + def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50, + num_inference_steps_min=4, num_inference_steps_max=8): + # DDIM sampling parameters + self.step_ratio = timesteps // ddim_timesteps + self.ddim_timesteps = ( + np.arange(1, ddim_timesteps + 1) * self.step_ratio + ).round().astype(np.int64) - 1 + self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] + self.ddim_timesteps_prev = np.asarray([0] + self.ddim_timesteps[:-1].tolist()) + self.ddim_alpha_cumprods_prev = np.asarray( + [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() + ) + # tdd parameters(DDPM) + self.tdd_alpha_cumprods = alpha_cumprods + self.merged_timesteps = self.set_timesteps_s(1000, num_inference_steps_min, num_inference_steps_max) + + # convert to torch tensors + self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() + self.ddim_timesteps_prev = torch.from_numpy(self.ddim_timesteps_prev).long() + self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) + self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) + # tdd parameters(DDPM) + self.tdd_alpha_cumprods = torch.from_numpy(self.tdd_alpha_cumprods) + self.merged_timesteps = torch.from_numpy(self.merged_timesteps) + + def to(self, device): + self.ddim_timesteps = self.ddim_timesteps.to(device) + self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device) + + self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) + self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) + self.tdd_alpha_cumprods = self.tdd_alpha_cumprods.to(device) + self.merged_timesteps = self.merged_timesteps.to(device) + return self + + def ddim_step(self, pred_x0, pred_noise, timestep_index): + alpha_cumprod_prev = extract_into_tensor( + self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape + ) + dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise + x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt + return x_prev + + def set_timesteps_s( + self, + num_train_timesteps=1000, + num_inference_steps_min=4, + num_inference_steps_max=8, + ): + merged_timesteps = [] + origin_timesteps = self.ddim_timesteps[::-1].copy() + for i in range(num_inference_steps_min, num_inference_steps_max + 1): + inference_indices = np.linspace(0, len(origin_timesteps), num=i, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = origin_timesteps[inference_indices] + merged_timesteps.append(timesteps) + merged_timesteps = np.concatenate(merged_timesteps) + merged_timesteps = np.unique(merged_timesteps) + merged_timesteps = np.sort(merged_timesteps) + merged_timesteps = np.insert(merged_timesteps, 0, 0)[:-1] + return merged_timesteps + + def select_s(self, timestep_index, s_ratio): + expanded_timestep_index = timestep_index.unsqueeze(1).expand( + -1, self.merged_timesteps.size(0) + ) + valid_indices_mask = expanded_timestep_index >= self.merged_timesteps + last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1) + last_valid_index = self.merged_timesteps.size(0) - 1 - last_valid_index + lower_bounds = torch.clamp(self.merged_timesteps[last_valid_index] - 260, min=0) #260 + upper_bounds = self.merged_timesteps[last_valid_index] + + valid_indices = [torch.where((self.merged_timesteps >= lb) & (self.merged_timesteps <= ub))[0] for lb, ub in zip(lower_bounds, upper_bounds)] + random_indices = [torch.randint(0, len(idx), (1,)).item() for idx in valid_indices] + timestep_index = torch.tensor([self.merged_timesteps[valid_indices[i][ri]] for i, ri in enumerate(random_indices)], device=expanded_timestep_index.device) + + timestep_index = torch.floor( + (1 - torch.rand(timestep_index.size(), device=timestep_index.device) * s_ratio) * timestep_index + ).to(dtype=torch.long) + return timestep_index + + def tdd_step(self, pred_x0, pred_noise, timestep_index_s): + alpha_cumprod_s = extract_into_tensor(self.tdd_alpha_cumprods, timestep_index_s, pred_x0.shape) + dir_xt = (1.0 - alpha_cumprod_s).sqrt() * pred_noise + x_prev_s = alpha_cumprod_s.sqrt() * pred_x0 + dir_xt + return x_prev_s + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder=subfolder, + revision=revision, + use_auth_token=True, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + # ----------Model Checkpoint Loading Arguments---------- + parser.add_argument( + "--pretrained_teacher_model", + type=str, + default=None, + required=True, + help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--teacher_revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM model identifier from huggingface.co/models.", + ) + # ----------Training Arguments---------- + # ----General Training Arguments---- + parser.add_argument( + "--output_dir", + type=str, + default="tdd-xl-distilled", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) + # ----Logging---- + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + # ----Checkpointing---- + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + # ----Image Processing---- + parser.add_argument( + "--train_shards_path_or_url", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--use_fix_crop_and_size", + action="store_true", + help="Whether or not to use the fixed crop and size for the teacher model.", + default=False, + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + # ----Dataloader---- + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=8, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + # ----Batch Size and Training Steps---- + parser.add_argument( + "--train_batch_size", + type=int, + default=16, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + # ----Learning Rate---- + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + # ----Optimizer (Adam)---- + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + # ----Diffusion Training Arguments---- + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + # ----Latent Consistency Distillation (LCD) Specific Arguments---- + parser.add_argument( + "--w_min", + type=float, + default=3.5, + required=False, + help=( + "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--w_max", + type=float, + default=3.5, + required=False, + help=( + "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--num_ddim_timesteps", + type=int, + default=50, + help="The number of timesteps to use for DDIM sampling.", + ) + parser.add_argument( + "--loss_type", + type=str, + default="l2", + choices=["l2", "huber"], + help="The type of loss to use for the LCD loss.", + ) + parser.add_argument( + "--huber_c", + type=float, + default=0.001, + help="The huber loss parameter. Only used if `--loss_type=huber`.", + ) + # ----Exponential Moving Average (EMA)---- + parser.add_argument( + "--ema_decay", + type=float, + default=0.95, + required=False, + help="The exponential moving average (EMA) rate or decay factor.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=64, + help="The rank of the LoRA projection matrix.", + ) + # ----Mixed Precision---- + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cast_teacher_unet", + action="store_true", + help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.", + ) + # ----Training Optimizations---- + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + # ----Distributed Training---- + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + # ----------Validation Arguments---------- + parser.add_argument( + "--validation_steps", + type=int, + default=200, + help="Run validation every X steps.", + ) + # ----------Huggingface Hub Arguments----------- + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + # ----------Accelerate Arguments---------- + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument("--val_infer_step", default=4, type=int) + parser.add_argument("--num_inference_steps_min", default=4, type=int) + parser.add_argument("--num_inference_steps_max", default=8, type=int) + + parser.add_argument("--s_ratio", default=0.3, type=float) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + return args + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True +): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed + accelerator.process_index) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + private=True, + ).repo_id + + # 1. Create the noise scheduler and the desired noise schedule. + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_teacher_model, + subfolder="scheduler", + revision=args.teacher_revision, + ) + + # The scheduler calculates the alpha and sigma schedule for us + alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + solver = DDIMSolver( + noise_scheduler.alphas_cumprod.numpy(), + timesteps=noise_scheduler.config.num_train_timesteps, + ddim_timesteps=args.num_ddim_timesteps, + num_inference_steps_min=args.num_inference_steps_min, + num_inference_steps_max=args.num_inference_steps_max, + ) + print(solver.merged_timesteps) + + # 2. Load tokenizers from SD-XL checkpoint. + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False + ) + + # 3. Load text encoders from SD-XL checkpoint. + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2" + ) + + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision + ) + + # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + vae_path = ( + args.pretrained_teacher_model + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.teacher_revision, + ) + + # 5. Load teacher U-Net from SD-XL checkpoint + teacher_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # 6. Freeze teacher vae, text_encoders, and teacher_unet + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + teacher_unet.requires_grad_(False) + + # 7. Create online (`unet`) student U-Nets. + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + unet.train() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=[ + "to_q", + "to_k", + "to_v", + "to_out.0", + "proj_in", + "proj_out", + "ff.net.0.proj", + "ff.net.2", + "conv1", + "conv2", + "conv_shortcut", + "downsamplers.0.conv", + "upsamplers.0.conv", + "time_emb_proj", + ], + ) + unet = get_peft_model(unet, lora_config) + + # 9. Handle mixed precision and device placement + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device) + if args.pretrained_vae_model_name_or_path is not None: + vae.to(dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Move teacher_unet to device, optionally cast to weight_dtype + teacher_unet.to(accelerator.device) + + if args.cast_teacher_unet: + teacher_unet.to(dtype=weight_dtype) + + # Also move the alpha and sigma noise schedules to accelerator.device. + alpha_schedule = alpha_schedule.to(accelerator.device) + sigma_schedule = sigma_schedule.to(accelerator.device) + solver = solver.to(accelerator.device) + + # 10. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + unet_ = accelerator.unwrap_model(unet) + lora_state_dict = get_peft_model_state_dict( + unet_, adapter_name="default" + ) + StableDiffusionXLPipeline.save_lora_weights( + os.path.join(output_dir, "unet_lora"), lora_state_dict + ) + # save weights in peft format to be able to load them back + unet_.save_pretrained(output_dir) + + for _, model in enumerate(models): + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + # load the LoRA into the model + unet_ = accelerator.unwrap_model(unet) + unet_.load_adapter(input_dir, "default", is_trainable=True) + + for _ in range(len(models)): + # pop models so that they are not loaded again + models.pop() + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # 11. Enable optimizations + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + teacher_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # 12. Optimizer creation + optimizer = optimizer_class( + filter(lambda p: p.requires_grad, unet.parameters()), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # 13. Dataset creation and data processing + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings( + prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True + ): + target_size = (args.resolution, args.resolution) + original_sizes = list(map(list, zip(*original_sizes))) + crops_coords_top_left = list(map(list, zip(*crop_coords))) + + original_sizes = torch.tensor(original_sizes, dtype=torch.long) + crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long) + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + dataset = SDXLText2ImageDataset( + train_shards_path_or_url=args.train_shards_path_or_url, + num_train_examples=args.max_train_samples, + per_gpu_batch_size=args.train_batch_size, + global_batch_size=args.train_batch_size * accelerator.num_processes, + num_workers=args.dataloader_num_workers, + resolution=args.resolution, + shuffle_buffer_size=1000, + pin_memory=True, + persistent_workers=True, + use_fix_crop_and_size=args.use_fix_crop_and_size, + random_flip=args.random_flip, + ) + train_dataloader = dataset.train_dataloader + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + + compute_embeddings_fn = functools.partial( + compute_embeddings, + proportion_empty_prompts=args.proportion_empty_prompts, + text_encoders=text_encoders, + tokenizers=tokenizers, + ) + + # 14. LR Scheduler creation + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + train_dataloader.num_batches / args.gradient_accumulation_steps + ) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # 15. Prepare for training + # Prepare everything with our `accelerator`. + unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Create uncond embeds for classifier free guidance + uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device) + uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device) + + # 16. Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num batches each epoch = {train_dataloader.num_batches}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + image, text, orig_size, crop_coords = batch + + image = image.to(accelerator.device, non_blocking=True) + encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) + + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = image.to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) + else: + pixel_values = image + + # encode pixel values with batch size of at most 8 + latents = [] + for i in range(0, pixel_values.shape[0], 8): + latents.append( + vae.encode(pixel_values[i : i + 8]).latent_dist.sample() + ) + latents = torch.cat(latents, dim=0) + + latents = latents * vae.config.scaling_factor + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps + index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() + + start_timesteps = solver.ddim_timesteps[index] + timesteps = start_timesteps - topk + timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) + + # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noisy_model_input = noise_scheduler.add_noise( + latents, noise, start_timesteps + ) + + # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w = w.reshape(bsz, 1, 1, 1) + w = w.to(device=latents.device, dtype=latents.dtype) + + # 20.4.8. Prepare prompt embeds and unet_added_conditions + prompt_embeds = encoded_text.pop("prompt_embeds") + + # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + noise_pred = unet( + noisy_model_input, + start_timesteps, + timestep_cond=None, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs=encoded_text, + ).sample + + pred_x_0 = predicted_origin( + noise_pred, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + timesteps_s = solver.select_s(timesteps, args.s_ratio) + model_pred = solver.tdd_step(pred_x_0, noise_pred, timesteps_s) + + # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after + # noisy_latents with both the conditioning embedding c and unconditional embedding 0 + # Get teacher model prediction on noisy_latents and conditional embedding + with torch.no_grad(): + with torch.autocast("cuda"): + cond_teacher_output = teacher_unet( + noisy_model_input.to(weight_dtype), + start_timesteps, + encoder_hidden_states=prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, + ).sample + cond_pred_x0 = predicted_origin( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # Get teacher model prediction on noisy_latents and unconditional embedding + uncond_added_conditions = copy.deepcopy(encoded_text) + uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + uncond_teacher_output = teacher_unet( + noisy_model_input.to(weight_dtype), + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, + ).sample + uncond_pred_x0 = predicted_origin( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_teacher_output + w * ( + cond_teacher_output - uncond_teacher_output + ) + x_prev = solver.ddim_step(pred_x0, pred_noise, index) + + # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + with torch.no_grad(): + with torch.autocast("cuda", enabled=True, dtype=weight_dtype): + target_noise_pred = unet( + x_prev.float(), + timesteps, + timestep_cond=None, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs=encoded_text, + ).sample + pred_x_0 = predicted_origin( + target_noise_pred, + timesteps, + x_prev, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + target = solver.tdd_step(pred_x_0, target_noise_pred, timesteps_s) + + # 20.4.13. Calculate loss + if args.loss_type == "l2": + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) + elif args.loss_type == "huber": + loss = torch.mean( + torch.sqrt( + (model_pred.float() - target.float()) ** 2 + args.huber_c**2 + ) + - args.huber_c + ) + + # 20.4.14. Backpropagate on the online student model (`unet`) + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + log_validation(vae, unet, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, args, accelerator, weight_dtype, global_step, + 1, + args.val_infer_step, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet.save_pretrained(args.output_dir) + lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") + StableDiffusionXLPipeline.save_lora_weights( + os.path.join(args.output_dir, "unet_lora"), lora_state_dict + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/train/SDXL/train_tdd_adv.py b/train/SDXL/train_tdd_adv.py new file mode 100644 index 0000000..4772548 --- /dev/null +++ b/train/SDXL/train_tdd_adv.py @@ -0,0 +1,1540 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import functools +import gc +import itertools +import json +import logging +import math +import os +import random +import shutil +from pathlib import Path +from PIL import Image + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo +from packaging import version +from peft import LoraConfig, get_peft_model, get_peft_model_state_dict +from torch.utils.data import default_collate +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig +from webdataset.tariterators import ( + base_plus_ext, + tar_file_expander, + url_opener, + valid_sample, +) + +import diffusers +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + TCDScheduler, + HeunDiscreteScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from pcm_scheduling_ddpm_modified import DDPMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from dataset import SDXLText2ImageDataset +from pcm_discriminator_sdxl import Discriminator +from scheduler import TDDSchedulerPlus + +MAX_SEQ_LENGTH = 77 + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__) + + +def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"): + kohya_ss_state_dict = {} + for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items(): + kohya_key = peft_key.replace("base_model.model", prefix) + kohya_key = kohya_key.replace("lora_A", "lora_down") + kohya_key = kohya_key.replace("lora_B", "lora_up") + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) + kohya_ss_state_dict[kohya_key] = weight.to(dtype) + + # Set alpha parameter + if "lora_down" in kohya_key: + alpha_key = f'{kohya_key.split(".")[0]}.alpha' + kohya_ss_state_dict[alpha_key] = torch.tensor( + module.peft_config[adapter_name].lora_alpha + ).to(dtype) + + return kohya_ss_state_dict + + +@torch.no_grad() +def log_validation( + vae, unet, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, args, accelerator, weight_dtype, step, cfg, num_inference_step +): + logger.info("Running validation... ") + + unet = accelerator.unwrap_model(unet) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_teacher_model, + vae=vae, + text_encoder = text_encoder_one, + text_encoder_2 = text_encoder_two, + tokenizer = tokenizer_one, + tokenizer_2 = tokenizer_two, + scheduler=TDDSchedulerPlus.from_config(args.pretrained_teacher_model, subfolder="scheduler", algorithm_type="dpmsolver++", solver_order=1, tdd_train_step=args.num_ddim_timesteps), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + print("Loading LORA weights... ") + lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype) + pipeline.load_lora_weights(lora_state_dict) + pipeline.fuse_lora() + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + validation_prompts = [ + "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", + "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", + "A young woman stands at the edge of a dense forest, her long hair flowing in the breeze. In the distance, sunlight filters through the leaves, casting dappled shadows on the moss-covered ground.", + "On a grassy plain, a shepherd wearing a worn cloak gazes towards distant mountains. Beside him, his flock of sheep grazes leisurely, while a few white clouds drift across the distant sky.", + "In front of a small cabin by the lake, an elderly fisherman prepares his boat. The surface of the lake shimmers, and a white swan glides across, leaving behind a beautiful ripple.", + "Along a rugged mountain path, a young hiker with a backpack strides towards the summit. Surrounding him are dense pine forests and distant peaks.", + "In the courtyard of an ancient castle, a young couple strolls hand in hand. Vines cling to the ancient stone pillars, and fluffy white clouds drift across the sky.", + "In a golden wheat field, a farmer swings a sickle to harvest the crops. A small yellow sparrow perches on his shoulder, as if engaged in conversation.", + "By the river in a small town, a group of children frolics in the clear water. Willow trees sway gently on the bank, while a light breeze carries the scent of freshness.", + "Deep in the heart of a tropical rainforest, a cheetah prowls silently among the dense foliage. Sunlight filters through the leaves, creating dappled patterns on the forest floor.", + "Amidst ancient ruins, an archaeologist excavates buried treasures. Carved into the ancient pillars are mysterious symbols, seeming to recount tales of history.", + "At the summit of a snow-capped mountain, a young mountaineer stands, admiring the breathtaking vista. Snowflakes dance in the air, and the distant valley echoes with the chirping of birds.", + ] + + image_logs = [] + + for idx, prompt in enumerate(validation_prompts): + images = [] + with torch.autocast("cuda", dtype=weight_dtype): + images = pipeline( + prompt=prompt, + num_inference_steps=num_inference_step, + num_images_per_prompt=1, + height=args.resolution, + width=args.resolution, + generator=generator, + guidance_scale=cfg, + eta=0.2, + ).images + image_logs.append({"validation_prompt": prompt, "images": images, "idx": idx}) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + idx = int(log["idx"]) + images = log["images"] + validation_prompt = log["validation_prompt"] + formatted_images = [] + for image in images: + formatted_images.append(image) + + total_width = sum(img.width for img in formatted_images) + max_height = max(img.height for img in formatted_images) + + combined_image = Image.new('RGB', (total_width, max_height)) + + x_offset = 0 + for img in formatted_images: + combined_image.paste(img, (x_offset, 0)) + x_offset += img.width + + target_output_dir = os.path.join(args.output_dir, "samples") + os.makedirs(target_output_dir, exist_ok=True) + + combined_image.save(os.path.join(target_output_dir, f"{cfg}_{idx:05d}_{step:06d}.png")) + combined_np_img = np.array(combined_image) + combined_np_img = np.expand_dims(combined_np_img, axis=0) # Add batch dimension + tracker.writer.add_images( + f"{cfg}: " + validation_prompt, combined_np_img, step, dataformats="NHWC" + ) + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({f"validation-{cfg}": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): + c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) + c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + +def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): + if prediction_type == "epsilon": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "v_prediction": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError(f"Prediction type {prediction_type} currently not supported.") + + return pred_x_0 + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +@torch.no_grad() +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +# From LatentConsistencyModel.get_guidance_scale_embedding +def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +class DDIMSolver: + def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50, + num_inference_steps_min=4, num_inference_steps_max=8): + # DDIM sampling parameters + self.step_ratio = timesteps // ddim_timesteps + self.ddim_timesteps = ( + np.arange(1, ddim_timesteps + 1) * self.step_ratio + ).round().astype(np.int64) - 1 + self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] + self.ddim_timesteps_prev = np.asarray([0] + self.ddim_timesteps[:-1].tolist()) + self.ddim_alpha_cumprods_prev = np.asarray( + [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() + ) + # tdd parameters(DDPM) + self.tdd_alpha_cumprods = alpha_cumprods + self.merged_timesteps = self.set_timesteps_s(1000, num_inference_steps_min, num_inference_steps_max) + + # convert to torch tensors + self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() + self.ddim_timesteps_prev = torch.from_numpy(self.ddim_timesteps_prev).long() + self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) + self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) + # tdd parameters(DDPM) + self.tdd_alpha_cumprods = torch.from_numpy(self.tdd_alpha_cumprods) + self.merged_timesteps = torch.from_numpy(self.merged_timesteps) + + def to(self, device): + self.ddim_timesteps = self.ddim_timesteps.to(device) + self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device) + + self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) + self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) + self.tdd_alpha_cumprods = self.tdd_alpha_cumprods.to(device) + self.merged_timesteps = self.merged_timesteps.to(device) + return self + + def ddim_step(self, pred_x0, pred_noise, timestep_index): + alpha_cumprod_prev = extract_into_tensor( + self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape + ) + dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise + x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt + return x_prev + + def set_timesteps_s( + self, + num_train_timesteps=1000, + num_inference_steps_min=4, + num_inference_steps_max=8, + ): + merged_timesteps = [] + origin_timesteps = self.ddim_timesteps[::-1].copy() + for i in range(num_inference_steps_min, num_inference_steps_max + 1): + inference_indices = np.linspace(0, len(origin_timesteps), num=i, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = origin_timesteps[inference_indices] + merged_timesteps.append(timesteps) + merged_timesteps = np.concatenate(merged_timesteps) + merged_timesteps = np.unique(merged_timesteps) + merged_timesteps = np.sort(merged_timesteps) + merged_timesteps = np.insert(merged_timesteps, 0, 0)[:-1] + return merged_timesteps + + def select_s(self, timestep_index, s_ratio): + expanded_timestep_index = timestep_index.unsqueeze(1).expand( + -1, self.merged_timesteps.size(0) + ) + valid_indices_mask = expanded_timestep_index >= self.merged_timesteps + last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1) + last_valid_index = self.merged_timesteps.size(0) - 1 - last_valid_index + lower_bounds = torch.clamp(self.merged_timesteps[last_valid_index] - 260, min=0) #260 + upper_bounds = self.merged_timesteps[last_valid_index] + + valid_indices = [torch.where((self.merged_timesteps >= lb) & (self.merged_timesteps <= ub))[0] for lb, ub in zip(lower_bounds, upper_bounds)] + random_indices = [torch.randint(0, len(idx), (1,)).item() for idx in valid_indices] + timestep_index = torch.tensor([self.merged_timesteps[valid_indices[i][ri]] for i, ri in enumerate(random_indices)], device=expanded_timestep_index.device) + + timestep_index = torch.floor( + (1 - torch.rand(timestep_index.size(), device=timestep_index.device) * s_ratio) * timestep_index + ).to(dtype=torch.long) + return timestep_index + + def tdd_step(self, pred_x0, pred_noise, timestep_index_s): + alpha_cumprod_s = extract_into_tensor(self.tdd_alpha_cumprods, timestep_index_s, pred_x0.shape) + dir_xt = (1.0 - alpha_cumprod_s).sqrt() * pred_noise + x_prev_s = alpha_cumprod_s.sqrt() * pred_x0 + dir_xt + return x_prev_s + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder=subfolder, + revision=revision, + use_auth_token=True, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + # ----------Model Checkpoint Loading Arguments---------- + parser.add_argument( + "--pretrained_teacher_model", + type=str, + default=None, + required=True, + help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--teacher_revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM model identifier from huggingface.co/models.", + ) + # ----------Training Arguments---------- + # ----General Training Arguments---- + parser.add_argument( + "--output_dir", + type=str, + default="tdd-xl-distilled", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) + # ----Logging---- + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + # ----Checkpointing---- + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + # ----Image Processing---- + parser.add_argument( + "--train_shards_path_or_url", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--use_fix_crop_and_size", + action="store_true", + help="Whether or not to use the fixed crop and size for the teacher model.", + default=False, + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + # ----Dataloader---- + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=8, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + # ----Batch Size and Training Steps---- + parser.add_argument( + "--train_batch_size", + type=int, + default=16, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + # ----Learning Rate---- + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + # ----Optimizer (Adam)---- + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + # ----Diffusion Training Arguments---- + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + # ----Latent Consistency Distillation (LCD) Specific Arguments---- + parser.add_argument( + "--w_min", + type=float, + default=3.5, + required=False, + help=( + "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--w_max", + type=float, + default=3.5, + required=False, + help=( + "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--num_ddim_timesteps", + type=int, + default=50, + help="The number of timesteps to use for DDIM sampling.", + ) + parser.add_argument( + "--loss_type", + type=str, + default="l2", + choices=["l2", "huber"], + help="The type of loss to use for the LCD loss.", + ) + parser.add_argument( + "--huber_c", + type=float, + default=0.001, + help="The huber loss parameter. Only used if `--loss_type=huber`.", + ) + # ----Exponential Moving Average (EMA)---- + parser.add_argument( + "--ema_decay", + type=float, + default=0.95, + required=False, + help="The exponential moving average (EMA) rate or decay factor.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=64, + help="The rank of the LoRA projection matrix.", + ) + # ----Mixed Precision---- + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cast_teacher_unet", + action="store_true", + help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.", + ) + # ----Training Optimizations---- + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + # ----Distributed Training---- + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + # ----------Validation Arguments---------- + parser.add_argument( + "--validation_steps", + type=int, + default=200, + help="Run validation every X steps.", + ) + # ----------Huggingface Hub Arguments----------- + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + # ----------Accelerate Arguments---------- + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument("--val_infer_step", default=4, type=int) + parser.add_argument("--num_inference_steps_min", default=4, type=int) + parser.add_argument("--num_inference_steps_max", default=8, type=int) + + parser.add_argument("--s_ratio", default=0.3, type=float) + parser.add_argument("--adv_weight", default=0.1, type=float) + parser.add_argument("--adv_lr", default=1e-5, type=float) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + return args + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True +): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed + accelerator.process_index) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + private=True, + ).repo_id + + # 1. Create the noise scheduler and the desired noise schedule. + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_teacher_model, + subfolder="scheduler", + revision=args.teacher_revision, + ) + + # The scheduler calculates the alpha and sigma schedule for us + alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + solver = DDIMSolver( + noise_scheduler.alphas_cumprod.numpy(), + timesteps=noise_scheduler.config.num_train_timesteps, + ddim_timesteps=args.num_ddim_timesteps, + num_inference_steps_min=args.num_inference_steps_min, + num_inference_steps_max=args.num_inference_steps_max, + ) + print(solver.merged_timesteps) + + # 2. Load tokenizers from SD-XL checkpoint. + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False + ) + + # 3. Load text encoders from SD-XL checkpoint. + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2" + ) + + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision + ) + + # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + vae_path = ( + args.pretrained_teacher_model + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.teacher_revision, + ) + + # 5. Load teacher U-Net from SD-XL checkpoint + teacher_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + + discriminator = Discriminator(teacher_unet) + + discriminator.unet.requires_grad_(False) + teacher_unet.requires_grad_(False) + discriminator_params = [] + for param in discriminator.heads.parameters(): + param.requires_grad = True + discriminator_params.append(param) + + # 6. Freeze teacher vae, text_encoders, and teacher_unet + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + # 7. Create online (`unet`) student U-Nets. + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + unet.train() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=[ + "to_q", + "to_k", + "to_v", + "to_out.0", + "proj_in", + "proj_out", + "ff.net.0.proj", + "ff.net.2", + "conv1", + "conv2", + "conv_shortcut", + "downsamplers.0.conv", + "upsamplers.0.conv", + "time_emb_proj", + ], + ) + unet = get_peft_model(unet, lora_config) + + # 9. Handle mixed precision and device placement + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device) + if args.pretrained_vae_model_name_or_path is not None: + vae.to(dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Also move the alpha and sigma noise schedules to accelerator.device. + alpha_schedule = alpha_schedule.to(accelerator.device) + sigma_schedule = sigma_schedule.to(accelerator.device) + solver = solver.to(accelerator.device) + + # 10. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + unet_ = accelerator.unwrap_model(unet) + lora_state_dict = get_peft_model_state_dict( + unet_, adapter_name="default" + ) + StableDiffusionXLPipeline.save_lora_weights( + os.path.join(output_dir, "unet_lora"), lora_state_dict + ) + # save weights in peft format to be able to load them back + unet_.save_pretrained(output_dir) + + for _, model in enumerate(models): + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + # load the LoRA into the model + unet_ = accelerator.unwrap_model(unet) + unet_.load_adapter(input_dir, "default", is_trainable=True) + + for _ in range(len(models)): + # pop models so that they are not loaded again + models.pop() + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # 11. Enable optimizations + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + teacher_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + teacher_unet.enable_gradient_checkpointing() + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # 12. Optimizer creation + optimizer = optimizer_class( + filter(lambda p: p.requires_grad, unet.parameters()), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + discriminator_optimizer = optimizer_class( + discriminator_params, + lr=args.adv_lr, + betas=(0.0, 0.999), + weight_decay=1e-3, + eps=args.adam_epsilon, + ) + + # 13. Dataset creation and data processing + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + @torch.no_grad() + def compute_embeddings( + prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True + ): + target_size = (args.resolution, args.resolution) + original_sizes = list(map(list, zip(*original_sizes))) + crops_coords_top_left = list(map(list, zip(*crop_coords))) + + original_sizes = torch.tensor(original_sizes, dtype=torch.long) + crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long) + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + dataset = SDXLText2ImageDataset( + train_shards_path_or_url=args.train_shards_path_or_url, + num_train_examples=args.max_train_samples, + per_gpu_batch_size=args.train_batch_size, + global_batch_size=args.train_batch_size * accelerator.num_processes, + num_workers=args.dataloader_num_workers, + resolution=args.resolution, + shuffle_buffer_size=1000, + pin_memory=True, + persistent_workers=True, + use_fix_crop_and_size=args.use_fix_crop_and_size, + random_flip=args.random_flip, + ) + train_dataloader = dataset.train_dataloader + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + + compute_embeddings_fn = functools.partial( + compute_embeddings, + proportion_empty_prompts=args.proportion_empty_prompts, + text_encoders=text_encoders, + tokenizers=tokenizers, + ) + + # 14. LR Scheduler creation + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + train_dataloader.num_batches / args.gradient_accumulation_steps + ) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # 15. Prepare for training + # Prepare everything with our `accelerator`. + unet, discriminator, optimizer, discriminator_optimizer, lr_scheduler = accelerator.prepare(unet, discriminator, optimizer, discriminator_optimizer, lr_scheduler) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Create uncond embeds for classifier free guidance + uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device) + uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device) + + # 16. Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num batches each epoch = {train_dataloader.num_batches}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + image, text, orig_size, crop_coords = batch + + image = image.to(accelerator.device, non_blocking=True) + encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) + + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = image.to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) + else: + pixel_values = image + + # encode pixel values with batch size of at most 8 + latents = [] + for i in range(0, pixel_values.shape[0], 8): + latents.append( + vae.encode(pixel_values[i : i + 8]).latent_dist.sample() + ) + latents = torch.cat(latents, dim=0) + + latents = latents * vae.config.scaling_factor + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps + index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() + + start_timesteps = solver.ddim_timesteps[index] + timesteps = start_timesteps - topk + timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) + + # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noisy_model_input = noise_scheduler.add_noise( + latents, noise, start_timesteps + ) + + # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w = w.reshape(bsz, 1, 1, 1) + w = w.to(device=latents.device, dtype=latents.dtype) + + # 20.4.8. Prepare prompt embeds and unet_added_conditions + prompt_embeds = encoded_text.pop("prompt_embeds") + + # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + noise_pred = unet( + noisy_model_input, + start_timesteps, + timestep_cond=None, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs=encoded_text, + ).sample + + pred_x_0 = predicted_origin( + noise_pred, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + timesteps_s = solver.select_s(timesteps, args.s_ratio) + model_pred = solver.tdd_step(pred_x_0, noise_pred, timesteps_s) + + # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after + # noisy_latents with both the conditioning embedding c and unconditional embedding 0 + # Get teacher model prediction on noisy_latents and conditional embedding + with torch.no_grad(): + with torch.autocast("cuda"): + cond_teacher_output = teacher_unet( + noisy_model_input.to(weight_dtype), + start_timesteps, + encoder_hidden_states=prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, + ).sample + cond_pred_x0 = predicted_origin( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # Get teacher model prediction on noisy_latents and unconditional embedding + uncond_added_conditions = copy.deepcopy(encoded_text) + uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + uncond_teacher_output = teacher_unet( + noisy_model_input.to(weight_dtype), + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, + ).sample + uncond_pred_x0 = predicted_origin( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_teacher_output + w * ( + cond_teacher_output - uncond_teacher_output + ) + x_prev = solver.ddim_step(pred_x0, pred_noise, index) + + # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + with torch.no_grad(): + with torch.autocast("cuda", enabled=True, dtype=weight_dtype): + target_noise_pred = unet( + x_prev.float(), + timesteps, + timestep_cond=None, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs=encoded_text, + ).sample + pred_x_0 = predicted_origin( + target_noise_pred, + timesteps, + x_prev, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + target = solver.tdd_step(pred_x_0, target_noise_pred, timesteps_s) + + # 20.4.13. Calculate loss + gan_timesteps = torch.empty_like(timesteps_s) + for i in range(timesteps_s.size(0)): + gan_timesteps[i] = torch.randint( + timesteps_s[i].item(), + min(timesteps_s[i].item() + noise_scheduler.config.num_train_timesteps // args.num_inference_steps_min, + noise_scheduler.config.num_train_timesteps), + (1,), + dtype=timesteps_s.dtype, + device=timesteps_s.device, + ) + real_gan = noise_scheduler.noise_travel( + target, torch.randn_like(latents), timesteps_s, gan_timesteps + ) + fake_gan = noise_scheduler.noise_travel( + model_pred, torch.randn_like(latents), timesteps_s, gan_timesteps + ) + + if global_step % 2 == 0: + discriminator_optimizer.zero_grad(set_to_none=True) + loss = discriminator( + "d_loss", + fake_gan.float(), + real_gan.float(), + gan_timesteps, + prompt_embeds.float(), + encoded_text, + 1, + ) + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + discriminator.parameters(), args.max_grad_norm + ) + discriminator_optimizer.step() + discriminator_optimizer.zero_grad(set_to_none=True) + else: + if args.loss_type == "l2": + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) + elif args.loss_type == "huber": + loss = torch.mean( + torch.sqrt( + (model_pred.float() - target.float()) ** 2 + args.huber_c**2 + ) + - args.huber_c + ) + + g_loss = args.adv_weight * discriminator( + "g_loss", + fake_gan.float(), + gan_timesteps, + prompt_embeds.float(), + encoded_text, + 1, + ) + loss += g_loss + + # 20.4.14. Backpropagate on the online student model (`unet`) + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + log_validation(vae, unet, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, args, accelerator, weight_dtype, global_step, + 1, + args.multiphase, + ) + log_validation(vae, unet, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, args, accelerator, weight_dtype, global_step, + 2.0, + args.multiphase, + ) + if args.not_apply_cfg_solver: + log_validation(vae, unet, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, args, accelerator, weight_dtype, global_step, + 7.5, + args.multiphase, + ) + if (global_step - 1) % 2 == 0: + logs = { + "d_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + else: + logs = { + "loss_cm": loss.detach().item() - g_loss.detach().item(), + "g_loss": g_loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet.save_pretrained(args.output_dir) + lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") + StableDiffusionXLPipeline.save_lora_weights( + os.path.join(args.output_dir, "unet_lora"), lora_state_dict + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/train/SDXL/train_tdd_adv_h800.sh b/train/SDXL/train_tdd_adv_h800.sh new file mode 100644 index 0000000..0fb4014 --- /dev/null +++ b/train/SDXL/train_tdd_adv_h800.sh @@ -0,0 +1,35 @@ +export TRAIN_SHARDS_PATH_OR_URL="/mnt/dataset/laion_6plus" +export PRETRAINED_TEACHER_MODEL="./stable-diffusion-xl-base-1.0" +export PRETRAINED_VAE_MODEL_NAME_OR_PATH="./sdxl-vae-fp16-fix" +accelerate launch --config_file=config.yaml train_tdd_adv.py \ + --pretrained_teacher_model=$PRETRAINED_TEACHER_MODEL \ + --pretrained_vae_model_name_or_path=$PRETRAINED_VAE_MODEL_NAME_OR_PATH \ + --train_shards_path_or_url=$TRAIN_SHARDS_PATH_OR_URL \ + --output_dir="result/TDD_uc0.2_etas0.3_ddim250_adv" \ + --seed=453645634 \ + --resolution=1024 \ + --max_train_samples=4000000 \ + --max_train_steps=100000 \ + --train_batch_size=14 \ + --dataloader_num_workers=32 \ + --gradient_accumulation_steps=4 \ + --checkpointing_steps=5000 \ + --validation_steps=500 \ + --learning_rate=2e-06 \ + --lora_rank=64 \ + --w_max=3.5 \ + --w_min=3.5 \ + --mixed_precision="fp16" \ + --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \ + --val_infer_step=4 \ + --gradient_checkpointing \ + --num_ddim_timesteps=250 \ + --proportion_empty_prompts=0.2 \ + --num_inference_steps_min=4 \ + --num_inference_steps_max=8 \ + --s_ratio=0.3 \ + --adv_lr=1e-5 \ + --adv_weight=0.1 \ + +# cd /mnt/nj-aigc/usr/polu +# bash run_gpu.sh \ No newline at end of file diff --git a/train/SDXL/train_tdd_h800.sh b/train/SDXL/train_tdd_h800.sh new file mode 100644 index 0000000..b872f65 --- /dev/null +++ b/train/SDXL/train_tdd_h800.sh @@ -0,0 +1,33 @@ +export TRAIN_SHARDS_PATH_OR_URL="/mnt/dataset/laion_6plus" +export PRETRAINED_TEACHER_MODEL="./stable-diffusion-xl-base-1.0" +export PRETRAINED_VAE_MODEL_NAME_OR_PATH="./sdxl-vae-fp16-fix" +accelerate launch --config_file=config.yaml train_tdd.py \ + --pretrained_teacher_model=$PRETRAINED_TEACHER_MODEL \ + --pretrained_vae_model_name_or_path=$PRETRAINED_VAE_MODEL_NAME_OR_PATH \ + --train_shards_path_or_url=$TRAIN_SHARDS_PATH_OR_URL \ + --output_dir="result/TDD_uc0.2_etas0.3_ddim250" \ + --seed=453645634 \ + --resolution=1024 \ + --max_train_samples=4000000 \ + --max_train_steps=100000 \ + --train_batch_size=16 \ + --dataloader_num_workers=32 \ + --gradient_accumulation_steps=4 \ + --checkpointing_steps=5000 \ + --validation_steps=1000 \ + --learning_rate=1e-06 \ + --lora_rank=64 \ + --w_max=3.5 \ + --w_min=3.5 \ + --mixed_precision="fp16" \ + --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \ + --val_infer_step=4 \ + --gradient_checkpointing \ + --num_ddim_timesteps=250 \ + --proportion_empty_prompts=0.2 \ + --num_inference_steps_min=4 \ + --num_inference_steps_max=8 \ + --s_ratio=0.3 \ + +# cd /mnt/nj-aigc/usr/polu +# bash run_gpu.sh \ No newline at end of file