Skip to content

Commit

Permalink
Generalize default LoRA config
Browse files Browse the repository at this point in the history
  • Loading branch information
isaac091 committed Jan 31, 2024
1 parent 9ce0819 commit 7843b67
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions silnlp/nmt/hugging_face_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,15 @@ def prepare_decoder_input_ids_from_labels(self: M2M100ForConditionalGeneration,
},
}

LORA_TARGET_MODULES = {
"nllb": ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"],
"madlad": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"]
}

LORA_MODULES_TO_SAVE = {
"nllb": ["embed_tokens"],
"madlad": ["embed_tokens"]
LORA_DEFAULT_CONFIGS = {
"facebook/nllb-200": {
"target_modules": ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"],
"modules_to_save": ["embed_tokens"]
},
"google/madlad400": {
"target_modules": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"modules_to_save": ["embed_tokens"]
}
}

def get_best_checkpoint(model_dir: Path) -> Path:
Expand Down Expand Up @@ -653,15 +654,13 @@ def train(self) -> None:
lora_config = self._config.train["lora_config"]

if "target_modules" not in lora_config:
if "nllb" in self._config.model:
lora_config["target_modules"] = LORA_TARGET_MODULES["nllb"]
elif "madlad" in self._config.model:
lora_config["target_modules"] = LORA_TARGET_MODULES["madlad"]
for model_prefix in LORA_DEFAULT_CONFIGS:
if self._config.model.startswith(model_prefix):
lora_config["target_modules"] = LORA_DEFAULT_CONFIGS[model_prefix]["target_modules"]
if "modules_to_save" not in lora_config:
if "nllb" in self._config.model:
lora_config["modules_to_save"] = LORA_MODULES_TO_SAVE["nllb"]
elif "madlad" in self._config.model:
lora_config["modules_to_save"] = LORA_MODULES_TO_SAVE["madlad"]
for model_prefix in LORA_DEFAULT_CONFIGS:
if self._config.model.startswith(model_prefix):
lora_config["modules_to_save"] = LORA_DEFAULT_CONFIGS[model_prefix]["modules_to_save"]

if isinstance(lora_config["target_modules"], str):
lora_config["target_modules"] = lora_config["target_modules"].split(",")
Expand Down

0 comments on commit 7843b67

Please sign in to comment.