forked from clovaai/donut
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
executable file
·106 lines (83 loc) · 3.52 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Donut
Copyright (c) 2022-present NAVER Corp.
MIT License
"""
import argparse
import json
import os
import re
from pathlib import Path
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm
from donut import DonutModel, JSONParseEvaluator, load_json, save_json
def test(args):
pretrained_model = DonutModel.from_pretrained(args.pretrained_model_name_or_path)
if torch.cuda.is_available():
pretrained_model.half()
pretrained_model.to("cuda")
pretrained_model.eval()
if args.save_path:
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
predictions = []
ground_truths = []
accs = []
evaluator = JSONParseEvaluator()
dataset = load_dataset(args.dataset_name_or_path, split=args.split)
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
ground_truth = json.loads(sample["ground_truth"])
if args.task_name == "docvqa":
print(f"INPUT:: <s_{args.task_name}><s_question>{ground_truth['gt_parses'][0]['question'].lower()}</s_question><s_answer>")
output = pretrained_model.inference(
image=sample["image"],
prompt=f"<s_{args.task_name}><s_question>{ground_truth['gt_parses'][0]['question'].lower()}</s_question><s_answer>",
)["predictions"][0]
output_confidences = pretrained_model.inference(
image=sample["image"],
prompt=f"<s_{args.task_name}><s_question>{ground_truth['gt_parses'][0]['question'].lower()}</s_question><s_answer>",
)
print(f"OUTPUT:: {output_confidences}")
else:
output = pretrained_model.inference(image=sample["image"], prompt=f"<s_{args.task_name}>")["predictions"][0]
if args.task_name == "rvlcdip":
gt = ground_truth["gt_parse"]
score = float(output["class"] == gt["class"])
elif args.task_name == "docvqa":
# Note: we evaluated the model on the official website.
# In this script, an exact-match based score will be returned instead
gt = ground_truth["gt_parses"]
answers = set([qa_parse["answer"] for qa_parse in gt])
score = float(output["answer"] in answers)
else:
gt = ground_truth["gt_parse"]
score = evaluator.cal_acc(output, gt)
accs.append(score)
predictions.append(output)
ground_truths.append(gt)
scores = {
"ted_accuracies": accs,
"ted_accuracy": np.mean(accs),
"f1_accuracy": evaluator.cal_f1(predictions, ground_truths),
}
print(
f"Total number of samples: {len(accs)}, Tree Edit Distance (TED) based accuracy score: {scores['ted_accuracy']}, F1 accuracy score: {scores['f1_accuracy']}"
)
if args.save_path:
scores["predictions"] = predictions
scores["ground_truths"] = ground_truths
save_json(args.save_path, scores)
return predictions
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_name_or_path", type=str)
parser.add_argument("--dataset_name_or_path", type=str)
parser.add_argument("--split", type=str, default="test")
parser.add_argument("--task_name", type=str, default=None)
parser.add_argument("--save_path", type=str, default=None)
args, left_argv = parser.parse_known_args()
if args.task_name is None:
args.task_name = os.path.basename(args.dataset_name_or_path)
predictions = test(args)