Skip to content

Commit

Permalink
fix tokenization?, eval during training
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 23, 2023
1 parent cdb8743 commit 5b37331
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 61 deletions.
4 changes: 2 additions & 2 deletions configs/xlmr_stratify_0.1_3layers.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 1,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 32,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 32,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 2000000,
"save_steps": 100000,
"eval_steps": 50000000000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
Expand Down
2 changes: 2 additions & 0 deletions wtpsplit/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ class SubwordXLMConfig(XLMRobertaConfig):
XLMRobertaConfig: Base class.
"""
model_type = "xlm-token"
mixture_name = "xlm-token"

def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.mixture_name = "xlm-token"


AutoConfig.register("bert-char", BertCharConfig)
Expand Down
7 changes: 5 additions & 2 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Args:
# }
eval_data_path: str = "data/eval.pth"
valid_text_path: str = None#"data/sentence/valid.parquet"
device: str = "xla:1"
device: str = "cpu"
block_size: int = 512
stride: int = 64
batch_size: int = 32
Expand All @@ -44,7 +44,8 @@ class Args:
def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_sentences=10_000):
logits_path = Constants.CACHE_DIR / (model.config.mixture_name + "_logits.h5")

with h5py.File(logits_path, "a") as f, torch.no_grad():
# TODO: revert to "a"
with h5py.File(logits_path, "w") as f, torch.no_grad():
for lang_code in Constants.LANGINFO.index:
if args.include_langs is not None and lang_code not in args.include_langs:
continue
Expand Down Expand Up @@ -152,13 +153,15 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_

if "train_logits" in f[lang_code][dataset_name]:
feature_indices = None
# TODO: tokenize here
clf = train_mixture(
[lang_code],
f[lang_code][dataset_name]["train_logits"][:],
f[lang_code][dataset_name]["train_labels"][:],
features=feature_indices,
)

# TODO: tokenize here, too
score_t, score_punct, _ = evaluate_mixture(
lang_code,
f[lang_code][dataset_name]["test_logits"][:],
Expand Down
92 changes: 65 additions & 27 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import numpy as np
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from tokenizers import AddedToken

from wtpsplit.utils import Constants, hash_encode

Expand Down Expand Up @@ -37,7 +39,7 @@ def __getattr__(self, name):
assert hasattr(self, "model")
return getattr(self.model, name)

def __call__(self, hashed_ids, attention_mask, language_ids=None):
def __call__(self, input_ids, hashed_ids, attention_mask, language_ids=None):
try:
import torch
except ImportError:
Expand All @@ -46,7 +48,8 @@ def __call__(self, hashed_ids, attention_mask, language_ids=None):
with torch.no_grad():
logits = (
self.model(
hashed_ids=torch.from_numpy(hashed_ids).to(self.model.device),
input_ids=torch.from_numpy(input_ids).to(self.model.device) if input_ids is not None else None,
hashed_ids=torch.from_numpy(hashed_ids).to(self.model.device) if hashed_ids is not None else None,
attention_mask=torch.from_numpy(attention_mask).to(self.model.device),
language_ids=torch.from_numpy(language_ids).to(self.model.device)
if language_ids is not None
Expand Down Expand Up @@ -76,6 +79,20 @@ def extract(
ad 1.: text is sliced into partially overlapping chunks by moving forward by a `stride` parameter (think conv1d).
"""
if "xlm" in model.config.model_type:
use_subwords = True
tokenizer = AutoTokenizer.from_pretrained(model.config.base_model)
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
tokens = tokenizer(batch_of_texts, return_offsets_mapping=True)
# remove CLS and SEP tokens, they are added later anyhow
batch_of_texts = [text[1:-1] for text in tokens["input_ids"]]
offset_mapping = [offset[1:-1] for offset in tokens["offset_mapping"]]
cls_token_id = tokenizer.cls_token_id
sep_token_id = tokenizer.sep_token_id
pad_token_id = tokenizer.pad_token_id
else:
pad_token_id = 0
use_subwords = False

