ste.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Callable, Optional, Tuple
  18. from torch import Tensor, nn
  19. from torch.autograd import Function
  20. __all__ = ["STEFunction", "StraightThroughEstimator"]
  21. class STEFunction(Function):
  22. """Straight-Through Estimation (STE) function.
  23. STE bridges the gradients between the input tensor and output tensor as if the function
  24. was an identity function. Meanwhile, advanced gradient functions are also supported. e.g.
  25. the output gradients can be mapped into [-1, 1] with ``F.hardtanh`` function.
  26. Args:
  27. grad_fn: function to restrain the gradient received. If None, no mapping will performed.
  28. Example:
  29. Let the gradients of ``torch.sign`` estimated from STE.
  30. >>> input = torch.randn(4, requires_grad = True)
  31. >>> output = torch.sign(input)
  32. >>> loss = output.mean()
  33. >>> loss.backward()
  34. >>> input.grad
  35. tensor([0., 0., 0., 0.])
  36. >>> with torch.no_grad():
  37. ... output = torch.sign(input)
  38. >>> out_est = STEFunction.apply(input, output)
  39. >>> loss = out_est.mean()
  40. >>> loss.backward()
  41. >>> input.grad
  42. tensor([0.2500, 0.2500, 0.2500, 0.2500])
  43. """
  44. @staticmethod
  45. def forward(ctx: Any, input: Tensor, output: Tensor, grad_fn: Optional[Callable[..., Any]] = None) -> Tensor:
  46. ctx.in_shape = input.shape
  47. ctx.out_shape = output.shape
  48. ctx.grad_fn = grad_fn
  49. return output
  50. @staticmethod
  51. def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, Tensor, None]:
  52. if ctx.grad_fn is None:
  53. return grad_output.sum_to_size(ctx.in_shape), grad_output.sum_to_size(ctx.out_shape), None
  54. return (
  55. ctx.grad_fn(grad_output.sum_to_size(ctx.in_shape)),
  56. ctx.grad_fn(grad_output.sum_to_size(ctx.out_shape)),
  57. None,
  58. )
  59. # https://pytorch.org/docs/1.10.0/onnx.html#torch-autograd-functions
  60. # @staticmethod
  61. # def symbolic(g: torch._C.graph, input: torch._C.Value) -> torch._C.Value:
  62. # raise NotImplementedError(
  63. # "ONNX support is not implemented at the moment."
  64. # "Feel free to contribute to https://github.com/kornia/kornia.")
  65. class StraightThroughEstimator(nn.Module):
  66. """Straight-Through Estimation (STE) module.
  67. STE wraps the ``STEFunction`` to aid the back propagation of non-differentiable modules.
  68. It may also use to avoid gradient computation for differentiable operations. By default,
  69. STE bridges the gradients between the input tensor and output tensor as if the function
  70. was an identity function. Meanwhile, advanced gradient functions are also supported. e.g.
  71. the output gradients can be mapped into [-1, 1] with ``F.hardtanh`` function.
  72. Args:
  73. target_fn: the target function to wrap with.
  74. grad_fn: function to restrain the gradient received. If None, no mapping will performed.
  75. Example:
  76. ``RandomPosterize`` is a non-differentiable operation. Let STE estimate the gradients.
  77. >>> import kornia.augmentation as K
  78. >>> import torch.nn.functional as F
  79. >>> input = torch.randn(1, 1, 4, 4, requires_grad = True)
  80. >>> estimator = StraightThroughEstimator(K.RandomPosterize(3, p=1.), grad_fn=F.hardtanh)
  81. >>> out = estimator(input)
  82. >>> out.mean().backward()
  83. >>> input.grad
  84. tensor([[[[0.0625, 0.0625, 0.0625, 0.0625],
  85. [0.0625, 0.0625, 0.0625, 0.0625],
  86. [0.0625, 0.0625, 0.0625, 0.0625],
  87. [0.0625, 0.0625, 0.0625, 0.0625]]]])
  88. This can be used to chain up the gradients within a ``Sequential`` block.
  89. >>> import kornia.augmentation as K
  90. >>> input = torch.randn(1, 1, 4, 4, requires_grad = True)
  91. >>> aug = K.ImageSequential(
  92. ... K.RandomAffine((77, 77)),
  93. ... StraightThroughEstimator(K.RandomPosterize(3, p=1.), grad_fn=None),
  94. ... K.RandomRotation((15, 15)),
  95. ... )
  96. >>> aug(input).mean().backward()
  97. >>> input.grad
  98. tensor([[[[0.0422, 0.0626, 0.0566, 0.0422],
  99. [0.0566, 0.0626, 0.0626, 0.0626],
  100. [0.0626, 0.0626, 0.0626, 0.0566],
  101. [0.0422, 0.0566, 0.0626, 0.0422]]]])
  102. """
  103. def __init__(self, target_fn: nn.Module, grad_fn: Optional[Callable[..., Any]] = None) -> None:
  104. super().__init__()
  105. self.target_fn = target_fn
  106. self.grad_fn = grad_fn
  107. def __repr__(self) -> str:
  108. return f"{self.__class__.__name__}(target_fn={self.target_fn}, grad_fn={self.grad_fn})"
  109. def forward(self, input: Tensor) -> Tensor:
  110. out = self.target_fn(input)
  111. if not isinstance(out, Tensor):
  112. raise NotImplementedError(
  113. "Only Tensor is supported at the moment. Feel free to contribute to https://github.com/kornia/kornia."
  114. )
  115. output = STEFunction.apply(input, out, self.grad_fn)
  116. return output