|
0 |
from models.med import BertConfig, BertModel, BertLMHeadModel
|
|
1 |
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
|
2 |
|
|
3 |
import torch
|
|
4 |
from torch import nn
|
|
5 |
import torch.nn.functional as F
|
|
6 |
from transformers import BertTokenizer
|
|
7 |
import numpy as np
|
|
8 |
|
|
9 |
class BLIP_VQA(nn.Module):
|
|
10 |
def __init__(self,
|
|
11 |
med_config = 'configs/med_config.json',
|
|
12 |
image_size = 480,
|
|
13 |
vit = 'base',
|
|
14 |
vit_grad_ckpt = False,
|
|
15 |
vit_ckpt_layer = 0,
|
|
16 |
):
|
|
17 |
"""
|
|
18 |
Args:
|
|
19 |
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
|
20 |
image_size (int): input image size
|
|
21 |
vit (str): model size of vision transformer
|
|
22 |
"""
|
|
23 |
super().__init__()
|
|
24 |
|
|
25 |
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
|
26 |
self.tokenizer = init_tokenizer()
|
|
27 |
|
|
28 |
encoder_config = BertConfig.from_json_file(med_config)
|
|
29 |
encoder_config.encoder_width = vision_width
|
|
30 |
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
|
31 |
|
|
32 |
decoder_config = BertConfig.from_json_file(med_config)
|
|
33 |
self.text_decoder = BertLMHeadModel(config=decoder_config)
|
|
34 |
|
|
35 |
|
|
36 |
def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
|
|
37 |
|
|
38 |
image_embeds = self.visual_encoder(image)
|
|
39 |
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
|
40 |
|
|
41 |
question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
|
|
42 |
return_tensors="pt").to(image.device)
|
|
43 |
question.input_ids[:,0] = self.tokenizer.enc_token_id
|
|
44 |
|
|
45 |
if train:
|
|
46 |
'''
|
|
47 |
n: number of answers for each question
|
|
48 |
weights: weight for each answer
|
|
49 |
'''
|
|
50 |
answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
|
|
51 |
answer.input_ids[:,0] = self.tokenizer.bos_token_id
|
|
52 |
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
|
|
53 |
|
|
54 |
question_output = self.text_encoder(question.input_ids,
|
|
55 |
attention_mask = question.attention_mask,
|
|
56 |
encoder_hidden_states = image_embeds,
|
|
57 |
encoder_attention_mask = image_atts,
|
|
58 |
return_dict = True)
|
|
59 |
|
|
60 |
question_states = []
|
|
61 |
question_atts = []
|
|
62 |
for b, n in enumerate(n):
|
|
63 |
question_states += [question_output.last_hidden_state[b]]*n
|
|
64 |
question_atts += [question.attention_mask[b]]*n
|
|
65 |
question_states = torch.stack(question_states,0)
|
|
66 |
question_atts = torch.stack(question_atts,0)
|
|
67 |
|
|
68 |
answer_output = self.text_decoder(answer.input_ids,
|
|
69 |
attention_mask = answer.attention_mask,
|
|
70 |
encoder_hidden_states = question_states,
|
|
71 |
encoder_attention_mask = question_atts,
|
|
72 |
labels = answer_targets,
|
|
73 |
return_dict = True,
|
|
74 |
reduction = 'none',
|
|
75 |
)
|
|
76 |
|
|
77 |
loss = weights * answer_output.loss
|
|
78 |
loss = loss.sum()/image.size(0)
|
|
79 |
|
|
80 |
return loss
|
|
81 |
|
|
82 |
|
|
83 |
else:
|
|
84 |
question_output = self.text_encoder(question.input_ids,
|
|
85 |
attention_mask = question.attention_mask,
|
|
86 |
encoder_hidden_states = image_embeds,
|
|
87 |
encoder_attention_mask = image_atts,
|
|
88 |
return_dict = True)
|
|
89 |
|
|
90 |
if inference=='generate':
|
|
91 |
num_beams = 3
|
|
92 |
question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
|
|
93 |
question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
|
|
94 |
model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
|
|
95 |
|
|
96 |
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
|
|
97 |
|
|
98 |
outputs = self.text_decoder.generate(input_ids=bos_ids,
|
|
99 |
max_length=10,
|
|
100 |
min_length=1,
|
|
101 |
num_beams=num_beams,
|
|
102 |
eos_token_id=self.tokenizer.sep_token_id,
|
|
103 |
pad_token_id=self.tokenizer.pad_token_id,
|
|
104 |
**model_kwargs)
|
|
105 |
|
|
106 |
answers = []
|
|
107 |
for output in outputs:
|
|
108 |
answer = self.tokenizer.decode(output, skip_special_tokens=True)
|
|
109 |
answers.append(answer)
|
|
110 |
return answers
|
|
111 |
|
|
112 |
elif inference=='rank':
|
|
113 |
max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
|
|
114 |
answer.input_ids, answer.attention_mask, k_test)
|
|
115 |
return max_ids
|
|
116 |
|
|
117 |
|
|
118 |
|
|
119 |
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
|
|
120 |
|
|
121 |
num_ques = question_states.size(0)
|
|
122 |
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
|
|
123 |
|
|
124 |
start_output = self.text_decoder(start_ids,
|
|
125 |
encoder_hidden_states = question_states,
|
|
126 |
encoder_attention_mask = question_atts,
|
|
127 |
return_dict = True,
|
|
128 |
reduction = 'none')
|
|
129 |
logits = start_output.logits[:,0,:] # first token's logit
|
|
130 |
|
|
131 |
# topk_probs: top-k probability
|
|
132 |
# topk_ids: [num_question, k]
|
|
133 |
answer_first_token = answer_ids[:,1]
|
|
134 |
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
|
|
135 |
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
|
|
136 |
|
|
137 |
# answer input: [num_question*k, answer_len]
|
|
138 |
input_ids = []
|
|
139 |
input_atts = []
|
|
140 |
for b, topk_id in enumerate(topk_ids):
|
|
141 |
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
|
|
142 |
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
|
143 |
input_ids = torch.cat(input_ids,dim=0)
|
|
144 |
input_atts = torch.cat(input_atts,dim=0)
|
|
145 |
|
|
146 |
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
|
|
147 |
|
|
148 |
# repeat encoder's output for top-k answers
|
|
149 |
question_states = tile(question_states, 0, k)
|
|
150 |
question_atts = tile(question_atts, 0, k)
|
|
151 |
|
|
152 |
output = self.text_decoder(input_ids,
|
|
153 |
attention_mask = input_atts,
|
|
154 |
encoder_hidden_states = question_states,
|
|
155 |
encoder_attention_mask = question_atts,
|
|
156 |
labels = targets_ids,
|
|
157 |
return_dict = True,
|
|
158 |
reduction = 'none')
|
|
159 |
|
|
160 |
log_probs_sum = -output.loss
|
|
161 |
log_probs_sum = log_probs_sum.view(num_ques,k)
|
|
162 |
|
|
163 |
max_topk_ids = log_probs_sum.argmax(dim=1)
|
|
164 |
max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
|
|
165 |
|
|
166 |
return max_ids
|
|
167 |
|
|
168 |
|
|
169 |
def blip_vqa(pretrained='',**kwargs):
|
|
170 |
model = BLIP_VQA(**kwargs)
|
|
171 |
if pretrained:
|
|
172 |
model,msg = load_checkpoint(model,pretrained)
|
|
173 |
# assert(len(msg.missing_keys)==0)
|
|
174 |
return model
|
|
175 |
|
|
176 |
|
|
177 |
def tile(x, dim, n_tile):
|
|
178 |
init_dim = x.size(dim)
|
|
179 |
repeat_idx = [1] * x.dim()
|
|
180 |
repeat_idx[dim] = n_tile
|
|
181 |
x = x.repeat(*(repeat_idx))
|
|
182 |
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
|
|
183 |
return torch.index_select(x, dim, order_index.to(x.device))
|
|
184 |
|
|
185 |
⏎
|