Skip to content

Commit

Permalink
Merge pull request #372 from sillsdev/lora_weight_tying_update
Browse files Browse the repository at this point in the history
Lora weight tying update
  • Loading branch information
isaac091 authored Apr 26, 2024
2 parents 43d83c9 + 9504d8a commit 1fbe518
Showing 1 changed file with 95 additions and 28 deletions.
123 changes: 95 additions & 28 deletions silnlp/nmt/hugging_face_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,30 +755,7 @@ def train(self) -> None:
)

if self._config.train["use_lora"]:
lora_config = self._config.train["lora_config"]

if isinstance(lora_config.get("target_modules", []), str):
lora_config["target_modules"] = lora_config["target_modules"].split(",")
if isinstance(lora_config.get("modules_to_save", []), str):
lora_config["modules_to_save"] = lora_config["modules_to_save"].split(",")

peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
r=lora_config.get("r", 4),
lora_alpha=lora_config.get("alpha", 32),
lora_dropout=lora_config.get("dropout", 0.1),
target_modules=lora_config.get(
"target_modules", LORA_DEFAULT_CONFIGS[self._config.model_prefix]["target_modules"]
),
modules_to_save=lora_config.get(
"modules_to_save", LORA_DEFAULT_CONFIGS[self._config.model_prefix]["modules_to_save"]
),
)
model = get_peft_model(model, peft_config)

# Necessary to allow gradients to propogate through frozen layers when using PEFT + gradient checkpointing + Trainer
if self._config.train["gradient_checkpointing"]:
model.enable_input_require_grads()
model = self._convert_to_lora_model(model)

# Change specific variables based on the type of model
model, tokenizer = self._configure_model(model, tokenizer)
Expand Down Expand Up @@ -1111,6 +1088,89 @@ def _get_dictionary(self) -> Dict[VerseRef, Set[str]]:

return self._dictionary

# Tie embedding weights to "shared" module weights and untie the embeddings modules if necessary
def _create_tied_embedding_weights(self, model: PreTrainedModel) -> PreTrainedModel:
encoder_embeddings = torch.nn.Embedding(
model.config.vocab_size, model.config.d_model, model.config.pad_token_id
)
decoder_embeddings = torch.nn.Embedding(
model.config.vocab_size, model.config.d_model, model.config.pad_token_id
)

if self._config.model_prefix == "facebook/nllb-200":
model.base_model.encoder.embed_tokens = encoder_embeddings
model.base_model.decoder.embed_tokens = decoder_embeddings
model.tie_weights()
elif self._config.model_prefix == "google/madlad400":
model.encoder.embed_tokens = encoder_embeddings
model.decoder.embed_tokens = decoder_embeddings
model._tie_or_clone_weights(model.encoder.embed_tokens, model.shared)
model._tie_or_clone_weights(model.decoder.embed_tokens, model.shared)

return model

def _convert_to_lora_model(self, model: PreTrainedModel) -> PreTrainedModel:
# Only tie embedding weights together rather than the entire modules so that peft recognizes each one
model = self._create_tied_embedding_weights(model)

lora_config = self._config.train["lora_config"]
target_modules = lora_config.get(
"target_modules", LORA_DEFAULT_CONFIGS[self._config.model_prefix]["target_modules"]
)
modules_to_save = lora_config.get(
"modules_to_save", LORA_DEFAULT_CONFIGS[self._config.model_prefix]["modules_to_save"]
)
if isinstance(target_modules, str):
target_modules = target_modules.split(",")
if isinstance(modules_to_save, str):
modules_to_save = modules_to_save.split(",")
peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
r=lora_config.get("r", 4),
lora_alpha=lora_config.get("alpha", 32),
lora_dropout=lora_config.get("dropout", 0.1),
target_modules=target_modules,
modules_to_save=modules_to_save,
)
model = get_peft_model(model, peft_config)

