From 3840e4811a954d518379754037101e1035193943 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 7 Jan 2024 21:08:52 +1300 Subject: [PATCH 01/13] zero the grads rather than resetting the weights --- plugins/textual_inversion.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index 147750f..a6413b4 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -42,6 +42,7 @@ def __init__(self): self.training_tokens = None self.training_token_ids = None self.original_text_embeddings = None + self.textual_inversion_tokens_only_grads = None def on_model_load(self, **kwargs): ed_state: EveryDreamTrainingState = kwargs.get('ed_state') @@ -109,12 +110,21 @@ def on_step_start(self, **kwargs): def on_step_end(self, **kwargs): ed_state: EveryDreamTrainingState = kwargs['ed_state'] - # reset the embeddings that have been touched this step, except the ones we're training, to their original state - with (torch.no_grad()): - embeddings = ed_state.text_encoder.get_input_embeddings() - embeddings_to_restore = [t for t in self.this_batch_tokens if t not in self.training_token_ids] - for t in embeddings_to_restore: - embeddings.weight[t] = self.original_text_embeddings[t] + # Zero out the gradients for all token embeddings except the newly added + # embeddings for the concept, as we only want to optimize the concept embeddings + grads = ed_state.text_encoder.get_input_embeddings().weight.grad + # Get the index for tokens that we want to zero the grads for + index_grads_to_zero = torch.arange(len(grads)) != placeholder_token_id + grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) + + grads = ed_state.text_encoder.get_input_embeddings().weight.grad + if self.textual_inversion_tokens_only_grads is None: + self.textual_inversion_tokens_only_grads = torch.zeros_like(grads) + for t in self.training_token_ids: + self.textual_inversion_tokens_only_grads[t] = grads[t] + + ed_state.text_encoder.get_input_embeddings().weight.grad = self.textual_inversion_tokens_only_grads + def on_model_save(self, **kwargs): ed_state: EveryDreamTrainingState = kwargs['ed_state'] From df4647e4ab6800f86b0692fd21170a59d058ddff Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 8 Jan 2024 20:00:39 +1300 Subject: [PATCH 02/13] wip long vector length --- plugins/plugins.py | 12 +++++++++++- plugins/textual_inversion.json | 16 ++++++++++++++++ plugins/textual_inversion.py | 1 + train.py | 12 +++++++++++- 4 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 plugins/textual_inversion.json diff --git a/plugins/plugins.py b/plugins/plugins.py index 029bde4..8153482 100644 --- a/plugins/plugins.py +++ b/plugins/plugins.py @@ -92,7 +92,17 @@ def run_on_step_end(self, **kwargs): for plugin in self.plugins: with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'): plugin.on_step_end(**kwargs) - + + def run_on_model_load(self, **kwargs): + for plugin in self.plugins: + with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'): + plugin.on_model_load(**kwargs) + + def run_on_model_save(self, **kwargs): + for plugin in self.plugins: + with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'): + plugin.on_model_save(**kwargs) + def run_transform_caption(self, caption): with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_caption"): for plugin in self.plugins: diff --git a/plugins/textual_inversion.json b/plugins/textual_inversion.json new file mode 100644 index 0000000..02576cc --- /dev/null +++ b/plugins/textual_inversion.json @@ -0,0 +1,16 @@ + +{ + "documentation": { + "tokens": { + "token": "the token (word) to train, given exactly as it appears in tokens", + "initializer_word": "starting point for the embedding. make it close to the intended meaning to kick-start training", + "vector_length": "length of the embedding. use more if you have more images and/or if what you want to train is complex" + }, + "example": "the example below trains `hat*`, `dancing shoes` and `cane` as custom tokens, if you have training data where the captions include those tokens." + }, + "tokens": [ + { "token": "hat*", "initializer_word": "hat", "vector_length": 8 }, + { "token": "dancing shoes", "initializer_word": "shoes" }, + { "token": "cane", "initializer_word": "cane" } + ] +} diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index a6413b4..4c9fc8f 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -133,6 +133,7 @@ def on_model_save(self, **kwargs): for token_id, token in zip(self.training_token_ids, self.training_tokens): _save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder) + def _save_embedding(token, embedding, save_folder): dict_to_save = {token: embedding} token_name_safe = clean_filename(token) diff --git a/train.py b/train.py index 749c5c7..a2a5988 100644 --- a/train.py +++ b/train.py @@ -57,6 +57,7 @@ from data.every_dream import EveryDreamBatch, build_torch_dataloader from data.every_dream_validation import EveryDreamValidator from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID +from plugins.plugins import PluginRunner from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter from utils.isolate_rng import isolate_rng @@ -141,7 +142,7 @@ def __init__(self, @torch.no_grad() def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, save_ckpt_dir, yaml_name, - save_full_precision=False, save_optimizer_flag=False, save_ckpt=True): + save_full_precision=False, save_optimizer_flag=False, save_ckpt=True, plugin_runner: PluginRunner=None): """ Save the model to disk """ @@ -188,6 +189,11 @@ def save_ckpt_file(diffusers_model_path, sd_ckpt_path): logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}") pipeline_ema.save_pretrained(diffusers_model_path) + plugin_runner.run_on_model_save( + ed_state=ed_state, + diffusers_save_path=diffusers_model_path + ) + if save_ckpt: sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.safetensors" @@ -784,6 +790,10 @@ def release_memory(model_to_delete, original_device): from plugins.plugins import PluginRunner plugin_runner = PluginRunner(plugins=plugins) + plugin_runner.run_on_model_load( + ed_state=EveryDreamTrainingState(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae), + optimizer_config=optimizer_config + ) data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, From 4ff972c67cf6d3b024035df23d11ab5abfbff245 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 15 Jan 2024 18:47:00 +1300 Subject: [PATCH 03/13] ti >1 vector lenght support (untested) --- optimizer/optimizers.py | 11 +++- plugins/textual_inversion.json | 10 ++-- plugins/textual_inversion.py | 94 +++++++++++++++++++++------------- train.py | 9 ++-- 4 files changed, 78 insertions(+), 46 deletions(-) diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index fa0ae78..d942968 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -480,6 +480,7 @@ def _create_optimizer(self, label, args, local_optimizer_config, parameters): def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: num_layers = len(text_encoder.text_model.encoder.layers) unfreeze_embeddings = True + unfreeze_position_embeddings = True unfreeze_last_n_layers = None unfreeze_final_layer_norm = True if "freeze_front_n_layers" in self.te_freeze_config: @@ -499,7 +500,6 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: unfreeze_last_n_layers = num_layers else: # something specified: - assert(unfreeze_last_n_layers > 0) if unfreeze_last_n_layers < num_layers: # if we're unfreezing layers then by default we ought to freeze the embeddings unfreeze_embeddings = False @@ -508,11 +508,13 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: unfreeze_embeddings = not self.te_freeze_config["freeze_embeddings"] if "freeze_final_layer_norm" in self.te_freeze_config: unfreeze_final_layer_norm = not self.te_freeze_config["freeze_final_layer_norm"] + if "freeze_position_embeddings" in self.te_freeze_config: + unfreeze_position_embeddings = not self.te_freeze_config["freeze_position_embeddings"] parameters = itertools.chain([]) if unfreeze_embeddings: - parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.parameters()) + parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.token_embedding.parameters()) else: print(" ❄️ freezing embeddings") @@ -530,6 +532,11 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: else: print(" ❄️ freezing final layer norm") + if unfreeze_position_embeddings: + parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.position_embeddings.parameters) + else: + print(" ❄️ freezing position embeddings") + return parameters diff --git a/plugins/textual_inversion.json b/plugins/textual_inversion.json index 02576cc..7b21167 100644 --- a/plugins/textual_inversion.json +++ b/plugins/textual_inversion.json @@ -3,14 +3,14 @@ "documentation": { "tokens": { "token": "the token (word) to train, given exactly as it appears in tokens", - "initializer_word": "starting point for the embedding. make it close to the intended meaning to kick-start training", - "vector_length": "length of the embedding. use more if you have more images and/or if what you want to train is complex" + "initializer": "starting point for the embedding. make it close to the intended meaning to kick-start training. should be shorter (in tokens) than the vector_length.", + "vector_length": "length of the embedding (default 1). use more if you have more images and/or if what you want to train is complex." }, "example": "the example below trains `hat*`, `dancing shoes` and `cane` as custom tokens, if you have training data where the captions include those tokens." }, "tokens": [ - { "token": "hat*", "initializer_word": "hat", "vector_length": 8 }, - { "token": "dancing shoes", "initializer_word": "shoes" }, - { "token": "cane", "initializer_word": "cane" } + { "token": "hat*", "initializer": "a man's hat", "vector_length": 8 }, + { "token": "dancing shoes", "initializer": "shoes" }, + { "token": "cane", "initializer": "cane" } ] } diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index 4c9fc8f..79d0544 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -4,6 +4,7 @@ import torch from colorama import Fore +import re from plugins.plugins import BasePlugin from train import EveryDreamTrainingState @@ -38,10 +39,10 @@ def __init__(self): logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}") with open(path, 'rt') as f: self.config = json.load(f) - self.this_batch_tokens = None + self.training_tokens = None self.training_token_ids = None - self.original_text_embeddings = None + self.padding_tokens = {} self.textual_inversion_tokens_only_grads = None def on_model_load(self, **kwargs): @@ -63,10 +64,20 @@ def get_token_ids(t: str): logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}") raise RuntimeError("Misconfigured optimizer config") - tokens_to_add = [t['token'] for t in self.config['tokens'] if len(get_token_ids(t['token']))>1] + # new - multi-vector support + training_tokens = set() + for token_info in self.config['tokens']: + start_token = token_info['token'] + vector_length = token_info.get('vector_length', 1) + this_padding_tokens = [f"{start_token}_pad!!!_{n+1}" for n in range(vector_length-1)] + self.padding_tokens[start_token] = this_padding_tokens + training_tokens.update([start_token] + this_padding_tokens) + # end new - multi vector support + + tokens_to_add = [t for t in training_tokens if len(get_token_ids(t))>1] logging.info( f" * Textual inversion training adding the following tokens: {tokens_to_add}") - tokens_to_overwrite = [t['token'] for t in self.config['tokens'] if t['token'] not in tokens_to_add] + tokens_to_overwrite = [t for t in training_tokens if t not in tokens_to_add] if any(tokens_to_overwrite): logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}") @@ -76,55 +87,58 @@ def get_token_ids(t: str): ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer)) added_token_ids = [] - input_embeddings = ed_state.text_encoder.get_input_embeddings() - for token_info in self.config['tokens']: - # get newly added token id - t = token_info['token'] - token_ids = get_token_ids(t) + for token in tokens_to_add: + token_ids = get_token_ids(token) if len(token_ids) != 1: raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {t}, found {len(token_ids)}") token_id = token_ids[0] added_token_ids.append(token_id) - # copy initializer embedding - initializer_word = token_info['initializer_word'] - initializer_word_token_ids = get_token_ids(initializer_word) - if len(initializer_word_token_ids) != 1: - raise RuntimeError(f"Tokens not added succesfully - initializer word '{initializer_word}' needs " - f"{len(initializer_word_token_ids)} tokens, but only single tokens are supported.") - initializer_word_token_id = initializer_word_token_ids[0] - initializer_embedding = input_embeddings.weight.data[initializer_word_token_id] - input_embeddings.weight.data[token_id] = initializer_embedding + # copy initializer embedding + input_embeddings = ed_state.text_encoder.get_input_embeddings() + for token_info in self.config['tokens']: + vector_length = token_info.get('vector_length', 1) + initializer_text = token_info['initializer'] + with torch.no_grad(): + initializer_token_ids_full = ed_state.tokenizer(initializer_text, + truncation=True, + padding="max_length", + max_length=ed_state.tokenizer.model_max_length, + ).input_ids + initializer_embedding_full = ed_state.text_encoder( + torch.tensor(initializer_token_ids_full).unsqueeze(0), output_hidden_states=True + ).last_hidden_state + initializer_embedding = initializer_embedding_full[0][1:vector_length+1] + + trigger_token = token_info['token'] + trigger_and_padding_tokens = [trigger_token] + self.padding_tokens[trigger_token] + for i in range(vector_length): + token_ids = get_token_ids(trigger_and_padding_tokens[i]) + token_id = token_ids[0] + input_embeddings.weight.data[token_id] = initializer_embedding[i] overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite] self.training_tokens = tokens_to_add + tokens_to_overwrite self.training_token_ids = added_token_ids + overwriting_token_ids - self.original_text_embeddings = ed_state.text_encoder.get_input_embeddings().weight.data.detach().clone() + # get indices of non-training tokens (ie tokens whose grads should be reset to 0 every step) + total_len = len(ed_state.text_encoder.get_input_embeddings().weight) + all_token_ids = torch.arange(total_len, dtype=torch.int) - def on_step_start(self, **kwargs): - batch = kwargs['batch'] - tokens = batch['tokens'] # a torch.stack - self.this_batch_tokens = torch.unique(torch.flatten(tokens)).tolist() + untrained_tokens_working = torch.cat((all_token_ids, torch.tensor(self.training_token_ids, dtype=torch.int))) + uniques, counts = untrained_tokens_working.unique(return_counts=True) + untrained_tokens = uniques[counts == 1] + self.non_training_token_ids = untrained_tokens - def on_step_end(self, **kwargs): - ed_state: EveryDreamTrainingState = kwargs['ed_state'] + def on_step_end(self, **kwargs): # Zero out the gradients for all token embeddings except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings + index_grads_to_zero = self.non_training_token_ids + ed_state: EveryDreamTrainingState = kwargs['ed_state'] grads = ed_state.text_encoder.get_input_embeddings().weight.grad - # Get the index for tokens that we want to zero the grads for - index_grads_to_zero = torch.arange(len(grads)) != placeholder_token_id grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) - grads = ed_state.text_encoder.get_input_embeddings().weight.grad - if self.textual_inversion_tokens_only_grads is None: - self.textual_inversion_tokens_only_grads = torch.zeros_like(grads) - for t in self.training_token_ids: - self.textual_inversion_tokens_only_grads[t] = grads[t] - - ed_state.text_encoder.get_input_embeddings().weight.grad = self.textual_inversion_tokens_only_grads - def on_model_save(self, **kwargs): ed_state: EveryDreamTrainingState = kwargs['ed_state'] @@ -133,6 +147,16 @@ def on_model_save(self, **kwargs): for token_id, token in zip(self.training_token_ids, self.training_tokens): _save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder) + def transform_caption(self, caption:str): + tokens = self.config['tokens'] + # for multi-vector tokens, replace the trigger token with a padded sequence of the correct length. + # eg "hat*" with vector length 3 -> "hat* hat*_pad!!!_1 hat*_pad!!!_2" + for t in tokens: + trigger = t['token'] + replacement = " ".join([trigger] + self.training_tokens[trigger]) + caption = re.sub(trigger, replacement, caption) + return caption + def _save_embedding(token, embedding, save_folder): dict_to_save = {token: embedding} diff --git a/train.py b/train.py index a2a5988..4861ce5 100644 --- a/train.py +++ b/train.py @@ -119,12 +119,12 @@ def convert_to_hf(ckpt_path): class EveryDreamTrainingState: def __init__(self, - optimizer: EveryDreamOptimizer, - train_batch: EveryDreamBatch, + optimizer: Optional[EveryDreamOptimizer], + train_batch: Optional[EveryDreamBatch], unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - scheduler, + scheduler: Optional, vae: AutoencoderKL, unet_ema: Optional[UNet2DConditionModel], text_encoder_ema: Optional[CLIPTextModel] @@ -791,7 +791,8 @@ def release_memory(model_to_delete, original_device): from plugins.plugins import PluginRunner plugin_runner = PluginRunner(plugins=plugins) plugin_runner.run_on_model_load( - ed_state=EveryDreamTrainingState(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae), + ed_state=EveryDreamTrainingState(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, + optimizer=None, train_batch=None, scheduler=noise_scheduler, unet_ema=None, text_encoder_ema=None), optimizer_config=optimizer_config ) From e90d163571882ae5bfb33e21ffe20a018949be8e Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 15 Jan 2024 19:23:52 +1300 Subject: [PATCH 04/13] fix issues --- plugins/textual_inversion.py | 13 ++++++++++--- train.py | 11 +++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index 79d0544..fbade7d 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -43,6 +43,7 @@ def __init__(self): self.training_tokens = None self.training_token_ids = None self.padding_tokens = {} + self.padding_token_ids = {} self.textual_inversion_tokens_only_grads = None def on_model_load(self, **kwargs): @@ -64,7 +65,6 @@ def get_token_ids(t: str): logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}") raise RuntimeError("Misconfigured optimizer config") - # new - multi-vector support training_tokens = set() for token_info in self.config['tokens']: start_token = token_info['token'] @@ -72,7 +72,7 @@ def get_token_ids(t: str): this_padding_tokens = [f"{start_token}_pad!!!_{n+1}" for n in range(vector_length-1)] self.padding_tokens[start_token] = this_padding_tokens training_tokens.update([start_token] + this_padding_tokens) - # end new - multi vector support + print(f"textual inversion training: token sequence for {start_token} is \"{' '.join([start_token] + this_padding_tokens)}\"") tokens_to_add = [t for t in training_tokens if len(get_token_ids(t))>1] logging.info( @@ -94,6 +94,10 @@ def get_token_ids(t: str): token_id = token_ids[0] added_token_ids.append(token_id) + for trigger_token, padding_tokens in self.padding_tokens.items(): + this_padding_token_ids = [get_token_ids(t)[0] for t in padding_tokens] + self.padding_token_ids[trigger_token] = this_padding_token_ids + # copy initializer embedding input_embeddings = ed_state.text_encoder.get_input_embeddings() for token_info in self.config['tokens']: @@ -145,7 +149,10 @@ def on_model_save(self, **kwargs): embeddings = ed_state.text_encoder.get_input_embeddings() save_folder = kwargs['save_folder'] for token_id, token in zip(self.training_token_ids, self.training_tokens): - _save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder) + padding_token_ids = self.padding_token_ids[token] + all_token_ids = [token_id] + padding_token_ids + full_embedding = embeddings.weight[all_token_ids] + _save_embedding(token=token, embedding=full_embedding, save_folder=save_folder) def transform_caption(self, caption:str): tokens = self.config['tokens'] diff --git a/train.py b/train.py index 4861ce5..cff9943 100644 --- a/train.py +++ b/train.py @@ -168,7 +168,6 @@ def save_ckpt_file(diffusers_model_path, sd_ckpt_path): logging.info(f" * Saving yaml to {yaml_save_path}") shutil.copyfile(yaml_name, yaml_save_path) - if global_step is None or global_step == 0: logging.warning(" No model to save, something likely blew up on startup, not saving") return @@ -189,11 +188,6 @@ def save_ckpt_file(diffusers_model_path, sd_ckpt_path): logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}") pipeline_ema.save_pretrained(diffusers_model_path) - plugin_runner.run_on_model_save( - ed_state=ed_state, - diffusers_save_path=diffusers_model_path - ) - if save_ckpt: sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.safetensors" @@ -223,6 +217,11 @@ def save_ckpt_file(diffusers_model_path, sd_ckpt_path): logging.info(f" Saving optimizer state to {save_path}") ed_state.optimizer.save(save_path) + plugin_runner.run_on_model_save( + ed_state=ed_state, + diffusers_save_path=diffusers_model_path + ) + def setup_local_logger(args): """ From bf5b03205e044c2a94a385278545cabbac1a3ed9 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 15 Jan 2024 21:08:12 +1300 Subject: [PATCH 05/13] fix saving, more strict config check --- plugins/textual_inversion.json | 6 +++++- plugins/textual_inversion.py | 25 ++++++++++++++++++------- train.py | 12 +++++++----- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/plugins/textual_inversion.json b/plugins/textual_inversion.json index 7b21167..3511db5 100644 --- a/plugins/textual_inversion.json +++ b/plugins/textual_inversion.json @@ -8,9 +8,13 @@ }, "example": "the example below trains `hat*`, `dancing shoes` and `cane` as custom tokens, if you have training data where the captions include those tokens." }, - "tokens": [ + "example_tokens": [ { "token": "hat*", "initializer": "a man's hat", "vector_length": 8 }, { "token": "dancing shoes", "initializer": "shoes" }, { "token": "cane", "initializer": "cane" } + ], + "tokens": [ + { "token": "ay hairstyle", "initializer": "long red hair, elaborately plaited", "vector_length": 18 }, + { "token": "hz outfit", "initializer": "skins, furs, and pieces of ceramic, plastic, and alloy plating", "vector_length": 18 } ] } diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index fbade7d..a7811a4 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -26,7 +26,8 @@ "text_encoder_freezing": { "unfreeze_last_n_layers": 0, "freeze_embeddings": false, - "freeze_final_layer_norm": true + "freeze_final_layer_norm": true, + "freeze_position_embeddings": true } In addition, you'll need a very high LR on the TE - maybe even as high as 1e-3. I recommend using the LR finder method. @@ -48,12 +49,17 @@ def __init__(self): def on_model_load(self, **kwargs): ed_state: EveryDreamTrainingState = kwargs.get('ed_state') - optimizer_config: dict = kwargs.get('optimizer_config') def get_token_ids(t: str): return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t)) # check for correctly configured text encoder training + disable_unet_training: bool = kwargs.get('disable_unet_training') + disable_textenc_training: bool = kwargs.get('disable_textenc_training') + if not disable_unet_training or disable_textenc_training: + logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES {Fore.RESET}\"disable_unet_training\": true{Fore.LIGHTRED_EX} and {Fore.RESET}\"disable_textenc_training\": false{Fore.LIGHTRED_EX} in your train.json{Fore.RESET}") + raise RuntimeError("Unet training must be disabled and text encoder training enabled") num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers) + optimizer_config: dict = kwargs.get('optimizer_config') if (optimizer_config is None or 'text_encoder_freezing' not in optimizer_config or optimizer_config['text_encoder_freezing'].get('freeze_embeddings') != False or @@ -65,18 +71,21 @@ def get_token_ids(t: str): logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}") raise RuntimeError("Misconfigured optimizer config") + training_tokens = set() for token_info in self.config['tokens']: start_token = token_info['token'] vector_length = token_info.get('vector_length', 1) + print(f" * Textual Inversion training on '{start_token}' with vector length {vector_length}") this_padding_tokens = [f"{start_token}_pad!!!_{n+1}" for n in range(vector_length-1)] self.padding_tokens[start_token] = this_padding_tokens training_tokens.update([start_token] + this_padding_tokens) - print(f"textual inversion training: token sequence for {start_token} is \"{' '.join([start_token] + this_padding_tokens)}\"") + if vector_length > 1: + print(f" - if you want accurate samples for trigger '{start_token}', replace it in sample prompts with the following text: \"{' '.join([start_token] + this_padding_tokens)}\"") tokens_to_add = [t for t in training_tokens if len(get_token_ids(t))>1] logging.info( - f" * Textual inversion training adding the following tokens: {tokens_to_add}") + f" * Textual inversion training adding the following tokens: {sorted(tokens_to_add)}") tokens_to_overwrite = [t for t in training_tokens if t not in tokens_to_add] if any(tokens_to_overwrite): logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}") @@ -110,7 +119,7 @@ def get_token_ids(t: str): max_length=ed_state.tokenizer.model_max_length, ).input_ids initializer_embedding_full = ed_state.text_encoder( - torch.tensor(initializer_token_ids_full).unsqueeze(0), output_hidden_states=True + torch.tensor(initializer_token_ids_full, device=ed_state.text_encoder.device).unsqueeze(0), output_hidden_states=True ).last_hidden_state initializer_embedding = initializer_embedding_full[0][1:vector_length+1] @@ -147,8 +156,10 @@ def on_step_end(self, **kwargs): def on_model_save(self, **kwargs): ed_state: EveryDreamTrainingState = kwargs['ed_state'] embeddings = ed_state.text_encoder.get_input_embeddings() - save_folder = kwargs['save_folder'] + save_folder = kwargs['diffusers_save_path'] for token_id, token in zip(self.training_token_ids, self.training_tokens): + if token not in self.padding_token_ids: + continue padding_token_ids = self.padding_token_ids[token] all_token_ids = [token_id] + padding_token_ids full_embedding = embeddings.weight[all_token_ids] @@ -160,7 +171,7 @@ def transform_caption(self, caption:str): # eg "hat*" with vector length 3 -> "hat* hat*_pad!!!_1 hat*_pad!!!_2" for t in tokens: trigger = t['token'] - replacement = " ".join([trigger] + self.training_tokens[trigger]) + replacement = " ".join([trigger] + self.padding_tokens[trigger]) caption = re.sub(trigger, replacement, caption) return caption diff --git a/train.py b/train.py index cff9943..7dfae05 100644 --- a/train.py +++ b/train.py @@ -792,7 +792,9 @@ def release_memory(model_to_delete, original_device): plugin_runner.run_on_model_load( ed_state=EveryDreamTrainingState(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, optimizer=None, train_batch=None, scheduler=noise_scheduler, unet_ema=None, text_encoder_ema=None), - optimizer_config=optimizer_config + optimizer_config=optimizer_config, + disable_unet_training=args.disable_unet_training, + disable_textenc_training=args.disable_textenc_training ) data_loader = DataLoaderMultiAspect( @@ -882,7 +884,7 @@ def sigterm_handler(signum, frame): time.sleep(2) # give opportunity to ctrl-C again to cancel save save_model(interrupted_checkpoint_path, global_step=global_step, ed_state=make_current_ed_state(), save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision, - save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt, plugin_runner=plugin_runner) exit(_SIGTERM_EXIT_CODE) else: # non-main threads (i.e. dataloader workers) should exit cleanly @@ -1249,7 +1251,7 @@ def update_arg(arg: str, newValue): save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(), save_ckpt_dir=args.save_ckpt_dir, yaml_name=None, save_full_precision=args.save_full_precision, - save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt, plugin_runner=plugin_runner) plugin_runner.run_on_step_end(epoch=epoch, global_step=global_step, @@ -1296,7 +1298,7 @@ def update_arg(arg: str, newValue): save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-")) save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(), save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision, - save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt, plugin_runner=plugin_runner) total_elapsed_time = time.time() - training_start_time logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}") @@ -1308,7 +1310,7 @@ def update_arg(arg: str, newValue): save_path = make_save_path(epoch, global_step, prepend="errored-") save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(), save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision, - save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt, plugin_runner=plugin_runner) logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}") raise ex From a7f3b0a146662c292d287329e1a71872af19b142 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Tue, 16 Jan 2024 11:20:37 +1300 Subject: [PATCH 06/13] zero non-training embedding grads in the correct place --- optimizer/optimizers.py | 14 +++++++++++++- plugins/plugins.py | 7 +++++++ plugins/textual_inversion.py | 14 ++++++++------ train.py | 4 ++-- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index d942968..2a1bf80 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -27,6 +27,8 @@ from colorama import Fore, Style import pprint +from plugins.plugins import PluginRunner + BETAS_DEFAULT = [0.9, 0.999] EPSILON_DEFAULT = 1e-8 WEIGHT_DECAY_DEFAULT = 0.01 @@ -120,8 +122,9 @@ def _calculate_norm(self, param, p): else: return 0.0 - def step(self, loss, step, global_step): + def step(self, loss, step, global_step, plugin_runner: PluginRunner, ed_state: 'EveryDreamTrainingState'): self.scaler.scale(loss).backward() + plugin_runner.run_on_backpropagation(ed_state=ed_state) if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1): if self.clip_grad_norm is not None: @@ -142,6 +145,7 @@ def step(self, loss, step, global_step): self.log_writer.add_scalar("optimizer/te_grad_norm", te_grad_norm, global_step) for optimizer in self.optimizers: + # the scaler steps the optimizer on our behalf self.scaler.step(optimizer) self.scaler.update() @@ -537,6 +541,10 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: else: print(" ❄️ freezing position embeddings") + # make sure there's some requires_grad in some places + parameters = list(parameters) + set_requires_grad(text_encoder.parameters(), False) + set_requires_grad(parameters, True) return parameters @@ -555,3 +563,7 @@ def log_optimizer(label: str, optimizer: torch.optim.Optimizer, betas, epsilon, logging.info(f"{Fore.CYAN} * {label} optimizer: {optimizer.__class__.__name__} {param_info} *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}") + +def set_requires_grad(params, requires_grad: bool): + for param in params: + param.requires_grad = requires_grad \ No newline at end of file diff --git a/plugins/plugins.py b/plugins/plugins.py index 8153482..48cfbcd 100644 --- a/plugins/plugins.py +++ b/plugins/plugins.py @@ -18,6 +18,8 @@ def on_step_start(self, **kwargs): pass def on_step_end(self, **kwargs): pass + def on_will_step_optimizer(self, **kwargs): + pass def transform_caption(self, caption:str): return caption def transform_pil_image(self, img:Image): @@ -93,6 +95,11 @@ def run_on_step_end(self, **kwargs): with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'): plugin.on_step_end(**kwargs) + def run_on_backpropagation(self, **kwargs): + for plugin in self.plugins: + with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'): + plugin.on_backpropagation(**kwargs) + def run_on_model_load(self, **kwargs): for plugin in self.plugins: with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'): diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index a7811a4..942f2e6 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -55,9 +55,9 @@ def get_token_ids(t: str): # check for correctly configured text encoder training disable_unet_training: bool = kwargs.get('disable_unet_training') disable_textenc_training: bool = kwargs.get('disable_textenc_training') - if not disable_unet_training or disable_textenc_training: - logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES {Fore.RESET}\"disable_unet_training\": true{Fore.LIGHTRED_EX} and {Fore.RESET}\"disable_textenc_training\": false{Fore.LIGHTRED_EX} in your train.json{Fore.RESET}") - raise RuntimeError("Unet training must be disabled and text encoder training enabled") + #if not disable_unet_training or disable_textenc_training: + # logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES {Fore.RESET}\"disable_unet_training\": true{Fore.LIGHTRED_EX} and {Fore.RESET}\"disable_textenc_training\": false{Fore.LIGHTRED_EX} in your train.json{Fore.RESET}") + # raise RuntimeError("Unet training must be disabled and text encoder training enabled") num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers) optimizer_config: dict = kwargs.get('optimizer_config') if (optimizer_config is None or @@ -111,7 +111,8 @@ def get_token_ids(t: str): input_embeddings = ed_state.text_encoder.get_input_embeddings() for token_info in self.config['tokens']: vector_length = token_info.get('vector_length', 1) - initializer_text = token_info['initializer'] + # make sure it's very long + initializer_text = " ".join([token_info['initializer']] * vector_length) with torch.no_grad(): initializer_token_ids_full = ed_state.tokenizer(initializer_text, truncation=True, @@ -143,14 +144,15 @@ def get_token_ids(t: str): untrained_tokens = uniques[counts == 1] self.non_training_token_ids = untrained_tokens - - def on_step_end(self, **kwargs): + def on_backpropagation(self, **kwargs): # Zero out the gradients for all token embeddings except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings index_grads_to_zero = self.non_training_token_ids ed_state: EveryDreamTrainingState = kwargs['ed_state'] grads = ed_state.text_encoder.get_input_embeddings().weight.grad + #print(f"before zeroing: global sum {torch.sum(grads)}, training sum {torch.sum(grads[self.training_token_ids])}, individual: {grads[self.training_token_ids]}") grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) + #print(f"after zeroing: global sum {torch.sum(grads)}, training sum {torch.sum(grads[self.training_token_ids])}, individual: {grads[self.training_token_ids]}") def on_model_save(self, **kwargs): diff --git a/train.py b/train.py index 7dfae05..186e1e6 100644 --- a/train.py +++ b/train.py @@ -903,7 +903,7 @@ def sigterm_handler(signum, frame): train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size) - unet.train() if not args.disable_unet_training else unet.eval() + unet.train() if (args.gradient_checkpointing or not args.disable_unet_training) else unet.eval() text_encoder.train() if not args.disable_textenc_training else text_encoder.eval() logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}") @@ -1181,7 +1181,7 @@ def update_arg(arg: str, newValue): runt_loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5 loss = loss * runt_loss_scale - ed_optimizer.step(loss, step, global_step) + ed_optimizer.step(loss, step, global_step, plugin_runner=plugin_runner, ed_state=make_current_ed_state()) if args.ema_decay_rate != None: if ((global_step + 1) % args.ema_update_interval) == 0: From 0a3259ae5b499e0f430e77653f6bd1cf2f101ecc Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 19 Jan 2024 16:34:13 +1300 Subject: [PATCH 07/13] don't clobber trained embeddings when resuming --- plugins/textual_inversion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index 942f2e6..508bbf3 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -129,7 +129,9 @@ def get_token_ids(t: str): for i in range(vector_length): token_ids = get_token_ids(trigger_and_padding_tokens[i]) token_id = token_ids[0] - input_embeddings.weight.data[token_id] = initializer_embedding[i] + # don't clobber trained embeddings when resuming + if token_id in tokens_to_add: + input_embeddings.weight.data[token_id] = initializer_embedding[i] overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite] self.training_tokens = tokens_to_add + tokens_to_overwrite From 635e04fe05de6820141346c1b38dbdb3b7ebff08 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 19 Jan 2024 18:46:39 +1300 Subject: [PATCH 08/13] add modify_sample_prompt plugin callback --- plugins/plugins.py | 16 ++++++++++++---- plugins/textual_inversion.py | 13 +++++++++---- utils/sample_generator.py | 10 +++++++--- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/plugins/plugins.py b/plugins/plugins.py index 48cfbcd..643e162 100644 --- a/plugins/plugins.py +++ b/plugins/plugins.py @@ -20,10 +20,12 @@ def on_step_end(self, **kwargs): pass def on_will_step_optimizer(self, **kwargs): pass - def transform_caption(self, caption:str): + def transform_caption(self, caption:str) -> str: return caption - def transform_pil_image(self, img:Image): + def transform_pil_image(self, img:Image) -> Image: return img + def modify_sample_prompt(self, prompt:str) -> str: + return prompt def load_plugin(plugin_path): print(f" - Attempting to load plugin: {plugin_path}") @@ -102,12 +104,12 @@ def run_on_backpropagation(self, **kwargs): def run_on_model_load(self, **kwargs): for plugin in self.plugins: - with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'): + with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'): plugin.on_model_load(**kwargs) def run_on_model_save(self, **kwargs): for plugin in self.plugins: - with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'): + with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'): plugin.on_model_save(**kwargs) def run_transform_caption(self, caption): @@ -121,3 +123,9 @@ def run_transform_pil_image(self, img): for plugin in self.plugins: img = plugin.transform_pil_image(img) return img + + def run_modify_sample_prompt(self, prompt) -> str: + with Timer(warn_seconds=self.step_warn_seconds, label="plugin.modify_sample_prompt"): + for plugin in self.plugins: + prompt = plugin.modify_sample_prompt(prompt) + return prompt diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index 508bbf3..98b721f 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -80,8 +80,6 @@ def get_token_ids(t: str): this_padding_tokens = [f"{start_token}_pad!!!_{n+1}" for n in range(vector_length-1)] self.padding_tokens[start_token] = this_padding_tokens training_tokens.update([start_token] + this_padding_tokens) - if vector_length > 1: - print(f" - if you want accurate samples for trigger '{start_token}', replace it in sample prompts with the following text: \"{' '.join([start_token] + this_padding_tokens)}\"") tokens_to_add = [t for t in training_tokens if len(get_token_ids(t))>1] logging.info( @@ -99,7 +97,7 @@ def get_token_ids(t: str): for token in tokens_to_add: token_ids = get_token_ids(token) if len(token_ids) != 1: - raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {t}, found {len(token_ids)}") + raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {token}, found {len(token_ids)}") token_id = token_ids[0] added_token_ids.append(token_id) @@ -169,7 +167,13 @@ def on_model_save(self, **kwargs): full_embedding = embeddings.weight[all_token_ids] _save_embedding(token=token, embedding=full_embedding, save_folder=save_folder) - def transform_caption(self, caption:str): + def transform_caption(self, caption:str) -> str: + return self.expand_trigger_tokens(caption) + + def modify_sample_prompt(self, prompt: str) -> str: + return self.expand_trigger_tokens(prompt) + + def expand_trigger_tokens(self, caption: str) -> str: tokens = self.config['tokens'] # for multi-vector tokens, replace the trigger token with a padded sequence of the correct length. # eg "hat*" with vector length 3 -> "hat* hat*_pad!!!_1 hat*_pad!!!_2" @@ -180,6 +184,7 @@ def transform_caption(self, caption:str): return caption + def _save_embedding(token, embedding, save_folder): dict_to_save = {token: embedding} token_name_safe = clean_filename(token) diff --git a/utils/sample_generator.py b/utils/sample_generator.py index ad40123..4fe5371 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -19,6 +19,8 @@ from tqdm.auto import tqdm from compel import Compel +from plugins.plugins import PluginRunner + def clean_filename(filename): """ @@ -184,7 +186,7 @@ def _reload_config_json(self, path): self.sample_requests = self._make_random_caption_sample_requests() @torch.no_grad() - def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, extra_info: str = ""): + def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, plugin_runner: PluginRunner, extra_info: str = ""): """ generates samples at different cfg scales and saves them to disk """ @@ -211,8 +213,10 @@ def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool: text_encoder=pipe.text_encoder, use_penultimate_clip_layer=self.use_penultimate_clip_layer) for batch in batches: - prompts = [p.prompt for p in batch] - negative_prompts = [p.negative_prompt for p in batch] + prompts = [plugin_runner.run_modify_sample_prompt(p.prompt) + for p in batch] + negative_prompts = [plugin_runner.run_modify_sample_prompt(p.negative_prompt) + for p in batch] seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) for p in batch] # all sizes in a batch are the same From 6205f7a5741c3fad726ea48c2abdccc4cd986dd4 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 19 Jan 2024 22:40:33 +1300 Subject: [PATCH 09/13] actually pass through the plugin_runner --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 186e1e6..7886a60 100644 --- a/train.py +++ b/train.py @@ -1066,7 +1066,7 @@ def generate_samples(global_step: int, batch): vae=vae, diffusers_scheduler_config=inference_scheduler.config ).to(device) - sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info) + sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info, plugin_runner=plugin_runner) # Cleanup del inference_pipe From 1cb8ff2ed891421743cc50311bb8674cd9088e84 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 20 Jan 2024 02:00:04 +1300 Subject: [PATCH 10/13] fix example --- plugins/textual_inversion.json | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/plugins/textual_inversion.json b/plugins/textual_inversion.json index 3511db5..7b21167 100644 --- a/plugins/textual_inversion.json +++ b/plugins/textual_inversion.json @@ -8,13 +8,9 @@ }, "example": "the example below trains `hat*`, `dancing shoes` and `cane` as custom tokens, if you have training data where the captions include those tokens." }, - "example_tokens": [ + "tokens": [ { "token": "hat*", "initializer": "a man's hat", "vector_length": 8 }, { "token": "dancing shoes", "initializer": "shoes" }, { "token": "cane", "initializer": "cane" } - ], - "tokens": [ - { "token": "ay hairstyle", "initializer": "long red hair, elaborately plaited", "vector_length": 18 }, - { "token": "hz outfit", "initializer": "skins, furs, and pieces of ceramic, plastic, and alloy plating", "vector_length": 18 } ] } From 0dcf0e8c365f557ce2366ce7224c2cac30d2ae64 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 22 Jan 2024 11:02:01 +1300 Subject: [PATCH 11/13] add missing base plugin defs and fix docs --- plugins/plugins.py | 4 ++++ plugins/textual_inversion.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/plugins/plugins.py b/plugins/plugins.py index 643e162..fdf032b 100644 --- a/plugins/plugins.py +++ b/plugins/plugins.py @@ -18,6 +18,10 @@ def on_step_start(self, **kwargs): pass def on_step_end(self, **kwargs): pass + def on_model_load(self, **kwargs): + pass + def on_model_save(self, **kwargs): + pass def on_will_step_optimizer(self, **kwargs): pass def transform_caption(self, caption:str) -> str: diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index 98b721f..21ff658 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -29,10 +29,11 @@ "freeze_final_layer_norm": true, "freeze_position_embeddings": true } -In addition, you'll need a very high LR on the TE - maybe even as high as 1e-3. I recommend using the LR finder method. +In addition, you'll need a very high LR on the TE - maybe even as high as 5e-2. I recommend using the LR finder method. """ + class TextualInversionPlugin(BasePlugin): def __init__(self): From 072c2a695aba98683a07756f276f75f4dfcdb068 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 22 Jan 2024 21:01:55 +1300 Subject: [PATCH 12/13] fix missing base method for on_backpropagation --- plugins/plugins.py | 2 +- plugins/textual_inversion_loader.json | 0 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 plugins/textual_inversion_loader.json diff --git a/plugins/plugins.py b/plugins/plugins.py index fdf032b..1e8f6d0 100644 --- a/plugins/plugins.py +++ b/plugins/plugins.py @@ -22,7 +22,7 @@ def on_model_load(self, **kwargs): pass def on_model_save(self, **kwargs): pass - def on_will_step_optimizer(self, **kwargs): + def on_backpropagation(self, **kwargs): pass def transform_caption(self, caption:str) -> str: return caption diff --git a/plugins/textual_inversion_loader.json b/plugins/textual_inversion_loader.json new file mode 100644 index 0000000..e69de29 From 1a4ac2d3396737abc64227a7ec0f6adb447de824 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 24 Jan 2024 21:45:46 +1300 Subject: [PATCH 13/13] add TextualInversionLoaderPlugin --- plugins/textual_inversion.py | 190 +++++++++++++++++++------- plugins/textual_inversion_loader.json | 16 +++ train.py | 1 + 3 files changed, 154 insertions(+), 53 deletions(-) diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py index 21ff658..8b506a7 100644 --- a/plugins/textual_inversion.py +++ b/plugins/textual_inversion.py @@ -6,6 +6,8 @@ from colorama import Fore import re +from transformers import CLIPTextModel, CLIPTokenizer + from plugins.plugins import BasePlugin from train import EveryDreamTrainingState from utils.sample_generator import clean_filename @@ -33,6 +35,67 @@ """ +class TextualInversionLoaderPlugin(BasePlugin): + def __init__(self): + path = os.path.join(os.path.dirname(__file__), "textual_inversion_loader.json") + logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}") + with open(path, 'rt') as f: + self.config = json.load(f) + self.padding_tokens = {} + + def on_model_load(self, **kwargs): + ed_state: EveryDreamTrainingState = kwargs['ed_state'] + resume_ckpt: str = kwargs['resume_ckpt'] + #self.original_tokens_length = len(ed_state.tokenizer) + tokenizer = ed_state.tokenizer + + token_config = self.config['tokens'] + + embeddings = {} + for token_info in self.config['tokens']: + token = token_info["token"] + path = token_info.get("path", None) or _get_embedding_path(resume_ckpt, token) + with open(path, "rb") as f: + embedding_dict = torch.load(f) + embedding = list(embedding_dict.values())[0] + embeddings[token] = embedding + token_info["vector_length"] = embedding.shape[0] + print(f" * Textual Inversion Loader loaded embedding with vector length {token_info['vector_length']} for token '{token}' from {path}") + + training_tokens, padding_tokens, tokens_to_add, tokens_to_overwrite = ( + _setup_tokens(text_encoder=ed_state.text_encoder, tokenizer=ed_state.tokenizer, token_infos=token_config)) + self.padding_tokens = padding_tokens + + input_embeddings = ed_state.text_encoder.get_input_embeddings() + for token_info in self.config['tokens']: + token = token_info["token"] + vector_length = token_info["vector_length"] + trigger_and_padding_tokens = [token] + padding_tokens[token] + embedding = embeddings[token] + for i in range(vector_length): + token_ids = _get_token_ids(tokenizer, trigger_and_padding_tokens[i]) + token_id = token_ids[0] + input_embeddings.weight.data[token_id] = embedding[i] + + + def transform_caption(self, caption:str) -> str: + return self.expand_trigger_tokens(caption) + + def modify_sample_prompt(self, prompt: str) -> str: + return self.expand_trigger_tokens(prompt) + + def expand_trigger_tokens(self, caption: str) -> str: + tokens = self.config['tokens'] + # for multi-vector tokens, replace the trigger token with a padded sequence of the correct length. + # eg "hat*" with vector length 3 -> "hat* hat*_pad!!!_1 hat*_pad!!!_2" + for t in tokens: + trigger = t['token'] + replacement = " ".join([trigger] + self.padding_tokens[trigger]) + caption = re.sub(trigger, replacement, caption) + return caption + + + class TextualInversionPlugin(BasePlugin): @@ -50,15 +113,14 @@ def __init__(self): def on_model_load(self, **kwargs): ed_state: EveryDreamTrainingState = kwargs.get('ed_state') - def get_token_ids(t: str): - return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t)) + tokenizer = ed_state.tokenizer # check for correctly configured text encoder training disable_unet_training: bool = kwargs.get('disable_unet_training') disable_textenc_training: bool = kwargs.get('disable_textenc_training') - #if not disable_unet_training or disable_textenc_training: - # logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES {Fore.RESET}\"disable_unet_training\": true{Fore.LIGHTRED_EX} and {Fore.RESET}\"disable_textenc_training\": false{Fore.LIGHTRED_EX} in your train.json{Fore.RESET}") - # raise RuntimeError("Unet training must be disabled and text encoder training enabled") + if not disable_unet_training or disable_textenc_training: + logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES {Fore.RESET}\"disable_unet_training\": true{Fore.LIGHTRED_EX} and {Fore.RESET}\"disable_textenc_training\": false{Fore.LIGHTRED_EX} in your train.json{Fore.RESET}") + raise RuntimeError("Unet training must be disabled and text encoder training enabled") num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers) optimizer_config: dict = kwargs.get('optimizer_config') if (optimizer_config is None or @@ -72,69 +134,57 @@ def get_token_ids(t: str): logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}") raise RuntimeError("Misconfigured optimizer config") - - training_tokens = set() - for token_info in self.config['tokens']: - start_token = token_info['token'] - vector_length = token_info.get('vector_length', 1) + for token_info in self.config["tokens"]: + start_token = token_info["token"] + vector_length = token_info.get("vector_length", 1) print(f" * Textual Inversion training on '{start_token}' with vector length {vector_length}") - this_padding_tokens = [f"{start_token}_pad!!!_{n+1}" for n in range(vector_length-1)] - self.padding_tokens[start_token] = this_padding_tokens - training_tokens.update([start_token] + this_padding_tokens) - tokens_to_add = [t for t in training_tokens if len(get_token_ids(t))>1] + + training_tokens, padding_tokens, tokens_to_add, tokens_to_overwrite = ( + _setup_tokens(text_encoder=ed_state.text_encoder, tokenizer=ed_state.tokenizer, token_infos=self.config['tokens'])) + self.padding_tokens = padding_tokens + for trigger_token, padding_tokens in padding_tokens.items(): + this_padding_token_ids = [_get_token_ids(tokenizer, t)[0] for t in padding_tokens] + self.padding_token_ids[trigger_token] = this_padding_token_ids + logging.info( f" * Textual inversion training adding the following tokens: {sorted(tokens_to_add)}") - tokens_to_overwrite = [t for t in training_tokens if t not in tokens_to_add] if any(tokens_to_overwrite): logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}") - num_added_tokens = ed_state.tokenizer.add_tokens(tokens_to_add) - if num_added_tokens != len(tokens_to_add): - raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}") - ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer)) - - added_token_ids = [] - for token in tokens_to_add: - token_ids = get_token_ids(token) - if len(token_ids) != 1: - raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {token}, found {len(token_ids)}") - token_id = token_ids[0] - added_token_ids.append(token_id) - - for trigger_token, padding_tokens in self.padding_tokens.items(): - this_padding_token_ids = [get_token_ids(t)[0] for t in padding_tokens] - self.padding_token_ids[trigger_token] = this_padding_token_ids - # copy initializer embedding input_embeddings = ed_state.text_encoder.get_input_embeddings() for token_info in self.config['tokens']: vector_length = token_info.get('vector_length', 1) # make sure it's very long - initializer_text = " ".join([token_info['initializer']] * vector_length) - with torch.no_grad(): - initializer_token_ids_full = ed_state.tokenizer(initializer_text, - truncation=True, - padding="max_length", - max_length=ed_state.tokenizer.model_max_length, - ).input_ids - initializer_embedding_full = ed_state.text_encoder( - torch.tensor(initializer_token_ids_full, device=ed_state.text_encoder.device).unsqueeze(0), output_hidden_states=True - ).last_hidden_state - initializer_embedding = initializer_embedding_full[0][1:vector_length+1] + initializer_text = None if token_info.get('random_initializer', True) else " ".join([token_info['initializer']] * vector_length) + if initializer_text is None: + reference = input_embeddings.weight[0] + embedding_length = reference.shape[0] + initializer_embedding = torch.rand([vector_length, embedding_length], + dtype=reference.dtype, + device=reference.device) * 0.1 - 0.05 + else: + with torch.no_grad(): + initializer_token_ids_full = ed_state.tokenizer(initializer_text, + truncation=True, + padding="max_length", + max_length=ed_state.tokenizer.model_max_length, + ).input_ids + initializer_embedding_full = ed_state.text_encoder( + torch.tensor(initializer_token_ids_full, device=ed_state.text_encoder.device).unsqueeze(0), output_hidden_states=True + ).last_hidden_state + initializer_embedding = initializer_embedding_full[0][1:vector_length+1] trigger_token = token_info['token'] trigger_and_padding_tokens = [trigger_token] + self.padding_tokens[trigger_token] for i in range(vector_length): - token_ids = get_token_ids(trigger_and_padding_tokens[i]) + token_ids = _get_token_ids(tokenizer, trigger_and_padding_tokens[i]) token_id = token_ids[0] - # don't clobber trained embeddings when resuming - if token_id in tokens_to_add: - input_embeddings.weight.data[token_id] = initializer_embedding[i] + input_embeddings.weight.data[token_id] = initializer_embedding[i] - overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite] self.training_tokens = tokens_to_add + tokens_to_overwrite - self.training_token_ids = added_token_ids + overwriting_token_ids + self.training_token_ids = [_get_token_ids(tokenizer, t)[0] for t in self.training_tokens] # get indices of non-training tokens (ie tokens whose grads should be reset to 0 every step) total_len = len(ed_state.text_encoder.get_input_embeddings().weight) @@ -188,10 +238,44 @@ def expand_trigger_tokens(self, caption: str) -> str: def _save_embedding(token, embedding, save_folder): dict_to_save = {token: embedding} - token_name_safe = clean_filename(token) - ti_folder = os.path.join(save_folder, 'textual_inversions') - os.makedirs(ti_folder, exist_ok=True) - save_path = os.path.join(ti_folder, token_name_safe + '.bin') + save_path = _get_embedding_path(save_folder, token) + os.makedirs(os.path.dirname(save_path), exist_ok=True) logging.info(f"Saving textual inversion for '{token}' to {save_path}") torch.save(dict_to_save, save_path) +def _get_embedding_path(save_folder: str, token: str) -> str: + token_name_safe = clean_filename(token) + ti_folder = os.path.join(save_folder, 'textual_inversions') + return os.path.join(ti_folder, token_name_safe + '.bin') + +def _setup_tokens(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, token_infos: list[dict]) -> tuple[set, dict, list, list]: + training_tokens = set() + padding_tokens = {} + for token_info in token_infos: + start_token = token_info['token'] + vector_length = token_info.get('vector_length', 1) + this_padding_tokens = [f"{start_token}_pad!!!_{n + 1}" for n in range(vector_length - 1)] + padding_tokens[start_token] = this_padding_tokens + training_tokens.update([start_token] + this_padding_tokens) + + tokens_to_add = [t for t in training_tokens if len(_get_token_ids(tokenizer, t)) > 1] + tokens_to_overwrite = [t for t in training_tokens if t not in tokens_to_add] + + num_added_tokens = tokenizer.add_tokens(tokens_to_add) + if num_added_tokens != len(tokens_to_add): + raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}") + text_encoder.resize_token_embeddings(len(tokenizer)) + + added_token_ids = [] + for token in tokens_to_add: + token_ids = _get_token_ids(tokenizer, token) + if len(token_ids) != 1: + raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {token}, found {len(token_ids)}") + token_id = token_ids[0] + added_token_ids.append(token_id) + + return training_tokens, padding_tokens, tokens_to_add, tokens_to_overwrite + + +def _get_token_ids(tokenizer, t: str): + return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(t)) diff --git a/plugins/textual_inversion_loader.json b/plugins/textual_inversion_loader.json index e69de29..4773ff1 100644 --- a/plugins/textual_inversion_loader.json +++ b/plugins/textual_inversion_loader.json @@ -0,0 +1,16 @@ + +{ + "documentation": { + "tokens": { + "token": "the trigger token (word or phrase). whenever this word or phrase appears in image captions, the embedding will be trained.", + "path": "(optional) /path/to/embedding.bin. If omitted, tries to load an embedding from the resume_ckpt diffusers folder, textual_inversions/.bin where is the token" + }, + "example": "the example below tries to load textual_inversions/hat*.bin from inside the resume ckpt's textual_inversion folder and textual_inversions/dancing shoes.bin from inside the model folder, and cane from the path specified." + }, + "tokens": [ + { "token": "hat*" }, + { "token": "dancing shoes" }, + { "token": "cane", "path": "/workspace/embeddings/my_cane_embedding_ep30.bin"} + ] + +} diff --git a/train.py b/train.py index 7886a60..f2d5ee3 100644 --- a/train.py +++ b/train.py @@ -792,6 +792,7 @@ def release_memory(model_to_delete, original_device): plugin_runner.run_on_model_load( ed_state=EveryDreamTrainingState(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, optimizer=None, train_batch=None, scheduler=noise_scheduler, unet_ema=None, text_encoder_ema=None), + resume_ckpt=args.resume_ckpt, optimizer_config=optimizer_config, disable_unet_training=args.disable_unet_training, disable_textenc_training=args.disable_textenc_training