-
Notifications
You must be signed in to change notification settings - Fork 53
/
predict.py
85 lines (65 loc) · 2.62 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from transformers import BertTokenizer
from PIL import Image
import argparse
from models import caption
from datasets import coco, utils
from configuration import Config
import os
parser = argparse.ArgumentParser(description='Image Captioning')
parser.add_argument('--path', type=str, help='path to image', required=True)
parser.add_argument('--v', type=str, help='version', default='v3')
parser.add_argument('--checkpoint', type=str, help='checkpoint path', default=None)
args = parser.parse_args()
image_path = args.path
version = args.v
checkpoint_path = args.checkpoint
config = Config()
if version == 'v1':
model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True)
elif version == 'v2':
model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True)
elif version == 'v3':
model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)
else:
print("Checking for checkpoint.")
if checkpoint_path is None:
raise NotImplementedError('No model to chose from!')
else:
if not os.path.exists(checkpoint_path):
raise NotImplementedError('Give valid checkpoint path')
print("Found checkpoint! Loading!")
model,_ = caption.build_model(config)
print("Loading Checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token)
image = Image.open(image_path)
image = coco.val_transform(image)
image = image.unsqueeze(0)
def create_caption_and_mask(start_token, max_length):
caption_template = torch.zeros((1, max_length), dtype=torch.long)
mask_template = torch.ones((1, max_length), dtype=torch.bool)
caption_template[:, 0] = start_token
mask_template[:, 0] = False
return caption_template, mask_template
caption, cap_mask = create_caption_and_mask(
start_token, config.max_position_embeddings)
@torch.no_grad()
def evaluate():
model.eval()
for i in range(config.max_position_embeddings - 1):
predictions = model(image, caption, cap_mask)
predictions = predictions[:, i, :]
predicted_id = torch.argmax(predictions, axis=-1)
if predicted_id[0] == 102:
return caption
caption[:, i+1] = predicted_id[0]
cap_mask[:, i+1] = False
return caption
output = evaluate()
result = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
#result = tokenizer.decode(output[0], skip_special_tokens=True)
print(result.capitalize())