Skip to content

Commit

Permalink
use pretrained models; improve leakage in non-whitespace models
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jan 10, 2024
1 parent fd6a63f commit 9e0b3e2
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 72 deletions.
9 changes: 5 additions & 4 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
f[lang_code][dataset_name]["train_labels"][:],
features=feature_indices,
)
print(clf)
print(np.argsort(clf[0].coef_[0])[:10], "...", np.argsort(clf[0].coef_[0])[-10:])
print(np.where(np.argsort(clf[0].coef_[0]) == 0)[0])
if clf[0] is not None:
print(clf)
print(np.argsort(clf[0].coef_[0])[:10], "...", np.argsort(clf[0].coef_[0])[-10:])
print(np.where(np.argsort(clf[0].coef_[0]) == 0)[0])
score_t, score_punct, _ = evaluate_mixture(
lang_code,
f[lang_code][dataset_name]["test_logits"][:],
Expand All @@ -178,7 +179,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
clfs[lang_code][dataset_name] = clf

clf = list(copy.deepcopy(clf))
clf[-1] = 0.005 # 0.01
clf[-1] = 0.01 # 0.01
else:
score_t = score_punct = None

Expand Down
93 changes: 43 additions & 50 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
lang = sample["lang"]

while len(input_ids) < args.block_size + args.overflow_size:
input_ids.append(0)
if tokenizer:
input_ids.append(tokenizer.pad_token_id)
else:
input_ids.append(0)

block_ids = [0] * len(input_ids)

Expand All @@ -135,47 +138,30 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
min_length=args.block_size,
tokenizer=tokenizer if args.use_subwords else None,
)

actual_block_size = args.block_size - 2 if args.use_subwords else args.block_size

if len(input_ids) > args.block_size:
if tokenizer:
# always include CLS
start = np.random.randint(0, len(input_ids) - args.block_size)
if start != 0:
# this removes the CLS token
# -1 removes the SEP token, for sure
input_ids = [tokenizer.cls_token_id] + input_ids[start : start + args.block_size - 2]
labels = [0] + labels[start : start + args.block_size - 2]
else:
input_ids = input_ids[start : start + args.block_size - 1]
labels = labels[start : start + args.block_size - 1]
# always include SEP
if input_ids[-1] != tokenizer.sep_token_id:
# also insert PAD token as long as len < block_size
while len(input_ids) < args.block_size - 1:
input_ids = input_ids + [tokenizer.pad_token_id]
labels = labels + [0]
input_ids = input_ids + [tokenizer.sep_token_id]
labels = labels + [0]
else:
start = np.random.randint(0, len(input_ids) - args.block_size)
input_ids = input_ids[start : start + args.block_size]
labels = labels[start : start + args.block_size]
if len(input_ids) != args.block_size and tokenizer:
del input_ids[-1]
del labels[-1]
while len(input_ids) < args.block_size - 1:
# insert pad token at second-to-last position
logger.warning("second", len(input_ids))
input_ids = input_ids + [tokenizer.pad_token_id]
labels = labels + [0]
input_ids = input_ids + [tokenizer.sep_token_id]
labels = labels + [0]

if len(input_ids) != args.block_size:
logger.warning(len(input_ids))
input_ids = torch.tensor(input_ids[: args.block_size], dtype=torch.long)
labels = torch.tensor(labels[: args.block_size], dtype=torch.long)
start = np.random.randint(0, len(input_ids) - actual_block_size)
input_ids = input_ids[start : start + actual_block_size]
labels = labels[start : start + actual_block_size]
elif len(input_ids) < actual_block_size:
padding = actual_block_size - len(input_ids)
# print(padding, lang)
input_ids += [tokenizer.pad_token_id] * padding if tokenizer else [0] * padding
labels += [0] * padding

if tokenizer:
input_ids = [tokenizer.cls_token_id] + input_ids[: actual_block_size] + [tokenizer.sep_token_id]
# labels for CLS and SEP tokens are 0 (none)
labels = [0] + labels[: actual_block_size] + [0]
else:
input_ids = input_ids[: actual_block_size]
labels = labels[: actual_block_size]