if self._config.model_prefix == "facebook/nllb-200" and (
("embed_tokens" in modules_to_save and "lm_head" not in modules_to_save)
or ("lm_head" in modules_to_save and "embed_tokens" not in modules_to_save)
):
LOGGER.warning(
"NLLB is typically trained with the embeddings tied. Add both embed_tokens and lm_head to modules_to_save to do this while using LoRA."
)

# Tie LoRA copies of the embedding weights together
if "embed_tokens" in modules_to_save:
if self._config.model_prefix == "facebook/nllb-200":
embedding = model.base_model.model.model.encoder.embed_tokens.modules_to_save.default.weight
model.base_model.model.model.decoder.embed_tokens.modules_to_save.default.weight = embedding
if "lm_head" in modules_to_save:
model.base_model.model.lm_head.modules_to_save.default.weight = embedding
elif self._config.model_prefix == "google/madlad400":
embedding = model.base_model.model.encoder.embed_tokens.modules_to_save.default.weight
model.base_model.model.decoder.embed_tokens.modules_to_save.default.weight = embedding
elif "embed_tokens" in target_modules:
if self._config.model_prefix == "facebook/nllb-200":
# TODO: figure out how to tie embedding weights and lm_head weights together
embedding_A = model.base_model.model.model.encoder.embed_tokens.lora_embedding_A.default
embedding_B = model.base_model.model.model.encoder.embed_tokens.lora_embedding_B.default
model.base_model.model.model.decoder.embed_tokens.lora_embedding_A.default = embedding_A
model.base_model.model.model.decoder.embed_tokens.lora_embedding_B.default = embedding_B
elif self._config.model_prefix == "google/madlad400":
embedding_A = model.base_model.model.encoder.embed_tokens.lora_embedding_A.default
embedding_B = model.base_model.model.encoder.embed_tokens.lora_embedding_B.default
model.base_model.model.decoder.embed_tokens.lora_embedding_A.default = embedding_A
model.base_model.model.decoder.embed_tokens.lora_embedding_B.default = embedding_B

# Necessary to allow gradients to propogate through frozen layers when using PEFT + gradient checkpointing + Trainer
if self._config.train["gradient_checkpointing"]:
model.enable_input_require_grads()

return model

def _translate_sentences(
self,
tokenizer: PreTrainedTokenizer,
Expand Down Expand Up @@ -1169,6 +1229,7 @@ def _create_inference_model(
base_model.resize_token_embeddings(
len(tokenizer), pad_to_multiple_of=8 if self._mixed_precision else None
)
base_model = self._create_tied_embedding_weights(base_model)
model = PeftModel.from_pretrained(base_model, model_name)
else:
model: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained(
Expand All @@ -1178,7 +1239,7 @@ def _create_inference_model(
model = model.to_bettertransformer()
if model_name == self._config.model and len(tokenizer) != model.get_input_embeddings().weight.size(dim=0):
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8 if self._mixed_precision else None)
if self._config.model.startswith("google/madlad400") or model_name == self._config.model:
if self._config.model_prefix == "google/madlad400" or model_name == self._config.model:
model, tokenizer = self._configure_model(model, tokenizer)

return model
Expand All @@ -1197,7 +1258,7 @@ def _configure_model(
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(self._config.val_trg_lang)

if self._config.model.startswith("google/madlad400"):
if self._config.model_prefix == "google/madlad400":
model.config.decoder_start_token_id = tokenizer.pad_token_id
model.generation_config.decoder_start_token_id = tokenizer.pad_token_id
model.config.max_length = 256
Expand Down Expand Up @@ -1412,12 +1473,18 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
if self._better_transformer:
self.model = self.model.reverse_bettertransformer()
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
save_embedding_layers=True,
)
self.model = self.model.to_bettertransformer()
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
save_embedding_layers=True,
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
Expand Down

0 comments on commit 1fbe518

Please sign in to comment.