text_lengths = [len(text) for text in batch_of_texts]
# reduce block size if possible
Expand All @@ -84,44 +101,56 @@ def extract(
# make sure block_size is a multiple of downsampling rate
downsampling_rate = getattr(model.config, "downsampling_rate", 1)
block_size = math.ceil(block_size / downsampling_rate) * downsampling_rate
actual_block_size = block_size - 2 if use_subwords else block_size # account for CLS and SEP tokens

# total number of forward passes
num_chunks = sum(math.ceil(max(length - block_size, 0) / stride) + 1 for length in text_lengths)
num_chunks = sum(math.ceil(max(length - actual_block_size, 0) / stride) + 1 for length in text_lengths)

# preallocate a buffer for all input hashes & attention masks
input_hashes = np.zeros((num_chunks, block_size, model.config.num_hash_functions), dtype=np.int64)
if not use_subwords:
input_hashes = np.zeros((num_chunks, block_size, model.config.num_hash_functions), dtype=np.int64)
else:
input_ids = np.zeros((num_chunks, block_size), dtype=np.int64)
attention_mask = np.zeros((num_chunks, block_size), dtype=np.float32)

# locs keep track of the location of every chunk with a 3-tuple (text_idx, char_start, char_end) that indexes
# back into the batch_of_texts
locs = np.zeros((num_chunks, 3), dtype=np.int32)

# this is equivalent to (but faster than) np.array([ord(c) for c in "".join(batch_of_texts)])
codec = "utf-32-le" if sys.byteorder == "little" else "utf-32-be"
ordinals = np.frombuffer(bytearray("".join(batch_of_texts), encoding=codec), dtype=np.int32)

# hash encode all ids
flat_hashed_ids = hash_encode(ordinals,
num_hashes=model.config.num_hash_functions,
num_buckets=model.config.num_hash_buckets)
if not use_subwords:
# this is equivalent to (but faster than) np.array([ord(c) for c in "".join(batch_of_texts)])
codec = "utf-32-le" if sys.byteorder == "little" else "utf-32-be"
ordinals = np.frombuffer(bytearray("".join(batch_of_texts), encoding=codec), dtype=np.int32)
# hash encode all ids
flat_hashed_ids = hash_encode(ordinals,
num_hashes=model.config.num_hash_functions,
num_buckets=model.config.num_hash_buckets)
# note that ordinals and flat_hashed_ids have the same length
offset = 0
current_chunk = 0



# create chunks
for i in range(len(batch_of_texts)):
for j in range(0, text_lengths[i], stride):
# for every chunk, assign input hashes, attention mask and loc
start, end = j, j + block_size
start, end = j, j + actual_block_size
done = False

if end >= text_lengths[i]:
end = text_lengths[i]
start = max(end - block_size, 0)
start = max(end - actual_block_size, 0)
done = True

input_hashes[current_chunk, : end - start] = flat_hashed_ids[offset + start : offset + end]
attention_mask[current_chunk, : end - start] = 1
if not use_subwords:
input_hashes[current_chunk, : end - start] = flat_hashed_ids[offset + start : offset + end]
attention_mask[current_chunk, : end - start] = 1
else:
chunk = [cls_token_id] + batch_of_texts[i][start:end] + [sep_token_id]
input_ids[current_chunk, :len(chunk)] = chunk
attention_mask[current_chunk, :len(chunk)] = 1

locs[current_chunk, :] = [i, start, end]

current_chunk += 1

if done:
Expand All @@ -130,7 +159,7 @@ def extract(
offset += text_lengths[i]

assert current_chunk == num_chunks
n_batches = math.ceil(len(input_hashes) / batch_size)
n_batches = math.ceil(len(attention_mask) / batch_size)

# containers for the final logits
all_logits = [
Expand Down Expand Up @@ -163,21 +192,30 @@ def extract(

# forward passes through all chunks
for batch_idx in tqdm(range(n_batches), disable=not verbose):
start, end = batch_idx * batch_size, min(len(input_hashes), (batch_idx + 1) * batch_size)
start, end = batch_idx * batch_size, min(len(attention_mask), (batch_idx + 1) * batch_size)

batch_input_hashes = input_hashes[start:end]
if not use_subwords:
batch_input_hashes = input_hashes[start:end]
else:
batch_input_ids = input_ids[start:end]
batch_attention_mask = attention_mask[start:end]

if len(batch_input_hashes) < batch_size and pad_last_batch:
n_missing = batch_size - len(batch_input_hashes)
if len(batch_attention_mask) < batch_size and pad_last_batch:
n_missing = batch_size - len(batch_attention_mask)

batch_input_hashes = np.pad(batch_input_hashes, ((0, n_missing), (0, 0), (0, 0)))
if not use_subwords:
batch_input_hashes = np.pad(batch_input_hashes, ((0, n_missing), (0, 0), (0, 0)))
else:
# Pad with the specific pad_token_id for the tokenizer
batch_input_ids = np.pad(batch_input_ids, ((0, n_missing), (0, 0)), constant_values=pad_token_id)
batch_attention_mask = np.pad(batch_attention_mask, ((0, n_missing), (0, 0)))


kwargs = {"language_ids": language_ids[: len(batch_input_hashes)]} if uses_lang_adapters else {}
kwargs = {"language_ids": language_ids[: len(batch_attention_mask)]} if uses_lang_adapters else {}

logits = model(
hashed_ids=batch_input_hashes,
input_ids=batch_input_ids if use_subwords else None,
hashed_ids=batch_input_hashes if not use_subwords else None,
attention_mask=batch_attention_mask,
**kwargs,
)["logits"]
Expand All @@ -190,4 +228,4 @@ def extract(
# so far, logits are summed, so we average them here
all_logits = [(logits / counts[:, None]).astype(np.float16) for logits, counts in zip(all_logits, all_counts)]

return all_logits
return all_logits, offset_mapping if use_subwords else None, tokenizer if use_subwords else None
7 changes: 4 additions & 3 deletions wtpsplit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,8 +989,9 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hashed_ids: Optional[torch.Tensor] = None,
language_ids=None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
return super().forward(
input_ids,
Expand Down Expand Up @@ -1030,12 +1031,12 @@ def forward(
text = "This is a test\n sentence \n\n"
tokenizer = AutoTokenizer.from_pretrained(model_str)

tokens = tokenizer(text, return_tensors="pt")
tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False)
from tokenizers import AddedToken
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
print(tokenizer.tokenize(text))
print(tokenizer.encode(text))
print(tokens)
# forward pass
print(backbone(**tokens))


45 changes: 41 additions & 4 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,33 @@ def get_metrics(labels, preds):

return metrics, info

def get_token_spans(tokenizer: object, offsets_mapping: list, tokens: list):
token_spans = []
for idx, token in enumerate(tokens):
# Skip special tokens like [CLS], [SEP]
if idx >= len(offsets_mapping):
continue
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
continue

char_start, char_end = offsets_mapping[idx]
token_spans.append((char_start, char_end))

return token_spans

def token_to_char_probs(text: str, tokens: list, token_probs: np.ndarray, tokenizer, offsets_mapping):
char_probs = np.zeros(len(text))
token_spans = get_token_spans(tokenizer, offsets_mapping, tokens)

for (start, end), prob in zip(token_spans, token_probs):
# assign the token's prob to the last char of the token
# Ensure the end index does not exceed the length of the text
if end >= len(text):
print(f"Adjusting end index from {end} to {len(text)} for token '{text[start:end]}'")
end = len(text) - 1
char_probs[end] = prob

return char_probs

def evaluate_sentence(
lang_code,
Expand All @@ -67,21 +94,30 @@ def evaluate_sentence(
separator = Constants.SEPARATORS[lang_code]
text = separator.join(sentences)

logits = extract(
logits, offsets_mapping, tokenizer = extract(
[text],
PyTorchWrapper(model.backbone),
lang_code=lang_code,
stride=stride,
block_size=block_size,
batch_size=batch_size,
verbose=True,
)[0]
)
logits, offsets_mapping = logits[0], offsets_mapping[0]

true_end_indices = np.cumsum(np.array([len(s) for s in sentences])) + np.arange(len(sentences)) * len(separator)
newline_labels = np.zeros(len(text))
newline_labels[true_end_indices - 1] = 1

metrics, info = get_metrics(newline_labels, logits[:, positive_index])

print("newline_labels", newline_labels.shape)

if "xlm" in model.config.model_type:
tokens = tokenizer.tokenize(text)
char_probs = token_to_char_probs(text, tokens, logits[:, positive_index], tokenizer, offsets_mapping)
else:
char_probs = logits[:, positive_index]
print("char probs", char_probs.shape)
metrics, info = get_metrics(newline_labels, char_probs)

info["newline_labels"] = newline_labels

Expand All @@ -94,3 +130,4 @@ def evaluate_sentence(
info["newline_probs_pysbd"] = newline_probs_pysbd

return metrics["pr_auc"], info

Loading

0 comments on commit 5b37331

Please sign in to comment.