forked from openai/simple-evals
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mmlu_eval.py
109 lines (101 loc) · 4.07 KB
/
mmlu_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Measuring Massive Multitask Language Understanding
Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
https://arxiv.org/abs/2009.03300
"""
import random
import re
import blobfile as bf
import pandas
from . import common
from .common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
subject2category = {
"abstract_algebra": "stem",
"anatomy": "other",
"astronomy": "stem",
"business_ethics": "other",
"clinical_knowledge": "other",
"college_biology": "stem",
"college_chemistry": "stem",
"college_computer_science": "stem",
"college_mathematics": "stem",
"college_medicine": "other",
"college_physics": "stem",
"computer_security": "stem",
"conceptual_physics": "stem",
"econometrics": "social_sciences",
"electrical_engineering": "stem",
"elementary_mathematics": "stem",
"formal_logic": "humanities",
"global_facts": "other",
"high_school_biology": "stem",
"high_school_chemistry": "stem",
"high_school_computer_science": "stem",
"high_school_european_history": "humanities",
"high_school_geography": "social_sciences",
"high_school_government_and_politics": "social_sciences",
"high_school_macroeconomics": "social_sciences",
"high_school_mathematics": "stem",
"high_school_microeconomics": "social_sciences",
"high_school_physics": "stem",
"high_school_psychology": "social_sciences",
"high_school_statistics": "stem",
"high_school_us_history": "humanities",
"high_school_world_history": "humanities",
"human_aging": "other",
"human_sexuality": "social_sciences",
"international_law": "humanities",
"jurisprudence": "humanities",
"logical_fallacies": "humanities",
"machine_learning": "stem",
"management": "other",
"marketing": "other",
"medical_genetics": "other",
"miscellaneous": "other",
"moral_disputes": "humanities",
"moral_scenarios": "humanities",
"nutrition": "other",
"philosophy": "humanities",
"prehistory": "humanities",
"professional_accounting": "other",
"professional_law": "humanities",
"professional_medicine": "other",
"professional_psychology": "social_sciences",
"public_relations": "social_sciences",
"security_studies": "social_sciences",
"sociology": "social_sciences",
"us_foreign_policy": "social_sciences",
"virology": "other",
"world_religions": "humanities",
}
class MMLUEval(Eval):
def __init__(self, num_examples: int | None = None):
df = pandas.read_csv(
bf.BlobFile("https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv")
)
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
examples = random.Random(0).sample(examples, num_examples)
self.examples = examples
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
prompt_messages = [
sampler._pack_message(content=format_multichoice_question(row), role="user")
]
response_text = sampler(prompt_messages)
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
extracted_answer = match.group(1) if match else None
score = 1.0 if extracted_answer == row["Answer"] else 0.0
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=row["Answer"],
extracted_answer=extracted_answer,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
category = subject2category.get(row["Subject"], "other")
return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo)
results = common.map_with_progress(fn, self.examples)
return common.aggregate_results(results)