Skip to content

Commit

Permalink
fix collate length?!
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 27, 2023
1 parent e2bff46 commit d55d547
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
2 changes: 1 addition & 1 deletion wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def evaluate_sentence(
newline_labels[true_end_indices - 1] = 1

if "xlm" in model.config.model_type:
tokens = tokenizer.tokenize(text)
tokens = tokenizer.tokenize(text, verbose=False)
char_probs = token_to_char_probs(text, tokens, logits[:, positive_index], tokenizer, offsets_mapping)
else:
char_probs = logits[:, positive_index]
Expand Down
24 changes: 20 additions & 4 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import torch
from tqdm.auto import tqdm
import wandb
from datasets import load_dataset
from datasets.download import DownloadConfig
Expand All @@ -34,7 +35,7 @@

# TODO: double-check checkpointing and saving (also to txt)

# os.environ["PJRT_DEVICE"] = "None"
os.environ["PJRT_DEVICE"] = "None"


class Model(nn.Module):
Expand Down Expand Up @@ -241,7 +242,8 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
# always include SEP
if input_ids[-1] != tokenizer.sep_token_id:
# also insert PAD token as long as len < block_size
while len(input_ids) < args.block_size - 1:
while len(input_ids) <= args.block_size - 1:
print("first", len(input_ids))
input_ids = input_ids + [tokenizer.pad_token_id]
labels = labels + [0]
input_ids = input_ids + [tokenizer.sep_token_id]
Expand All @@ -250,6 +252,16 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
start = np.random.randint(0, len(input_ids) - args.block_size)
input_ids = input_ids[start : start + args.block_size]
labels = labels[start : start + args.block_size]
elif len(input_ids) != args.block_size and args.use_subwords:
del input_ids[-1]
del labels[-1]
while len(input_ids) <= args.block_size - 1:
# insert pad token at second-to-last position
print("second", len(input_ids))
input_ids = input_ids + [tokenizer.pad_token_id]
labels = labels + [0]
input_ids = input_ids + [tokenizer.sep_token_id]
labels = labels + [0]

input_ids = torch.tensor(input_ids[: args.block_size], dtype=torch.long)
labels = torch.tensor(labels[: args.block_size], dtype=torch.long)
Expand Down Expand Up @@ -580,6 +592,10 @@ def maybe_pad(text):
num_proc=num_workers,
remove_columns=[args.text_column],
)
else:
# this is no longer used and would cause an error otherwise
with training_args.main_process_first():
dataset = dataset.remove_columns([args.text_column])

if split == "train":
with training_args.main_process_first():
Expand All @@ -596,7 +612,7 @@ def maybe_pad(text):
batched=True,
num_proc=num_workers,
# a bit hacky but oh well, only drop if sentence
remove_columns=["ends_with_punctuation"] # FIXME: needed for char-based args.text_column dropping
remove_columns=["ends_with_punctuation"]
if args.text_column == "text"
else [],
)
Expand Down Expand Up @@ -649,7 +665,7 @@ def compute_metrics(trainer):

model = trainer._wrap_model(trainer.model, training=False)

for lang_code, lang_data in eval_data.items(): # TODO: tqdm integration
for lang_code, lang_data in tqdm(eval_data.items(), desc="Evaluate!"):
if args.include_languages is not None and lang_code not in args.include_languages:
continue

Expand Down

0 comments on commit d55d547

Please sign in to comment.