Skip to content

Commit

Permalink
fix module saving
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Feb 21, 2024
1 parent dc835ec commit d4abc80
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,6 @@ def maybe_pad(text):
for lang in data.keys():
if lang in args.include_languages:
for dataset_name in data[lang]["sentence"].keys():
if "a" in lang or "b" in lang or "cs" in lang:
continue
# do model stuff here; otherwise, head params would be overwritten every time
backbone = SubwordXLMForTokenClassification.from_pretrained(
args.model_name_or_path, config=config, ignore_mismatched_sizes=True
Expand Down Expand Up @@ -446,7 +444,7 @@ def compute_metrics(trainer):
eval_data,
model,
stride=64,
block_size=512,
block_size=512, ## TODO: change to args version x2?
batch_size=training_args.per_device_eval_batch_size,
)
metrics[f"{lang}_{dataset_name}_pr_auc"] = score
Expand Down Expand Up @@ -495,7 +493,7 @@ def compute_metrics(trainer):
with training_args.main_process_first():
if not os.path.exists(os.path.join(training_args.output_dir, dataset_name, lang)):
os.makedirs(os.path.join(training_args.output_dir, dataset_name, lang))
save_model = copy.deepcopy(model)
save_model = copy.deepcopy(model.backbone)
save_model = save_model.to("cpu")
save_model.to("cpu").save_adapter(
adapter_name="text",
Expand All @@ -509,6 +507,7 @@ def compute_metrics(trainer):

# TODO: try 1. double aux, 2. no aux at all (new head?), 3. no aux training but use_aux 4. higher/different aux prob
# TODO: try freezing head
# TODO: faster safe?!


def _mp_fn(index):
Expand Down

0 comments on commit d4abc80

Please sign in to comment.