master
/ demo.ipynb

demo.ipynb @master

e6917e8
 
 
 
3c11360
2948614
e6917e8
3c11360
 
2948614
 
 
 
 
 
 
 
3c11360
 
 
e6917e8
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
2948614
e6917e8
fe35cc9
 
 
 
3c11360
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c11360
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c11360
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c11360
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6917e8
 
 
 
 
 
 
 
 
 
3c11360
2948614
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c11360
2948614
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c11360
2948614
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
3c11360
2948614
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c11360
2948614
e6917e8
fe35cc9
 
 
 
 
3c11360
69624d8
fe35cc9
 
 
 
 
69624d8
fe35cc9
 
3c11360
fe35cc9
 
 
 
e6917e8
69624d8
e6917e8
69624d8
 
 
 
2948614
3c11360
 
 
 
 
 
 
2948614
69624d8
 
 
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "531218e2",
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "cannot import name 'InterpolationMode' from 'torchvision.transforms.functional' (/home/jovyan/.virtualenvs/basenv/lib/python3.7/site-packages/torchvision/transforms/functional.py)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_122/3179672444.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchvision\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtransforms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mInterpolationMode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblip_vqa\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mblip_vqa\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreprocessing\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mImportError\u001b[0m: cannot import name 'InterpolationMode' from 'torchvision.transforms.functional' (/home/jovyan/.virtualenvs/basenv/lib/python3.7/site-packages/torchvision/transforms/functional.py)"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "import torch, numpy as np\n",
    "from torchvision import transforms\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "from models.blip_vqa import blip_vqa\n",
    "from keras.preprocessing import image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aa46a919",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a6ca4f752d1743888aa067b138494be7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "946df54241da4fc2b6acce1c79a1c3e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4dbd3f83d7c34ae491d68b4eedc48784",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3382095fbc847098b596057ecad9527",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "load checkpoint from ./ckpt/model_base_vqa_capfilt_large.pth\n"
     ]
    }
   ],
   "source": [
    "image_size = 480\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model_url = './ckpt/model_base_vqa_capfilt_large.pth'\n",
    "model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "db83a911",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "def run_time(func):\n",
    "    def inner(model, image, question):\n",
    "        back = func(model, image, question)\n",
    "        print(\"Runned time: {} s\".format(round((time.time() - t)/10, 3)))\n",
    "        return back\n",
    "    t = time.time()\n",
    "    return inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4edf5af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_demo_image(img_url, image_size, device):\n",
    "    if \"http\" in img_url:\n",
    "        raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')\n",
    "    else:\n",
    "        raw_image = Image.open(img_url).convert('RGB')\n",
    "    \n",
    "    transform = transforms.Compose([\n",
    "        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
    "        ]) \n",
    "    image = transform(raw_image).unsqueeze(0).to(device)   \n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c67f575b",
   "metadata": {},
   "outputs": [],
   "source": [
    "@run_time\n",
    "def inference(model, image, question = 'what is in the picture?'):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        answer = model(image, question, train=False, inference='generate') \n",
    "        return answer[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9fe0b78f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle(conf):\n",
    "    base64_str = conf['Photo']\n",
    "    question = conf['Question']\n",
    "    image = load_demo_image(base64_str, image_size, device)\n",
    "    res = inference(model, image, question)\n",
    "    print('Answer :', res)\n",
    "    # add your code\n",
    "    return {'Answer': res}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bef4f0f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Runned time: 0.378 s\n",
      "Answer : squirrel eating from basket\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Answer': 'squirrel eating from basket'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "handle({'Photo': 'https://img2.baidu.com/it/u=3908142881,2459234098&fm=253&fmt=auto&app=138&f=JPEG?w=750&h=500', 'Question': 'What is in this image?'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab387e2a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c63bbe2b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}