Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat textual inversion plugin v2 #248

Merged
25 changes: 22 additions & 3 deletions optimizer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from colorama import Fore, Style
import pprint

from plugins.plugins import PluginRunner

BETAS_DEFAULT = [0.9, 0.999]
EPSILON_DEFAULT = 1e-8
WEIGHT_DECAY_DEFAULT = 0.01
Expand Down Expand Up @@ -120,8 +122,9 @@ def _calculate_norm(self, param, p):
else:
return 0.0

def step(self, loss, step, global_step):
def step(self, loss, step, global_step, plugin_runner: PluginRunner, ed_state: 'EveryDreamTrainingState'):
self.scaler.scale(loss).backward()
plugin_runner.run_on_backpropagation(ed_state=ed_state)

if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1):
if self.clip_grad_norm is not None:
Expand All @@ -142,6 +145,7 @@ def step(self, loss, step, global_step):
self.log_writer.add_scalar("optimizer/te_grad_norm", te_grad_norm, global_step)

for optimizer in self.optimizers:
# the scaler steps the optimizer on our behalf
self.scaler.step(optimizer)

self.scaler.update()
Expand Down Expand Up @@ -480,6 +484,7 @@ def _create_optimizer(self, label, args, local_optimizer_config, parameters):
def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]:
num_layers = len(text_encoder.text_model.encoder.layers)
unfreeze_embeddings = True
unfreeze_position_embeddings = True
unfreeze_last_n_layers = None
unfreeze_final_layer_norm = True
if "freeze_front_n_layers" in self.te_freeze_config:
Expand All @@ -499,7 +504,6 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]:
unfreeze_last_n_layers = num_layers
else:
# something specified:
assert(unfreeze_last_n_layers > 0)
if unfreeze_last_n_layers < num_layers:
# if we're unfreezing layers then by default we ought to freeze the embeddings
unfreeze_embeddings = False
Expand All @@ -508,11 +512,13 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]:
unfreeze_embeddings = not self.te_freeze_config["freeze_embeddings"]
if "freeze_final_layer_norm" in self.te_freeze_config:
unfreeze_final_layer_norm = not self.te_freeze_config["freeze_final_layer_norm"]
if "freeze_position_embeddings" in self.te_freeze_config:
unfreeze_position_embeddings = not self.te_freeze_config["freeze_position_embeddings"]

parameters = itertools.chain([])

if unfreeze_embeddings:
parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.parameters())
parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.token_embedding.parameters())
else:
print(" ❄️ freezing embeddings")

Expand All @@ -530,6 +536,15 @@ def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]:
else:
print(" ❄️ freezing final layer norm")

if unfreeze_position_embeddings:
parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.position_embeddings.parameters)
else:
print(" ❄️ freezing position embeddings")

# make sure there's some requires_grad in some places
parameters = list(parameters)
set_requires_grad(text_encoder.parameters(), False)
set_requires_grad(parameters, True)
return parameters


Expand All @@ -548,3 +563,7 @@ def log_optimizer(label: str, optimizer: torch.optim.Optimizer, betas, epsilon,
logging.info(f"{Fore.CYAN} * {label} optimizer: {optimizer.__class__.__name__} {param_info} *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")


def set_requires_grad(params, requires_grad: bool):
for param in params:
param.requires_grad = requires_grad
35 changes: 32 additions & 3 deletions plugins/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@ def on_step_start(self, **kwargs):
pass
def on_step_end(self, **kwargs):
pass
def transform_caption(self, caption:str):
def on_model_load(self, **kwargs):
pass
def on_model_save(self, **kwargs):
pass
def on_backpropagation(self, **kwargs):
pass
def transform_caption(self, caption:str) -> str:
return caption
def transform_pil_image(self, img:Image):
def transform_pil_image(self, img:Image) -> Image:
return img
def modify_sample_prompt(self, prompt:str) -> str:
return prompt

def load_plugin(plugin_path):
print(f" - Attempting to load plugin: {plugin_path}")
Expand Down Expand Up @@ -92,7 +100,22 @@ def run_on_step_end(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_step_end(**kwargs)


def run_on_backpropagation(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_backpropagation(**kwargs)

def run_on_model_load(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_model_load(**kwargs)

def run_on_model_save(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_model_save(**kwargs)

def run_transform_caption(self, caption):
with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_caption"):
for plugin in self.plugins:
Expand All @@ -104,3 +127,9 @@ def run_transform_pil_image(self, img):
for plugin in self.plugins:
img = plugin.transform_pil_image(img)
return img

def run_modify_sample_prompt(self, prompt) -> str:
with Timer(warn_seconds=self.step_warn_seconds, label="plugin.modify_sample_prompt"):
for plugin in self.plugins:
prompt = plugin.modify_sample_prompt(prompt)
return prompt
16 changes: 16 additions & 0 deletions plugins/textual_inversion.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

{
"documentation": {
"tokens": {
"token": "the token (word) to train, given exactly as it appears in tokens",
"initializer": "starting point for the embedding. make it close to the intended meaning to kick-start training. should be shorter (in tokens) than the vector_length.",
"vector_length": "length of the embedding (default 1). use more if you have more images and/or if what you want to train is complex."
},
"example": "the example below trains `hat*`, `dancing shoes` and `cane` as custom tokens, if you have training data where the captions include those tokens."
},
"tokens": [
{ "token": "hat*", "initializer": "a man's hat", "vector_length": 8 },
{ "token": "dancing shoes", "initializer": "shoes" },
{ "token": "cane", "initializer": "cane" }
]
}
Loading
Loading