Skip to content

Commit

Permalink
add some stuff for adp
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 2, 2024
1 parent 934381d commit 6459683
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 12 deletions.
8 changes: 8 additions & 0 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
with_head=True,
load_as="text",
)
if hasattr(model.model.config, "unfreeze_ln"):
if model.model.config.unfreeze_ln:
ln_dict = torch.load(
args.adapter_path + "/" + dataset_name + "/" + lang_code + "/ln_dict.pth"
)
for n, p in model.backbone.named_parameters():
if "LayerNorm" in n:
p.data = ln_dict[n].data
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
continue
Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def extract(
np.zeros(
(
length,
model.config.num_labels,
model.classifier.out_features
),
dtype=np.float16,
)
Expand Down
37 changes: 26 additions & 11 deletions wtpsplit/train/train_adapter_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class Args:
use_subwords: bool = False
freeze_classifier: bool = False
clf_from_scratch: bool = False
unfreeze_ln: bool = False
do_process: bool = False
n_train_steps: List[int] = field(default_factory=lambda: [1000, 10000, 100000])
meta_clf: bool = False
Expand All @@ -137,11 +138,11 @@ def main(
logger.warning(f"{tpu_core_idx}: LANG GROUP {lang_groups}")

num_labels = Constants.AUX_OFFSET + (
(1 + len(Constants.PUNCTUATION_CHARS)) if label_args.use_auxiliary or args.do_auxiliary_training else 0
(1 + len(Constants.PUNCTUATION_CHARS)) if (label_args.use_auxiliary or args.do_auxiliary_training or args.meta_clf) else 0
)
config = SubwordXLMConfig.from_pretrained(
args.model_name_or_path,
num_labels=num_labels if not args.meta_clf else 1,
num_labels=num_labels,
)

# 1 wandb run for all language-dataset combinations
Expand All @@ -165,7 +166,7 @@ def main(
for i, ((lang, dataset_name), train_step) in tqdm(enumerate(zip(lang_groups, train_steps)), total=len(lang_groups)):
# do model stuff here; otherwise, head params would be overwritten every time
backbone = SubwordXLMForTokenClassification.from_pretrained(
args.model_name_or_path, config=config, ignore_mismatched_sizes=True
args.model_name_or_path, config=copy.deepcopy(config), ignore_mismatched_sizes=True
)
logger.warning(f"{tpu_core_idx}: Loaded backbone {args.model_name_or_path}.")
backbone.config.base_model = args.base_model
Expand Down Expand Up @@ -243,12 +244,19 @@ def compute_metrics(trainer):
p.requires_grad = False
if args.clf_from_scratch:
model.backbone.classifier = torch.nn.Linear(model.backbone.config.hidden_size, num_labels)

if args.unfreeze_ln:
for n, p in model.backbone.named_parameters():
if "LayerNorm" in n:
p.requires_grad = True

if args.meta_clf:
clf = model.backbone.classifier
model.backbone.classifier = torch.nn.Sequential(
clf, # original classifier - if frozen above, also frozen here
torch.nn.Linear(clf.out_features, num_labels)
torch.nn.Linear(clf.out_features, 1)
)
model.backbone.config.num_labels = 1

trainer = AdapterTrainer(
model,
Expand Down Expand Up @@ -277,19 +285,24 @@ def compute_metrics(trainer):
if not os.path.exists(os.path.join(training_args.output_dir, dataset_name, lang)):
os.makedirs(os.path.join(training_args.output_dir, dataset_name, lang))
save_model = copy.deepcopy(model.backbone)
# TODO: check if concurrent saving is fine (if ds duplicated for TPUs)
save_model = save_model.to("cpu")
save_model.save_adapter(
adapter_name="text",
save_directory=os.path.join(training_args.output_dir, dataset_name, lang),
with_head=True,
)
# also save LNs
if args.unfreeze_ln:
# no way within adapters to do this, need to do it manually
ln_dict = {n: p for n, p in save_model.named_parameters() if "LayerNorm" in n}
torch.save(ln_dict, os.path.join(training_args.output_dir, dataset_name, lang, "ln_dict.pth"))
logger.warning(f"{tpu_core_idx}: DONE TRAIN {lang} {dataset_name}.")

if callbacks:
wandb.log({"train/batch_progress": (i + 1) / len(lang_groups)})

xm.rendezvous("end_training")
xm.mark_step()
xm.rendezvous("all_done")
wandb.finish()

Expand Down Expand Up @@ -671,6 +684,13 @@ def merge_dict_into_sublists(d):
current_lang_groups,
train_steps,
)

xm.rendezvous("all training done")
if index == 0:
# eval here within 1 go
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.5"
)


if __name__ == "__main__":
Expand All @@ -681,9 +701,4 @@ def merge_dict_into_sublists(d):
args=(),
nprocs=8,
)

# TODO: check grouping for TPUs: 1k, 10k, ...; what is most sensible?

# TODO: see if shuffle x1, shuffle x num_epochs, or no shuffle is best
# TODO: double-check effect of non_punctuation_sample_ratio
# TODO: try: freeze head, add clf on top (or do not freeze head, diff LR, etc.)

0 comments on commit 6459683

Please sign in to comment.