| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Any, Dict, List, Tuple, Union
- import torch
- import torch.nn.functional as F
- from torch import nn
- from kornia.constants import pi
- from kornia.core import Tensor, cos, sin, tensor, zeros
- from kornia.filters import GaussianBlur2d, SpatialGradient
- from kornia.geometry.conversions import cart2pol
- from kornia.utils import create_meshgrid
- # Precomputed coefficients for Von Mises kernel, given N and K(appa).
- sqrt2: float = 1.4142135623730951
- COEFFS_N1_K1: List[float] = [0.38214156, 0.48090413]
- COEFFS_N2_K8: List[float] = [0.14343168, 0.268285, 0.21979234]
- COEFFS_N3_K8: List[float] = [0.14343168, 0.268285, 0.21979234, 0.15838885]
- COEFFS: Dict[str, List[float]] = {"xy": COEFFS_N1_K1, "rhophi": COEFFS_N2_K8, "theta": COEFFS_N3_K8}
- urls: Dict[str, str] = {
- k: f"https://github.com/manyids2/mkd_pytorch/raw/master/mkd_pytorch/mkd-{k}-64.pth"
- for k in ["cart", "polar", "concat"]
- }
- def get_grid_dict(patch_size: int = 32) -> Dict[str, Tensor]:
- r"""Get cartesian and polar parametrizations of grid."""
- kgrid = create_meshgrid(height=patch_size, width=patch_size, normalized_coordinates=True)
- x = kgrid[0, :, :, 0]
- y = kgrid[0, :, :, 1]
- rho, phi = cart2pol(x, y)
- grid_dict = {"x": x, "y": y, "rho": rho, "phi": phi}
- return grid_dict
- def get_kron_order(d1: int, d2: int) -> Tensor:
- r"""Get order for doing kronecker product."""
- grid_d1, grid_d2 = torch.meshgrid(torch.arange(d1), torch.arange(d2), indexing="ij")
- kron_order = torch.stack([grid_d1, grid_d2], dim=2).reshape(-1, 2)
- return kron_order.to(torch.int64)
- class MKDGradients(nn.Module):
- r"""Module, which computes gradients of given patches, stacked as [magnitudes, orientations].
- Given gradients $g_x$, $g_y$ with respect to $x$, $y$ respectively,
- - $\mathbox{mags} = $\sqrt{g_x^2 + g_y^2 + eps}$
- - $\mathbox{oris} = $\mbox{tan}^{-1}(\nicefrac{g_y}{g_x})$.
- Args:
- patch_size: Input patch size in pixels.
- Returns:
- gradients of given patches.
- Shape:
- - Input: (B, 1, patch_size, patch_size)
- - Output: (B, 2, patch_size, patch_size)
- Example:
- >>> patches = torch.rand(23, 1, 32, 32)
- >>> gradient = MKDGradients()
- >>> g = gradient(patches) # 23x2x32x32
- """
- def __init__(self) -> None:
- super().__init__()
- self.eps = 1e-8
- self.grad = SpatialGradient(mode="diff", order=1, normalized=False)
- def forward(self, x: Tensor) -> Tensor:
- if not isinstance(x, Tensor):
- raise TypeError(f"Input type is not a Tensor. Got {type(x)}")
- if not len(x.shape) == 4:
- raise ValueError(f"Invalid input shape, we expect Bx1xHxW. Got: {x.shape}")
- # Modify 'diff' gradient. Before we had lambda function, but it is not jittable
- grads_xy = -self.grad(x)
- gx = grads_xy[:, :, 0, :, :]
- gy = grads_xy[:, :, 1, :, :]
- y = torch.cat(cart2pol(gx, gy, self.eps), dim=1)
- return y
- def __repr__(self) -> str:
- return self.__class__.__name__
- class VonMisesKernel(nn.Module):
- r"""Module, which computes parameters of Von Mises kernel given coefficients, and embeds given patches.
- Args:
- patch_size: Input patch size in pixels.
- coeffs: List of coefficients. Some examples are hardcoded in COEFFS,
- Returns:
- Von Mises embedding of given parametrization.
- Shape:
- - Input: (B, 1, patch_size, patch_size)
- - Output: (B, d, patch_size, patch_size)
- Examples:
- >>> oris = torch.rand(23, 1, 32, 32)
- >>> vm = VonMisesKernel(patch_size=32,
- ... coeffs=[0.14343168,
- ... 0.268285,
- ... 0.21979234])
- >>> emb = vm(oris) # 23x7x32x32
- """
- def __init__(self, patch_size: int, coeffs: Union[List[Union[float, int]], Tuple[Union[float, int], ...]]) -> None:
- super().__init__()
- self.patch_size = patch_size
- b_coeffs = tensor(coeffs)
- self.register_buffer("coeffs", b_coeffs)
- # Compute parameters.
- n = len(coeffs) - 1
- self.n = n
- self.d = 2 * n + 1
- # Precompute helper variables.
- emb0 = torch.ones([1, 1, patch_size, patch_size])
- frange = torch.arange(n) + 1
- frange = frange.reshape(-1, 1, 1)
- weights = zeros([2 * n + 1])
- weights[: n + 1] = torch.sqrt(b_coeffs)
- weights[n + 1 :] = torch.sqrt(b_coeffs[1:])
- weights = weights.reshape(-1, 1, 1)
- self.register_buffer("emb0", emb0)
- self.register_buffer("frange", frange)
- self.register_buffer("weights", weights)
- def forward(self, x: Tensor) -> Tensor:
- if not isinstance(x, Tensor):
- raise TypeError(f"Input type is not a Tensor. Got {type(x)}")
- if not len(x.shape) == 4 or x.shape[1] != 1:
- raise ValueError(f"Invalid input shape, we expect Bx1xHxW. Got: {x.shape}")
- if not isinstance(self.emb0, Tensor):
- raise TypeError(f"Emb0 type is not a Tensor. Got {type(x)}")
- emb0 = self.emb0.to(x).repeat(x.size(0), 1, 1, 1)
- frange = self.frange.to(x) * x
- emb1 = cos(frange)
- emb2 = sin(frange)
- embedding = torch.cat([emb0, emb1, emb2], dim=1)
- embedding = self.weights * embedding
- return embedding
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(patch_size={self.patch_size}, n={self.n}, d={self.d}, coeffs={self.coeffs})"
- class EmbedGradients(nn.Module):
- r"""Module that computes gradient embedding, weighted by sqrt of magnitudes of given patches.
- Args:
- patch_size: Input patch size in pixels.
- relative: absolute or relative gradients.
- Returns:
- Gradient embedding.
- Shape:
- - Input: (B, 2, patch_size, patch_size)
- - Output: (B, 7, patch_size, patch_size)
- Examples:
- >>> grads = torch.rand(23, 2, 32, 32)
- >>> emb_grads = EmbedGradients(patch_size=32,
- ... relative=False)
- >>> emb = emb_grads(grads) # 23x7x32x32
- """
- def __init__(self, patch_size: int = 32, relative: bool = False) -> None:
- super().__init__()
- self.patch_size = patch_size
- self.relative = relative
- self.eps = 1e-8
- # Theta kernel for gradients.
- self.kernel = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS["theta"])
- # Relative gradients.
- kgrid = create_meshgrid(height=patch_size, width=patch_size, normalized_coordinates=True)
- _, phi = cart2pol(kgrid[:, :, :, 0], kgrid[:, :, :, 1])
- self.register_buffer("phi", phi)
- def emb_mags(self, mags: Tensor) -> Tensor:
- """Embed square roots of magnitudes with eps for numerical reasons."""
- mags = torch.sqrt(mags + self.eps)
- return mags
- def forward(self, grads: Tensor) -> Tensor:
- if not isinstance(grads, Tensor):
- raise TypeError(f"Input type is not a Tensor. Got {type(grads)}")
- if not len(grads.shape) == 4:
- raise ValueError(f"Invalid input shape, we expect Bx2xHxW. Got: {grads.shape}")
- mags = grads[:, :1, :, :]
- oris = grads[:, 1:, :, :]
- if self.relative:
- oris = oris - self.phi.to(oris)
- y = self.kernel(oris) * self.emb_mags(mags)
- return y
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(patch_size={self.patch_size}, relative={self.relative})"
- def spatial_kernel_embedding(kernel_type: str, grids: Dict[str, Tensor]) -> Tensor:
- r"""Compute embeddings for cartesian and polar parametrizations."""
- factors = {"phi": 1.0, "rho": pi / sqrt2, "x": pi / 2, "y": pi / 2}
- if kernel_type == "cart":
- coeffs_ = "xy"
- params_ = ["x", "y"]
- elif kernel_type == "polar":
- coeffs_ = "rhophi"
- params_ = ["phi", "rho"]
- # Infer patch_size.
- keys = list(grids.keys())
- patch_size = grids[keys[0]].shape[-1]
- # Scale appropriately.
- grids_normed = {k: v * factors[k] for k, v in grids.items()}
- grids_normed = {k: v.unsqueeze(0).unsqueeze(0).float() for k, v in grids_normed.items()}
- # x,y/rho,phi kernels.
- vm_a = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS[coeffs_])
- vm_b = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS[coeffs_])
- emb_a = vm_a(grids_normed[params_[0]]).squeeze()
- emb_b = vm_b(grids_normed[params_[1]]).squeeze()
- # Final precomputed position embedding.
- kron_order = get_kron_order(vm_a.d, vm_b.d)
- spatial_kernel = emb_a.index_select(0, kron_order[:, 0]) * emb_b.index_select(0, kron_order[:, 1])
- return spatial_kernel
- class ExplicitSpacialEncoding(nn.Module):
- r"""Module that computes explicit cartesian or polar embedding.
- Args:
- kernel_type: Parametrization of kernel ``'polar'`` or ``'cart'``.
- fmap_size: Input feature map size in pixels.
- in_dims: Dimensionality of input feature map.
- do_gmask: Apply gaussian mask.
- do_l2: Apply l2-normalization.
- Returns:
- Explicit cartesian or polar embedding.
- Shape:
- - Input: (B, in_dims, fmap_size, fmap_size)
- - Output: (B, out_dims, fmap_size, fmap_size)
- Example:
- >>> emb_ori = torch.rand(23, 7, 32, 32)
- >>> ese = ExplicitSpacialEncoding(kernel_type='polar',
- ... fmap_size=32,
- ... in_dims=7,
- ... do_gmask=True,
- ... do_l2=True)
- >>> desc = ese(emb_ori) # 23x175x32x32
- """
- def __init__(
- self,
- kernel_type: str = "polar",
- fmap_size: int = 32,
- in_dims: int = 7,
- do_gmask: bool = True,
- do_l2: bool = True,
- ) -> None:
- super().__init__()
- if kernel_type not in ["polar", "cart"]:
- raise NotImplementedError(f"{kernel_type} is not valid, use polar or cart).")
- self.kernel_type = kernel_type
- self.fmap_size = fmap_size
- self.in_dims = in_dims
- self.do_gmask = do_gmask
- self.do_l2 = do_l2
- self.grid = get_grid_dict(fmap_size)
- self.gmask = None
- # Precompute embedding.
- emb = spatial_kernel_embedding(self.kernel_type, self.grid)
- # Gaussian mask.
- if self.do_gmask:
- self.gmask = self.get_gmask(sigma=1.0)
- emb = emb * self.gmask
- # Store precomputed embedding.
- self.register_buffer("emb", emb.unsqueeze(0))
- self.d_emb: int = emb.shape[0]
- self.out_dims: int = self.in_dims * self.d_emb
- self.odims: int = self.out_dims
- # Store kronecker form.
- emb2, idx1 = self.init_kron()
- self.register_buffer("emb2", emb2)
- self.register_buffer("idx1", idx1)
- def get_gmask(self, sigma: float) -> Tensor:
- """Compute Gaussian mask."""
- norm_rho = self.grid["rho"] / self.grid["rho"].max()
- gmask = torch.exp(-1 * norm_rho**2 / sigma**2)
- return gmask
- def init_kron(self) -> Tuple[Tensor, Tensor]:
- """Initialize helper variables to calculate kronecker."""
- kron = get_kron_order(self.in_dims, self.d_emb)
- _emb = torch.jit.annotate(Tensor, self.emb)
- emb2 = torch.index_select(_emb, 1, kron[:, 1])
- return emb2, kron[:, 0]
- def forward(self, x: Tensor) -> Tensor:
- if not isinstance(x, Tensor):
- raise TypeError(f"Input type is not a Tensor. Got {type(x)}")
- if not ((len(x.shape) == 4) | (x.shape[1] == self.in_dims)):
- raise ValueError(f"Invalid input shape, we expect Bx{self.in_dims}xHxW. Got: {x.shape}")
- idx1 = torch.jit.annotate(Tensor, self.idx1)
- emb1 = torch.index_select(x, 1, idx1)
- output = emb1 * self.emb2
- output = output.sum(dim=(2, 3))
- if self.do_l2:
- output = F.normalize(output, dim=1)
- return output
- def __repr__(self) -> str:
- return (
- f"{self.__class__.__name__}("
- f"kernel_type={self.kernel_type}, "
- f"fmap_size={self.fmap_size}, "
- f"in_dims={self.in_dims}, "
- f"out_dims={self.out_dims}, "
- f"do_gmask={self.do_gmask}, "
- f"do_l2={self.do_l2})"
- )
- class Whitening(nn.Module):
- r"""Module, performs supervised or unsupervised whitening.
- This is based on the paper "Understanding and Improving Kernel Local Descriptors".
- See :cite:`mukundan2019understanding` for more details.
- Args:
- xform: Variant of whitening to use. None, 'lw', 'pca', 'pcaws', 'pcawt'.
- whitening_model: Dictionary with keys 'mean', 'eigvecs', 'eigvals' holding Tensors.
- in_dims: Dimensionality of input descriptors.
- output_dims: (int) Dimensionality reduction.
- keval: Shrinkage parameter.
- t: Attenuation parameter.
- Returns:
- l2-normalized, whitened descriptors.
- Shape:
- - Input: (B, in_dims, fmap_size, fmap_size)
- - Output: (B, out_dims, fmap_size, fmap_size)
- Examples:
- >>> descs = torch.rand(23, 238)
- >>> whitening_model = {'pca': {'mean': torch.zeros(238),
- ... 'eigvecs': torch.eye(238),
- ... 'eigvals': torch.ones(238)}}
- >>> whitening = Whitening(xform='pcawt',
- ... whitening_model=whitening_model,
- ... in_dims=238,
- ... output_dims=128,
- ... keval=40,
- ... t=0.7)
- >>> wdescs = whitening(descs) # 23x128
- """
- def __init__(
- self,
- xform: str,
- whitening_model: Union[Dict[str, Dict[str, Tensor]], None],
- in_dims: int,
- output_dims: int = 128,
- keval: int = 40,
- t: float = 0.7,
- ) -> None:
- super().__init__()
- self.xform = xform
- self.in_dims = in_dims
- self.keval = keval
- self.t = t
- self.pval = 1.0
- # Compute true output_dims.
- output_dims = min(output_dims, in_dims)
- self.output_dims = output_dims
- # Initialize identity transform.
- self.mean = nn.Parameter(zeros(in_dims), requires_grad=True)
- self.evecs = nn.Parameter(torch.eye(in_dims)[:, :output_dims], requires_grad=True)
- self.evals = nn.Parameter(torch.ones(in_dims)[:output_dims], requires_grad=True)
- if whitening_model is not None:
- self.load_whitening_parameters(whitening_model)
- def load_whitening_parameters(self, whitening_model: Dict[str, Dict[str, Tensor]]) -> None:
- algo = "lw" if self.xform == "lw" else "pca"
- wh_model = whitening_model[algo]
- self.mean.data = wh_model["mean"]
- self.evecs.data = wh_model["eigvecs"][:, : self.output_dims]
- self.evals.data = wh_model["eigvals"][: self.output_dims]
- modifications = {
- "pca": self._modify_pca,
- "lw": self._modify_lw,
- "pcaws": self._modify_pcaws,
- "pcawt": self._modify_pcawt,
- }
- # Call modification.
- modifications[self.xform]()
- def _modify_pca(self) -> None:
- """Modify powerlaw parameter."""
- self.pval = 0.5
- def _modify_lw(self) -> None:
- """No modification required."""
- def _modify_pcaws(self) -> None:
- """Shrinkage for eigenvalues."""
- alpha = self.evals[self.keval]
- evals = ((1 - alpha) * self.evals) + alpha
- self.evecs.data = self.evecs @ torch.diag(torch.pow(evals, -0.5))
- def _modify_pcawt(self) -> None:
- """Attenuation for eigenvalues."""
- m = -0.5 * self.t
- self.evecs.data = self.evecs @ torch.diag(torch.pow(self.evals, m))
- def forward(self, x: Tensor) -> Tensor:
- if not isinstance(x, Tensor):
- raise TypeError(f"Input type is not a Tensor. Got {type(x)}")
- if not len(x.shape) == 2:
- raise ValueError(f"Invalid input shape, we expect NxD. Got: {x.shape}")
- x = x - self.mean # Center the data.
- x = x @ self.evecs # Apply rotation and/or scaling.
- x = torch.sign(x) * torch.pow(torch.abs(x), self.pval) # Powerlaw.
- return F.normalize(x, dim=1)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(xform={self.xform}, in_dims={self.in_dims}, output_dims={self.output_dims})"
- class MKDDescriptor(nn.Module):
- r"""Module that computes Multiple Kernel local descriptors.
- This is based on the paper "Understanding and Improving Kernel Local Descriptors".
- See :cite:`mukundan2019understanding` for more details.
- Args:
- patch_size: Input patch size in pixels.
- kernel_type: Parametrization of kernel ``'concat'``, ``'cart'``, ``'polar'``.
- whitening: Whitening transform to apply ``None``, ``'lw'``, ``'pca'``, ``'pcawt'``, ``'pcaws'``.
- training_set: Set that model was trained on ``'liberty'``, ``'notredame'``, ``'yosemite'``.
- output_dims: Dimensionality reduction.
- Returns:
- Explicit cartesian or polar embedding.
- Shape:
- - Input: :math:`(B, in_{dims}, fmap_{size}, fmap_{size})`.
- - Output: :math:`(B, out_{dims}, fmap_{size}, fmap_{size})`,
- Examples:
- >>> patches = torch.rand(23, 1, 32, 32)
- >>> mkd = MKDDescriptor(patch_size=32,
- ... kernel_type='concat',
- ... whitening='pcawt',
- ... training_set='liberty',
- ... output_dims=128)
- >>> desc = mkd(patches) # 23x128
- """
- def __init__(
- self,
- patch_size: int = 32,
- kernel_type: str = "concat",
- whitening: str = "pcawt",
- training_set: str = "liberty",
- output_dims: int = 128,
- ) -> None:
- super().__init__()
- self.patch_size: int = patch_size
- self.kernel_type: str = kernel_type
- self.whitening: str = whitening
- self.training_set: str = training_set
- self.sigma = 1.4 * (patch_size / 64)
- self.smoothing = GaussianBlur2d((5, 5), (self.sigma, self.sigma), "replicate")
- self.gradients = MKDGradients()
- # This stupid thing needed for jitting...
- polar_s: str = "polar"
- cart_s: str = "cart"
- self.parametrizations = [polar_s, cart_s] if self.kernel_type == "concat" else [self.kernel_type]
- # Initialize cartesian/polar embedding with absolute/relative gradients.
- self.odims: int = 0
- relative_orientations = {polar_s: True, cart_s: False}
- self.feats = {}
- for parametrization in self.parametrizations:
- gradient_embedding = EmbedGradients(patch_size=patch_size, relative=relative_orientations[parametrization])
- spatial_encoding = ExplicitSpacialEncoding(
- kernel_type=parametrization, fmap_size=patch_size, in_dims=gradient_embedding.kernel.d
- )
- self.feats[parametrization] = nn.Sequential(gradient_embedding, spatial_encoding)
- self.odims += spatial_encoding.odims
- # Compute true output_dims.
- self.output_dims: int = min(output_dims, self.odims)
- # Load supervised(lw)/unsupervised(pca) model trained on training_set.
- if self.whitening is not None:
- whitening_models = torch.hub.load_state_dict_from_url(
- urls[self.kernel_type], map_location=torch.device("cpu")
- )
- whitening_model = whitening_models[training_set]
- self.whitening_layer = Whitening(
- whitening, whitening_model, in_dims=self.odims, output_dims=self.output_dims
- )
- self.odims = self.output_dims
- self.eval()
- def forward(self, patches: Tensor) -> Tensor:
- if not isinstance(patches, Tensor):
- raise TypeError(f"Input type is not a Tensor. Got {type(patches)}")
- if not len(patches.shape) == 4:
- raise ValueError(f"Invalid input shape, we expect Bx1xHxW. Got: {patches.shape}")
- # Extract gradients.
- g = self.smoothing(patches)
- g = self.gradients(g)
- # Extract polar/cart features.
- features = []
- for parametrization in self.parametrizations:
- self.feats[parametrization].to(g.device)
- features.append(self.feats[parametrization](g))
- # Concatenate.
- y = torch.cat(features, dim=1)
- # l2-normalize.
- y = F.normalize(y, dim=1)
- # Whiten descriptors.
- if self.whitening is not None:
- y = self.whitening_layer(y)
- return y
- def __repr__(self) -> str:
- return (
- f"{self.__class__.__name__}("
- f"patch_size={self.patch_size}, "
- f"kernel_type={self.kernel_type}, "
- f"whitening={self.whitening}, "
- f"training_set={self.training_set}, "
- f"output_dims={self.output_dims})"
- )
- def load_whitening_model(kernel_type: str, training_set: str) -> Dict[str, Any]:
- """Load whitening model."""
- whitening_models = torch.hub.load_state_dict_from_url(urls[kernel_type], map_location=torch.device("cpu"))
- whitening_model = whitening_models[training_set]
- return whitening_model
- class SimpleKD(nn.Module):
- """Example to write custom Kernel Descriptors."""
- def __init__(
- self,
- patch_size: int = 32,
- kernel_type: str = "polar", # 'cart' 'polar'
- whitening: str = "pcawt", # 'lw', 'pca', 'pcaws', 'pcawt
- training_set: str = "liberty", # 'liberty', 'notredame', 'yosemite'
- output_dims: int = 128,
- ) -> None:
- super().__init__()
- relative: bool = kernel_type == "polar"
- sigma: float = 1.4 * (patch_size / 64)
- self.patch_size = patch_size
- # Sequence of modules.
- smoothing = GaussianBlur2d((5, 5), (sigma, sigma), "replicate")
- gradients = MKDGradients()
- ori = EmbedGradients(patch_size=patch_size, relative=relative)
- ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=patch_size, in_dims=ori.kernel.d)
- wh = Whitening(
- whitening, load_whitening_model(kernel_type, training_set), in_dims=ese.odims, output_dims=output_dims
- )
- self.features = nn.Sequential(smoothing, gradients, ori, ese, wh)
- def forward(self, x: Tensor) -> Tensor:
- return self.features(x)
|