| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- # based on: https://github.com/anguelos/tormentor/blob/e8050ac235b0c7ad3c7d931cfa47c308a305c486/diamond_square/diamond_square.py # noqa: E501
- import math
- from typing import Callable, List, Optional, Tuple, Union
- import torch
- from kornia.core import Tensor
- from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
- from kornia.enhance import normalize_min_max
- from kornia.filters import filter2d
- # the default kernels for the diamond square
- default_diamond_kernel: List[List[float]] = [[0.25, 0.0, 0.25], [0.0, 0.0, 0.0], [0.25, 0.0, 0.25]]
- default_square_kernel: List[List[float]] = [[0.0, 0.25, 0.0], [0.25, 0.0, 0.25], [0.0, 0.25, 0.0]]
- def _diamond_square_seed(
- replicates: int,
- width: int,
- height: int,
- random_fn: Callable[..., Tensor],
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- ) -> Tensor:
- """Generate the diamond square image seee.
- Args:
- replicates: the num of batched replicas for the image.
- width: the expected image width.
- height: the expected image height.
- random_fn: the random function to generate the image seed.
- device: the torch device where to create the image seed.
- dtype: the torch dtype where to create the image seed.
- Return:
- the generated image seed of size Bx1xHxW.
- """
- KORNIA_CHECK(width == 3 or height == 3, "Height or Width must be equal to 3.")
- # TODO(anguelos): can we avoid transposing and passing always fixed size. This will cause issues with onnx/jit
- transpose: bool = False
- if height == 3:
- transpose = True
- width, height = height, width
- # width is always 3
- KORNIA_CHECK(height % 2 == 1 and height > 2, "Height must be odd and height bigger than 2")
- res: Tensor = random_fn([replicates, 1, width, height], device=device, dtype=dtype)
- res[..., ::2, ::2] = random_fn([replicates, 1, 2, (height + 1) // 2], device=device, dtype=dtype)
- # Diamond step
- res[..., 1, 1::2] = (res[..., ::2, :-2:2] + res[..., ::2, 2::2]).sum(dim=2) / 4.0
- # Square step
- if width > 3:
- res[..., 1, 2:-3:2] = (
- res[..., 0, 2:-3:2] + res[..., 2, 2:-3:2] + res[..., 1, 0:-4:2] + res[..., 1, 2:-3:2]
- ) / 4.0
- tmp1 = res[..., 2, 0]
- res[..., 1, 0] = res[..., 0, 0] + res[..., 1, 1] + tmp1
- res[..., 1, -1] = res[..., -1, -1] + res[..., 1, -2] + tmp1
- tmp2 = res[..., 1, 1::2]
- res[..., 0, 1::2] = res[..., 0, 0:-2:2] + res[..., 0, 2::2] + tmp2
- res[..., 2, 1::2] = res[..., 2, 0:-2:2] + res[..., 2, 2::2] + tmp2
- res = res / 3.0
- if transpose:
- res = res.transpose(2, 3)
- return res
- def _one_diamond_one_square(
- img: Tensor,
- random_scale: Union[float, Tensor],
- random_fn: Callable[..., Tensor] = torch.rand,
- diamond_kernel: Optional[Tensor] = None,
- square_kernel: Optional[Tensor] = None,
- ) -> Tensor:
- """Doubles the image resolution by applying a single diamond square steps.
- Recursive application of this method creates plasma fractals.
- Attention! The function is differentiable and gradients are computed as well.
- If this function is run in the usual sense, it is more efficient if it is run in a no_grad()
- Args:
- img: a 4D tensor where dimensions are Batch, Channel, Width, Height. Width and Height must both be 2^N+1 and
- Batch and Channels should in the usual case be 1.
- random_scale: a float number in [0,1] controlling the randomness created pixels get. I the usual case, it is
- halved at every application of this function.
- random_fn: the random function to generate the image seed.
- diamond_kernel: the 3x3 kernel to perform the diamond step.
- square_kernel: the 3x3 kernel to perform the square step.
- Return:
- A tensor on the same device as img with the same channels as img and width, height of 2^(N+1)+1.
- """
- KORNIA_CHECK_SHAPE(img, ["B", "C", "H", "W"])
- # TODO (anguelos) test multi channel and batch size > 1
- if diamond_kernel is None:
- diamond_kernel = Tensor([default_diamond_kernel]).to(img) # 1x3x3
- if square_kernel is None:
- square_kernel = Tensor([default_square_kernel]).to(img) # 1x3x3
- batch_sz, _, height, width = img.shape
- new_img: Tensor = torch.zeros(
- [batch_sz, 1, 2 * (height - 1) + 1, 2 * (width - 1) + 1], device=img.device, dtype=img.dtype
- )
- new_img[:, :, ::2, ::2] = img
- factor: float = 1.0 / 0.75
- pad_compencate = torch.ones_like(new_img)
- pad_compencate[:, :, :, 0] = factor
- pad_compencate[:, :, :, -1] = factor
- pad_compencate[:, :, 0, :] = factor
- pad_compencate[:, :, -1, :] = factor
- random_img: Tensor = random_fn(new_img.size(), device=img.device, dtype=img.dtype) * random_scale
- # TODO(edgar): use kornia.filter2d
- # diamond
- diamond_regions = filter2d(new_img, diamond_kernel)
- diamond_centers = (diamond_regions > 0).to(img.dtype)
- # TODO (anguelos) make sure diamond_regions*diamond_centers is needed
- new_img = new_img + (1 - random_scale) * diamond_regions * diamond_centers + diamond_centers * random_img
- # square
- square_regions = filter2d(new_img, square_kernel) * pad_compencate
- square_centers = (square_regions > 0).to(img.dtype)
- # TODO (anguelos) make sure square_centers*square_regions is needed
- new_img = new_img + square_centers * random_img + (1 - random_scale) * square_centers * square_regions
- return new_img
- def diamond_square(
- output_size: Tuple[int, int, int, int],
- roughness: Union[float, Tensor] = 0.5,
- random_scale: Union[float, Tensor] = 1.0,
- random_fn: Callable[..., Tensor] = torch.rand,
- normalize_range: Optional[Tuple[float, float]] = None,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- ) -> Tensor:
- """Generate Plasma Fractal Images using the diamond square algorithm.
- See: https://en.wikipedia.org/wiki/Diamond-square_algorithm
- Args:
- output_size: a tuple of integers with the BxCxHxW of the image to be generated.
- roughness: the scale value to apply at each recursion step.
- random_scale: the initial value of the scale for recursion.
- random_fn: the callable function to use to sample a random tensor.
- normalize_range: whether to normalize using min-max the output map. In case of a
- range is specified, min-max norm is applied between the provided range.
- device: the torch device to place the output map.
- dtype: the torch dtype to place the output map.
- Returns:
- A tensor with shape :math:`(B,C,H,W)` containing the fractal image.
- """
- KORNIA_CHECK(len(output_size) == 4, "output_size must be (B,C,H,W)")
- if not isinstance(random_scale, Tensor):
- random_scale = Tensor([[[[random_scale]]]]).to(device, dtype)
- random_scale = random_scale.expand([output_size[0] * output_size[1], 1, 1, 1])
- else:
- KORNIA_CHECK_IS_TENSOR(random_scale)
- random_scale = random_scale.view(-1, 1, 1, 1)
- random_scale = random_scale.expand([output_size[0], output_size[1], 1, 1])
- random_scale = random_scale.reshape([-1, 1, 1, 1])
- if not isinstance(roughness, Tensor):
- roughness = Tensor([[[[roughness]]]]).to(device, dtype)
- roughness = roughness.expand([output_size[0] * output_size[1], 1, 1, 1])
- else:
- roughness = roughness.view(-1, 1, 1, 1)
- roughness = roughness.expand([output_size[0], output_size[1], 1, 1])
- roughness = roughness.reshape([-1, 1, 1, 1])
- width, height = output_size[-2:]
- num_samples: int = 1
- for x in output_size[:-2]:
- num_samples *= x
- # compute the image seed
- p2_width: float = 2 ** math.ceil(math.log2(width - 1)) + 1
- p2_height: float = 2 ** math.ceil(math.log2(height - 1)) + 1
- recursion_depth: int = int(min(math.log2(p2_width - 1) - 1, math.log2(p2_height - 1) - 1))
- seed_width: int = (p2_width - 1) // 2**recursion_depth + 1
- seed_height: int = (p2_height - 1) // 2**recursion_depth + 1
- img: Tensor = random_scale * _diamond_square_seed(num_samples, seed_width, seed_height, random_fn, device, dtype)
- # perform recursion
- scale = random_scale
- for _ in range(recursion_depth):
- scale = scale * roughness
- img = _one_diamond_one_square(img, scale, random_fn)
- # slice to match with the output size
- img = img[..., :width, :height]
- img = img.view(output_size)
- # normalize the output in the range using min-max
- if normalize_range is not None:
- min_val, max_val = normalize_range
- img = normalize_min_max(img.contiguous(), min_val, max_val)
- return img
|