inplace_abn.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import torch
  2. from torch import nn as nn
  3. try:
  4. from inplace_abn.functions import inplace_abn, inplace_abn_sync
  5. has_iabn = True
  6. except ImportError:
  7. has_iabn = False
  8. def inplace_abn(x, weight, bias, running_mean, running_var,
  9. training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
  10. raise ImportError(
  11. "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
  12. def inplace_abn_sync(**kwargs):
  13. inplace_abn(**kwargs)
  14. from ._fx import register_notrace_module
  15. @register_notrace_module
  16. class InplaceAbn(nn.Module):
  17. """Activated Batch Normalization
  18. This gathers a BatchNorm and an activation function in a single module
  19. Parameters
  20. ----------
  21. num_features : int
  22. Number of feature channels in the input and output.
  23. eps : float
  24. Small constant to prevent numerical issues.
  25. momentum : float
  26. Momentum factor applied to compute running statistics.
  27. affine : bool
  28. If `True` apply learned scale and shift transformation after normalization.
  29. act_layer : str or nn.Module type
  30. Name or type of the activation functions, one of: `leaky_relu`, `elu`
  31. act_param : float
  32. Negative slope for the `leaky_relu` activation.
  33. """
  34. def __init__(
  35. self,
  36. num_features,
  37. eps=1e-5,
  38. momentum=0.1,
  39. affine=True,
  40. apply_act=True,
  41. act_layer="leaky_relu",
  42. act_param=0.01,
  43. drop_layer=None,
  44. ):
  45. super().__init__()
  46. self.num_features = num_features
  47. self.affine = affine
  48. self.eps = eps
  49. self.momentum = momentum
  50. if apply_act:
  51. if isinstance(act_layer, str):
  52. assert act_layer in ('leaky_relu', 'elu', 'identity', '')
  53. self.act_name = act_layer if act_layer else 'identity'
  54. else:
  55. # convert act layer passed as type to string
  56. if act_layer == nn.ELU:
  57. self.act_name = 'elu'
  58. elif act_layer == nn.LeakyReLU:
  59. self.act_name = 'leaky_relu'
  60. elif act_layer is None or act_layer == nn.Identity:
  61. self.act_name = 'identity'
  62. else:
  63. assert False, f'Invalid act layer {act_layer.__name__} for IABN'
  64. else:
  65. self.act_name = 'identity'
  66. self.act_param = act_param
  67. if self.affine:
  68. self.weight = nn.Parameter(torch.ones(num_features))
  69. self.bias = nn.Parameter(torch.zeros(num_features))
  70. else:
  71. self.register_parameter('weight', None)
  72. self.register_parameter('bias', None)
  73. self.register_buffer('running_mean', torch.zeros(num_features))
  74. self.register_buffer('running_var', torch.ones(num_features))
  75. self.reset_parameters()
  76. def reset_parameters(self):
  77. nn.init.constant_(self.running_mean, 0)
  78. nn.init.constant_(self.running_var, 1)
  79. if self.affine:
  80. nn.init.constant_(self.weight, 1)
  81. nn.init.constant_(self.bias, 0)
  82. def forward(self, x):
  83. output = inplace_abn(
  84. x, self.weight, self.bias, self.running_mean, self.running_var,
  85. self.training, self.momentum, self.eps, self.act_name, self.act_param)
  86. if isinstance(output, tuple):
  87. output = output[0]
  88. return output