master
/ demo.ipynb

demo.ipynb @04a6333

e6917e8
 
 
 
fe35cc9
04a6333
e6917e8
fe35cc9
 
 
 
 
04a6333
 
fe35cc9
 
 
e6917e8
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
04a6333
e6917e8
fe35cc9
 
 
 
04a6333
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04a6333
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04a6333
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04a6333
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6917e8
 
 
 
 
 
 
 
 
 
fe35cc9
04a6333
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
04a6333
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
04a6333
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
04a6333
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04a6333
 
e6917e8
fe35cc9
 
 
 
 
04a6333
fe35cc9
 
 
 
 
 
 
 
 
04a6333
fe35cc9
 
 
 
e6917e8
fe35cc9
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3f8becd7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-06-22 16:01:19.315518: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory\n",
      "2022-06-22 16:01:19.315554: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "import torch, numpy as np\n",
    "from torchvision import transforms\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "from models.blip_vqa import blip_vqa\n",
    "from keras.preprocessing import image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1e193909",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d64bc197cae04268b5496da5d322da16",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fcab18055e1b47038d1102f8e250b25b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b00264e7496f4f0cbae704a96d5f52d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94212645b9954e439ce10238528af598",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "load checkpoint from ./ckpt/model_base_vqa_capfilt_large.pth\n"
     ]
    }
   ],
   "source": [
    "image_size = 480\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model_url = './ckpt/model_base_vqa_capfilt_large.pth'\n",
    "model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7c8bc1e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "def run_time(func):\n",
    "    def inner(model, image, question):\n",
    "        back = func(model, image, question)\n",
    "        print(\"Runned time: {} s\".format(round((time.time() - t)/10, 3)))\n",
    "        return back\n",
    "    t = time.time()\n",
    "    return inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "38495b44",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_demo_image(img_url, image_size, device):\n",
    "    if \"http\" in img_url:\n",
    "        raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')\n",
    "    else:\n",
    "        raw_image = Image.open(img_url).convert('RGB')\n",
    "    \n",
    "    transform = transforms.Compose([\n",
    "        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
    "        ]) \n",
    "    image = transform(raw_image).unsqueeze(0).to(device)   \n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7ce54831",
   "metadata": {},
   "outputs": [],
   "source": [
    "@run_time\n",
    "def inference(model, image, question = 'what is in the picture?'):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        answer = model(image, question, train=False, inference='generate') \n",
    "        return answer[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "497772e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle(conf):\n",
    "    base64_str = conf['Photo']\n",
    "    question = conf['Question']\n",
    "    image = load_demo_image(base64_str, image_size, device)\n",
    "    res = inference(model, image, question)\n",
    "    print('Answer :', res)\n",
    "    # add your code\n",
    "    return {'Answer': res}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "158995c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Runned time: 1.17 s\n",
      "Answer : woman and dog\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Answer': 'woman and dog'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "handle({'Photo': './img/demo.jpg', 'Question': 'What is in this image?'})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}