Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Feb 26, 2024
1 parent c346a88 commit 0fee97d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
16 changes: 4 additions & 12 deletions wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def main():
setup_logging(training_args)
set_seed(training_args.seed)

num_labels = Constants.AUX_OFFSET + ((1 + len(Constants.PUNCTUATION_CHARS)) if args.do_auxiliary_training else 0)
num_labels = Constants.AUX_OFFSET + (
(1 + len(Constants.PUNCTUATION_CHARS)) if args.do_auxiliary_training or label_args.use_auxiliary else 0
)
config = SubwordXLMConfig.from_pretrained(
args.model_name_or_path,
num_labels=num_labels,
Expand Down Expand Up @@ -349,7 +351,7 @@ def maybe_pad(text):
wandb.config.update(args)
wandb.config.update(training_args)
wandb.config.update(label_args)

for file in glob(os.path.join(os.path.dirname(__file__), "*.py")):
wandb.save(os.path.abspath(file), policy="now")

Expand Down Expand Up @@ -488,7 +490,6 @@ def compute_metrics(trainer):
),
logging_suffix=f"{lang}_{dataset_name}",
)

trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
with training_args.main_process_first():
if not os.path.exists(os.path.join(training_args.output_dir, dataset_name, lang)):
Expand All @@ -501,15 +502,6 @@ def compute_metrics(trainer):
with_head=True,
)


# TODO: at end, do full eval?
# TODO: multi-TPU training - split up?

# TODO: try 1. double aux, 2. no aux at all (new head?), 3. no aux training but use_aux 4. higher/different aux prob
# TODO: try freezing head
# TODO: faster safe?!


def _mp_fn(index):
# For xla_spawn (TPUs)
main()
Expand Down
24 changes: 11 additions & 13 deletions wtpsplit/train/train_adapter_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,22 @@
import wandb
from adapters import AdapterArguments
from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification
from wtpsplit.train.adapter_utils import (
ParallelTPUAdapterTrainingArguments as TrainingArguments,
)
from wtpsplit.train.adapter_utils import (
ParallelTPUWandbCallback as WandbCallback,
)
from wtpsplit.train.adaptertrainer import AdapterTrainer
from wtpsplit.train.evaluate import evaluate_sentence
from wtpsplit.train.train import collate_fn
from wtpsplit.train.utils import Model
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict
from wtpsplit.train.adapter_utils import (
ParallelTPUAdapterTrainingArguments as TrainingArguments,
ParallelTPUWandbCallback as WandbCallback,
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def setup_logging(training_args, job_id=None) -> None:
# TODO: log saving based on folders
# Generate a unique logger name based on the job_id or process identifier
unique_logger_name = f"{__name__}.{job_id}" if job_id is not None else __name__
logger = logging.getLogger(unique_logger_name)
Expand Down Expand Up @@ -241,7 +242,7 @@ def compute_metrics(trainer):
p.requires_grad = False
if args.clf_from_scratch:
model.backbone.classifier = torch.nn.Linear(model.backbone.config.hidden_size, num_labels)

trainer = AdapterTrainer(
model,
training_args,
Expand All @@ -257,9 +258,9 @@ def compute_metrics(trainer):
),
logging_prefix=f"{dataset_name}/{lang}/",
)
if callbacks:
if callbacks:
trainer.add_callback(callbacks)

logger.warning(f"{tpu_core_idx}: START TRAIN {lang} {dataset_name}.")
# wait until all TPUs are ready
xm.rendezvous("start_training")
Expand Down Expand Up @@ -674,11 +675,8 @@ def merge_dict_into_sublists(d):
nprocs=8,
)

# FIXME: integrate trainer, training_args (world!)
# FIXME: grouping for TPUs: 1k, 10k, ...
# TODO: what if very different token distribution? check & fix - duplicate data?
# TODO: accomodate last batch (or, generally, leverage division by 8)
# TODO: check grouping for TPUs: 1k, 10k, ...; what is most sensible?

# TODO: see if shuffle x1, shuffle x num_epochs, or no shuffle is best
# TODO: double-check effect of non_punctuation_sample_ratio
# TODO: try Benjamin's idea: freeze head, add clf on top
# TODO: try: freeze head, add clf on top (or do not freeze head, diff LR, etc.)

0 comments on commit 0fee97d

Please sign in to comment.