Skip to content

Commit

Permalink
fix db + target modules
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Nov 3, 2023
1 parent 19c9b82 commit 5c9e1f9
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 92 deletions.
4 changes: 3 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ output2/
test/
test.py
.DS_Store
.vscode/
.vscode/
op*
op_*
2 changes: 2 additions & 0 deletions Dockerfile.api
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FROM huggingface/autotrain-advanced:latest
CMD autotrain setup && autotrain api --port 7860 --host 0.0.0.0
4 changes: 4 additions & 0 deletions src/autotrain/trainers/clm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
DEFAULT_UNK_TOKEN = "</s>"
TARGET_MODULES = {
"Salesforce/codegen25-7b-multi": "q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
"HuggingFaceH4/zephyr-7b-beta": "q_proj,v_proj",
"HuggingFaceH4/zephyr-7b-alpha": "q_proj,v_proj",
"mistralai/Mistral-7B-Instruct-v0.1": "q_proj,v_proj",
"mistralai/Mistral-7B-v0.1": "q_proj,v_proj",
}

MODEL_CARD = """
Expand Down
191 changes: 116 additions & 75 deletions src/autotrain/trainers/dreambooth/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import diffusers
import torch
import torch.nn.functional as F
import transformers
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed
Expand All @@ -13,11 +12,9 @@
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from huggingface_hub import HfApi, snapshot_download

from autotrain import logger
Expand Down Expand Up @@ -57,7 +54,8 @@ def train(config):
config.image_path = "/tmp/model/concept1/"

accelerator_project_config = ProjectConfiguration(
project_dir=config.project_name, logging_dir=os.path.join(config.project_name, "logs")
project_dir=config.project_name,
logging_dir=os.path.join(config.project_name, "logs"),
)

if config.fp16:
Expand Down Expand Up @@ -110,31 +108,68 @@ def train(config):
utils.enable_xformers(unet, config)
utils.enable_gradient_checkpointing(unet, text_encoders, config)

unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
lora_attn_processor_class = LoRAAttnAddedKVProcessor
else:
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)

# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features,
out_features=attn_module.to_q.out_features,
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features,
out_features=attn_module.to_k.out_features,
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features,
out_features=attn_module.to_v.out_features,
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
)
)

module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())

unet.set_attn_processor(unet_lora_attn_procs)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())

if not config.xl:
if isinstance(
attn_processor,
(
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
),
):
attn_module.add_k_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features,
)
)
attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features,
)
)
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())

text_lora_parameters = []
if config.train_text_encoder:
Expand All @@ -160,30 +195,38 @@ def save_model_hook(models, weights, output_dir):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()

if len(text_encoder_lora_layers_to_save) == 0:
LoraLoaderMixin.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=None,
safe_serialization=True,
)
elif len(text_encoder_lora_layers_to_save) == 1:
LoraLoaderMixin.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save[0],
safe_serialization=True,
)
elif len(text_encoder_lora_layers_to_save) == 2:
StableDiffusionXLPipeline.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save[0],
text_encoder_2_lora_layers=text_encoder_lora_layers_to_save[1],
safe_serialization=True,
)
if config.xl:
if len(text_encoder_lora_layers_to_save) > 1:
StableDiffusionXLPipeline.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save[0],
text_encoder_2_lora_layers=text_encoder_lora_layers_to_save[1],
safe_serialization=True,
)
else:
StableDiffusionXLPipeline.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=None,
text_encoder_2_lora_layers=None,
safe_serialization=True,
)
else:
raise ValueError("unexpected number of text encoders")
if len(text_encoder_lora_layers_to_save) > 0:
LoraLoaderMixin.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save[0],
safe_serialization=True,
)
else:
LoraLoaderMixin.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=None,
safe_serialization=True,
)

def load_model_hook(models, input_dir):
unet_ = None
Expand All @@ -198,34 +241,29 @@ def load_model_hook(models, input_dir):
if isinstance(model, type(accelerator.unwrap_model(_text_encoder))):
text_encoders_.append(model)

lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alphas, unet=unet_)

if len(text_encoders_) == 0:
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict,
network_alpha=network_alpha,
text_encoder=None,
)
elif len(text_encoders_) == 1:
if config.xl:
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict,
network_alpha=network_alpha,
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=text_encoders_[0],
)
elif len(text_encoders_) == 2:

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict,
network_alpha=network_alpha,
text_encoder=text_encoders_[0],
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=text_encoders_[1],
)
else:
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict,
network_alpha=network_alpha,
text_encoder=text_encoders_[1],
network_alphas=network_alphas,
text_encoder=text_encoders_[0],
)
else:
raise ValueError("unexpected number of text encoders")

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand Down Expand Up @@ -276,12 +314,12 @@ def load_model_hook(models, input_dir):
# first check if file exists
if os.path.exists(f"{config.project_name}/training_params.json"):
training_params = json.load(open(f"{config.project_name}/training_params.json"))
training_params.pop("token")
json.dump(training_params, open(f"{config.project_name}/training_params.json", "w"))

# remove config.image_path directory if it exists
if os.path.exists(config.image_path):
os.system(f"rm -rf {config.image_path}")
if "token" in training_params:
training_params.pop("token")
json.dump(
training_params,
open(f"{config.project_name}/training_params.json", "w"),
)

# add config.prompt as a text file in the output directory
with open(f"{config.project_name}/prompt.txt", "w") as f:
Expand All @@ -291,6 +329,9 @@ def load_model_hook(models, input_dir):
trainer.push_to_hub()

if "SPACE_ID" in os.environ:
# remove config.image_path directory if it exists
if os.path.exists(config.image_path):
os.system(f"rm -rf {config.image_path}")
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
Expand Down
Loading

0 comments on commit 5c9e1f9

Please sign in to comment.