| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- import torch.nn as nn
- def build_act_layer(act_layer):
- if act_layer == 'ReLU':
- return nn.ReLU(inplace=True)
- elif act_layer == 'SiLU':
- return nn.SiLU(inplace=True)
- elif act_layer == 'GELU':
- return nn.GELU()
- raise NotImplementedError(f'build_act_layer does not support {act_layer}')
- def build_norm_layer(dim,
- norm_layer,
- in_format='channels_last',
- out_format='channels_last',
- eps=1e-6):
- layers = []
- if norm_layer == 'BN':
- if in_format == 'channels_last':
- layers.append(to_channels_first())
- layers.append(nn.BatchNorm2d(dim))
- if out_format == 'channels_last':
- layers.append(to_channels_last())
- elif norm_layer == 'LN':
- if in_format == 'channels_first':
- layers.append(to_channels_last())
- layers.append(nn.LayerNorm(dim, eps=eps))
- if out_format == 'channels_first':
- layers.append(to_channels_first())
- else:
- raise NotImplementedError(
- f'build_norm_layer does not support {norm_layer}')
- return nn.Sequential(*layers)
- class to_channels_first(nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, x):
- return x.permute(0, 3, 1, 2)
- class to_channels_last(nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, x):
- return x.permute(0, 2, 3, 1)
|