cond_conv2d.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. """ PyTorch Conditionally Parameterized Convolution (CondConv)
  2. Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
  3. (https://arxiv.org/abs/1904.04971)
  4. Hacked together by / Copyright 2020 Ross Wightman
  5. """
  6. import math
  7. from functools import partial
  8. from typing import Union, Tuple
  9. import torch
  10. from torch import nn as nn
  11. from torch.nn import functional as F
  12. from ._fx import register_notrace_module
  13. from .helpers import to_2tuple
  14. from .conv2d_same import conv2d_same
  15. from .padding import get_padding_value
  16. def get_condconv_initializer(initializer, num_experts, expert_shape):
  17. def condconv_initializer(weight):
  18. """CondConv initializer function."""
  19. num_params = math.prod(expert_shape)
  20. if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
  21. weight.shape[1] != num_params):
  22. raise (ValueError(
  23. 'CondConv variables must have shape [num_experts, num_params]'))
  24. for i in range(num_experts):
  25. initializer(weight[i].view(expert_shape))
  26. return condconv_initializer
  27. @register_notrace_module
  28. class CondConv2d(nn.Module):
  29. """ Conditionally Parameterized Convolution
  30. Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
  31. Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
  32. https://github.com/pytorch/pytorch/issues/17983
  33. """
  34. __constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
  35. def __init__(
  36. self,
  37. in_channels: int,
  38. out_channels: int,
  39. kernel_size: Union[int, Tuple[int, int]] = 3,
  40. stride: Union[int, Tuple[int, int]] = 1,
  41. padding: Union[int, Tuple[int, int], str] = '',
  42. dilation: Union[int, Tuple[int, int]] = 1,
  43. groups: int = 1,
  44. bias: bool = False,
  45. num_experts: int = 4,
  46. device=None,
  47. dtype=None,
  48. ):
  49. dd = {'device': device, 'dtype': dtype}
  50. super().__init__()
  51. self.in_channels = in_channels
  52. self.out_channels = out_channels
  53. self.kernel_size = to_2tuple(kernel_size)
  54. self.stride = to_2tuple(stride)
  55. padding_val, is_padding_dynamic = get_padding_value(
  56. padding, kernel_size, stride=stride, dilation=dilation)
  57. self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
  58. self.padding = to_2tuple(padding_val)
  59. self.dilation = to_2tuple(dilation)
  60. self.groups = groups
  61. self.num_experts = num_experts
  62. self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
  63. weight_num_param = 1
  64. for wd in self.weight_shape:
  65. weight_num_param *= wd
  66. self.weight = torch.nn.Parameter(torch.empty(self.num_experts, weight_num_param, **dd))
  67. if bias:
  68. self.bias_shape = (self.out_channels,)
  69. self.bias = torch.nn.Parameter(torch.empty(self.num_experts, self.out_channels, **dd))
  70. else:
  71. self.register_parameter('bias', None)
  72. self.reset_parameters()
  73. def reset_parameters(self):
  74. init_weight = get_condconv_initializer(
  75. partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
  76. init_weight(self.weight)
  77. if self.bias is not None:
  78. fan_in = math.prod(self.weight_shape[1:])
  79. bound = 1 / math.sqrt(fan_in)
  80. init_bias = get_condconv_initializer(
  81. partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
  82. init_bias(self.bias)
  83. def forward(self, x, routing_weights):
  84. B, C, H, W = x.shape
  85. weight = torch.matmul(routing_weights, self.weight)
  86. new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
  87. weight = weight.view(new_weight_shape)
  88. bias = None
  89. if self.bias is not None:
  90. bias = torch.matmul(routing_weights, self.bias)
  91. bias = bias.view(B * self.out_channels)
  92. # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
  93. # reshape instead of view to work with channels_last input
  94. x = x.reshape(1, B * C, H, W)
  95. if self.dynamic_padding:
  96. out = conv2d_same(
  97. x, weight, bias, stride=self.stride, padding=self.padding,
  98. dilation=self.dilation, groups=self.groups * B)
  99. else:
  100. out = F.conv2d(
  101. x, weight, bias, stride=self.stride, padding=self.padding,
  102. dilation=self.dilation, groups=self.groups * B)
  103. out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
  104. # Literal port (from TF definition)
  105. # x = torch.split(x, 1, 0)
  106. # weight = torch.split(weight, 1, 0)
  107. # if self.bias is not None:
  108. # bias = torch.matmul(routing_weights, self.bias)
  109. # bias = torch.split(bias, 1, 0)
  110. # else:
  111. # bias = [None] * B
  112. # out = []
  113. # for xi, wi, bi in zip(x, weight, bias):
  114. # wi = wi.view(*self.weight_shape)
  115. # if bi is not None:
  116. # bi = bi.view(*self.bias_shape)
  117. # out.append(self.conv_fn(
  118. # xi, wi, bi, stride=self.stride, padding=self.padding,
  119. # dilation=self.dilation, groups=self.groups))
  120. # out = torch.cat(out, 0)
  121. return out