diff --git a/demo.ipynb b/demo.ipynb index bb172bc..72924c8 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -2,10 +2,19 @@ "cells": [ { "cell_type": "code", - "execution_count": null, - "id": "bff972b0", - "metadata": {}, - "outputs": [], + "execution_count": 1, + "id": "e468c4d7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2022-06-22 10:15:14.085329: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory\n", + "2022-06-22 10:15:14.085364: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" + ] + } + ], "source": [ "from PIL import Image\n", "import requests\n", @@ -18,10 +27,96 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "eb3cb644", - "metadata": {}, - "outputs": [], + "execution_count": 2, + "id": "b88c91d8", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2aac41dcbabc4b488d73cbf4efc48ae5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4ac0b9f1c76841d6a01ecbcec1151547", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ede9398befdf4f4aa1da5e73ada4f558", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1253d25503bc486dbae2516355c1a7d7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "load checkpoint from ./ckpt/model_base_vqa_capfilt_large.pth\n" + ] + } + ], "source": [ "image_size = 480\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", @@ -32,8 +127,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "fdca9681", + "execution_count": 3, + "id": "97b3f806", "metadata": {}, "outputs": [], "source": [ @@ -49,8 +144,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "079b131e", + "execution_count": 4, + "id": "8ff0f445", "metadata": {}, "outputs": [], "source": [ @@ -71,8 +166,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "2744fa59", + "execution_count": 5, + "id": "28558fef", "metadata": {}, "outputs": [], "source": [ @@ -86,8 +181,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "259b4290", + "execution_count": 6, + "id": "f1759896", "metadata": {}, "outputs": [], "source": [ @@ -103,18 +198,37 @@ }, { "cell_type": "code", + "execution_count": 8, + "id": "c8ce3a9a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Runned time: 1.788 s\n", + "Answer : woman and dog\n" + ] + }, + { + "data": { + "text/plain": [ + "{'Answer': 'woman and dog'}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "handle({'Photo': './img/demo.jpg', 'Question': 'What is in this image?'})" + ] + }, + { + "cell_type": "code", "execution_count": null, - "id": "4dfc8004", - "metadata": {}, - "outputs": [], - "source": [ - "# handle({'Photo': './img/demo.jpg', 'Question': 'What is in this image?'})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34e85fb7", + "id": "c1668fb5", "metadata": {}, "outputs": [], "source": [] diff --git a/models/.ipynb_checkpoints/blip_vqa-checkpoint.py b/models/.ipynb_checkpoints/blip_vqa-checkpoint.py new file mode 100644 index 0000000..d4cb368 --- /dev/null +++ b/models/.ipynb_checkpoints/blip_vqa-checkpoint.py @@ -0,0 +1,186 @@ +from models.med import BertConfig, BertModel, BertLMHeadModel +from models.blip import create_vit, init_tokenizer, load_checkpoint + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import BertTokenizer +import numpy as np + +class BLIP_VQA(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 480, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) + self.tokenizer = init_tokenizer() + + encoder_config = BertConfig.from_json_file(med_config) + encoder_config.encoder_width = vision_width + self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) + + decoder_config = BertConfig.from_json_file(med_config) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + + def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, + return_tensors="pt").to(image.device) + question.input_ids[:,0] = self.tokenizer.enc_token_id + + if train: + ''' + n: number of answers for each question + weights: weight for each answer + ''' + answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) + answer.input_ids[:,0] = self.tokenizer.bos_token_id + answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) + + question_output = self.text_encoder(question.input_ids, + attention_mask = question.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True) + + question_states = [] + question_atts = [] + for b, n in enumerate(n): + question_states += [question_output.last_hidden_state[b]]*n + question_atts += [question.attention_mask[b]]*n + question_states = torch.stack(question_states,0) + question_atts = torch.stack(question_atts,0) + + answer_output = self.text_decoder(answer.input_ids, + attention_mask = answer.attention_mask, + encoder_hidden_states = question_states, + encoder_attention_mask = question_atts, + labels = answer_targets, + return_dict = True, + reduction = 'none', + ) + + loss = weights * answer_output.loss + loss = loss.sum()/image.size(0) + + return loss + + + else: + question_output = self.text_encoder(question.input_ids, + attention_mask = question.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True) + + if inference=='generate': + num_beams = 3 + question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) + question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) + model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} + + bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) + + outputs = self.text_decoder.generate(input_ids=bos_ids, + max_length=10, + min_length=1, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + **model_kwargs) + + answers = [] + for output in outputs: + answer = self.tokenizer.decode(output, skip_special_tokens=True) + answers.append(answer) + return answers + + elif inference=='rank': + max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, + answer.input_ids, answer.attention_mask, k_test) + return max_ids + + + + def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): + + num_ques = question_states.size(0) + start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token + + start_output = self.text_decoder(start_ids, + encoder_hidden_states = question_states, + encoder_attention_mask = question_atts, + return_dict = True, + reduction = 'none') + logits = start_output.logits[:,0,:] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:,1] + prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) + topk_probs, topk_ids = prob_first_token.topk(k,dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids,dim=0) + input_atts = torch.cat(input_atts,dim=0) + + targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) + + # repeat encoder's output for top-k answers + question_states = tile(question_states, 0, k) + question_atts = tile(question_atts, 0, k) + + output = self.text_decoder(input_ids, + attention_mask = input_atts, + encoder_hidden_states = question_states, + encoder_attention_mask = question_atts, + labels = targets_ids, + return_dict = True, + reduction = 'none') + + log_probs_sum = -output.loss + log_probs_sum = log_probs_sum.view(num_ques,k) + + max_topk_ids = log_probs_sum.argmax(dim=1) + max_ids = topk_ids[max_topk_ids>=0,max_topk_ids] + + return max_ids + + +def blip_vqa(pretrained='',**kwargs): + model = BLIP_VQA(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) +# assert(len(msg.missing_keys)==0) + return model + + +def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) + return torch.index_select(x, dim, order_index.to(x.device)) + + \ No newline at end of file diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..ecbaafd --- /dev/null +++ b/requirement.txt @@ -0,0 +1,4 @@ +timm==0.4.12 +transformers==4.15.0 +fairscale==0.4.4 +pycocoevalcap \ No newline at end of file