input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = torch.tensor(labels, dtype=torch.long)
position_ids = torch.arange(len(input_ids), dtype=torch.long)
label_weights = torch.ones(args.block_size, dtype=torch.float32)
if tokenizer:
Expand Down Expand Up @@ -232,7 +218,10 @@ def main():
num_hidden_layers=args.num_hidden_layers,
num_labels=num_labels,
)
backbone = SubwordXLMForTokenClassification(config)
backbone = SubwordXLMForTokenClassification.from_pretrained(
args.model_name_or_path,
config=config,
)

backbone.config.base_model = args.model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
Expand Down Expand Up @@ -536,17 +525,17 @@ def maybe_pad(text):
num_workers=args.preprocessing_num_workers,
include_languages=args.include_languages,
shuffle=args.shuffle,
split="train",
split="valid",
)
logger.warning(f"Train dataset has {len(train_dataset)} examples.")

# print some samples from the dataset
count = 0
while count < 5:
while count < 20:
index = random.choice(range(len(train_dataset)))
sample = train_dataset[index]

if sample.get("lang") == "de":
if sample.get("lang") in ["zh", "ja", "my", "km"]:
logger.warning(f"Sample {index} of the training set: {sample}.")
if tokenizer:
logger.warning(tokenizer.decode(sample["input_ids"]))
Expand Down Expand Up @@ -578,6 +567,10 @@ def compute_metrics(trainer):
)
metrics[f"{lang_code}_{dataset_name}_pr_auc"] = score
avg_metrics[f"average_{dataset_name}_pr_auc"].append(score)
if lang_code in ["zh", "ja", "my", "km"]:
avg_metrics[f"average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
else:
avg_metrics[f"average_whitespace_{dataset_name}_pr_auc"].append(score)

for name, values in avg_metrics.items():
if len(values) > 1:
Expand All @@ -604,13 +597,13 @@ def compute_metrics(trainer):
training_args.adapter_lr_multiplier = args.adapter_lr_multiplier

# give .map in multiprocessing enough of time to finish, to be safe
time.sleep(10)
if training_args.local_rank == 0:
# since both share the *same* cache_dir, we cannot simply call dataset.cleanup_cache_files()
# because that would remove the cache files of the other dataset!
cleanup_cache_files([train_dataset, valid_dataset])
logger.warning("Cleaned up cache files.")
time.sleep(10)
# time.sleep(10)
# if training_args.local_rank == 0:
# # since both share the *same* cache_dir, we cannot simply call dataset.cleanup_cache_files()
# # because that would remove the cache files of the other dataset!
# cleanup_cache_files([train_dataset, valid_dataset])
# logger.warning("Cleaned up cache files.")
# time.sleep(10)

trainer = Trainer(
model,
Expand Down
113 changes: 95 additions & 18 deletions wtpsplit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class LabelArgs:
hyphen_chars: List[str] = field(default_factory=lambda: ["-", "‐"])
use_auxiliary: bool = False
custom_punctuation_file: str = None
retain_first_consecutive_punctuation: bool = True
non_whitespace_remove_spaces: bool = True

def __post_init__(self):
if self.custom_punctuation_file:
Expand Down Expand Up @@ -184,12 +186,6 @@ def corrupt(
try:
i = next(index for index, label in enumerate(labels) if label != 0)
except StopIteration:
if tokenizer is not None:
input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id]
# Extend block_ids for the added CLS and SEP tokens
block_ids = [block_ids[0]] + block_ids + [block_ids[-1]]
# labels for CLS and SEP tokens are 0 (none)
labels = [0] + labels + [0]
return input_ids, block_ids, labels

if tokenizer:
Expand Down Expand Up @@ -223,27 +219,64 @@ def corrupt(
labels.insert(last_index_in_block, 0)
else:
del block_ids[i + 1]
if tokenizer and separator == "" and label_args.non_whitespace_remove_spaces and i + 1 < len(input_ids):
# tokenizer.decode() retains the space that leaks the information
# so we need to get the position within the tokenized text and then remove the space
# (so there is no more space when fed into the tokenizer call)
if input_ids[i + 1] == tokenizer.convert_tokens_to_ids("▁"):
# remove artificial space
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
if i + 1 < len(input_ids):
next_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1])
if next_token.startswith("▁"):
# next token starts with _ --> remove the _ from the token and re-tokenize
remove_next = False
remaining_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1])
if len(remaining_token) > 1:
# ▁Test --> Test
remaining_token = remaining_token[1:]
else:
# ▁ --> remove
remove_next = True
remaining_id = tokenizer.convert_tokens_to_ids(remaining_token)
# replace the token with the remaining token
if remaining_id != tokenizer.unk_token_id:
input_ids[i + 1] = remaining_id
else:
# UNK token, remove it
remove_next = True
if remove_next:
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]

