create_norm.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. """ Norm Layer Factory
  2. Create norm modules by string (to mirror create_act and creat_norm-act fns)
  3. Copyright 2022 Ross Wightman
  4. """
  5. import functools
  6. import types
  7. from typing import Type
  8. import torch.nn as nn
  9. from .norm import (
  10. GroupNorm,
  11. GroupNorm1,
  12. LayerNorm,
  13. LayerNorm2d,
  14. LayerNormFp32,
  15. LayerNorm2dFp32,
  16. RmsNorm,
  17. RmsNorm2d,
  18. RmsNormFp32,
  19. RmsNorm2dFp32,
  20. SimpleNorm,
  21. SimpleNorm2d,
  22. SimpleNormFp32,
  23. SimpleNorm2dFp32,
  24. )
  25. from torchvision.ops.misc import FrozenBatchNorm2d
  26. _NORM_MAP = dict(
  27. batchnorm=nn.BatchNorm2d,
  28. batchnorm2d=nn.BatchNorm2d,
  29. batchnorm1d=nn.BatchNorm1d,
  30. groupnorm=GroupNorm,
  31. groupnorm1=GroupNorm1,
  32. layernorm=LayerNorm,
  33. layernorm2d=LayerNorm2d,
  34. layernormfp32=LayerNormFp32,
  35. layernorm2dfp32=LayerNorm2dFp32,
  36. rmsnorm=RmsNorm,
  37. rmsnorm2d=RmsNorm2d,
  38. rmsnormfp32=RmsNormFp32,
  39. rmsnorm2dfp32=RmsNorm2dFp32,
  40. simplenorm=SimpleNorm,
  41. simplenorm2d=SimpleNorm2d,
  42. simplenormfp32=SimpleNormFp32,
  43. simplenorm2dfp32=SimpleNorm2dFp32,
  44. frozenbatchnorm2d=FrozenBatchNorm2d,
  45. )
  46. _NORM_TYPES = {m for n, m in _NORM_MAP.items()}
  47. def create_norm_layer(layer_name, num_features, **kwargs):
  48. layer = get_norm_layer(layer_name)
  49. layer_instance = layer(num_features, **kwargs)
  50. return layer_instance
  51. def get_norm_layer(norm_layer):
  52. if norm_layer is None:
  53. return None
  54. assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
  55. norm_kwargs = {}
  56. # unbind partial fn, so args can be rebound later
  57. if isinstance(norm_layer, functools.partial):
  58. norm_kwargs.update(norm_layer.keywords)
  59. norm_layer = norm_layer.func
  60. if isinstance(norm_layer, str):
  61. if not norm_layer:
  62. return None
  63. layer_name = norm_layer.replace('_', '').lower()
  64. norm_layer = _NORM_MAP[layer_name]
  65. else:
  66. norm_layer = norm_layer
  67. if norm_kwargs:
  68. norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
  69. return norm_layer