diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 81c1883d..c0624d04 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -755,30 +755,7 @@ def train(self) -> None: ) if self._config.train["use_lora"]: - lora_config = self._config.train["lora_config"] - - if isinstance(lora_config.get("target_modules", []), str): - lora_config["target_modules"] = lora_config["target_modules"].split(",") - if isinstance(lora_config.get("modules_to_save", []), str): - lora_config["modules_to_save"] = lora_config["modules_to_save"].split(",") - - peft_config = LoraConfig( - task_type=TaskType.SEQ_2_SEQ_LM, - r=lora_config.get("r", 4), - lora_alpha=lora_config.get("alpha", 32), - lora_dropout=lora_config.get("dropout", 0.1), - target_modules=lora_config.get( - "target_modules", LORA_DEFAULT_CONFIGS[self._config.model_prefix]["target_modules"] - ), - modules_to_save=lora_config.get( - "modules_to_save", LORA_DEFAULT_CONFIGS[self._config.model_prefix]["modules_to_save"] - ), - ) - model = get_peft_model(model, peft_config) - - # Necessary to allow gradients to propogate through frozen layers when using PEFT + gradient checkpointing + Trainer - if self._config.train["gradient_checkpointing"]: - model.enable_input_require_grads() + model = self._convert_to_lora_model(model) # Change specific variables based on the type of model model, tokenizer = self._configure_model(model, tokenizer) @@ -1111,6 +1088,89 @@ def _get_dictionary(self) -> Dict[VerseRef, Set[str]]: return self._dictionary + # Tie embedding weights to "shared" module weights and untie the embeddings modules if necessary + def _create_tied_embedding_weights(self, model: PreTrainedModel) -> PreTrainedModel: + encoder_embeddings = torch.nn.Embedding( + model.config.vocab_size, model.config.d_model, model.config.pad_token_id + ) + decoder_embeddings = torch.nn.Embedding( + model.config.vocab_size, model.config.d_model, model.config.pad_token_id + ) + + if self._config.model_prefix == "facebook/nllb-200": + model.base_model.encoder.embed_tokens = encoder_embeddings + model.base_model.decoder.embed_tokens = decoder_embeddings + model.tie_weights() + elif self._config.model_prefix == "google/madlad400": + model.encoder.embed_tokens = encoder_embeddings + model.decoder.embed_tokens = decoder_embeddings + model._tie_or_clone_weights(model.encoder.embed_tokens, model.shared) + model._tie_or_clone_weights(model.decoder.embed_tokens, model.shared) + + return model + + def _convert_to_lora_model(self, model: PreTrainedModel) -> PreTrainedModel: + # Only tie embedding weights together rather than the entire modules so that peft recognizes each one + model = self._create_tied_embedding_weights(model) + + lora_config = self._config.train["lora_config"] + target_modules = lora_config.get( + "target_modules", LORA_DEFAULT_CONFIGS[self._config.model_prefix]["target_modules"] + ) + modules_to_save = lora_config.get( + "modules_to_save", LORA_DEFAULT_CONFIGS[self._config.model_prefix]["modules_to_save"] + ) + if isinstance(target_modules, str): + target_modules = target_modules.split(",") + if isinstance(modules_to_save, str): + modules_to_save = modules_to_save.split(",") + peft_config = LoraConfig( + task_type=TaskType.SEQ_2_SEQ_LM, + r=lora_config.get("r", 4), + lora_alpha=lora_config.get("alpha", 32), + lora_dropout=lora_config.get("dropout", 0.1), + target_modules=target_modules, + modules_to_save=modules_to_save, + ) + model = get_peft_model(model, peft_config) + + if self._config.model_prefix == "facebook/nllb-200" and ( + ("embed_tokens" in modules_to_save and "lm_head" not in modules_to_save) + or ("lm_head" in modules_to_save and "embed_tokens" not in modules_to_save) + ): + LOGGER.warning( + "NLLB is typically trained with the embeddings tied. Add both embed_tokens and lm_head to modules_to_save to do this while using LoRA." + ) + + # Tie LoRA copies of the embedding weights together + if "embed_tokens" in modules_to_save: + if self._config.model_prefix == "facebook/nllb-200": + embedding = model.base_model.model.model.encoder.embed_tokens.modules_to_save.default.weight + model.base_model.model.model.decoder.embed_tokens.modules_to_save.default.weight = embedding + if "lm_head" in modules_to_save: + model.base_model.model.lm_head.modules_to_save.default.weight = embedding + elif self._config.model_prefix == "google/madlad400": + embedding = model.base_model.model.encoder.embed_tokens.modules_to_save.default.weight + model.base_model.model.decoder.embed_tokens.modules_to_save.default.weight = embedding + elif "embed_tokens" in target_modules: + if self._config.model_prefix == "facebook/nllb-200": + # TODO: figure out how to tie embedding weights and lm_head weights together + embedding_A = model.base_model.model.model.encoder.embed_tokens.lora_embedding_A.default + embedding_B = model.base_model.model.model.encoder.embed_tokens.lora_embedding_B.default + model.base_model.model.model.decoder.embed_tokens.lora_embedding_A.default = embedding_A + model.base_model.model.model.decoder.embed_tokens.lora_embedding_B.default = embedding_B + elif self._config.model_prefix == "google/madlad400": + embedding_A = model.base_model.model.encoder.embed_tokens.lora_embedding_A.default + embedding_B = model.base_model.model.encoder.embed_tokens.lora_embedding_B.default + model.base_model.model.decoder.embed_tokens.lora_embedding_A.default = embedding_A + model.base_model.model.decoder.embed_tokens.lora_embedding_B.default = embedding_B + + # Necessary to allow gradients to propogate through frozen layers when using PEFT + gradient checkpointing + Trainer + if self._config.train["gradient_checkpointing"]: + model.enable_input_require_grads() + + return model + def _translate_sentences( self, tokenizer: PreTrainedTokenizer, @@ -1169,6 +1229,7 @@ def _create_inference_model( base_model.resize_token_embeddings( len(tokenizer), pad_to_multiple_of=8 if self._mixed_precision else None ) + base_model = self._create_tied_embedding_weights(base_model) model = PeftModel.from_pretrained(base_model, model_name) else: model: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained( @@ -1178,7 +1239,7 @@ def _create_inference_model( model = model.to_bettertransformer() if model_name == self._config.model and len(tokenizer) != model.get_input_embeddings().weight.size(dim=0): model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8 if self._mixed_precision else None) - if self._config.model.startswith("google/madlad400") or model_name == self._config.model: + if self._config.model_prefix == "google/madlad400" or model_name == self._config.model: model, tokenizer = self._configure_model(model, tokenizer) return model @@ -1197,7 +1258,7 @@ def _configure_model( else: model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(self._config.val_trg_lang) - if self._config.model.startswith("google/madlad400"): + if self._config.model_prefix == "google/madlad400": model.config.decoder_start_token_id = tokenizer.pad_token_id model.generation_config.decoder_start_token_id = tokenizer.pad_token_id model.config.max_length = 256 @@ -1412,12 +1473,18 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): if self._better_transformer: self.model = self.model.reverse_bettertransformer() self.model.save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + output_dir, + state_dict=state_dict, + safe_serialization=self.args.save_safetensors, + save_embedding_layers=True, ) self.model = self.model.to_bettertransformer() else: self.model.save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + output_dir, + state_dict=state_dict, + safe_serialization=self.args.save_safetensors, + save_embedding_layers=True, ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir)