import torch
import numpy as np
def bbox_overlaps(boxes, query_boxes):
"""
Parameters
----------
boxes: (N, 4) ndarray or tensor or variable
query_boxes: (K, 4) ndarray or tensor or variable
Returns
-------
overlaps: (N, K) overlap between boxes and query_boxes
"""
if isinstance(boxes, np.ndarray):
boxes = torch.from_numpy(boxes)
query_boxes = torch.from_numpy(query_boxes)
out_fn = lambda x: x.numpy() # If input is ndarray, turn the overlaps back to ndarray when return
else:
out_fn = lambda x: x
box_areas = (boxes[:, 2] - boxes[:, 0] + 1) * \
(boxes[:, 3] - boxes[:, 1] + 1)
query_areas = (query_boxes[:, 2] - query_boxes[:, 0] + 1) * \
(query_boxes[:, 3] - query_boxes[:, 1] + 1)
iw = (torch.min(boxes[:, 2:3], query_boxes[:, 2:3].t()) - torch.max(
boxes[:, 0:1], query_boxes[:, 0:1].t()) + 1).clamp(min=0)
ih = (torch.min(boxes[:, 3:4], query_boxes[:, 3:4].t()) - torch.max(
boxes[:, 1:2], query_boxes[:, 1:2].t()) + 1).clamp(min=0)
ua = box_areas.view(-1, 1) + query_areas.view(1, -1) - iw * ih
overlaps = iw * ih / ua
return out_fn(overlaps)