mixed_conv2d.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. """ PyTorch Mixed Convolution
  2. Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
  3. Hacked together by / Copyright 2020 Ross Wightman
  4. """
  5. from typing import List, Union
  6. import torch
  7. from torch import nn as nn
  8. from .conv2d_same import create_conv2d_pad
  9. def _split_channels(num_chan, num_groups):
  10. split = [num_chan // num_groups for _ in range(num_groups)]
  11. split[0] += num_chan - sum(split)
  12. return split
  13. class MixedConv2d(nn.ModuleDict):
  14. """ Mixed Grouped Convolution
  15. Based on MDConv and GroupedConv in MixNet impl:
  16. https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
  17. """
  18. def __init__(
  19. self,
  20. in_channels: int,
  21. out_channels: int,
  22. kernel_size: Union[int, List[int]] = 3,
  23. stride: int = 1,
  24. padding: str = '',
  25. dilation: int = 1,
  26. depthwise: bool = False,
  27. **kwargs
  28. ):
  29. super().__init__()
  30. kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
  31. num_groups = len(kernel_size)
  32. in_splits = _split_channels(in_channels, num_groups)
  33. out_splits = _split_channels(out_channels, num_groups)
  34. self.in_channels = sum(in_splits)
  35. self.out_channels = sum(out_splits)
  36. for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
  37. conv_groups = in_ch if depthwise else 1
  38. # use add_module to keep key space clean
  39. self.add_module(
  40. str(idx),
  41. create_conv2d_pad(
  42. in_ch,
  43. out_ch,
  44. k,
  45. stride=stride,
  46. padding=padding,
  47. dilation=dilation,
  48. groups=conv_groups,
  49. **kwargs,
  50. )
  51. )
  52. self.splits = in_splits
  53. def forward(self, x):
  54. x_split = torch.split(x, self.splits, 1)
  55. x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
  56. x = torch.cat(x_out, 1)
  57. return x