163 lines
6.6 KiB
Python
163 lines
6.6 KiB
Python
![]() |
# Copied from https://github.com/mlcommons/training/blob/637c82f9e699cd6caf108f92efb2c1d446b630e0/single_stage_detector/ssd/model/transform.py
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from torch import nn, Tensor
|
||
|
from typing import List, Tuple, Dict, Optional
|
||
|
|
||
|
from test.external.mlperf_retinanet.model.image_list import ImageList
|
||
|
|
||
|
@torch.jit.unused
|
||
|
def _get_shape_onnx(image: Tensor) -> Tensor:
|
||
|
from torch.onnx import operators
|
||
|
return operators.shape_as_tensor(image)[-2:]
|
||
|
|
||
|
def _resize_image_and_masks(image: Tensor,
|
||
|
target: Optional[Dict[str, Tensor]] = None,
|
||
|
image_size: Optional[Tuple[int, int]] = None,
|
||
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
||
|
image = torch.nn.functional.interpolate(image[None], size=image_size, scale_factor=None, mode='bilinear',
|
||
|
recompute_scale_factor=None, align_corners=False)[0]
|
||
|
|
||
|
if target is None:
|
||
|
return image, target
|
||
|
|
||
|
if "masks" in target:
|
||
|
mask = target["masks"]
|
||
|
mask = torch.nn.functional.interpolate(mask[:, None].float(), size=image_size, scale_factor=None,
|
||
|
recompute_scale_factor=None)[:, 0].byte()
|
||
|
target["masks"] = mask
|
||
|
return image, target
|
||
|
|
||
|
class GeneralizedRCNNTransform(nn.Module):
|
||
|
"""
|
||
|
Performs input / target transformation before feeding the data to a GeneralizedRCNN
|
||
|
model.
|
||
|
|
||
|
The transformations it perform are:
|
||
|
- input normalization (mean subtraction and std division)
|
||
|
- input / target resizing to match image_size
|
||
|
|
||
|
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
|
||
|
"""
|
||
|
|
||
|
def __init__(self, image_size: Optional[Tuple[int, int]],
|
||
|
image_mean: List[float], image_std: List[float],):
|
||
|
super(GeneralizedRCNNTransform, self).__init__()
|
||
|
self.image_size = image_size
|
||
|
self.image_mean = image_mean
|
||
|
self.image_std = image_std
|
||
|
|
||
|
def forward(self,
|
||
|
images: List[Tensor],
|
||
|
targets: Optional[List[Dict[str, Tensor]]] = None
|
||
|
) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
|
||
|
images = list(img for img in images)
|
||
|
if targets is not None:
|
||
|
# make a copy of targets to avoid modifying it in-place
|
||
|
# once torchscript supports dict comprehension
|
||
|
# this can be simplified as follows
|
||
|
# targets = [{k: v for k,v in t.items()} for t in targets]
|
||
|
targets_copy: List[Dict[str, Tensor]] = []
|
||
|
for t in targets:
|
||
|
data: Dict[str, Tensor] = {}
|
||
|
for k, v in t.items():
|
||
|
data[k] = v
|
||
|
targets_copy.append(data)
|
||
|
targets = targets_copy
|
||
|
for i in range(len(images)):
|
||
|
image = images[i]
|
||
|
target_index = targets[i] if targets is not None else None
|
||
|
|
||
|
if image.dim() != 3:
|
||
|
raise ValueError("images is expected to be a list of 3d tensors "
|
||
|
"of shape [C, H, W], got {}".format(image.shape))
|
||
|
image = self.normalize(image)
|
||
|
image, target_index = self.resize(image, target_index)
|
||
|
images[i] = image
|
||
|
if targets is not None and target_index is not None:
|
||
|
targets[i] = target_index
|
||
|
|
||
|
image_sizes = [img.shape[-2:] for img in images]
|
||
|
images = torch.stack(images)
|
||
|
image_sizes_list: List[Tuple[int, int]] = []
|
||
|
for image_size in image_sizes:
|
||
|
assert len(image_size) == 2
|
||
|
image_sizes_list.append((image_size[0], image_size[1]))
|
||
|
|
||
|
image_list = ImageList(images, image_sizes_list)
|
||
|
return image_list, targets
|
||
|
|
||
|
def normalize(self, image: Tensor) -> Tensor:
|
||
|
if not image.is_floating_point():
|
||
|
raise TypeError(
|
||
|
f"Expected input images to be of floating type (in range [0, 1]), "
|
||
|
f"but found type {image.dtype} instead"
|
||
|
)
|
||
|
dtype, device = image.dtype, image.device
|
||
|
mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
|
||
|
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
|
||
|
return (image - mean[:, None, None]) / std[:, None, None]
|
||
|
|
||
|
def torch_choice(self, k: List[int]) -> int:
|
||
|
"""
|
||
|
Implements `random.choice` via torch ops so it can be compiled with
|
||
|
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
|
||
|
is fixed.
|
||
|
"""
|
||
|
index = int(torch.empty(1).uniform_(0., float(len(k))).item())
|
||
|
return k[index]
|
||
|
|
||
|
def resize(self,
|
||
|
image: Tensor,
|
||
|
target: Optional[Dict[str, Tensor]] = None,
|
||
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
||
|
h, w = image.shape[-2:]
|
||
|
image, target = _resize_image_and_masks(image, target, self.image_size)
|
||
|
|
||
|
if target is None:
|
||
|
return image, target
|
||
|
|
||
|
bbox = target["boxes"]
|
||
|
bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
|
||
|
target["boxes"] = bbox
|
||
|
|
||
|
return image, target
|
||
|
|
||
|
def postprocess(self,
|
||
|
result: List[Dict[str, Tensor]],
|
||
|
image_shapes: List[Tuple[int, int]],
|
||
|
original_image_sizes: List[Tuple[int, int]]
|
||
|
) -> List[Dict[str, Tensor]]:
|
||
|
if self.training:
|
||
|
return result
|
||
|
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
|
||
|
boxes = pred["boxes"]
|
||
|
boxes = resize_boxes(boxes, im_s, o_im_s)
|
||
|
result[i]["boxes"] = boxes
|
||
|
return result
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
format_string = self.__class__.__name__ + '('
|
||
|
_indent = '\n '
|
||
|
format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
|
||
|
format_string += "{0}Resize(height={1}, width={2}, mode='bilinear')".format(_indent, self.image_size[0],
|
||
|
self.image_size[1])
|
||
|
format_string += '\n)'
|
||
|
return format_string
|
||
|
|
||
|
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
|
||
|
ratios = [
|
||
|
torch.tensor(s, dtype=torch.float32, device=boxes.device) /
|
||
|
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
|
||
|
for s, s_orig in zip(new_size, original_size)
|
||
|
]
|
||
|
ratio_height, ratio_width = ratios
|
||
|
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
||
|
|
||
|
xmin = xmin * ratio_width
|
||
|
xmax = xmax * ratio_width
|
||
|
ymin = ymin * ratio_height
|
||
|
ymax = ymax * ratio_height
|
||
|
res = torch.stack((xmin, ymin, xmax, ymax), dim=1)
|
||
|
return res
|