Skip to content

Commit

Permalink
improve adaptation + format
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Aug 7, 2024
1 parent 831b184 commit 41dce4b
Show file tree
Hide file tree
Showing 18 changed files with 72 additions and 55 deletions.
1 change: 1 addition & 0 deletions scripts/export_to_onnx_charbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import wtpsplit # noqa
import wtpsplit.models # noqa


@dataclass
class Args:
model_name_or_path: str = "benjamin/wtp-bert-mini"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"pandas>=1",
"cached_property", # for Py37
"mosestokenizer",
"adapters"
"adapters",
],
url="https://github.com/segment-any-text/wtpsplit",
package_data={"wtpsplit": ["data/*"]},
Expand Down
15 changes: 8 additions & 7 deletions wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
warnings.simplefilter("default", DeprecationWarning) # show by default
warnings.simplefilter("ignore", category=FutureWarning) # for tranformers


class WtP:
def __init__(
self,
Expand Down Expand Up @@ -435,9 +436,7 @@ def __init__(

self.use_lora = False

self.tokenizer = AutoTokenizer.from_pretrained(
"facebookAI/xlm-roberta-base"
)
self.tokenizer = AutoTokenizer.from_pretrained("facebookAI/xlm-roberta-base")

if isinstance(model_name_or_model, (str, Path)):
model_name = str(model_name_or_model)
Expand Down Expand Up @@ -498,12 +497,14 @@ def __init__(
)
# LoRA LOADING
# TODO: LoRA + ONNX ?
if (style_or_domain and not language) or (language and not style_or_domain):
raise ValueError("Please specify both language and style_or_domain!")
if style_or_domain and language:
if not lora_path:
if (style_or_domain and not language) or (language and not style_or_domain):
raise ValueError("Please specify both language and style_or_domain!")
if (style_or_domain and language) or lora_path:
import adapters # noqa
from adapters.models import MODEL_MIXIN_MAPPING # noqa
from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa

# monkey patch mixin to avoid forking whole adapters library
MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin
model_type = self.model.model.config.model_type
Expand Down Expand Up @@ -642,7 +643,7 @@ def newline_probability_fn(logits):
batch_size=batch_size,
pad_last_batch=pad_last_batch,
verbose=verbose,
tokenizer=self.tokenizer
tokenizer=self.tokenizer,
)

# convert token probabilities to character probabilities for the entire array
Expand Down
4 changes: 2 additions & 2 deletions wtpsplit/evaluation/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
if args.adapter_path:
if args.clf_from_scratch:
model.model.classifier = torch.nn.Linear(model.model.classifier.in_features, 1)

dataset_load_name = dataset_name
model.model.load_adapter(
args.adapter_path + "/" + dataset_load_name + "/" + lang_code,
Expand Down Expand Up @@ -354,7 +354,7 @@ def main(args):
clfs = {}
if args.return_indices:
indices = {}

u_scores, t_scores, punct_scores = [], [], []

for lang_code, dsets in tqdm(eval_data.items()):
Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/evaluation/evaluate_sepp_nlg_subtask1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def evaluate_subtask1(splits, langs, prediction_dir: str, supervisions, include_
Mirrors the original SEPP-NLG 2021 Shared Task evaluation function
https://sites.google.com/view/sentence-segmentation
"""

results = {}
avg_holder = {}
for supervision in supervisions:
Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/evaluation/intrinsic_baselines_multilingual.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Args:
continue
results[lang][dataset_name] = {}
indices[lang][dataset_name] = {}

if "-" in lang:
# code-switched data: eval 2x
lang_code = lang.split("_")[1].lower()
Expand Down
1 change: 0 additions & 1 deletion wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class Args:
min_k_mer_length: int = 0



def process_logits_k_mers(pairs, model, lang_code, block_size, batch_size, verbose=True) -> List[np.ndarray]:
logits_list = []
n_tokens_list = []
Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/evaluation/legal_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None):

start_time = time.time()
test_logits = get_law_preds(test_text, model, current_name, args)
end_time = time.time()
end_time = time.time()
total_test_time += end_time - start_time
if isinstance(test_sentences[0], list):
test_logit_lengths = []
Expand Down
12 changes: 3 additions & 9 deletions wtpsplit/evaluation/llm_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None):
if isinstance(test_sentences[0], list) or args.type == "pairs":
if args.type == "pairs":
all_pairs = generate_k_mers(
test_sentences,
k=2,
do_lowercase=False,
do_remove_punct=False,
sample_pct=0.5
test_sentences, k=2, do_lowercase=False, do_remove_punct=False, sample_pct=0.5
)
test_sentences = all_pairs
# list of lists: chunk each sublist
Expand Down Expand Up @@ -568,10 +564,8 @@ def main(args):

eval_data = torch.load(eval_data_path)

save_str = (
f"{args.model.split('/')[-1]}_k{args.k}_s{args.n_shots}"
).replace("/", "_")

save_str = (f"{args.model.split('/')[-1]}_k{args.k}_s{args.n_shots}").replace("/", "_")

if args.max_n_test_sentences < sys.maxsize and args.max_n_test_sentences != -1:
save_str += f"_n{args.max_n_test_sentences}"
if args.max_n_test_sentences == -1:
Expand Down
1 change: 0 additions & 1 deletion wtpsplit/evaluation/stat_tests/permutation_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
if isinstance(data_list[0], int):
data_list = [data_list]


raw_data[lang][dataset][model + "-" + model_type] = data_list

true_indices = data[lang][dataset][model_type]["true_indices"]
Expand Down
15 changes: 10 additions & 5 deletions wtpsplit/extract_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,17 @@ def extract_batched(
if use_subwords and model.config.model_type == "xlm-roberta":
# TODO: generalize
import torch

with torch.no_grad():
logits = model.model(
input_ids=torch.from_numpy(input_ids).to(model.model.device),
attention_mask=torch.from_numpy(attention_mask).to(model.model.device),
**kwargs,
)["logits"].cpu().numpy()
logits = (
model.model(
input_ids=torch.from_numpy(input_ids).to(model.model.device),
attention_mask=torch.from_numpy(attention_mask).to(model.model.device),
**kwargs,
)["logits"]
.cpu()
.numpy()
)
else:
logits = model(
input_ids=input_ids if use_subwords else None,
Expand Down
10 changes: 5 additions & 5 deletions wtpsplit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,11 +1231,11 @@ def get_extended_attention_mask(
"""

# if not (attention_mask.dim() == 2 and config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
# if device is not None:
# warnings.warn(
# "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
# )
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
# if device is not None:
# warnings.warn(
# "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
# )
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
Expand Down
7 changes: 5 additions & 2 deletions wtpsplit/train/adaptertrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@

TRAINING_ARGS_NAME = "training_args.bin"


class AdapterTrainer(Trainer):
def __init__(
self,
Expand Down Expand Up @@ -484,11 +485,13 @@ def evaluation_loop(
if all_inputs is not None:
all_inputs = nested_truncate(all_inputs, num_samples)
else:
xm.rendezvous("eval_metrics")
if is_torch_tpu_available():
xm.rendezvous("eval_metrics")
all_losses, all_preds, all_labels, all_inputs, num_samples = None, None, None, None, 0

# Metrics!
xm.rendezvous("eval_metrics")
if is_torch_tpu_available():
xm.rendezvous("eval_metrics")
# MODIFIED: always compute metrics
if self.compute_metrics is not None:
metrics = self.compute_metrics(self)
Expand Down
13 changes: 6 additions & 7 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pysbd
import sklearn.metrics

from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, generate_k_mers, process_logits_k_mers
from wtpsplit.evaluation.intrinsic_pairwise import generate_k_mers, process_logits_k_mers
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants, sigmoid, corrupt, token_to_char_probs

Expand Down Expand Up @@ -153,13 +153,12 @@ def evaluate_sentence_pairwise(
accuracy_list = []

# get pairs of sentences (non-overlapping)
sampled_pairs = generate_pairs(
sampled_pairs = generate_k_mers(
sentences=sentences,
pair_sample_pct=pair_sample_pct,
max_n_pairs=max_pairs,
min_pair_length=0,
do_lowercase=do_lowercase,
do_remove_punct=do_remove_punct,
k=2,
sample_pct=pair_sample_pct,
max_n_samples=max_pairs,
min_k_mer_length=0,
)

# get logits for each pair
Expand Down
23 changes: 18 additions & 5 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import shutil
import sys

# import time
from collections import Counter, defaultdict
from dataclasses import dataclass
Expand All @@ -14,14 +15,15 @@
import datasets
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import transformers
from datasets import load_dataset

# from datasets.download import DownloadConfig
from tokenizers import AddedToken
from torchinfo import summary
from tqdm.auto import tqdm
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
from transformers.trainer import is_torch_tpu_available

import wandb
from wtpsplit.models import (
Expand All @@ -35,6 +37,7 @@
from wtpsplit.train.evaluate import evaluate_sentence
from wtpsplit.train.trainer import Trainer
from wtpsplit.train.utils import Model

# from wtpsplit.train.utils import cleanup_cache_files
from wtpsplit.utils import Constants, LabelArgs, corrupt_training, get_label_dict, get_subword_label_dict

Expand Down Expand Up @@ -197,12 +200,22 @@ def main():
else:
(args, training_args, label_args) = parser.parse_args_into_dataclasses()
wandb_name = None
if xm.xrt_world_size() == 4:
# ensure same batch size on TPUv3 and TPUv4 using same config.json
training_args.per_device_train_batch_size *= 2

if is_torch_tpu_available():
import torch_xla.core.xla_model as xm

world_size = xm.xrt_world_size()
if world_size == 4:
# ensure same batch size on TPUv3 and TPUv4 using same config.json
training_args.per_device_train_batch_size *= 2
elif torch.cuda.is_available():
world_size = torch.cuda.device_count()
else:
world_size = 1

logger.warning(f"Per device train batch size: {training_args.per_device_train_batch_size}")
logger.warning(
f"Total train batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps* xm.xrt_world_size()}"
f"Total train batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * world_size}"
)

setup_logging(training_args)
Expand Down
13 changes: 7 additions & 6 deletions wtpsplit/train/train_SM.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
@dataclass
class Args:
block_size: int = 256
num_layers: int = 12 # number of layers
lim_lookahead: bool = False # our "Lookahead" ablation
without_pretraining: bool = False # our "No pre-training" ablation
no_sm_corruption: bool = False # our "Only clean text" ablation
num_layers: int = 12 # number of layers
lim_lookahead: bool = False # our "Lookahead" ablation
without_pretraining: bool = False # our "No pre-training" ablation
no_sm_corruption: bool = False # our "Only clean text" ablation


# Parsing command line arguments or JSON config files as needed
parser = HfArgumentParser([Args, TrainingArguments])
Expand Down Expand Up @@ -55,7 +56,7 @@ class Args:
and all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"] is not None
):
train_data = all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"]

if len(train_data) < 10000:
# some languages have an insufficient number of sentences to fill a single batch
# this is just a quick way to upsample these so we don't run into problems later
Expand Down Expand Up @@ -155,7 +156,7 @@ class Args:
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

if args.num_layers == 3 and args.without_pretraining:
# special case for one of our ablations, where we trim XLM-R (without any of our newline pretraining) to 3 layers
# special case for one of our ablations, where we trim XLM-R (without any of our newline pretraining) to 3 layers
model = SubwordXLMForTokenClassification.from_pretrained(
model_checkpoint,
num_labels=1,
Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/train/train_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def compute_metrics(trainer):
)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
logger.warning(f"Finished training for {lang} {dataset_name}.")

# only save trained module
if training_args.local_rank == 0:
if not os.path.exists(os.path.join(training_args.output_dir, dataset_name, lang)):
Expand Down
2 changes: 2 additions & 0 deletions wtpsplit/utils/download_spacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
"multi": "xx_sent_ud_sm"
}


def download_models():
for lang, model in SPACY_LANG_TO_DP_MODEL.items():
subprocess.run(["python3", "-m", "spacy", "download", model])


if __name__ == "__main__":
download_models()

0 comments on commit 41dce4b

Please sign in to comment.