diamond_square.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # based on: https://github.com/anguelos/tormentor/blob/e8050ac235b0c7ad3c7d931cfa47c308a305c486/diamond_square/diamond_square.py # noqa: E501
  18. import math
  19. from typing import Callable, List, Optional, Tuple, Union
  20. import torch
  21. from kornia.core import Tensor
  22. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
  23. from kornia.enhance import normalize_min_max
  24. from kornia.filters import filter2d
  25. # the default kernels for the diamond square
  26. default_diamond_kernel: List[List[float]] = [[0.25, 0.0, 0.25], [0.0, 0.0, 0.0], [0.25, 0.0, 0.25]]
  27. default_square_kernel: List[List[float]] = [[0.0, 0.25, 0.0], [0.25, 0.0, 0.25], [0.0, 0.25, 0.0]]
  28. def _diamond_square_seed(
  29. replicates: int,
  30. width: int,
  31. height: int,
  32. random_fn: Callable[..., Tensor],
  33. device: Optional[torch.device] = None,
  34. dtype: Optional[torch.dtype] = None,
  35. ) -> Tensor:
  36. """Generate the diamond square image seee.
  37. Args:
  38. replicates: the num of batched replicas for the image.
  39. width: the expected image width.
  40. height: the expected image height.
  41. random_fn: the random function to generate the image seed.
  42. device: the torch device where to create the image seed.
  43. dtype: the torch dtype where to create the image seed.
  44. Return:
  45. the generated image seed of size Bx1xHxW.
  46. """
  47. KORNIA_CHECK(width == 3 or height == 3, "Height or Width must be equal to 3.")
  48. # TODO(anguelos): can we avoid transposing and passing always fixed size. This will cause issues with onnx/jit
  49. transpose: bool = False
  50. if height == 3:
  51. transpose = True
  52. width, height = height, width
  53. # width is always 3
  54. KORNIA_CHECK(height % 2 == 1 and height > 2, "Height must be odd and height bigger than 2")
  55. res: Tensor = random_fn([replicates, 1, width, height], device=device, dtype=dtype)
  56. res[..., ::2, ::2] = random_fn([replicates, 1, 2, (height + 1) // 2], device=device, dtype=dtype)
  57. # Diamond step
  58. res[..., 1, 1::2] = (res[..., ::2, :-2:2] + res[..., ::2, 2::2]).sum(dim=2) / 4.0
  59. # Square step
  60. if width > 3:
  61. res[..., 1, 2:-3:2] = (
  62. res[..., 0, 2:-3:2] + res[..., 2, 2:-3:2] + res[..., 1, 0:-4:2] + res[..., 1, 2:-3:2]
  63. ) / 4.0
  64. tmp1 = res[..., 2, 0]
  65. res[..., 1, 0] = res[..., 0, 0] + res[..., 1, 1] + tmp1
  66. res[..., 1, -1] = res[..., -1, -1] + res[..., 1, -2] + tmp1
  67. tmp2 = res[..., 1, 1::2]
  68. res[..., 0, 1::2] = res[..., 0, 0:-2:2] + res[..., 0, 2::2] + tmp2
  69. res[..., 2, 1::2] = res[..., 2, 0:-2:2] + res[..., 2, 2::2] + tmp2
  70. res = res / 3.0
  71. if transpose:
  72. res = res.transpose(2, 3)
  73. return res
  74. def _one_diamond_one_square(
  75. img: Tensor,
  76. random_scale: Union[float, Tensor],
  77. random_fn: Callable[..., Tensor] = torch.rand,
  78. diamond_kernel: Optional[Tensor] = None,
  79. square_kernel: Optional[Tensor] = None,
  80. ) -> Tensor:
  81. """Doubles the image resolution by applying a single diamond square steps.
  82. Recursive application of this method creates plasma fractals.
  83. Attention! The function is differentiable and gradients are computed as well.
  84. If this function is run in the usual sense, it is more efficient if it is run in a no_grad()
  85. Args:
  86. img: a 4D tensor where dimensions are Batch, Channel, Width, Height. Width and Height must both be 2^N+1 and
  87. Batch and Channels should in the usual case be 1.
  88. random_scale: a float number in [0,1] controlling the randomness created pixels get. I the usual case, it is
  89. halved at every application of this function.
  90. random_fn: the random function to generate the image seed.
  91. diamond_kernel: the 3x3 kernel to perform the diamond step.
  92. square_kernel: the 3x3 kernel to perform the square step.
  93. Return:
  94. A tensor on the same device as img with the same channels as img and width, height of 2^(N+1)+1.
  95. """
  96. KORNIA_CHECK_SHAPE(img, ["B", "C", "H", "W"])
  97. # TODO (anguelos) test multi channel and batch size > 1
  98. if diamond_kernel is None:
  99. diamond_kernel = Tensor([default_diamond_kernel]).to(img) # 1x3x3
  100. if square_kernel is None:
  101. square_kernel = Tensor([default_square_kernel]).to(img) # 1x3x3
  102. batch_sz, _, height, width = img.shape
  103. new_img: Tensor = torch.zeros(
  104. [batch_sz, 1, 2 * (height - 1) + 1, 2 * (width - 1) + 1], device=img.device, dtype=img.dtype
  105. )
  106. new_img[:, :, ::2, ::2] = img
  107. factor: float = 1.0 / 0.75
  108. pad_compencate = torch.ones_like(new_img)
  109. pad_compencate[:, :, :, 0] = factor
  110. pad_compencate[:, :, :, -1] = factor
  111. pad_compencate[:, :, 0, :] = factor
  112. pad_compencate[:, :, -1, :] = factor
  113. random_img: Tensor = random_fn(new_img.size(), device=img.device, dtype=img.dtype) * random_scale
  114. # TODO(edgar): use kornia.filter2d
  115. # diamond
  116. diamond_regions = filter2d(new_img, diamond_kernel)
  117. diamond_centers = (diamond_regions > 0).to(img.dtype)
  118. # TODO (anguelos) make sure diamond_regions*diamond_centers is needed
  119. new_img = new_img + (1 - random_scale) * diamond_regions * diamond_centers + diamond_centers * random_img
  120. # square
  121. square_regions = filter2d(new_img, square_kernel) * pad_compencate
  122. square_centers = (square_regions > 0).to(img.dtype)
  123. # TODO (anguelos) make sure square_centers*square_regions is needed
  124. new_img = new_img + square_centers * random_img + (1 - random_scale) * square_centers * square_regions
  125. return new_img
  126. def diamond_square(
  127. output_size: Tuple[int, int, int, int],
  128. roughness: Union[float, Tensor] = 0.5,
  129. random_scale: Union[float, Tensor] = 1.0,
  130. random_fn: Callable[..., Tensor] = torch.rand,
  131. normalize_range: Optional[Tuple[float, float]] = None,
  132. device: Optional[torch.device] = None,
  133. dtype: Optional[torch.dtype] = None,
  134. ) -> Tensor:
  135. """Generate Plasma Fractal Images using the diamond square algorithm.
  136. See: https://en.wikipedia.org/wiki/Diamond-square_algorithm
  137. Args:
  138. output_size: a tuple of integers with the BxCxHxW of the image to be generated.
  139. roughness: the scale value to apply at each recursion step.
  140. random_scale: the initial value of the scale for recursion.
  141. random_fn: the callable function to use to sample a random tensor.
  142. normalize_range: whether to normalize using min-max the output map. In case of a
  143. range is specified, min-max norm is applied between the provided range.
  144. device: the torch device to place the output map.
  145. dtype: the torch dtype to place the output map.
  146. Returns:
  147. A tensor with shape :math:`(B,C,H,W)` containing the fractal image.
  148. """
  149. KORNIA_CHECK(len(output_size) == 4, "output_size must be (B,C,H,W)")
  150. if not isinstance(random_scale, Tensor):
  151. random_scale = Tensor([[[[random_scale]]]]).to(device, dtype)
  152. random_scale = random_scale.expand([output_size[0] * output_size[1], 1, 1, 1])
  153. else:
  154. KORNIA_CHECK_IS_TENSOR(random_scale)
  155. random_scale = random_scale.view(-1, 1, 1, 1)
  156. random_scale = random_scale.expand([output_size[0], output_size[1], 1, 1])
  157. random_scale = random_scale.reshape([-1, 1, 1, 1])
  158. if not isinstance(roughness, Tensor):
  159. roughness = Tensor([[[[roughness]]]]).to(device, dtype)
  160. roughness = roughness.expand([output_size[0] * output_size[1], 1, 1, 1])
  161. else:
  162. roughness = roughness.view(-1, 1, 1, 1)
  163. roughness = roughness.expand([output_size[0], output_size[1], 1, 1])
  164. roughness = roughness.reshape([-1, 1, 1, 1])
  165. width, height = output_size[-2:]
  166. num_samples: int = 1
  167. for x in output_size[:-2]:
  168. num_samples *= x
  169. # compute the image seed
  170. p2_width: float = 2 ** math.ceil(math.log2(width - 1)) + 1
  171. p2_height: float = 2 ** math.ceil(math.log2(height - 1)) + 1
  172. recursion_depth: int = int(min(math.log2(p2_width - 1) - 1, math.log2(p2_height - 1) - 1))
  173. seed_width: int = (p2_width - 1) // 2**recursion_depth + 1
  174. seed_height: int = (p2_height - 1) // 2**recursion_depth + 1
  175. img: Tensor = random_scale * _diamond_square_seed(num_samples, seed_width, seed_height, random_fn, device, dtype)
  176. # perform recursion
  177. scale = random_scale
  178. for _ in range(recursion_depth):
  179. scale = scale * roughness
  180. img = _one_diamond_one_square(img, scale, random_fn)
  181. # slice to match with the output size
  182. img = img[..., :width, :height]
  183. img = img.view(output_size)
  184. # normalize the output in the range using min-max
  185. if normalize_range is not None:
  186. min_val, max_val = normalize_range
  187. img = normalize_min_max(img.contiguous(), min_val, max_val)
  188. return img