{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e5ad7f34",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior()\n",
    "tf.reset_default_graph()\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import src.model\n",
    "import src.util\n",
    "import os\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4e847a1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_PATH='./src/output/models/model2000.ckpt'\n",
    "out_PATH='./results/test_output.png'\n",
    "IMAGE_SZ = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a3ab7344",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "def run_time(func):\n",
    "    def inner(model_PATH, img_p):\n",
    "        back = func(model_PATH, img_p)\n",
    "        print(\"Runned time: {} s\".format(round((time.time() - t)/10, 3)))\n",
    "        return back\n",
    "    t = time.time()\n",
    "    return inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "491be63a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_demo_image(in_PATH):\n",
    "    img = np.array(Image.open(in_PATH).convert('RGB'))[np.newaxis] / 255.0\n",
    "    img_p = src.util.preprocess_images_outpainting(img)\n",
    "    return img_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ac6e19b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def image_to_path(img):\n",
    "    resize_img = img\n",
    "    path = out_PATH\n",
    "    resize_img.save(path)\n",
    "    return path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "54877357",
   "metadata": {},
   "outputs": [],
   "source": [
    "@run_time\n",
    "def inference(model_PATH, img_p):\n",
    "    G_Z = tf.placeholder(tf.float32, shape=[None, IMAGE_SZ, IMAGE_SZ, 4], name='G_Z')\n",
    "    G_sample = src.model.generator(G_Z)\n",
    "    \n",
    "    saver = tf.train.Saver()\n",
    "    with tf.Session() as sess:\n",
    "        saver.restore(sess, model_PATH)\n",
    "        output, = sess.run([G_sample], feed_dict={G_Z: img_p})\n",
    "        img_norm = (output[0] * 255.0).astype(np.uint8)\n",
    "        img = Image.fromarray(img_norm, 'RGB')\n",
    "        #util.save_image(output[0], out_PATH)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "15a5515c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle(conf):\n",
    "    \"\"\"\n",
    "    该方法是部署之后，其他人调用你的服务时候的处理方法。\n",
    "    请按规范填写参数结构，这样我们就能替你自动生成配置文件，方便其他人的调用。\n",
    "    范例：\n",
    "    params['key'] = value # value_type: str # description: some description\n",
    "    value_type 可以选择：img, video, audio, str, int, float, [int], [str], [float]\n",
    "    参数请放到params字典中，我们会自动解析该变量。\n",
    "    \"\"\"\n",
    "    base64_str = conf['Photo']\n",
    "    image = load_demo_image(base64_str)\n",
    "    res = inference(model_PATH, image)\n",
    "    image_str = image_to_path(res)\n",
    "    return {'Output': image_str}\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "93cf5ae7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from /home/jovyan/work/src/output/models/model2000.ckpt\n",
      "Runned time: 0.317 s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Output': '/home/jovyan/work/results/test_output.png'}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "handle({'Photo': '/home/jovyan/work/images/test.png'})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
