video.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, List, Optional, Tuple, Union, cast
  18. import torch
  19. import kornia.augmentation as K
  20. from kornia.augmentation.base import _AugmentationBase
  21. from kornia.augmentation.container.base import SequentialBase
  22. from kornia.augmentation.container.image import ImageSequential, _get_new_batch_shape
  23. from kornia.core import Module, Tensor
  24. from kornia.geometry.boxes import Boxes
  25. from kornia.geometry.keypoints import Keypoints
  26. from .params import ParamItem
  27. __all__ = ["VideoSequential"]
  28. class VideoSequential(ImageSequential):
  29. r"""VideoSequential for processing 5-dim video data like (B, T, C, H, W) and (B, C, T, H, W).
  30. `VideoSequential` is used to replace `nn.Sequential` for processing video data augmentations.
  31. By default, `VideoSequential` enabled `same_on_frame` to make sure the same augmentations happen
  32. across temporal dimension. Meanwhile, it will not affect other augmentation behaviours like the
  33. settings on `same_on_batch`, etc.
  34. Args:
  35. *args: a list of augmentation module.
  36. data_format: only BCTHW and BTCHW are supported.
  37. same_on_frame: apply the same transformation across the channel per frame.
  38. random_apply: randomly select a sublist (order agnostic) of args to
  39. apply transformation.
  40. If int, a fixed number of transformations will be selected.
  41. If (a,), x number of transformations (a <= x <= len(args)) will be selected.
  42. If (a, b), x number of transformations (a <= x <= b) will be selected.
  43. If None, the whole list of args will be processed as a sequence.
  44. Note:
  45. Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module.
  46. Those transformations in ``kornia.geometry`` will not be taken into account.
  47. Example:
  48. If set `same_on_frame` to True, we would expect the same augmentation has been applied to each
  49. timeframe.
  50. >>> import kornia
  51. >>> input = torch.randn(2, 3, 1, 5, 6).repeat(1, 1, 4, 1, 1)
  52. >>> aug_list = VideoSequential(
  53. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  54. ... kornia.color.BgrToRgb(),
  55. ... kornia.augmentation.RandomAffine(360, p=1.0),
  56. ... random_apply=10,
  57. ... data_format="BCTHW",
  58. ... same_on_frame=True)
  59. >>> output = aug_list(input)
  60. >>> (output[0, :, 0] == output[0, :, 1]).all()
  61. tensor(True)
  62. >>> (output[0, :, 1] == output[0, :, 2]).all()
  63. tensor(True)
  64. >>> (output[0, :, 2] == output[0, :, 3]).all()
  65. tensor(True)
  66. If set `same_on_frame` to False:
  67. >>> aug_list = VideoSequential(
  68. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  69. ... kornia.augmentation.RandomAffine(360, p=1.0),
  70. ... kornia.augmentation.RandomMixUpV2(p=1.0),
  71. ... data_format="BCTHW",
  72. ... same_on_frame=False)
  73. >>> output = aug_list(input)
  74. >>> output.shape
  75. torch.Size([2, 3, 4, 5, 6])
  76. >>> (output[0, :, 0] == output[0, :, 1]).all()
  77. tensor(False)
  78. Reproduce with provided params.
  79. >>> out2 = aug_list(input, params=aug_list._params)
  80. >>> torch.equal(output, out2)
  81. True
  82. Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights`` in ``VideoSequential``.
  83. >>> import kornia
  84. >>> input, label = torch.randn(2, 3, 1, 5, 6).repeat(1, 1, 4, 1, 1), torch.tensor([0, 1])
  85. >>> aug_list = VideoSequential(
  86. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  87. ... kornia.augmentation.RandomAffine(360, p=1.0),
  88. ... kornia.augmentation.RandomMixUpV2(p=1.0),
  89. ... data_format="BCTHW",
  90. ... same_on_frame=False,
  91. ... random_apply=1,
  92. ... random_apply_weights=[0.5, 0.3, 0.8]
  93. ... )
  94. >>> out = aug_list(input)
  95. >>> out.shape
  96. torch.Size([2, 3, 4, 5, 6])
  97. """
  98. # TODO: implement transform_matrix
  99. def __init__(
  100. self,
  101. *args: Module,
  102. data_format: str = "BTCHW",
  103. same_on_frame: bool = True,
  104. random_apply: Union[int, bool, Tuple[int, int]] = False,
  105. random_apply_weights: Optional[List[float]] = None,
  106. ) -> None:
  107. super().__init__(
  108. *args,
  109. same_on_batch=None,
  110. keepdim=None,
  111. random_apply=random_apply,
  112. random_apply_weights=random_apply_weights,
  113. )
  114. self.same_on_frame = same_on_frame
  115. self.data_format = data_format.upper()
  116. if self.data_format not in ["BCTHW", "BTCHW"]:
  117. raise AssertionError(f"Only `BCTHW` and `BTCHW` are supported. Got `{data_format}`.")
  118. self._temporal_channel: int
  119. if self.data_format == "BCTHW":
  120. self._temporal_channel = 2
  121. elif self.data_format == "BTCHW":
  122. self._temporal_channel = 1
  123. def __infer_channel_exclusive_batch_shape__(self, batch_shape: torch.Size, chennel_index: int) -> torch.Size:
  124. # Fix mypy complains: error: Incompatible return value type (got "Tuple[int, ...]", expected "Size")
  125. return cast(torch.Size, batch_shape[:chennel_index] + batch_shape[chennel_index + 1 :])
  126. def __repeat_param_across_channels__(self, param: Tensor, frame_num: int) -> Tensor:
  127. """Repeat parameters across channels.
  128. The input is shaped as (B, ...), while to output (B * same_on_frame, ...), which
  129. to guarantee that the same transformation would happen for each frame.
  130. (B1, B2, ..., Bn) => (B1, ... B1, B2, ..., B2, ..., Bn, ..., Bn)
  131. | ch_size | | ch_size | ..., | ch_size |
  132. """
  133. repeated = param[:, None, ...].repeat(1, frame_num, *([1] * len(param.shape[1:])))
  134. return repeated.reshape(-1, *list(param.shape[1:]))
  135. def __broadcast_param__(
  136. self, v: Tensor, batch_shape: torch.Size, frame_num: int, same_on_frame: bool, same_on_batch: bool
  137. ) -> Tensor:
  138. if not v.numel():
  139. return v
  140. if same_on_frame and same_on_batch:
  141. return v.repeat(batch_shape[0] * frame_num, *([1] * (v.ndim - 1)))
  142. elif same_on_frame:
  143. return self.__repeat_param_across_channels__(v, frame_num)
  144. elif same_on_batch:
  145. return v.unsqueeze(1).repeat(1, batch_shape[0], *([1] * (v.ndim - 1))).reshape(-1, *v.shape[1:])
  146. return v
  147. def _input_shape_convert_in(self, input: Tensor, frame_num: int) -> Tensor:
  148. # Convert any shape to (B, T, C, H, W)
  149. if self.data_format == "BCTHW":
  150. # Convert (B, C, T, H, W) to (B, T, C, H, W)
  151. input = input.transpose(1, 2)
  152. if self.data_format == "BTCHW":
  153. pass
  154. input = input.reshape(-1, *input.shape[2:])
  155. return input
  156. def _input_shape_convert_back(self, input: Tensor, frame_num: int) -> Tensor:
  157. input = input.view(-1, frame_num, *input.shape[1:])
  158. if self.data_format == "BCTHW":
  159. input = input.transpose(1, 2)
  160. if self.data_format == "BTCHW":
  161. pass
  162. return input
  163. def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
  164. frame_num = batch_shape[self._temporal_channel]
  165. named_modules = self.get_forward_sequence()
  166. # Got param generation shape to (B, C, H, W). Ignoring T.
  167. batch_shape = self.__infer_channel_exclusive_batch_shape__(batch_shape, self._temporal_channel)
  168. params = []
  169. for name, module in named_modules:
  170. if isinstance(module, (K.RandomCrop, _AugmentationBase, K.MixAugmentationBaseV2)):
  171. is_same_on_batch = getattr(module, "same_on_batch", False)
  172. if self.same_on_frame and is_same_on_batch:
  173. mod_shape = torch.Size([1, *batch_shape[1:]])
  174. elif self.same_on_frame:
  175. mod_shape = batch_shape
  176. elif is_same_on_batch:
  177. mod_shape = torch.Size([frame_num, *batch_shape[1:]])
  178. else:
  179. mod_shape = torch.Size([batch_shape[0] * frame_num, *batch_shape[1:]])
  180. mod_param = module.forward_parameters(mod_shape)
  181. if isinstance(mod_param, dict):
  182. for k, v in mod_param.items():
  183. # TODO: revise ColorJiggle and ColorJitter order param in the future to align the standard.
  184. if k == "order" and isinstance(module, (K.ColorJiggle, K.ColorJitter)):
  185. continue
  186. if k == "forward_input_shape":
  187. mod_param.update({k: v})
  188. continue
  189. mod_param[k] = self.__broadcast_param__(
  190. v, batch_shape, frame_num, self.same_on_frame, is_same_on_batch
  191. )
  192. param = ParamItem(name, mod_param)
  193. elif isinstance(module, (SequentialBase,)):
  194. seq_param = module.forward_parameters(batch_shape)
  195. if self.same_on_frame:
  196. raise ValueError("Sequential is currently unsupported for ``same_on_frame``.")
  197. param = ParamItem(name, seq_param)
  198. else:
  199. param = ParamItem(name, None)
  200. batch_shape = _get_new_batch_shape(param, batch_shape)
  201. params.append(param)
  202. return params
  203. def transform_inputs(
  204. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  205. ) -> Tensor:
  206. frame_num: int = input.size(self._temporal_channel)
  207. input = self._input_shape_convert_in(input, frame_num)
  208. input = super().transform_inputs(input, params, extra_args=extra_args)
  209. input = self._input_shape_convert_back(input, frame_num)
  210. return input
  211. def inverse_inputs(
  212. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  213. ) -> Tensor:
  214. frame_num: int = input.size(self._temporal_channel)
  215. input = self._input_shape_convert_in(input, frame_num)
  216. input = super().inverse_inputs(input, params, extra_args=extra_args)
  217. input = self._input_shape_convert_back(input, frame_num)
  218. return input
  219. def transform_masks(
  220. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  221. ) -> Tensor:
  222. frame_num: int = input.size(self._temporal_channel)
  223. input = self._input_shape_convert_in(input, frame_num)
  224. input = super().transform_masks(input, params, extra_args=extra_args)
  225. input = self._input_shape_convert_back(input, frame_num)
  226. return input
  227. def inverse_masks(
  228. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  229. ) -> Tensor:
  230. frame_num: int = input.size(self._temporal_channel)
  231. input = self._input_shape_convert_in(input, frame_num)
  232. input = super().inverse_masks(input, params, extra_args=extra_args)
  233. input = self._input_shape_convert_back(input, frame_num)
  234. return input
  235. def transform_boxes( # type: ignore[override]
  236. self, input: Union[Tensor, Boxes], params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  237. ) -> Union[Tensor, Boxes]:
  238. """Transform bounding boxes.
  239. Args:
  240. input: tensor with shape :math:`(B, T, N, 4, 2)`.
  241. If input is a `Keypoints` type, the internal shape is :math:`(B * T, N, 4, 2)`.
  242. params: params for the sequence.
  243. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  244. """
  245. if isinstance(input, Tensor):
  246. batchsize, frame_num = input.size(0), input.size(1)
  247. input = Boxes.from_tensor(input.view(-1, input.size(2), input.size(3), input.size(4)), mode="vertices_plus")
  248. input = super().transform_boxes(input, params, extra_args=extra_args)
  249. input = input.data.view(batchsize, frame_num, -1, 4, 2)
  250. else:
  251. input = super().transform_boxes(input, params, extra_args=extra_args)
  252. return input
  253. def inverse_boxes( # type: ignore[override]
  254. self, input: Union[Tensor, Boxes], params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  255. ) -> Union[Tensor, Boxes]:
  256. """Transform bounding boxes.
  257. Args:
  258. input: tensor with shape :math:`(B, T, N, 4, 2)`.
  259. If input is a `Keypoints` type, the internal shape is :math:`(B * T, N, 4, 2)`.
  260. params: params for the sequence.
  261. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  262. """
  263. if isinstance(input, Tensor):
  264. batchsize, frame_num = input.size(0), input.size(1)
  265. input = Boxes.from_tensor(input.view(-1, input.size(2), input.size(3), input.size(4)), mode="vertices_plus")
  266. input = super().inverse_boxes(input, params, extra_args=extra_args)
  267. input = input.data.view(batchsize, frame_num, -1, 4, 2)
  268. else:
  269. input = super().inverse_boxes(input, params, extra_args=extra_args)
  270. return input
  271. def transform_keypoints( # type: ignore[override]
  272. self, input: Union[Tensor, Keypoints], params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  273. ) -> Union[Tensor, Keypoints]:
  274. """Transform bounding boxes.
  275. Args:
  276. input: tensor with shape :math:`(B, T, N, 2)`.
  277. If input is a `Keypoints` type, the internal shape is :math:`(B * T, N, 2)`.
  278. params: params for the sequence.
  279. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  280. """
  281. if isinstance(input, Tensor):
  282. batchsize, frame_num = input.size(0), input.size(1)
  283. input = Keypoints(input.view(-1, input.size(2), input.size(3)))
  284. input = super().transform_keypoints(input, params, extra_args=extra_args)
  285. input = input.data.view(batchsize, frame_num, -1, 2)
  286. else:
  287. input = super().transform_keypoints(input, params, extra_args=extra_args)
  288. return input
  289. def inverse_keypoints( # type: ignore[override]
  290. self, input: Union[Tensor, Keypoints], params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  291. ) -> Union[Tensor, Keypoints]:
  292. """Transform bounding boxes.
  293. Args:
  294. input: tensor with shape :math:`(B, T, N, 2)`.
  295. If input is a `Keypoints` type, the internal shape is :math:`(B * T, N, 2)`.
  296. params: params for the sequence.
  297. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  298. """
  299. if isinstance(input, Tensor):
  300. frame_num, batchsize = input.size(0), input.size(1)
  301. input = Keypoints(input.view(-1, input.size(2), input.size(3)))
  302. input = super().inverse_keypoints(input, params, extra_args=extra_args)
  303. input = input.data.view(batchsize, frame_num, -1, 2)
  304. else:
  305. input = super().inverse_keypoints(input, params, extra_args=extra_args)
  306. return input
  307. def inverse(
  308. self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
  309. ) -> Tensor:
  310. """Inverse transformation.
  311. Used to inverse a tensor according to the performed transformation by a forward pass, or with respect to
  312. provided parameters.
  313. """
  314. if params is None:
  315. if self._params is not None:
  316. params = self._params
  317. else:
  318. raise RuntimeError("No valid params to inverse the transformation.")
  319. return self.inverse_inputs(input, params, extra_args=extra_args)
  320. def forward(
  321. self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
  322. ) -> Tensor:
  323. """Define the video computation performed."""
  324. if len(input.shape) != 5:
  325. raise AssertionError(f"Input must be a 5-dim tensor. Got {input.shape}.")
  326. if params is None:
  327. self._params = self.forward_parameters(input.shape)
  328. params = self._params
  329. output = self.transform_inputs(input, params, extra_args=extra_args)
  330. return output