# Copied from https://github.com/mlcommons/training/blob/637c82f9e699cd6caf108f92efb2c1d446b630e0/single_stage_detector/ssd/model/transform.py import torch from torch import nn, Tensor from typing import List, Tuple, Dict, Optional from test.external.mlperf_retinanet.model.image_list import ImageList @torch.jit.unused def _get_shape_onnx(image: Tensor) -> Tensor: from torch.onnx import operators return operators.shape_as_tensor(image)[-2:] def _resize_image_and_masks(image: Tensor, target: Optional[Dict[str, Tensor]] = None, image_size: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: image = torch.nn.functional.interpolate(image[None], size=image_size, scale_factor=None, mode='bilinear', recompute_scale_factor=None, align_corners=False)[0] if target is None: return image, target if "masks" in target: mask = target["masks"] mask = torch.nn.functional.interpolate(mask[:, None].float(), size=image_size, scale_factor=None, recompute_scale_factor=None)[:, 0].byte() target["masks"] = mask return image, target class GeneralizedRCNNTransform(nn.Module): """ Performs input / target transformation before feeding the data to a GeneralizedRCNN model. The transformations it perform are: - input normalization (mean subtraction and std division) - input / target resizing to match image_size It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets """ def __init__(self, image_size: Optional[Tuple[int, int]], image_mean: List[float], image_std: List[float],): super(GeneralizedRCNNTransform, self).__init__() self.image_size = image_size self.image_mean = image_mean self.image_std = image_std def forward(self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]: images = list(img for img in images) if targets is not None: # make a copy of targets to avoid modifying it in-place # once torchscript supports dict comprehension # this can be simplified as follows # targets = [{k: v for k,v in t.items()} for t in targets] targets_copy: List[Dict[str, Tensor]] = [] for t in targets: data: Dict[str, Tensor] = {} for k, v in t.items(): data[k] = v targets_copy.append(data) targets = targets_copy for i in range(len(images)): image = images[i] target_index = targets[i] if targets is not None else None if image.dim() != 3: raise ValueError("images is expected to be a list of 3d tensors " "of shape [C, H, W], got {}".format(image.shape)) image = self.normalize(image) image, target_index = self.resize(image, target_index) images[i] = image if targets is not None and target_index is not None: targets[i] = target_index image_sizes = [img.shape[-2:] for img in images] images = torch.stack(images) image_sizes_list: List[Tuple[int, int]] = [] for image_size in image_sizes: assert len(image_size) == 2 image_sizes_list.append((image_size[0], image_size[1])) image_list = ImageList(images, image_sizes_list) return image_list, targets def normalize(self, image: Tensor) -> Tensor: if not image.is_floating_point(): raise TypeError( f"Expected input images to be of floating type (in range [0, 1]), " f"but found type {image.dtype} instead" ) dtype, device = image.dtype, image.device mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) std = torch.as_tensor(self.image_std, dtype=dtype, device=device) return (image - mean[:, None, None]) / std[:, None, None] def torch_choice(self, k: List[int]) -> int: """ Implements `random.choice` via torch ops so it can be compiled with TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 is fixed. """ index = int(torch.empty(1).uniform_(0., float(len(k))).item()) return k[index] def resize(self, image: Tensor, target: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: h, w = image.shape[-2:] image, target = _resize_image_and_masks(image, target, self.image_size) if target is None: return image, target bbox = target["boxes"] bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) target["boxes"] = bbox return image, target def postprocess(self, result: List[Dict[str, Tensor]], image_shapes: List[Tuple[int, int]], original_image_sizes: List[Tuple[int, int]] ) -> List[Dict[str, Tensor]]: if self.training: return result for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): boxes = pred["boxes"] boxes = resize_boxes(boxes, im_s, o_im_s) result[i]["boxes"] = boxes return result def __repr__(self) -> str: format_string = self.__class__.__name__ + '(' _indent = '\n ' format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std) format_string += "{0}Resize(height={1}, width={2}, mode='bilinear')".format(_indent, self.image_size[0], self.image_size[1]) format_string += '\n)' return format_string def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: ratios = [ torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) for s, s_orig in zip(new_size, original_size) ] ratio_height, ratio_width = ratios xmin, ymin, xmax, ymax = boxes.unbind(1) xmin = xmin * ratio_width xmax = xmax * ratio_width ymin = ymin * ratio_height ymax = ymax * ratio_height res = torch.stack((xmin, ymin, xmax, ymax), dim=1) return res