elif label_args.use_auxiliary and labels[i] > Constants.AUX_OFFSET: # auxiliary
if pack_samples:
raise NotImplementedError()

if random.random() < label_args.auxiliary_remove_prob:
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
if random.random() < label_args.auxiliary_remove_prob:
if label_args.retain_first_consecutive_punctuation:
# remove only if the next token is not a newline
# this retains the current auxiliary character, even though we decided to remove it
# it may skew the statistics since an auxiliary character is a better proxy for a newline
if labels[i + 1] != 1:
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
else:
# in case of something like ".\n", this removes the "." and the \n label (=1)
# so the newline in the text is kept, but the label is removed!
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]

try:
i = i + 1 + next(index for index, label in enumerate(labels[i + 1 :]) if label != 0)
except StopIteration:
break

# Add CLS and SEP tokens after the corruption process
if tokenizer is not None:
input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id]
# Extend block_ids for the added CLS and SEP tokens
block_ids = [block_ids[0]] + block_ids + [block_ids[-1]]
# labels for CLS and SEP tokens are 0 (none)
labels = [0] + labels + [0]

return input_ids, block_ids, labels


Expand Down Expand Up @@ -312,3 +345,47 @@ def reconstruct_sentences(text, partial_sentences):
fixed_sentences.append(text[i:])

return fixed_sentences


if __name__ == "__main__":
# test corrupt function
from transformers import AutoTokenizer
from tokenizers import AddedToken

tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
text = "That's right, Five!\n!\n!!!\n!\n Always lay the blame on others!"
input_ids = tokenizer(text)["input_ids"]
block_ids = [0] * len(input_ids)
label_args = LabelArgs(
custom_punctuation_file="punctuation_xlmr_unk.txt",
use_auxiliary=True,
auxiliary_remove_prob=1.0,
newline_whitespace_prob=1.0,
)
label_dict = get_subword_label_dict(label_args, tokenizer)

# corrupt
input_ids, block_ids, labels = corrupt(input_ids, block_ids, "en", label_args, label_dict, tokenizer=tokenizer)
print(input_ids)
print(labels)
print(tokenizer.tokenize(text))
print([(tokenizer.decode([input_id]), label) for input_id, label in zip(input_ids, labels)])
print("newline labels in text:")
print(np.where(np.array(labels) == 1))
print("newline ids in output text:")
print(np.where(np.array(input_ids) == tokenizer.all_special_ids[-1]))
print(tokenizer.decode(input_ids))

# ords = [ord(c) for c in text]
# block_ords = [0] * len(ords)
# label_args = LabelArgs(use_auxiliary=True, auxiliary_remove_prob=1.0)
# label_dict = get_label_dict(label_args)

# ords, block_ords, labels = corrupt(ords, block_ords, "en", label_args, label_dict)
# print("ords", ords)
# print("labels", labels)
# print("newline labels in text:")
# print(np.where(np.array(labels) == 1))
# print("newline ids in output text:")
# print(np.where(np.array([ord("\n")]) == ords))

0 comments on commit 9e0b3e2

Please sign in to comment.