master
/ inference.py

inference.py @master raw · history · blame

from loadData import DataLayer
from model import VGG16, Network
import torch
from lib.bbox_transform import bbox_transform_inv
import os
import numpy as np
from torchvision.ops import nms
from config import cfg
import cv2

import matplotlib.pyplot as plt

image_output = "./image_output"
if not os.path.exists(image_output):  # 如果路径不存在
    os.makedirs(image_output)
weight_output = "./weight_output"

def _clip_boxes(boxes, im_shape):
    """Clip boxes to image boundaries."""
    # x1 >= 0
    boxes[:, 0::4] = np.maximum(boxes[:, 0::4], 0)
    # y1 >= 0
    boxes[:, 1::4] = np.maximum(boxes[:, 1::4], 0)
    # x2 < im_shape[1]
    boxes[:, 2::4] = np.minimum(boxes[:, 2::4], im_shape[1] - 1)
    # y2 < im_shape[0]
    boxes[:, 3::4] = np.minimum(boxes[:, 3::4], im_shape[0] - 1)
    return boxes

def inference():
    data_layer_test = DataLayer("./dataset/val")

    net = VGG16()
    # Construct the computation graph
    net.create_architecture(2, tag='default', anchor_scales=[8, 16, 32], anchor_ratios=[0.5, 1, 2])
    net.eval()
    # if not torch.cuda.is_available():
    net._device = 'cpu'
    net.to(net._device)

    # 加载模型
    list_dir = os.listdir(weight_output)
    if len(list_dir) == 0:
        net.load_pretrained_cnn(torch.load("./imagenet_weights/vgg16.pth"))
    else:
        # net.load_state_dict(torch.load(os.path.join(weight_output, list_dir[-1])))
        net.load_state_dict(torch.load(os.path.join(weight_output, list_dir[-1]), map_location='cpu'))

    print("load model params successfully. ")

    fig = plt.gcf()  # 获取当前图表,get current figure
    fig.set_size_inches(10, 12)  # 1寸等于 2.54 cm

    for i in range(data_layer_test.length):
        blobs = data_layer_test.forward()
        im_blob = blobs['data']
        im_scale = blobs['im_info'][2]
        _, scores, bbox_pred, rois = net.test_image(blobs['data'], blobs['im_info'])
        boxes = rois[:, 1:5] / im_scale
        img = cv2.resize(
            im_blob[0],
            None,
            None,
            fx=1/im_scale,
            fy=1/im_scale,
            interpolation=cv2.INTER_LINEAR)

        img += cfg.PIXEL_MEANS

        scores = np.reshape(scores, [scores.shape[0], -1])
        bbox_pred = np.reshape(bbox_pred, [bbox_pred.shape[0], -1])

        # Apply bounding-box regression deltas
        box_deltas = bbox_pred
        pred_boxes = bbox_transform_inv(
                torch.from_numpy(boxes), torch.from_numpy(box_deltas)).numpy()
        pred_boxes = _clip_boxes(pred_boxes, img.shape)

        # skip j = 0, because it's the background class
        for j in range(1, 2):
            inds = np.where(scores[:, j] > 0.5)[0]
            cls_scores = scores[inds, j]
            cls_boxes = pred_boxes[inds, j * 4:(j + 1) * 4]
            cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \
                .astype(np.float32, copy=False)
            # NMS 操作,剔除一部分box
            keep = nms(
                torch.from_numpy(cls_boxes), torch.from_numpy(cls_scores),
                cfg.TEST.NMS).numpy() if cls_dets.size > 0 else []
            cls_dets = cls_dets[keep, :]

            height, width, channel = img.shape
            for r in range(len(cls_dets)):
                # 画预测的框
                left = int(max(cls_dets[r][0], 0))
                top = int(max(cls_dets[r][1], 0))
                right = int(min(cls_dets[r][2], width))
                bottom = int(min(cls_dets[r][3], height))

                cv2.rectangle(img, (left, top), (right, bottom), (0, 255, 0), 1)
                text_size, baseline = cv2.getTextSize(str(cls_dets[r][4]), 1, 1, 1)
                cv2.rectangle(img, (left, top - text_size[1] - (baseline * 2)), (left + text_size[0], top),
                              (44, 44, 44), -1)
                cv2.putText(img, str(cls_dets[r][4]), (left, top - baseline), 1,
                            1, (255, 255, 255), 1)

        ax = plt.subplot(1, 1, 1)  # 获取当前需要处理的子图
        show_image = img.astype(np.int32, copy=False)
        ax.imshow(show_image, cmap="binary")
        ax.set_xticks([])
        ax.set_yticks([])
        plt.show()

        # 文件保存
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(image_output, '{0:0>6}'.format(i)+ ".jpg"), img)

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    inference()