diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index fa0ae78..2a1bf80 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -27,6 +27,8 @@ from colorama import Fore, Style import pprint +from plugins.plugins import PluginRunner + BETAS_DEFAULT = [0.9, 0.999] EPSILON_DEFAULT = 1e-8 WEIGHT_DECAY_DEFAULT = 0.01 @@ -120,8 +122,9 @@ def _calculate_norm(self, param, p): else: return 0.0 - def step(self, loss, step, global_step): + def step(self, loss, step, global_step, plugin_runner: PluginRunner, ed_state: 'EveryDreamTrainingState'): self.scaler.scale(loss).backward() + plugin_runner.run_on_backpropagation(ed_state=ed_state) if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1): if self.clip_grad_norm is not None: @@ -142,6 +145,7 @@ def step(self, loss, step, global_step): self.log_writer.add_scalar("optimizer/te_grad_norm", te_grad_norm, global_step) for optimizer in self.optimizers: + # the scaler steps the optimizer on our behalf self.scaler.step(optimizer) self.scaler.update() @@ -480,6 +484,7 @@ def _create_optimizer(self, label, args, local_optimizer_config, parameters): def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: num_layers = len(text_encoder.text_model.encoder.layers) unfreeze_embeddings = True + unfreeze_position_embeddings = True unfreeze_last_n_layers = None unfreeze_final_layer_norm = True if "freeze_front_n_layers" in self.te_freeze_config: @@ -499,7 +504,6 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: unfreeze_last_n_layers = num_layers else: # something specified: - assert(unfreeze_last_n_layers > 0) if unfreeze_last_n_layers < num_layers: # if we're unfreezing layers then by default we ought to freeze the embeddings unfreeze_embeddings = False @@ -508,11 +512,13 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: unfreeze_embeddings = not self.te_freeze_config["freeze_embeddings"] if "freeze_final_layer_norm" in self.te_freeze_config: unfreeze_final_layer_norm = not self.te_freeze_config["freeze_final_layer_norm"] + if "freeze_position_embeddings" in self.te_freeze_config: + unfreeze_position_embeddings = not self.te_freeze_config["freeze_position_embeddings"] parameters = itertools.chain([]) if unfreeze_embeddings: - parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.parameters()) + parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.token_embedding.parameters()) else: print(" ❄️ freezing embeddings") @@ -530,6 +536,15 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: else: print(" ❄️ freezing final layer norm") + if unfreeze_position_embeddings: + parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.position_embeddings.parameters) + else: + print(" ❄️ freezing position embeddings") + + # make sure there's some requires_grad in some places + parameters = list(parameters) + set_requires_grad(text_encoder.parameters(), False) + set_requires_grad(parameters, True) return parameters @@ -548,3 +563,7 @@ def log_optimizer(label: str, optimizer: torch.optim.Optimizer, betas, epsilon, logging.info(f"{Fore.CYAN} * {label} optimizer: {optimizer.__class__.__name__} {param_info} *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}") + +def set_requires_grad(params, requires_grad: bool): + for param in params: + param.requires_grad = requires_grad \ No newline at end of file diff --git a/plugins/plugins.py b/plugins/plugins.py index 029bde4..1e8f6d0 100644 --- a/plugins/plugins.py +++ b/plugins/plugins.py @@ -18,10 +18,18 @@ def on_step_start(self, **kwargs): pass def on_step_end(self, **kwargs): pass - def transform_caption(self, caption:str): + def on_model_load(self, **kwargs): + pass + def on_model_save(self, **kwargs): + pass + def on_backpropagation(self, **kwargs): + pass + def transform_caption(self, caption:str) -> str: return caption - def transform_pil_image(self, img:Image): + def transform_pil_image(self, img:Image) -> Image: return img + def modify_sample_prompt(self, prompt:str) -> str: + return prompt def load_plugin(plugin_path): print(f" - Attempting to load plugin: {plugin_path}") @@ -92,7 +100,22 @@ def run_on_step_end(self, **kwargs): for plugin in self.plugins: with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'): plugin.on_step_end(**kwargs) - + + def run_on_backpropagation(self, **kwargs): + for plugin in self.plugins: + with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'): + plugin.on_backpropagation(**kwargs) + + def run_on_model_load(self, **kwargs): + for plugin in self.plugins: + with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'): + plugin.on_model_load(**kwargs) + + def run_on_model_save(self, **kwargs): + for plugin in self.plugins: + with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'): + plugin.on_model_save(**kwargs) + def run_transform_caption(self, caption): with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_caption"): for plugin in self.plugins: @@ -104,3 +127,9 @@ def run_transform_pil_image(self, img): for plugin in self.plugins: img = plugin.transform_pil_image(img) return img + + def run_modify_sample_prompt(self, prompt) -> str: + with Timer(warn_seconds=self.step_warn_seconds, label="plugin.modify_sample_prompt"): + for plugin in self.plugins: + prompt = plugin.modify_sample_prompt(prompt) + return prompt diff --git a/plugins/textual_inversion.json b/plugins/textual_inversion.json new file mode 100644 index 0000000..7b21167 --- /dev/null +++ b/plugins/textual_inversion.json @@ -0,0 +1,16 @@ + +{ + "documentation": { + "tokens": { + "token": "the token (word) to train, given exactly as it appears in tokens", + "initializer": "starting point for the embedding. make it close to the intended meaning to kick-start training. should be shorter (in tokens) than the vector_length.", + "vector_length": "length of the embedding (default 1). use more if you have more images and/or if what you want to train is complex." + }, + "example": "the example below trains `hat*`, `dancing shoes` and `cane` as custom tokens, if you have training data where the captions include those tokens." + }, + "tokens": [ + { "token": "hat*", "initializer": "a man's hat", "vector_length": 8 }, + { "token": "dancing shoes", "initializer": "shoes" }, + { "token": "cane", "initializer": "cane" } + ] +} diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index 147750f..8b506a7 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -4,6 +4,9 @@ import torch from colorama import Fore +import re + +from transformers import CLIPTextModel, CLIPTokenizer from plugins.plugins import BasePlugin from train import EveryDreamTrainingState @@ -25,12 +28,75 @@ "text_encoder_freezing": { "unfreeze_last_n_layers": 0, "freeze_embeddings": false, - "freeze_final_layer_norm": true + "freeze_final_layer_norm": true, + "freeze_position_embeddings": true } -In addition, you'll need a very high LR on the TE - maybe even as high as 1e-3. I recommend using the LR finder method. +In addition, you'll need a very high LR on the TE - maybe even as high as 5e-2. I recommend using the LR finder method. """ +class TextualInversionLoaderPlugin(BasePlugin): + def __init__(self): + path = os.path.join(os.path.dirname(__file__), "textual_inversion_loader.json") + logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}") + with open(path, 'rt') as f: + self.config = json.load(f) + self.padding_tokens = {} + + def on_model_load(self, **kwargs): + ed_state: EveryDreamTrainingState = kwargs['ed_state'] + resume_ckpt: str = kwargs['resume_ckpt'] + #self.original_tokens_length = len(ed_state.tokenizer) + tokenizer = ed_state.tokenizer + + token_config = self.config['tokens'] + + embeddings = {} + for token_info in self.config['tokens']: + token = token_info["token"] + path = token_info.get("path", None) or _get_embedding_path(resume_ckpt, token) + with open(path, "rb") as f: + embedding_dict = torch.load(f) + embedding = list(embedding_dict.values())[0] + embeddings[token] = embedding + token_info["vector_length"] = embedding.shape[0] + print(f" * Textual Inversion Loader loaded embedding with vector length {token_info['vector_length']} for token '{token}' from {path}") + + training_tokens, padding_tokens, tokens_to_add, tokens_to_overwrite = ( + _setup_tokens(text_encoder=ed_state.text_encoder, tokenizer=ed_state.tokenizer, token_infos=token_config)) + self.padding_tokens = padding_tokens + + input_embeddings = ed_state.text_encoder.get_input_embeddings() + for token_info in self.config['tokens']: + token = token_info["token"] + vector_length = token_info["vector_length"] + trigger_and_padding_tokens = [token] + padding_tokens[token] + embedding = embeddings[token] + for i in range(vector_length): + token_ids = _get_token_ids(tokenizer, trigger_and_padding_tokens[i]) + token_id = token_ids[0] + input_embeddings.weight.data[token_id] = embedding[i] + + + def transform_caption(self, caption:str) -> str: + return self.expand_trigger_tokens(caption) + + def modify_sample_prompt(self, prompt: str) -> str: + return self.expand_trigger_tokens(prompt) + + def expand_trigger_tokens(self, caption: str) -> str: + tokens = self.config['tokens'] + # for multi-vector tokens, replace the trigger token with a padded sequence of the correct length. + # eg "hat*" with vector length 3 -> "hat* hat*_pad!!!_1 hat*_pad!!!_2" + for t in tokens: + trigger = t['token'] + replacement = " ".join([trigger] + self.padding_tokens[trigger]) + caption = re.sub(trigger, replacement, caption) + return caption + + + + class TextualInversionPlugin(BasePlugin): def __init__(self): @@ -38,19 +104,25 @@ def __init__(self): logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}") with open(path, 'rt') as f: self.config = json.load(f) - self.this_batch_tokens = None + self.training_tokens = None self.training_token_ids = None - self.original_text_embeddings = None + self.padding_tokens = {} + self.padding_token_ids = {} + self.textual_inversion_tokens_only_grads = None def on_model_load(self, **kwargs): ed_state: EveryDreamTrainingState = kwargs.get('ed_state') - optimizer_config: dict = kwargs.get('optimizer_config') - def get_token_ids(t: str): - return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t)) + tokenizer = ed_state.tokenizer # check for correctly configured text encoder training + disable_unet_training: bool = kwargs.get('disable_unet_training') + disable_textenc_training: bool = kwargs.get('disable_textenc_training') + if not disable_unet_training or disable_textenc_training: + logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES {Fore.RESET}\"disable_unet_training\": true{Fore.LIGHTRED_EX} and {Fore.RESET}\"disable_textenc_training\": false{Fore.LIGHTRED_EX} in your train.json{Fore.RESET}") + raise RuntimeError("Unet training must be disabled and text encoder training enabled") num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers) + optimizer_config: dict = kwargs.get('optimizer_config') if (optimizer_config is None or 'text_encoder_freezing' not in optimizer_config or optimizer_config['text_encoder_freezing'].get('freeze_embeddings') != False or @@ -62,73 +134,148 @@ def get_token_ids(t: str): logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}") raise RuntimeError("Misconfigured optimizer config") - tokens_to_add = [t['token'] for t in self.config['tokens'] if len(get_token_ids(t['token']))>1] + for token_info in self.config["tokens"]: + start_token = token_info["token"] + vector_length = token_info.get("vector_length", 1) + print(f" * Textual Inversion training on '{start_token}' with vector length {vector_length}") + + + training_tokens, padding_tokens, tokens_to_add, tokens_to_overwrite = ( + _setup_tokens(text_encoder=ed_state.text_encoder, tokenizer=ed_state.tokenizer, token_infos=self.config['tokens'])) + self.padding_tokens = padding_tokens + for trigger_token, padding_tokens in padding_tokens.items(): + this_padding_token_ids = [_get_token_ids(tokenizer, t)[0] for t in padding_tokens] + self.padding_token_ids[trigger_token] = this_padding_token_ids + logging.info( - f" * Textual inversion training adding the following tokens: {tokens_to_add}") - tokens_to_overwrite = [t['token'] for t in self.config['tokens'] if t['token'] not in tokens_to_add] + f" * Textual inversion training adding the following tokens: {sorted(tokens_to_add)}") if any(tokens_to_overwrite): logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}") - num_added_tokens = ed_state.tokenizer.add_tokens(tokens_to_add) - if num_added_tokens != len(tokens_to_add): - raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}") - ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer)) - - added_token_ids = [] + # copy initializer embedding input_embeddings = ed_state.text_encoder.get_input_embeddings() for token_info in self.config['tokens']: - # get newly added token id - t = token_info['token'] - token_ids = get_token_ids(t) - if len(token_ids) != 1: - raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {t}, found {len(token_ids)}") - token_id = token_ids[0] - added_token_ids.append(token_id) - - # copy initializer embedding - initializer_word = token_info['initializer_word'] - initializer_word_token_ids = get_token_ids(initializer_word) - if len(initializer_word_token_ids) != 1: - raise RuntimeError(f"Tokens not added succesfully - initializer word '{initializer_word}' needs " - f"{len(initializer_word_token_ids)} tokens, but only single tokens are supported.") - initializer_word_token_id = initializer_word_token_ids[0] - initializer_embedding = input_embeddings.weight.data[initializer_word_token_id] - input_embeddings.weight.data[token_id] = initializer_embedding - - overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite] + vector_length = token_info.get('vector_length', 1) + # make sure it's very long + initializer_text = None if token_info.get('random_initializer', True) else " ".join([token_info['initializer']] * vector_length) + if initializer_text is None: + reference = input_embeddings.weight[0] + embedding_length = reference.shape[0] + initializer_embedding = torch.rand([vector_length, embedding_length], + dtype=reference.dtype, + device=reference.device) * 0.1 - 0.05 + else: + with torch.no_grad(): + initializer_token_ids_full = ed_state.tokenizer(initializer_text, + truncation=True, + padding="max_length", + max_length=ed_state.tokenizer.model_max_length, + ).input_ids + initializer_embedding_full = ed_state.text_encoder( + torch.tensor(initializer_token_ids_full, device=ed_state.text_encoder.device).unsqueeze(0), output_hidden_states=True + ).last_hidden_state + initializer_embedding = initializer_embedding_full[0][1:vector_length+1] + + trigger_token = token_info['token'] + trigger_and_padding_tokens = [trigger_token] + self.padding_tokens[trigger_token] + for i in range(vector_length): + token_ids = _get_token_ids(tokenizer, trigger_and_padding_tokens[i]) + token_id = token_ids[0] + input_embeddings.weight.data[token_id] = initializer_embedding[i] + self.training_tokens = tokens_to_add + tokens_to_overwrite - self.training_token_ids = added_token_ids + overwriting_token_ids - self.original_text_embeddings = ed_state.text_encoder.get_input_embeddings().weight.data.detach().clone() + self.training_token_ids = [_get_token_ids(tokenizer, t)[0] for t in self.training_tokens] + # get indices of non-training tokens (ie tokens whose grads should be reset to 0 every step) + total_len = len(ed_state.text_encoder.get_input_embeddings().weight) + all_token_ids = torch.arange(total_len, dtype=torch.int) - def on_step_start(self, **kwargs): - batch = kwargs['batch'] - tokens = batch['tokens'] # a torch.stack - self.this_batch_tokens = torch.unique(torch.flatten(tokens)).tolist() + untrained_tokens_working = torch.cat((all_token_ids, torch.tensor(self.training_token_ids, dtype=torch.int))) + uniques, counts = untrained_tokens_working.unique(return_counts=True) + untrained_tokens = uniques[counts == 1] + self.non_training_token_ids = untrained_tokens - def on_step_end(self, **kwargs): + def on_backpropagation(self, **kwargs): + # Zero out the gradients for all token embeddings except the newly added + # embeddings for the concept, as we only want to optimize the concept embeddings + index_grads_to_zero = self.non_training_token_ids ed_state: EveryDreamTrainingState = kwargs['ed_state'] + grads = ed_state.text_encoder.get_input_embeddings().weight.grad + #print(f"before zeroing: global sum {torch.sum(grads)}, training sum {torch.sum(grads[self.training_token_ids])}, individual: {grads[self.training_token_ids]}") + grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) + #print(f"after zeroing: global sum {torch.sum(grads)}, training sum {torch.sum(grads[self.training_token_ids])}, individual: {grads[self.training_token_ids]}") - # reset the embeddings that have been touched this step, except the ones we're training, to their original state - with (torch.no_grad()): - embeddings = ed_state.text_encoder.get_input_embeddings() - embeddings_to_restore = [t for t in self.this_batch_tokens if t not in self.training_token_ids] - for t in embeddings_to_restore: - embeddings.weight[t] = self.original_text_embeddings[t] def on_model_save(self, **kwargs): ed_state: EveryDreamTrainingState = kwargs['ed_state'] embeddings = ed_state.text_encoder.get_input_embeddings() - save_folder = kwargs['save_folder'] + save_folder = kwargs['diffusers_save_path'] for token_id, token in zip(self.training_token_ids, self.training_tokens): - _save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder) + if token not in self.padding_token_ids: + continue + padding_token_ids = self.padding_token_ids[token] + all_token_ids = [token_id] + padding_token_ids + full_embedding = embeddings.weight[all_token_ids] + _save_embedding(token=token, embedding=full_embedding, save_folder=save_folder) + + def transform_caption(self, caption:str) -> str: + return self.expand_trigger_tokens(caption) + + def modify_sample_prompt(self, prompt: str) -> str: + return self.expand_trigger_tokens(prompt) + + def expand_trigger_tokens(self, caption: str) -> str: + tokens = self.config['tokens'] + # for multi-vector tokens, replace the trigger token with a padded sequence of the correct length. + # eg "hat*" with vector length 3 -> "hat* hat*_pad!!!_1 hat*_pad!!!_2" + for t in tokens: + trigger = t['token'] + replacement = " ".join([trigger] + self.padding_tokens[trigger]) + caption = re.sub(trigger, replacement, caption) + return caption + + def _save_embedding(token, embedding, save_folder): dict_to_save = {token: embedding} - token_name_safe = clean_filename(token) - ti_folder = os.path.join(save_folder, 'textual_inversions') - os.makedirs(ti_folder, exist_ok=True) - save_path = os.path.join(ti_folder, token_name_safe + '.bin') + save_path = _get_embedding_path(save_folder, token) + os.makedirs(os.path.dirname(save_path), exist_ok=True) logging.info(f"Saving textual inversion for '{token}' to {save_path}") torch.save(dict_to_save, save_path) +def _get_embedding_path(save_folder: str, token: str) -> str: + token_name_safe = clean_filename(token) + ti_folder = os.path.join(save_folder, 'textual_inversions') + return os.path.join(ti_folder, token_name_safe + '.bin') + +def _setup_tokens(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, token_infos: list[dict]) -> tuple[set, dict, list, list]: + training_tokens = set() + padding_tokens = {} + for token_info in token_infos: + start_token = token_info['token'] + vector_length = token_info.get('vector_length', 1) + this_padding_tokens = [f"{start_token}_pad!!!_{n + 1}" for n in range(vector_length - 1)] + padding_tokens[start_token] = this_padding_tokens + training_tokens.update([start_token] + this_padding_tokens) + + tokens_to_add = [t for t in training_tokens if len(_get_token_ids(tokenizer, t)) > 1] + tokens_to_overwrite = [t for t in training_tokens if t not in tokens_to_add] + + num_added_tokens = tokenizer.add_tokens(tokens_to_add) + if num_added_tokens != len(tokens_to_add): + raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}") + text_encoder.resize_token_embeddings(len(tokenizer)) + + added_token_ids = [] + for token in tokens_to_add: + token_ids = _get_token_ids(tokenizer, token) + if len(token_ids) != 1: + raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {token}, found {len(token_ids)}") + token_id = token_ids[0] + added_token_ids.append(token_id) + + return training_tokens, padding_tokens, tokens_to_add, tokens_to_overwrite + + +def _get_token_ids(tokenizer, t: str): + return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(t)) diff --git a/plugins/textual_inversion_loader.json b/plugins/textual_inversion_loader.json new file mode 100644 index 0000000..4773ff1 --- /dev/null +++ b/plugins/textual_inversion_loader.json @@ -0,0 +1,16 @@ + +{ + "documentation": { + "tokens": { + "token": "the trigger token (word or phrase). whenever this word or phrase appears in image captions, the embedding will be trained.", + "path": "(optional) /path/to/embedding.bin. If omitted, tries to load an embedding from the resume_ckpt diffusers folder, textual_inversions/.bin where is the token" + }, + "example": "the example below tries to load textual_inversions/hat*.bin from inside the resume ckpt's textual_inversion folder and textual_inversions/dancing shoes.bin from inside the model folder, and cane from the path specified." + }, + "tokens": [ + { "token": "hat*" }, + { "token": "dancing shoes" }, + { "token": "cane", "path": "/workspace/embeddings/my_cane_embedding_ep30.bin"} + ] + +} diff --git a/train.py b/train.py index 749c5c7..f2d5ee3 100644 --- a/train.py +++ b/train.py @@ -57,6 +57,7 @@ from data.every_dream import EveryDreamBatch, build_torch_dataloader from data.every_dream_validation import EveryDreamValidator from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID +from plugins.plugins import PluginRunner from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter from utils.isolate_rng import isolate_rng @@ -118,12 +119,12 @@ def convert_to_hf(ckpt_path): class EveryDreamTrainingState: def __init__(self, - optimizer: EveryDreamOptimizer, - train_batch: EveryDreamBatch, + optimizer: Optional[EveryDreamOptimizer], + train_batch: Optional[EveryDreamBatch], unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - scheduler, + scheduler: Optional, vae: AutoencoderKL, unet_ema: Optional[UNet2DConditionModel], text_encoder_ema: Optional[CLIPTextModel] @@ -141,7 +142,7 @@ def __init__(self, @torch.no_grad() def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, save_ckpt_dir, yaml_name, - save_full_precision=False, save_optimizer_flag=False, save_ckpt=True): + save_full_precision=False, save_optimizer_flag=False, save_ckpt=True, plugin_runner: PluginRunner=None): """ Save the model to disk """ @@ -167,7 +168,6 @@ def save_ckpt_file(diffusers_model_path, sd_ckpt_path): logging.info(f" * Saving yaml to {yaml_save_path}") shutil.copyfile(yaml_name, yaml_save_path) - if global_step is None or global_step == 0: logging.warning(" No model to save, something likely blew up on startup, not saving") return @@ -217,6 +217,11 @@ def save_ckpt_file(diffusers_model_path, sd_ckpt_path): logging.info(f" Saving optimizer state to {save_path}") ed_state.optimizer.save(save_path) + plugin_runner.run_on_model_save( + ed_state=ed_state, + diffusers_save_path=diffusers_model_path + ) + def setup_local_logger(args): """ @@ -784,6 +789,14 @@ def release_memory(model_to_delete, original_device): from plugins.plugins import PluginRunner plugin_runner = PluginRunner(plugins=plugins) + plugin_runner.run_on_model_load( + ed_state=EveryDreamTrainingState(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, + optimizer=None, train_batch=None, scheduler=noise_scheduler, unet_ema=None, text_encoder_ema=None), + resume_ckpt=args.resume_ckpt, + optimizer_config=optimizer_config, + disable_unet_training=args.disable_unet_training, + disable_textenc_training=args.disable_textenc_training + ) data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, @@ -872,7 +885,7 @@ def sigterm_handler(signum, frame): time.sleep(2) # give opportunity to ctrl-C again to cancel save save_model(interrupted_checkpoint_path, global_step=global_step, ed_state=make_current_ed_state(), save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision, - save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt, plugin_runner=plugin_runner) exit(_SIGTERM_EXIT_CODE) else: # non-main threads (i.e. dataloader workers) should exit cleanly @@ -891,7 +904,7 @@ def sigterm_handler(signum, frame): train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size) - unet.train() if not args.disable_unet_training else unet.eval() + unet.train() if (args.gradient_checkpointing or not args.disable_unet_training) else unet.eval() text_encoder.train() if not args.disable_textenc_training else text_encoder.eval() logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}") @@ -1054,7 +1067,7 @@ def generate_samples(global_step: int, batch): vae=vae, diffusers_scheduler_config=inference_scheduler.config ).to(device) - sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info) + sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info, plugin_runner=plugin_runner) # Cleanup del inference_pipe @@ -1169,7 +1182,7 @@ def update_arg(arg: str, newValue): runt_loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5 loss = loss * runt_loss_scale - ed_optimizer.step(loss, step, global_step) + ed_optimizer.step(loss, step, global_step, plugin_runner=plugin_runner, ed_state=make_current_ed_state()) if args.ema_decay_rate != None: if ((global_step + 1) % args.ema_update_interval) == 0: @@ -1239,7 +1252,7 @@ def update_arg(arg: str, newValue): save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(), save_ckpt_dir=args.save_ckpt_dir, yaml_name=None, save_full_precision=args.save_full_precision, - save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt, plugin_runner=plugin_runner) plugin_runner.run_on_step_end(epoch=epoch, global_step=global_step, @@ -1286,7 +1299,7 @@ def update_arg(arg: str, newValue): save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-")) save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(), save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision, - save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt, plugin_runner=plugin_runner) total_elapsed_time = time.time() - training_start_time logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}") @@ -1298,7 +1311,7 @@ def update_arg(arg: str, newValue): save_path = make_save_path(epoch, global_step, prepend="errored-") save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(), save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision, - save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt, plugin_runner=plugin_runner) logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}") raise ex diff --git a/utils/sample_generator.py b/utils/sample_generator.py index ad40123..4fe5371 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -19,6 +19,8 @@ from tqdm.auto import tqdm from compel import Compel +from plugins.plugins import PluginRunner + def clean_filename(filename): """ @@ -184,7 +186,7 @@ def _reload_config_json(self, path): self.sample_requests = self._make_random_caption_sample_requests() @torch.no_grad() - def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, extra_info: str = ""): + def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, plugin_runner: PluginRunner, extra_info: str = ""): """ generates samples at different cfg scales and saves them to disk """ @@ -211,8 +213,10 @@ def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool: text_encoder=pipe.text_encoder, use_penultimate_clip_layer=self.use_penultimate_clip_layer) for batch in batches: - prompts = [p.prompt for p in batch] - negative_prompts = [p.negative_prompt for p in batch] + prompts = [plugin_runner.run_modify_sample_prompt(p.prompt) + for p in batch] + negative_prompts = [plugin_runner.run_modify_sample_prompt(p.negative_prompt) + for p in batch] seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) for p in batch] # all sizes in a batch are the same