Skip to content

Commit

Permalink
Support probing T5 models (#16)
Browse files Browse the repository at this point in the history
* Track the functions for probing T5

* Batchify the primpts to fit within the GPU limits

* Some tweaks for probing T5 models
  • Loading branch information
AMR-KELEG authored May 28, 2023
1 parent 475c734 commit 9d07402
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 21 deletions.
3 changes: 1 addition & 2 deletions mlama/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def compute_probs(logits):

# Infer the domain from the sample ID
domains = [d for d in DOMAINS if d in sample_id]
assert len(domains) == 1
domain = domains[0]
domain = domains[0] if domains else "general"

region = (
normalize_region_name(fields[-2])
Expand Down
66 changes: 66 additions & 0 deletions mlama/scripts/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,69 @@ def run_evaluation(args, NUM_MASK, candidate_objects_dict, model=None):
all_results = dict(list_of_results=list_of_results)
with open("{}/result.pkl".format(log_directory), "wb") as f:
pickle.dump(all_results, f)


def get_T5_ranking(model, tokenizer, candidate_answers, prompt, device):
"""Rank the answers according to their probability of filling the masked object.
Args:
model: A T5 model
tokenizer: The model's tokenizer
candidate_answers: A list of strings for all the candidate answers
prompt: The manual prompt used to probe the model
device: The GPU to use or "cpu"
Returns:
The answers with their corresponding probabilities.
"""
#  Replace the span for the object within the template
input_ids = tokenizer(
prompt.replace("[Y]", "<extra_id_0>"), return_tensors="pt"
).input_ids

# Tokenize the different answers for the span
answers_probabilities = {}

# TODO: Make this an argument to the function
BATCH_SIZE = 128
for i in tqdm(range(0, len(candidate_answers), BATCH_SIZE)):
answers = candidate_answers[i : i + BATCH_SIZE]
labels = tokenizer(
["<extra_id_0> " + answer + " <extra_id_1>" for answer in answers],
return_tensors="pt",
padding=True,
).input_ids

# Output in the form (Queries, Token Index, Value in Vocab)
# T5 generates an output in the form "<extra_id_0> 'answer' <extra_id_1>"
outputs = model(
input_ids=torch.concat([input_ids for _ in range(len(answers))]).to(device),
labels=labels.to(device),
).logits

# Find the ids of the extra tokens
EXTRA_ID_0_index = tokenizer("<extra_id_0>").input_ids[0]
EXTRA_ID_1_index = tokenizer("<extra_id_1>").input_ids[0]

for answer_id in range(len(answers)):
target_ids = labels[answer_id]
answer_subword_probabilities = []

for idx, t_idx in enumerate(target_ids):
# Skip the first t_idx which is always <extra_id_0>
if idx == 0:
assert t_idx == EXTRA_ID_0_index
continue

#  Stop computing the probabilities just before the <extra_id_1>
if t_idx == EXTRA_ID_1_index:
break

logits = outputs[answer_id, idx, :]
probs = logits.cpu().softmax(dim=-1).detach().numpy()
answer_subword_probabilities.append(-np.log(probs[t_idx]))

answer_probability = np.mean(answer_subword_probabilities)
answers_probabilities[answers[answer_id]] = answer_probability

return answers_probabilities
38 changes: 27 additions & 11 deletions mlama/scripts/model_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def prepare_lms_configuration(models_tuples):
def prepare_bert_lms_configuration(models_tuples):
"""
Convert list of models' names to configuration dictionaries
Expand All @@ -17,7 +17,25 @@ def prepare_lms_configuration(models_tuples):
]


AR_LMs = prepare_lms_configuration(
def prepare_T5_lms_configuration(models_tuples):
"""
Convert list of models' names to configuration dictionaries
Args:
model_tuples -- A tuple of tuples of the models' labels and names on the HF site
"""
return [
{
"lm": "T5",
"label": model_label,
"model_name": "T5",
"T5_model_name": hf_model_name,
}
for model_label, hf_model_name in models_tuples
]


AR_LMs = prepare_bert_lms_configuration(
(
("mbert_base_cased", "bert-base-multilingual-cased"),
("mbert_base_uncased", "bert-base-multilingual-uncased"),
Expand All @@ -38,9 +56,9 @@ def prepare_lms_configuration(models_tuples):
("marbert", "UBC-NLP/MARBERT"),
("marbert_v2", "UBC-NLP/MARBERTv2"),
)
)
) + prepare_T5_lms_configuration((("mT5_base", "google/mt5-base"),))

EN_LMs = prepare_lms_configuration(
EN_LMs = prepare_bert_lms_configuration(
(
("mbert_base_cased", "bert-base-multilingual-cased"),
("mbert_base_uncased", "bert-base-multilingual-uncased"),
Expand All @@ -51,9 +69,9 @@ def prepare_lms_configuration(models_tuples):
("bert-large_cased", "bert-large-cased"),
("bert-large_uncased", "bert-large-uncased"),
)
)
) + prepare_T5_lms_configuration((("mT5_base", "google/mt5-base"),("T5_base", "t5-base")))

JA_LMs = prepare_lms_configuration(
JA_LMs = prepare_bert_lms_configuration(
(
("mbert_base_cased", "bert-base-multilingual-cased"),
("mbert_base_uncased", "bert-base-multilingual-uncased"),
Expand All @@ -63,13 +81,11 @@ def prepare_lms_configuration(models_tuples):
("tohoku_bert_base_char_v2", "cl-tohoku/bert-base-japanese-char-v2"),
("tohoku_bert_large", "cl-tohoku/bert-large-japanese"),
("tohoku_bert_large_char", "cl-tohoku/bert-large-japanese-char"),
("tohoku_bert_base_word_masking", "cl-tohoku/bert-base-japanese-whole-word-masking"),
("tohoku_bert_base_word_masking", "cl-tohoku/bert-base-japanese-whole-word-masking"),
("japanese_bert_base", "colorfulscoop/bert-base-ja"),
)
)

KO_LMs = prepare_lms_configuration(
KO_LMs = prepare_bert_lms_configuration(
(
("mbert_base_cased", "bert-base-multilingual-cased"),
("mbert_base_uncased", "bert-base-multilingual-uncased"),
Expand All @@ -78,7 +94,7 @@ def prepare_lms_configuration(models_tuples):
)
)

ES_LMs = prepare_lms_configuration(
ES_LMs = prepare_bert_lms_configuration(
(
("beto_cased", "dccuchile/bert-base-spanish-wwm-cased"),
("beto_uncased", "dccuchile/bert-base-spanish-wwm-uncased"),
Expand All @@ -87,7 +103,7 @@ def prepare_lms_configuration(models_tuples):
)
)

ZH_LMs = prepare_lms_configuration(
ZH_LMs = prepare_bert_lms_configuration(
(
("mbert_base_cased", "bert-base-multilingual-cased"),
("mbert_base_uncased", "bert-base-multilingual-uncased"),
Expand Down
139 changes: 131 additions & 8 deletions mlama/scripts/run_prompting_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,119 @@
from pathlib import Path
import glob

# T5 dependencies
import os
import re
import pickle
from tqdm import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration
from eval_utils import get_T5_ranking


def run_T5_experiments(
relations_templates,
data_path_pre,
language,
device,
input_param={
"lm": "T5",
"label": "mt5_base",
"model_name": "T5",
"T5_model_name": "google/mt5-small",
},
use_dlama=False,
):
# Load the model
model_name = input_param["T5_model_name"]
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)

LOGDIR = "output" if not use_dlama else "output_dlama"

for relation in relations_templates:
relation_name = relation["relation"]

# Build the list of candidate objects
if use_dlama:
# The relation can have multiple subsets
relation_files_path = str(Path(data_path_pre, f"{relation_name}_*.jsonl"))
else:
relation_files_path = str(Path(data_path_pre, f"{relation_name}.jsonl"))

relation_files = [f for f in glob.glob(relation_files_path)]

if not relation_files:
print("Relation {} excluded.".format(relation["relation"]))
continue

relation_triples = []
for file in set(relation_files):
with open(file, "r") as f:
relation_triples += [json.loads(line) for line in f]

# TODO: Augment valid objects with normalized values
candidate_objects = [
triple["obj_label"]
if type(triple["obj_label"]) == list
else [triple["obj_label"]]
for triple in relation_triples
]

unique_candidate_objects = sorted(
set([c for c_l in candidate_objects for c in c_l])
)

relation_template = relation["template"]
triples_results = []
for triple in tqdm(relation_triples):
triple_results = {"sample": triple, "uuid": triple["uuid"]}
sub_label = triple["sub_label"]
obj_labels = triple["obj_label"]

# Find the candidate answers probabilities for this triple
answers_probabilities = get_T5_ranking(
model,
tokenizer,
unique_candidate_objects,
re.sub("[X]", sub_label, relation_template),
device,
)

# Sort the answers
sorted_answers_probabilities = sorted(
[
(answers_probabilities[answer], answer)
for answer in answers_probabilities
]
)
sorted_probablities = [t[0] for t in sorted_answers_probabilities]
sorted_answers = [t[1] for t in sorted_answers_probabilities]

# Form the output dictionary for this relation
ranks = [sorted_answers.index(obj_label) for obj_label in obj_labels]
probs = [answers_probabilities[obj_label] for obj_label in obj_labels]

triple_results["masked_topk"] = {
"ranks": ranks,
"prob_true": probs,
"predicted": sorted_answers,
"probs": sorted_probablities,
}

triples_results.append(triple_results)

log_directory = str(
Path(
LOGDIR, "results", input_param["label"], language, relation["relation"],
)
)
os.makedirs(log_directory, exist_ok=True)

# Dump the results to a .pkl file
with open("{}/result.pkl".format(log_directory), "wb") as f:
output_dict = {"list_of_results": triples_results}
pickle.dump(output_dict, f)


def run_experiments(
relations_templates,
Expand Down Expand Up @@ -135,14 +248,24 @@ def run_experiment_on_list_of_lms(
for lm in language_models:
print(lm["label"])
try:
run_experiments(
relations_templates,
data_path_pre,
language,
input_param=lm,
use_dlama=use_dlama,
device=device,
)
if "T5" in lm["label"]:
run_T5_experiments(
relations_templates,
data_path_pre,
language,
input_param=lm,
use_dlama=use_dlama,
device=device,
)
else:
run_experiments(
relations_templates,
data_path_pre,
language,
input_param=lm,
use_dlama=use_dlama,
device=device,
)
except Exception as e:
print(e)
print(f'Failed for: {lm["label"]}', file=sys.stderr)
Expand Down

0 comments on commit 9d07402

Please sign in to comment.