# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # based on: https://github.com/anguelos/tormentor/blob/e8050ac235b0c7ad3c7d931cfa47c308a305c486/diamond_square/diamond_square.py # noqa: E501 import math from typing import Callable, List, Optional, Tuple, Union import torch from kornia.core import Tensor from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE from kornia.enhance import normalize_min_max from kornia.filters import filter2d # the default kernels for the diamond square default_diamond_kernel: List[List[float]] = [[0.25, 0.0, 0.25], [0.0, 0.0, 0.0], [0.25, 0.0, 0.25]] default_square_kernel: List[List[float]] = [[0.0, 0.25, 0.0], [0.25, 0.0, 0.25], [0.0, 0.25, 0.0]] def _diamond_square_seed( replicates: int, width: int, height: int, random_fn: Callable[..., Tensor], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Tensor: """Generate the diamond square image seee. Args: replicates: the num of batched replicas for the image. width: the expected image width. height: the expected image height. random_fn: the random function to generate the image seed. device: the torch device where to create the image seed. dtype: the torch dtype where to create the image seed. Return: the generated image seed of size Bx1xHxW. """ KORNIA_CHECK(width == 3 or height == 3, "Height or Width must be equal to 3.") # TODO(anguelos): can we avoid transposing and passing always fixed size. This will cause issues with onnx/jit transpose: bool = False if height == 3: transpose = True width, height = height, width # width is always 3 KORNIA_CHECK(height % 2 == 1 and height > 2, "Height must be odd and height bigger than 2") res: Tensor = random_fn([replicates, 1, width, height], device=device, dtype=dtype) res[..., ::2, ::2] = random_fn([replicates, 1, 2, (height + 1) // 2], device=device, dtype=dtype) # Diamond step res[..., 1, 1::2] = (res[..., ::2, :-2:2] + res[..., ::2, 2::2]).sum(dim=2) / 4.0 # Square step if width > 3: res[..., 1, 2:-3:2] = ( res[..., 0, 2:-3:2] + res[..., 2, 2:-3:2] + res[..., 1, 0:-4:2] + res[..., 1, 2:-3:2] ) / 4.0 tmp1 = res[..., 2, 0] res[..., 1, 0] = res[..., 0, 0] + res[..., 1, 1] + tmp1 res[..., 1, -1] = res[..., -1, -1] + res[..., 1, -2] + tmp1 tmp2 = res[..., 1, 1::2] res[..., 0, 1::2] = res[..., 0, 0:-2:2] + res[..., 0, 2::2] + tmp2 res[..., 2, 1::2] = res[..., 2, 0:-2:2] + res[..., 2, 2::2] + tmp2 res = res / 3.0 if transpose: res = res.transpose(2, 3) return res def _one_diamond_one_square( img: Tensor, random_scale: Union[float, Tensor], random_fn: Callable[..., Tensor] = torch.rand, diamond_kernel: Optional[Tensor] = None, square_kernel: Optional[Tensor] = None, ) -> Tensor: """Doubles the image resolution by applying a single diamond square steps. Recursive application of this method creates plasma fractals. Attention! The function is differentiable and gradients are computed as well. If this function is run in the usual sense, it is more efficient if it is run in a no_grad() Args: img: a 4D tensor where dimensions are Batch, Channel, Width, Height. Width and Height must both be 2^N+1 and Batch and Channels should in the usual case be 1. random_scale: a float number in [0,1] controlling the randomness created pixels get. I the usual case, it is halved at every application of this function. random_fn: the random function to generate the image seed. diamond_kernel: the 3x3 kernel to perform the diamond step. square_kernel: the 3x3 kernel to perform the square step. Return: A tensor on the same device as img with the same channels as img and width, height of 2^(N+1)+1. """ KORNIA_CHECK_SHAPE(img, ["B", "C", "H", "W"]) # TODO (anguelos) test multi channel and batch size > 1 if diamond_kernel is None: diamond_kernel = Tensor([default_diamond_kernel]).to(img) # 1x3x3 if square_kernel is None: square_kernel = Tensor([default_square_kernel]).to(img) # 1x3x3 batch_sz, _, height, width = img.shape new_img: Tensor = torch.zeros( [batch_sz, 1, 2 * (height - 1) + 1, 2 * (width - 1) + 1], device=img.device, dtype=img.dtype ) new_img[:, :, ::2, ::2] = img factor: float = 1.0 / 0.75 pad_compencate = torch.ones_like(new_img) pad_compencate[:, :, :, 0] = factor pad_compencate[:, :, :, -1] = factor pad_compencate[:, :, 0, :] = factor pad_compencate[:, :, -1, :] = factor random_img: Tensor = random_fn(new_img.size(), device=img.device, dtype=img.dtype) * random_scale # TODO(edgar): use kornia.filter2d # diamond diamond_regions = filter2d(new_img, diamond_kernel) diamond_centers = (diamond_regions > 0).to(img.dtype) # TODO (anguelos) make sure diamond_regions*diamond_centers is needed new_img = new_img + (1 - random_scale) * diamond_regions * diamond_centers + diamond_centers * random_img # square square_regions = filter2d(new_img, square_kernel) * pad_compencate square_centers = (square_regions > 0).to(img.dtype) # TODO (anguelos) make sure square_centers*square_regions is needed new_img = new_img + square_centers * random_img + (1 - random_scale) * square_centers * square_regions return new_img def diamond_square( output_size: Tuple[int, int, int, int], roughness: Union[float, Tensor] = 0.5, random_scale: Union[float, Tensor] = 1.0, random_fn: Callable[..., Tensor] = torch.rand, normalize_range: Optional[Tuple[float, float]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Tensor: """Generate Plasma Fractal Images using the diamond square algorithm. See: https://en.wikipedia.org/wiki/Diamond-square_algorithm Args: output_size: a tuple of integers with the BxCxHxW of the image to be generated. roughness: the scale value to apply at each recursion step. random_scale: the initial value of the scale for recursion. random_fn: the callable function to use to sample a random tensor. normalize_range: whether to normalize using min-max the output map. In case of a range is specified, min-max norm is applied between the provided range. device: the torch device to place the output map. dtype: the torch dtype to place the output map. Returns: A tensor with shape :math:`(B,C,H,W)` containing the fractal image. """ KORNIA_CHECK(len(output_size) == 4, "output_size must be (B,C,H,W)") if not isinstance(random_scale, Tensor): random_scale = Tensor([[[[random_scale]]]]).to(device, dtype) random_scale = random_scale.expand([output_size[0] * output_size[1], 1, 1, 1]) else: KORNIA_CHECK_IS_TENSOR(random_scale) random_scale = random_scale.view(-1, 1, 1, 1) random_scale = random_scale.expand([output_size[0], output_size[1], 1, 1]) random_scale = random_scale.reshape([-1, 1, 1, 1]) if not isinstance(roughness, Tensor): roughness = Tensor([[[[roughness]]]]).to(device, dtype) roughness = roughness.expand([output_size[0] * output_size[1], 1, 1, 1]) else: roughness = roughness.view(-1, 1, 1, 1) roughness = roughness.expand([output_size[0], output_size[1], 1, 1]) roughness = roughness.reshape([-1, 1, 1, 1]) width, height = output_size[-2:] num_samples: int = 1 for x in output_size[:-2]: num_samples *= x # compute the image seed p2_width: float = 2 ** math.ceil(math.log2(width - 1)) + 1 p2_height: float = 2 ** math.ceil(math.log2(height - 1)) + 1 recursion_depth: int = int(min(math.log2(p2_width - 1) - 1, math.log2(p2_height - 1) - 1)) seed_width: int = (p2_width - 1) // 2**recursion_depth + 1 seed_height: int = (p2_height - 1) // 2**recursion_depth + 1 img: Tensor = random_scale * _diamond_square_seed(num_samples, seed_width, seed_height, random_fn, device, dtype) # perform recursion scale = random_scale for _ in range(recursion_depth): scale = scale * roughness img = _one_diamond_one_square(img, scale, random_fn) # slice to match with the output size img = img[..., :width, :height] img = img.view(output_size) # normalize the output in the range using min-max if normalize_range is not None: min_val, max_val = normalize_range img = normalize_min_max(img.contiguous(), min_val, max_val) return img