-
Notifications
You must be signed in to change notification settings - Fork 0
/
encoder.py
56 lines (37 loc) · 1.87 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Encoder Using CLIP.
import torch
import torch.nn as nn
from transformers import CLIPModel, AutoTokenizer, AutoProcessor, CLIPTextModel, CLIPVisionModelWithProjection
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "openai/clip-vit-base-patch32"
class Encoder(nn.Module):
def __init__(self, model_id = model_id):
super(Encoder, self).__init__()
self.model = CLIPModel.from_pretrained(model_id)
for p in self.model.parameters():
p.requires_grad = False
self.preprocess = AutoProcessor.from_pretrained(model_id)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.text_encoder = CLIPTextModel.from_pretrained(model_id)
for p in self.text_encoder.parameters():
p.requires_grad = False
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_id, output_hidden_states = True)
for p in self.image_encoder.parameters():
p.requires_grad = False
def textForward(self, prompt):
tokenized = self.tokenizer([prompt], padding=True, return_tensors='pt')
prompt_embedding = self.text_encoder(**tokenized)
return prompt_embedding
def visualForward(self, image):
preprocessed = self.preprocess(images=image, return_tensors='pt')
image_embedding = self.image_encoder(**preprocessed)
return image_embedding
def forward(self, image, prompt, layers = [8, 9, 10, 11]):
text_op = self.textForward(prompt)
image_op_temp = self.visualForward(image)
text_encoding = text_op[1]
image_encoding = image_op_temp[0]
mid_layers = []
for i in range(len(layers)):
mid_layers.append(image_op_temp['hidden_states'][layers[i]])
return text_encoding, image_encoding, mid_layers