diff --git a/wtpsplit/train/train_SM.py b/wtpsplit/train/train_SM.py index 86088ed7..87cfefa7 100644 --- a/wtpsplit/train/train_SM.py +++ b/wtpsplit/train/train_SM.py @@ -22,10 +22,10 @@ @dataclass class Args: block_size: int = 256 - num_layers: int = 12 - lim_lookahead: bool = False - without_pretraining: bool = False - no_sm_corruption: bool = False + num_layers: int = 12 # number of layers + lim_lookahead: bool = False # our "Lookahead" ablation + without_pretraining: bool = False # our "No pre-training" ablation + no_sm_corruption: bool = False # our "Only clean text" ablation # Parsing command line arguments or JSON config files as needed parser = HfArgumentParser([Args, TrainingArguments]) @@ -46,19 +46,20 @@ class Args: punct_chars = set(Constants.PUNCTUATION_CHARS) -for lang_code in tqdm(all_data, desc="Loading train/dev data"): +for lang_code in tqdm(all_data, desc="Loading data"): if "-" in lang_code or "_" in lang_code: + # we only train on monolingual data in SM, so no "en-de" code-switching for example! pass elif ( "ud" in all_data[lang_code]["sentence"] and all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"] is not None ): train_data = all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"] - # cf. Appendix A.2 + if len(train_data) < 10000: - train_data = train_data * (10000 // len(train_data) + 1) - - if len(train_data) < 5000: + # some languages have an insufficient number of sentences to fill a single batch + # this is just a quick way to upsample these so we don't run into problems later + # later we will use a uniform round-robin sampler for all languages train_data = train_data * (10000 // len(train_data) + 1) train_sentences[lang_code]["uncorrupted"].extend(train_data) @@ -67,6 +68,9 @@ class Args: train_data = all_data[lang_code]["sentence"]["ud-corrupted-asr"]["meta"]["train_data"] if len(train_data) < 5000: + # some languages have an insufficient number of sentences to fill a single batch + # this is just a quick way to upsample these so we don't run into problems later + # later we will use a uniform round-robin sampler for all languages train_data = train_data * (10000 // len(train_data) + 1) train_sentences[lang_code]["corrupted-asr"].extend(train_data) @@ -74,6 +78,9 @@ class Args: train_data = all_data[lang_code]["sentence"]["ud-corrupted-social-media"]["meta"]["train_data"] if len(train_data) < 5000: + # some languages have an insufficient number of sentences to fill a single batch + # this is just a quick way to upsample these so we don't run into problems later + # later we will use a uniform round-robin sampler for all languages train_data = train_data * (10000 // len(train_data) + 1) train_sentences[lang_code]["corrupted-social-media"].extend(train_data) @@ -83,29 +90,23 @@ class Args: and all_data[lang_code]["sentence"]["opus100"]["meta"]["train_data"] is not None ): train_data = all_data[lang_code]["sentence"]["opus100"]["meta"]["train_data"] - assert len(train_data) == 10000 train_sentences[lang_code]["uncorrupted"].extend(train_data) if not args.no_sm_corruption: train_data = all_data[lang_code]["sentence"]["opus100-corrupted-asr"]["meta"]["train_data"] - assert len(train_data) == 10000 train_sentences[lang_code]["corrupted-asr"].extend(train_data) train_data = all_data[lang_code]["sentence"]["opus100-corrupted-social-media"]["meta"]["train_data"] - assert len(train_data) == 10000 train_sentences[lang_code]["corrupted-social-media"].extend(train_data) else: train_data = all_data[lang_code]["sentence"]["nllb"]["meta"]["train_data"] - assert len(train_data) == 10000 train_sentences[lang_code]["uncorrupted"].extend(train_data) if not args.no_sm_corruption: train_data = all_data[lang_code]["sentence"]["nllb-corrupted-asr"]["meta"]["train_data"] - assert len(train_data) == 10000 train_sentences[lang_code]["corrupted-asr"].extend(train_data) train_data = all_data[lang_code]["sentence"]["nllb-corrupted-social-media"]["meta"]["train_data"] - assert len(train_data) == 10000 train_sentences[lang_code]["corrupted-social-media"].extend(train_data) for dataset in all_data[lang_code]["sentence"]: @@ -145,7 +146,6 @@ class Args: model_checkpoint = "segment-any-text/sat-12l-no-limited-lookahead" else: model_checkpoint = "segment-any-text/sat-12l" - else: raise ValueError("Invalid number of layers. Valid values are 1, 3, 6, 9, 12.") @@ -155,6 +155,7 @@ class Args: assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast) if args.num_layers == 3 and args.without_pretraining: + # special case for one of our ablations, where we trim XLM-R (without any of our newline pretraining) to 3 layers model = SubwordXLMForTokenClassification.from_pretrained( model_checkpoint, num_labels=1, @@ -299,8 +300,6 @@ def pack_sentences(input_data_dict, block_size): experiment_name = model_checkpoint.split("/")[-1] -# experiment_name += str(args.num_layers) + "L" - if args.no_sm_corruption: experiment_name += "-no-corruption"