-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
71 lines (57 loc) · 2.59 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# encoding = "utf-8"
import argparse
import json
from scorers.delta import Delta_Scorer
from load_dataset import Dataset
from transformers import LlamaTokenizer, LlamaForCausalLM
import torch
def main(args):
model = LlamaForCausalLM.from_pretrained(args.pretrained_name, torch_dtype=torch.float16).to(args.device)
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained_name)
model.eval()
for file_name in ["summeval", "xsumfaith", "qagsxsum", "qagscnn", "frankxsum", "frankcnn", \
"summac-summeval", "summac-xsumfaith", "summac-cogensum", "summac-factcc", "summac-polytope",
"summac-frank"]:
args.file_name = file_name
print("Testing on {} {}".format(args.file_name, args.scorer))
'''load dataset'''
dataset = Dataset(args.file_name)
source_lines = dataset.source_lines
target_lines = dataset.target_lines
human_scores = dataset.human_scores
data = dataset.data
'''get scores'''
scorer = Delta_Scorer(model=model, tokenizer=tokenizer, pretrained_name=args.pretrained_name,
device=args.device)
s2s_tok_list, lm_tok_list, prefix_tok_list, s2s_tok_list_doc, lm_tok_list_doc = scorer.compute(sources=source_lines,
targets=target_lines,
seperator="TL;DR ")
'''save to files'''
model_name = {"decapoda-research/llama-7b-hf": "llama7b",
"decapoda-research/llama-13b-hf": "llama13b",
"openlm-research/open_llama_3b": "llama3b",
}
outputpath = "./output/" + str(args.file_name) + "-fflm-" + model_name[
args.pretrained_name] + ".jsonl"
outputfile = open(outputpath, "a+")
for d, s2s, lm, pf, s2s_doc, lm_doc in zip(data, s2s_tok_list, lm_tok_list, prefix_tok_list, s2s_tok_list_doc, lm_tok_list_doc):
d["s2s_tok_list"] = s2s
d["lm_tok_list"] = lm
d["prefix_tok_list"] = pf
d["s2s_tok_list_1"] = s2s_doc
d["lm_tok_list_1"] = lm_doc
outputfile.write(json.dumps(d) + "\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_name",
type=str,
default="decapoda-research/llama-7b-hf"
)
parser.add_argument(
"--device",
type=str,
default="cuda"
)
args = parser.parse_args()
main(args)