roi_align.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import functools
  2. from typing import Union
  3. import torch
  4. import torch.fx
  5. from torch import nn, Tensor
  6. from torch._dynamo.utils import is_compile_supported
  7. from torch.jit.annotations import BroadcastingList2
  8. from torch.nn.modules.utils import _pair
  9. from torchvision.extension import _assert_has_ops, _has_ops
  10. from ..utils import _log_api_usage_once
  11. from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
  12. def lazy_compile(**compile_kwargs):
  13. """Lazily wrap a function with torch.compile on the first call
  14. This avoids eagerly importing dynamo.
  15. """
  16. def decorate_fn(fn):
  17. @functools.wraps(fn)
  18. def compile_hook(*args, **kwargs):
  19. compiled_fn = torch.compile(fn, **compile_kwargs)
  20. globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
  21. return compiled_fn(*args, **kwargs)
  22. return compile_hook
  23. return decorate_fn
  24. # NB: all inputs are tensors
  25. def _bilinear_interpolate(
  26. input, # [N, C, H, W]
  27. roi_batch_ind, # [K]
  28. y, # [K, PH, IY]
  29. x, # [K, PW, IX]
  30. ymask, # [K, IY]
  31. xmask, # [K, IX]
  32. ):
  33. _, channels, height, width = input.size()
  34. # deal with inverse element out of feature map boundary
  35. y = y.clamp(min=0)
  36. x = x.clamp(min=0)
  37. y_low = y.int()
  38. x_low = x.int()
  39. y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
  40. y_low = torch.where(y_low >= height - 1, height - 1, y_low)
  41. y = torch.where(y_low >= height - 1, y.to(input.dtype), y)
  42. x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
  43. x_low = torch.where(x_low >= width - 1, width - 1, x_low)
  44. x = torch.where(x_low >= width - 1, x.to(input.dtype), x)
  45. ly = y - y_low
  46. lx = x - x_low
  47. hy = 1.0 - ly
  48. hx = 1.0 - lx
  49. # do bilinear interpolation, but respect the masking!
  50. # TODO: It's possible the masking here is unnecessary if y and
  51. # x were clamped appropriately; hard to tell
  52. def masked_index(
  53. y, # [K, PH, IY]
  54. x, # [K, PW, IX]
  55. ):
  56. if ymask is not None:
  57. assert xmask is not None
  58. y = torch.where(ymask[:, None, :], y, 0)
  59. x = torch.where(xmask[:, None, :], x, 0)
  60. return input[
  61. roi_batch_ind[:, None, None, None, None, None],
  62. torch.arange(channels, device=input.device)[None, :, None, None, None, None],
  63. y[:, None, :, None, :, None], # prev [K, PH, IY]
  64. x[:, None, None, :, None, :], # prev [K, PW, IX]
  65. ] # [K, C, PH, PW, IY, IX]
  66. v1 = masked_index(y_low, x_low)
  67. v2 = masked_index(y_low, x_high)
  68. v3 = masked_index(y_high, x_low)
  69. v4 = masked_index(y_high, x_high)
  70. # all ws preemptively [K, C, PH, PW, IY, IX]
  71. def outer_prod(y, x):
  72. return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
  73. w1 = outer_prod(hy, hx)
  74. w2 = outer_prod(hy, lx)
  75. w3 = outer_prod(ly, hx)
  76. w4 = outer_prod(ly, lx)
  77. val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
  78. return val
  79. # TODO: this doesn't actually cache
  80. # TODO: main library should make this easier to do
  81. def maybe_cast(tensor):
  82. if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
  83. return tensor.float()
  84. else:
  85. return tensor
  86. # This is a pure Python and differentiable implementation of roi_align. When
  87. # run in eager mode, it uses a lot of memory, but when compiled it has
  88. # acceptable memory usage. The main point of this implementation is that
  89. # its backwards is deterministic.
  90. # It is transcribed directly off of the roi_align CUDA kernel, see
  91. # https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
  92. @lazy_compile(dynamic=True)
  93. def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
  94. orig_dtype = input.dtype
  95. input = maybe_cast(input)
  96. rois = maybe_cast(rois)
  97. _, _, height, width = input.size()
  98. ph = torch.arange(pooled_height, device=input.device) # [PH]
  99. pw = torch.arange(pooled_width, device=input.device) # [PW]
  100. # input: [N, C, H, W]
  101. # rois: [K, 5]
  102. roi_batch_ind = rois[:, 0].int() # [K]
  103. offset = 0.5 if aligned else 0.0
  104. roi_start_w = rois[:, 1] * spatial_scale - offset # [K]
  105. roi_start_h = rois[:, 2] * spatial_scale - offset # [K]
  106. roi_end_w = rois[:, 3] * spatial_scale - offset # [K]
  107. roi_end_h = rois[:, 4] * spatial_scale - offset # [K]
  108. roi_width = roi_end_w - roi_start_w # [K]
  109. roi_height = roi_end_h - roi_start_h # [K]
  110. if not aligned:
  111. roi_width = torch.clamp(roi_width, min=1.0) # [K]
  112. roi_height = torch.clamp(roi_height, min=1.0) # [K]
  113. bin_size_h = roi_height / pooled_height # [K]
  114. bin_size_w = roi_width / pooled_width # [K]
  115. exact_sampling = sampling_ratio > 0
  116. roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K]
  117. roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K]
  118. """
  119. iy, ix = dims(2)
  120. """
  121. if exact_sampling:
  122. count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar
  123. iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY]
  124. ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX]
  125. ymask = None
  126. xmask = None
  127. else:
  128. count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K]
  129. # When doing adaptive sampling, the number of samples we need to do
  130. # is data-dependent based on how big the ROIs are. This is a bit
  131. # awkward because first-class dims can't actually handle this.
  132. # So instead, we inefficiently suppose that we needed to sample ALL
  133. # the points and mask out things that turned out to be unnecessary
  134. iy = torch.arange(height, device=input.device) # [IY]
  135. ix = torch.arange(width, device=input.device) # [IX]
  136. ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY]
  137. xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX]
  138. def from_K(t):
  139. return t[:, None, None]
  140. y = (
  141. from_K(roi_start_h)
  142. + ph[None, :, None] * from_K(bin_size_h)
  143. + (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
  144. ) # [K, PH, IY]
  145. x = (
  146. from_K(roi_start_w)
  147. + pw[None, :, None] * from_K(bin_size_w)
  148. + (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
  149. ) # [K, PW, IX]
  150. val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]
  151. # Mask out samples that weren't actually adaptively needed
  152. if not exact_sampling:
  153. val = torch.where(ymask[:, None, None, None, :, None], val, 0)
  154. val = torch.where(xmask[:, None, None, None, None, :], val, 0)
  155. output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW]
  156. if isinstance(count, torch.Tensor):
  157. output /= count[:, None, None, None]
  158. else:
  159. output /= count
  160. output = output.to(orig_dtype)
  161. return output
  162. @torch.fx.wrap
  163. def roi_align(
  164. input: Tensor,
  165. boxes: Union[Tensor, list[Tensor]],
  166. output_size: BroadcastingList2[int],
  167. spatial_scale: float = 1.0,
  168. sampling_ratio: int = -1,
  169. aligned: bool = False,
  170. ) -> Tensor:
  171. """
  172. Performs Region of Interest (RoI) Align operator with average pooling, as described in Mask R-CNN.
  173. Args:
  174. input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element
  175. contains ``C`` feature maps of dimensions ``H x W``.
  176. If the tensor is quantized, we expect a batch size of ``N == 1``.
  177. boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
  178. format where the regions will be taken from.
  179. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  180. If a single Tensor is passed, then the first column should
  181. contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``.
  182. If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i
  183. in the batch.
  184. output_size (int or Tuple[int, int]): the size of the output (in bins or pixels) after the pooling
  185. is performed, as (height, width).
  186. spatial_scale (float): a scaling factor that maps the box coordinates to
  187. the input coordinates. For example, if your boxes are defined on the scale
  188. of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of
  189. the original image), you'll want to set this to 0.5. Default: 1.0
  190. sampling_ratio (int): number of sampling points in the interpolation grid
  191. used to compute the output value of each pooled output bin. If > 0,
  192. then exactly ``sampling_ratio x sampling_ratio`` sampling points per bin are used. If
  193. <= 0, then an adaptive number of grid points are used (computed as
  194. ``ceil(roi_width / output_width)``, and likewise for height). Default: -1
  195. aligned (bool): If False, use the legacy implementation.
  196. If True, pixel shift the box coordinates it by -0.5 for a better alignment with the two
  197. neighboring pixel indices. This version is used in Detectron2
  198. Returns:
  199. Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.
  200. """
  201. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  202. _log_api_usage_once(roi_align)
  203. check_roi_boxes_shape(boxes)
  204. rois = boxes
  205. output_size = _pair(output_size)
  206. if not isinstance(rois, torch.Tensor):
  207. rois = convert_boxes_to_roi_format(rois)
  208. if not torch.jit.is_scripting():
  209. if (
  210. not _has_ops()
  211. or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps or input.is_xpu))
  212. ) and is_compile_supported(input.device.type):
  213. return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
  214. _assert_has_ops()
  215. return torch.ops.torchvision.roi_align(
  216. input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
  217. )
  218. class RoIAlign(nn.Module):
  219. """
  220. See :func:`roi_align`.
  221. """
  222. def __init__(
  223. self,
  224. output_size: BroadcastingList2[int],
  225. spatial_scale: float,
  226. sampling_ratio: int,
  227. aligned: bool = False,
  228. ):
  229. super().__init__()
  230. _log_api_usage_once(self)
  231. self.output_size = output_size
  232. self.spatial_scale = spatial_scale
  233. self.sampling_ratio = sampling_ratio
  234. self.aligned = aligned
  235. def forward(self, input: Tensor, rois: Union[Tensor, list[Tensor]]) -> Tensor:
  236. return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
  237. def __repr__(self) -> str:
  238. s = (
  239. f"{self.__class__.__name__}("
  240. f"output_size={self.output_size}"
  241. f", spatial_scale={self.spatial_scale}"
  242. f", sampling_ratio={self.sampling_ratio}"
  243. f", aligned={self.aligned}"
  244. f")"
  245. )
  246. return s