master
/ demo.ipynb

demo.ipynb @fe35cc9

e6917e8
 
 
 
fe35cc9
 
e6917e8
fe35cc9
 
 
 
 
 
 
 
 
 
e6917e8
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
 
e6917e8
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6917e8
 
 
 
 
 
 
 
 
 
fe35cc9
 
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
 
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
 
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
 
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe35cc9
 
e6917e8
fe35cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6917e8
fe35cc9
e6917e8
 
 
 
 
fe35cc9
e6917e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e468c4d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-06-22 10:15:14.085329: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory\n",
      "2022-06-22 10:15:14.085364: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "import torch, numpy as np\n",
    "from torchvision import transforms\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "from models.blip_vqa import blip_vqa\n",
    "from keras.preprocessing import image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b88c91d8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2aac41dcbabc4b488d73cbf4efc48ae5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4ac0b9f1c76841d6a01ecbcec1151547",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ede9398befdf4f4aa1da5e73ada4f558",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1253d25503bc486dbae2516355c1a7d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "load checkpoint from ./ckpt/model_base_vqa_capfilt_large.pth\n"
     ]
    }
   ],
   "source": [
    "image_size = 480\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model_url = './ckpt/model_base_vqa_capfilt_large.pth'\n",
    "model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "97b3f806",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "def run_time(func):\n",
    "    def inner(model, image, question):\n",
    "        back = func(model, image, question)\n",
    "        print(\"Runned time: {} s\".format(round((time.time() - t)/10, 3)))\n",
    "        return back\n",
    "    t = time.time()\n",
    "    return inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8ff0f445",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_demo_image(img_url, image_size, device):\n",
    "    if \"http\" in img_url:\n",
    "        raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')\n",
    "    else:\n",
    "        raw_image = Image.open(img_url).convert('RGB')\n",
    "    \n",
    "    transform = transforms.Compose([\n",
    "        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
    "        ]) \n",
    "    image = transform(raw_image).unsqueeze(0).to(device)   \n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "28558fef",
   "metadata": {},
   "outputs": [],
   "source": [
    "@run_time\n",
    "def inference(model, image, question = 'what is in the picture?'):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        answer = model(image, question, train=False, inference='generate') \n",
    "        return answer[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f1759896",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle(conf):\n",
    "    base64_str = conf['Photo']\n",
    "    question = conf['Question']\n",
    "    image = load_demo_image(base64_str, image_size, device)\n",
    "    res = inference(model, image, question)\n",
    "    print('Answer :', res)\n",
    "    # add your code\n",
    "    return {'Answer': res}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c8ce3a9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Runned time: 1.788 s\n",
      "Answer : woman and dog\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Answer': 'woman and dog'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "handle({'Photo': './img/demo.jpg', 'Question': 'What is in this image?'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1668fb5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}