-
Notifications
You must be signed in to change notification settings - Fork 0
/
analyze_answers.py
152 lines (119 loc) · 4.89 KB
/
analyze_answers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import argparse
from datasets import load_dataset
import transformers
import pickle
import re
import os
from collections import defaultdict
def get_mmlu_answer(response):
if response is not None:
return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
return None
def get_mmlu_cot_answer(response):
pattern = r"The best answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "").replace("*", "")
pattern = r"the best answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")
pattern = r"The correct answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")
pattern = r"the correct answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")
def get_answer_gsm8k(response):
pattern = r"The final answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
s = match.group(1)
for ok_symbol in ["%", "$"]:
s = s.replace(ok_symbol, "")
return s
TASK_TO_ANSWER_EXTRACTOR = {
"evals__mmlu__details": get_mmlu_answer,
"evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
"evals__gsm8k__details": get_answer_gsm8k,
"evals__mmlu_pro__details": get_mmlu_cot_answer,
}
def get_dataset_from_task(task, response_path):
ds_405b = load_dataset(
f"meta-llama/Meta-Llama-3.1-405B-Instruct-evals",
f"Meta-Llama-3.1-405B-Instruct-{task}",
)
ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
if "70b" in str(response_path) or "8b" in str(response_path):
if "70" in str(response_path):
ref_model_ds = load_dataset(
f"meta-llama/Meta-Llama-3.1-70B-Instruct-evals",
f"Meta-Llama-3.1-70B-Instruct-{task}",
)
else:
ref_model_ds = load_dataset(
f"meta-llama/Meta-Llama-3.1-8B-Instruct-evals",
f"Meta-Llama-3.1-8B-Instruct-{task}",
)
hash_to_row = {}
for row in ref_model_ds["latest"]:
hash_to_row[row["input_final_prompts_hash"][0]] = row
reordered_rows = []
for prompt_hash in ds_405b_hash_order:
reordered_rows.append(hash_to_row[prompt_hash])
ref_model_ds["latest"] = reordered_rows
return ref_model_ds
return ds_405b
def main(task, response_path):
ds = get_dataset_from_task(task, response_path)
responses = []
total = len(ds["latest"])
for i in range(0, total):
response = pickle.load(
open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
)
responses.append(response)
from dataclasses import dataclass
@dataclass
class Stats:
correct: int = 0
total: int = 0
meta_correct: int = 0
average : float = None
subtask_name_to_stats = defaultdict(lambda: Stats())
for response, ds_row in zip(responses, ds["latest"]):
model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
subtask = ds_row["subtask_name"]
is_eval_correct = model_answer in ds_row["input_correct_responses"]
if is_eval_correct:
subtask_name_to_stats[subtask].correct += 1
if ds_row["is_correct"]:
subtask_name_to_stats[subtask].meta_correct += 1
subtask_name_to_stats[subtask].total += 1
micro_stats = Stats()
for subtask, stats in subtask_name_to_stats.items():
stats.average = stats.correct / stats.total
stats.meta_average = stats.meta_correct / stats.total
micro_stats.correct += stats.correct
micro_stats.total += stats.total
micro_stats.meta_correct += stats.meta_correct
micro_stats.average = micro_stats.correct/micro_stats.total
micro_stats.meta_average = micro_stats.meta_correct/micro_stats.total
import numpy as np
print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
print("Meta Macro average", np.mean([x.meta_average for x in subtask_name_to_stats.values()]))
print("Micro average", micro_stats.average)
print("Meta Micro average", micro_stats.meta_average)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process some tasks and responses.")
parser.add_argument("--task", type=str, required=True, help="The task to process")
parser.add_argument(
"--response-path",
type=str,
required=True,
help="The path to read responses from",
)
args = parser.parse_args()
main(args.task, args.response_path)