-
Notifications
You must be signed in to change notification settings - Fork 0
/
encode_captions.py
91 lines (74 loc) · 3.04 KB
/
encode_captions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
from pathlib import Path
import shutil
import h5py
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer, LightningModule, seed_everything
import clip
import sys
sys.path.append('.')
from dataset import VisualGenomeCaptions
class CaptionDB(LightningModule):
def __init__(self, save_dir):
super().__init__()
self.save_dir = save_dir
self.model, _ = clip.load(args.model, device="cpu")
def test_step(self, batch, batch_idx):
captions, tokens = batch
x = self.model.token_embedding(tokens).type(self.model.dtype) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding.type(self.model.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.model.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
keys = self.model.ln_final(x).type(self.model.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
values = x[torch.arange(len(x)), tokens.argmax(dim=-1)]
keys = keys[torch.arange(len(x)), tokens.argmax(dim=-1)] @ self.model.text_projection
keys /= keys.norm(dim=-1, keepdim=True)
values = values.detach().cpu().numpy()
keys = keys.detach().cpu().numpy()
with h5py.File(self.save_dir/"caption_db.hdf5", "a") as f:
g = f.create_group(str(batch_idx))
g.create_dataset("keys", data=keys, compression="gzip")
g.create_dataset("values", data=values, compression="gzip")
g.create_dataset("captions", data=captions, compression="gzip")
def encode_captions(args):
dset = VisualGenomeCaptions(args.ann_dir, clip.tokenize)
dloader = DataLoader(
dataset=dset,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers
)
cap_db = CaptionDB(args.save_dir)
trainer = Trainer(
gpus=[args.device, ],
deterministic=True,
benchmark=False,
default_root_dir=args.save_dir
)
trainer.test(cap_db, dloader)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Encode captions')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--exp_name', type=str, default='captions_db')
parser.add_argument('--ann_dir', type=str, default='datasets/visual_genome')
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--num_workers', type=int, default=6)
parser.add_argument(
"--model", type=str, default="ViT-L/14",
choices=[
"RN50", "RN101", "RN50x4", "RN50x16", "RN50x64",
"ViT-B/32", "ViT-B/16", "ViT-L/14"
]
)
args = parser.parse_args()
setattr(args, "save_dir", Path("outputs")/args.exp_name)
shutil.rmtree(args.save_dir, ignore_errors=True)
args.save_dir.mkdir(parents=True, exist_ok=True)
print(args)
seed_everything(1, workers=True)
encode_captions(args)