| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- """ PyTorch Conditionally Parameterized Convolution (CondConv)
- Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
- (https://arxiv.org/abs/1904.04971)
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import math
- from functools import partial
- from typing import Union, Tuple
- import torch
- from torch import nn as nn
- from torch.nn import functional as F
- from ._fx import register_notrace_module
- from .helpers import to_2tuple
- from .conv2d_same import conv2d_same
- from .padding import get_padding_value
- def get_condconv_initializer(initializer, num_experts, expert_shape):
- def condconv_initializer(weight):
- """CondConv initializer function."""
- num_params = math.prod(expert_shape)
- if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
- weight.shape[1] != num_params):
- raise (ValueError(
- 'CondConv variables must have shape [num_experts, num_params]'))
- for i in range(num_experts):
- initializer(weight[i].view(expert_shape))
- return condconv_initializer
- @register_notrace_module
- class CondConv2d(nn.Module):
- """ Conditionally Parameterized Convolution
- Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
- Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
- https://github.com/pytorch/pytorch/issues/17983
- """
- __constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int]] = 3,
- stride: Union[int, Tuple[int, int]] = 1,
- padding: Union[int, Tuple[int, int], str] = '',
- dilation: Union[int, Tuple[int, int]] = 1,
- groups: int = 1,
- bias: bool = False,
- num_experts: int = 4,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = to_2tuple(kernel_size)
- self.stride = to_2tuple(stride)
- padding_val, is_padding_dynamic = get_padding_value(
- padding, kernel_size, stride=stride, dilation=dilation)
- self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
- self.padding = to_2tuple(padding_val)
- self.dilation = to_2tuple(dilation)
- self.groups = groups
- self.num_experts = num_experts
- self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
- weight_num_param = 1
- for wd in self.weight_shape:
- weight_num_param *= wd
- self.weight = torch.nn.Parameter(torch.empty(self.num_experts, weight_num_param, **dd))
- if bias:
- self.bias_shape = (self.out_channels,)
- self.bias = torch.nn.Parameter(torch.empty(self.num_experts, self.out_channels, **dd))
- else:
- self.register_parameter('bias', None)
- self.reset_parameters()
- def reset_parameters(self):
- init_weight = get_condconv_initializer(
- partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
- init_weight(self.weight)
- if self.bias is not None:
- fan_in = math.prod(self.weight_shape[1:])
- bound = 1 / math.sqrt(fan_in)
- init_bias = get_condconv_initializer(
- partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
- init_bias(self.bias)
- def forward(self, x, routing_weights):
- B, C, H, W = x.shape
- weight = torch.matmul(routing_weights, self.weight)
- new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
- weight = weight.view(new_weight_shape)
- bias = None
- if self.bias is not None:
- bias = torch.matmul(routing_weights, self.bias)
- bias = bias.view(B * self.out_channels)
- # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
- # reshape instead of view to work with channels_last input
- x = x.reshape(1, B * C, H, W)
- if self.dynamic_padding:
- out = conv2d_same(
- x, weight, bias, stride=self.stride, padding=self.padding,
- dilation=self.dilation, groups=self.groups * B)
- else:
- out = F.conv2d(
- x, weight, bias, stride=self.stride, padding=self.padding,
- dilation=self.dilation, groups=self.groups * B)
- out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
- # Literal port (from TF definition)
- # x = torch.split(x, 1, 0)
- # weight = torch.split(weight, 1, 0)
- # if self.bias is not None:
- # bias = torch.matmul(routing_weights, self.bias)
- # bias = torch.split(bias, 1, 0)
- # else:
- # bias = [None] * B
- # out = []
- # for xi, wi, bi in zip(x, weight, bias):
- # wi = wi.view(*self.weight_shape)
- # if bi is not None:
- # bi = bi.view(*self.bias_shape)
- # out.append(self.conv_fn(
- # xi, wi, bi, stride=self.stride, padding=self.padding,
- # dilation=self.dilation, groups=self.groups))
- # out = torch.cat(out, 0)
- return out
|