fe35cc9
chenzhiqing 3 years ago
3 changed file(s) with 331 addition(s) and 27 deletion(s). Raw diff Collapse all Expand all
11 "cells": [
22 {
33 "cell_type": "code",
4 "execution_count": null,
5 "id": "bff972b0",
6 "metadata": {},
7 "outputs": [],
4 "execution_count": 1,
5 "id": "e468c4d7",
6 "metadata": {},
7 "outputs": [
8 {
9 "name": "stderr",
10 "output_type": "stream",
11 "text": [
12 "2022-06-22 10:15:14.085329: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory\n",
13 "2022-06-22 10:15:14.085364: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
14 ]
15 }
16 ],
817 "source": [
918 "from PIL import Image\n",
1019 "import requests\n",
1726 },
1827 {
1928 "cell_type": "code",
20 "execution_count": null,
21 "id": "eb3cb644",
22 "metadata": {},
23 "outputs": [],
29 "execution_count": 2,
30 "id": "b88c91d8",
31 "metadata": {},
32 "outputs": [
33 {
34 "data": {
35 "application/vnd.jupyter.widget-view+json": {
36 "model_id": "2aac41dcbabc4b488d73cbf4efc48ae5",
37 "version_major": 2,
38 "version_minor": 0
39 },
40 "text/plain": [
41 "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…"
42 ]
43 },
44 "metadata": {},
45 "output_type": "display_data"
46 },
47 {
48 "name": "stdout",
49 "output_type": "stream",
50 "text": [
51 "\n"
52 ]
53 },
54 {
55 "data": {
56 "application/vnd.jupyter.widget-view+json": {
57 "model_id": "4ac0b9f1c76841d6a01ecbcec1151547",
58 "version_major": 2,
59 "version_minor": 0
60 },
61 "text/plain": [
62 "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…"
63 ]
64 },
65 "metadata": {},
66 "output_type": "display_data"
67 },
68 {
69 "name": "stdout",
70 "output_type": "stream",
71 "text": [
72 "\n"
73 ]
74 },
75 {
76 "data": {
77 "application/vnd.jupyter.widget-view+json": {
78 "model_id": "ede9398befdf4f4aa1da5e73ada4f558",
79 "version_major": 2,
80 "version_minor": 0
81 },
82 "text/plain": [
83 "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…"
84 ]
85 },
86 "metadata": {},
87 "output_type": "display_data"
88 },
89 {
90 "name": "stdout",
91 "output_type": "stream",
92 "text": [
93 "\n"
94 ]
95 },
96 {
97 "data": {
98 "application/vnd.jupyter.widget-view+json": {
99 "model_id": "1253d25503bc486dbae2516355c1a7d7",
100 "version_major": 2,
101 "version_minor": 0
102 },
103 "text/plain": [
104 "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…"
105 ]
106 },
107 "metadata": {},
108 "output_type": "display_data"
109 },
110 {
111 "name": "stdout",
112 "output_type": "stream",
113 "text": [
114 "\n",
115 "load checkpoint from ./ckpt/model_base_vqa_capfilt_large.pth\n"
116 ]
117 }
118 ],
24119 "source": [
25120 "image_size = 480\n",
26121 "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
31126 },
32127 {
33128 "cell_type": "code",
34 "execution_count": null,
35 "id": "fdca9681",
129 "execution_count": 3,
130 "id": "97b3f806",
36131 "metadata": {},
37132 "outputs": [],
38133 "source": [
48143 },
49144 {
50145 "cell_type": "code",
51 "execution_count": null,
52 "id": "079b131e",
146 "execution_count": 4,
147 "id": "8ff0f445",
53148 "metadata": {},
54149 "outputs": [],
55150 "source": [
70165 },
71166 {
72167 "cell_type": "code",
73 "execution_count": null,
74 "id": "2744fa59",
168 "execution_count": 5,
169 "id": "28558fef",
75170 "metadata": {},
76171 "outputs": [],
77172 "source": [
85180 },
86181 {
87182 "cell_type": "code",
88 "execution_count": null,
89 "id": "259b4290",
183 "execution_count": 6,
184 "id": "f1759896",
90185 "metadata": {},
91186 "outputs": [],
92187 "source": [
102197 },
103198 {
104199 "cell_type": "code",
200 "execution_count": 8,
201 "id": "c8ce3a9a",
202 "metadata": {},
203 "outputs": [
204 {
205 "name": "stdout",
206 "output_type": "stream",
207 "text": [
208 "Runned time: 1.788 s\n",
209 "Answer : woman and dog\n"
210 ]
211 },
212 {
213 "data": {
214 "text/plain": [
215 "{'Answer': 'woman and dog'}"
216 ]
217 },
218 "execution_count": 8,
219 "metadata": {},
220 "output_type": "execute_result"
221 }
222 ],
223 "source": [
224 "handle({'Photo': './img/demo.jpg', 'Question': 'What is in this image?'})"
225 ]
226 },
227 {
228 "cell_type": "code",
105229 "execution_count": null,
106 "id": "4dfc8004",
107 "metadata": {},
108 "outputs": [],
109 "source": [
110 "# handle({'Photo': './img/demo.jpg', 'Question': 'What is in this image?'})"
111 ]
112 },
113 {
114 "cell_type": "code",
115 "execution_count": null,
116 "id": "34e85fb7",
230 "id": "c1668fb5",
117231 "metadata": {},
118232 "outputs": [],
119233 "source": []
0 from models.med import BertConfig, BertModel, BertLMHeadModel
1 from models.blip import create_vit, init_tokenizer, load_checkpoint
2
3 import torch
4 from torch import nn
5 import torch.nn.functional as F
6 from transformers import BertTokenizer
7 import numpy as np
8
9 class BLIP_VQA(nn.Module):
10 def __init__(self,
11 med_config = 'configs/med_config.json',
12 image_size = 480,
13 vit = 'base',
14 vit_grad_ckpt = False,
15 vit_ckpt_layer = 0,
16 ):
17 """
18 Args:
19 med_config (str): path for the mixture of encoder-decoder model's configuration file
20 image_size (int): input image size
21 vit (str): model size of vision transformer
22 """
23 super().__init__()
24
25 self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
26 self.tokenizer = init_tokenizer()
27
28 encoder_config = BertConfig.from_json_file(med_config)
29 encoder_config.encoder_width = vision_width
30 self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
31
32 decoder_config = BertConfig.from_json_file(med_config)
33 self.text_decoder = BertLMHeadModel(config=decoder_config)
34
35
36 def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
37
38 image_embeds = self.visual_encoder(image)
39 image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
40
41 question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
42 return_tensors="pt").to(image.device)
43 question.input_ids[:,0] = self.tokenizer.enc_token_id
44
45 if train:
46 '''
47 n: number of answers for each question
48 weights: weight for each answer
49 '''
50 answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
51 answer.input_ids[:,0] = self.tokenizer.bos_token_id
52 answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
53
54 question_output = self.text_encoder(question.input_ids,
55 attention_mask = question.attention_mask,
56 encoder_hidden_states = image_embeds,
57 encoder_attention_mask = image_atts,
58 return_dict = True)
59
60 question_states = []
61 question_atts = []
62 for b, n in enumerate(n):
63 question_states += [question_output.last_hidden_state[b]]*n
64 question_atts += [question.attention_mask[b]]*n
65 question_states = torch.stack(question_states,0)
66 question_atts = torch.stack(question_atts,0)
67
68 answer_output = self.text_decoder(answer.input_ids,
69 attention_mask = answer.attention_mask,
70 encoder_hidden_states = question_states,
71 encoder_attention_mask = question_atts,
72 labels = answer_targets,
73 return_dict = True,
74 reduction = 'none',
75 )
76
77 loss = weights * answer_output.loss
78 loss = loss.sum()/image.size(0)
79
80 return loss
81
82
83 else:
84 question_output = self.text_encoder(question.input_ids,
85 attention_mask = question.attention_mask,
86 encoder_hidden_states = image_embeds,
87 encoder_attention_mask = image_atts,
88 return_dict = True)
89
90 if inference=='generate':
91 num_beams = 3
92 question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
93 question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
94 model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
95
96 bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
97
98 outputs = self.text_decoder.generate(input_ids=bos_ids,
99 max_length=10,
100 min_length=1,
101 num_beams=num_beams,
102 eos_token_id=self.tokenizer.sep_token_id,
103 pad_token_id=self.tokenizer.pad_token_id,
104 **model_kwargs)
105
106 answers = []
107 for output in outputs:
108 answer = self.tokenizer.decode(output, skip_special_tokens=True)
109 answers.append(answer)
110 return answers
111
112 elif inference=='rank':
113 max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
114 answer.input_ids, answer.attention_mask, k_test)
115 return max_ids
116
117
118
119 def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
120
121 num_ques = question_states.size(0)
122 start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
123
124 start_output = self.text_decoder(start_ids,
125 encoder_hidden_states = question_states,
126 encoder_attention_mask = question_atts,
127 return_dict = True,
128 reduction = 'none')
129 logits = start_output.logits[:,0,:] # first token's logit
130
131 # topk_probs: top-k probability
132 # topk_ids: [num_question, k]
133 answer_first_token = answer_ids[:,1]
134 prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
135 topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
136
137 # answer input: [num_question*k, answer_len]
138 input_ids = []
139 input_atts = []
140 for b, topk_id in enumerate(topk_ids):
141 input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
142 input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
143 input_ids = torch.cat(input_ids,dim=0)
144 input_atts = torch.cat(input_atts,dim=0)
145
146 targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
147
148 # repeat encoder's output for top-k answers
149 question_states = tile(question_states, 0, k)
150 question_atts = tile(question_atts, 0, k)
151
152 output = self.text_decoder(input_ids,
153 attention_mask = input_atts,
154 encoder_hidden_states = question_states,
155 encoder_attention_mask = question_atts,
156 labels = targets_ids,
157 return_dict = True,
158 reduction = 'none')
159
160 log_probs_sum = -output.loss
161 log_probs_sum = log_probs_sum.view(num_ques,k)
162
163 max_topk_ids = log_probs_sum.argmax(dim=1)
164 max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
165
166 return max_ids
167
168
169 def blip_vqa(pretrained='',**kwargs):
170 model = BLIP_VQA(**kwargs)
171 if pretrained:
172 model,msg = load_checkpoint(model,pretrained)
173 # assert(len(msg.missing_keys)==0)
174 return model
175
176
177 def tile(x, dim, n_tile):
178 init_dim = x.size(dim)
179 repeat_idx = [1] * x.dim()
180 repeat_idx[dim] = n_tile
181 x = x.repeat(*(repeat_idx))
182 order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
183 return torch.index_select(x, dim, order_index.to(x.device))
184
185
0 timm==0.4.12
1 transformers==4.15.0
2 fairscale==0.4.4
3 pycocoevalcap