| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Any, Callable, Optional, Tuple
- from torch import Tensor, nn
- from torch.autograd import Function
- __all__ = ["STEFunction", "StraightThroughEstimator"]
- class STEFunction(Function):
- """Straight-Through Estimation (STE) function.
- STE bridges the gradients between the input tensor and output tensor as if the function
- was an identity function. Meanwhile, advanced gradient functions are also supported. e.g.
- the output gradients can be mapped into [-1, 1] with ``F.hardtanh`` function.
- Args:
- grad_fn: function to restrain the gradient received. If None, no mapping will performed.
- Example:
- Let the gradients of ``torch.sign`` estimated from STE.
- >>> input = torch.randn(4, requires_grad = True)
- >>> output = torch.sign(input)
- >>> loss = output.mean()
- >>> loss.backward()
- >>> input.grad
- tensor([0., 0., 0., 0.])
- >>> with torch.no_grad():
- ... output = torch.sign(input)
- >>> out_est = STEFunction.apply(input, output)
- >>> loss = out_est.mean()
- >>> loss.backward()
- >>> input.grad
- tensor([0.2500, 0.2500, 0.2500, 0.2500])
- """
- @staticmethod
- def forward(ctx: Any, input: Tensor, output: Tensor, grad_fn: Optional[Callable[..., Any]] = None) -> Tensor:
- ctx.in_shape = input.shape
- ctx.out_shape = output.shape
- ctx.grad_fn = grad_fn
- return output
- @staticmethod
- def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, Tensor, None]:
- if ctx.grad_fn is None:
- return grad_output.sum_to_size(ctx.in_shape), grad_output.sum_to_size(ctx.out_shape), None
- return (
- ctx.grad_fn(grad_output.sum_to_size(ctx.in_shape)),
- ctx.grad_fn(grad_output.sum_to_size(ctx.out_shape)),
- None,
- )
- # https://pytorch.org/docs/1.10.0/onnx.html#torch-autograd-functions
- # @staticmethod
- # def symbolic(g: torch._C.graph, input: torch._C.Value) -> torch._C.Value:
- # raise NotImplementedError(
- # "ONNX support is not implemented at the moment."
- # "Feel free to contribute to https://github.com/kornia/kornia.")
- class StraightThroughEstimator(nn.Module):
- """Straight-Through Estimation (STE) module.
- STE wraps the ``STEFunction`` to aid the back propagation of non-differentiable modules.
- It may also use to avoid gradient computation for differentiable operations. By default,
- STE bridges the gradients between the input tensor and output tensor as if the function
- was an identity function. Meanwhile, advanced gradient functions are also supported. e.g.
- the output gradients can be mapped into [-1, 1] with ``F.hardtanh`` function.
- Args:
- target_fn: the target function to wrap with.
- grad_fn: function to restrain the gradient received. If None, no mapping will performed.
- Example:
- ``RandomPosterize`` is a non-differentiable operation. Let STE estimate the gradients.
- >>> import kornia.augmentation as K
- >>> import torch.nn.functional as F
- >>> input = torch.randn(1, 1, 4, 4, requires_grad = True)
- >>> estimator = StraightThroughEstimator(K.RandomPosterize(3, p=1.), grad_fn=F.hardtanh)
- >>> out = estimator(input)
- >>> out.mean().backward()
- >>> input.grad
- tensor([[[[0.0625, 0.0625, 0.0625, 0.0625],
- [0.0625, 0.0625, 0.0625, 0.0625],
- [0.0625, 0.0625, 0.0625, 0.0625],
- [0.0625, 0.0625, 0.0625, 0.0625]]]])
- This can be used to chain up the gradients within a ``Sequential`` block.
- >>> import kornia.augmentation as K
- >>> input = torch.randn(1, 1, 4, 4, requires_grad = True)
- >>> aug = K.ImageSequential(
- ... K.RandomAffine((77, 77)),
- ... StraightThroughEstimator(K.RandomPosterize(3, p=1.), grad_fn=None),
- ... K.RandomRotation((15, 15)),
- ... )
- >>> aug(input).mean().backward()
- >>> input.grad
- tensor([[[[0.0422, 0.0626, 0.0566, 0.0422],
- [0.0566, 0.0626, 0.0626, 0.0626],
- [0.0626, 0.0626, 0.0626, 0.0566],
- [0.0422, 0.0566, 0.0626, 0.0422]]]])
- """
- def __init__(self, target_fn: nn.Module, grad_fn: Optional[Callable[..., Any]] = None) -> None:
- super().__init__()
- self.target_fn = target_fn
- self.grad_fn = grad_fn
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(target_fn={self.target_fn}, grad_fn={self.grad_fn})"
- def forward(self, input: Tensor) -> Tensor:
- out = self.target_fn(input)
- if not isinstance(out, Tensor):
- raise NotImplementedError(
- "Only Tensor is supported at the moment. Feel free to contribute to https://github.com/kornia/kornia."
- )
- output = STEFunction.apply(input, out, self.grad_fn)
- return output
|