utils.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import torch.nn as nn
  2. def build_act_layer(act_layer):
  3. if act_layer == 'ReLU':
  4. return nn.ReLU(inplace=True)
  5. elif act_layer == 'SiLU':
  6. return nn.SiLU(inplace=True)
  7. elif act_layer == 'GELU':
  8. return nn.GELU()
  9. raise NotImplementedError(f'build_act_layer does not support {act_layer}')
  10. def build_norm_layer(dim,
  11. norm_layer,
  12. in_format='channels_last',
  13. out_format='channels_last',
  14. eps=1e-6):
  15. layers = []
  16. if norm_layer == 'BN':
  17. if in_format == 'channels_last':
  18. layers.append(to_channels_first())
  19. layers.append(nn.BatchNorm2d(dim))
  20. if out_format == 'channels_last':
  21. layers.append(to_channels_last())
  22. elif norm_layer == 'LN':
  23. if in_format == 'channels_first':
  24. layers.append(to_channels_last())
  25. layers.append(nn.LayerNorm(dim, eps=eps))
  26. if out_format == 'channels_first':
  27. layers.append(to_channels_first())
  28. else:
  29. raise NotImplementedError(
  30. f'build_norm_layer does not support {norm_layer}')
  31. return nn.Sequential(*layers)
  32. class to_channels_first(nn.Module):
  33. def __init__(self):
  34. super().__init__()
  35. def forward(self, x):
  36. return x.permute(0, 3, 1, 2)
  37. class to_channels_last(nn.Module):
  38. def __init__(self):
  39. super().__init__()
  40. def forward(self, x):
  41. return x.permute(0, 2, 3, 1)