format.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from enum import Enum
  2. from typing import Union
  3. import torch
  4. class Format(str, Enum):
  5. NCHW = 'NCHW'
  6. NHWC = 'NHWC'
  7. NCL = 'NCL'
  8. NLC = 'NLC'
  9. FormatT = Union[str, Format]
  10. def get_spatial_dim(fmt: FormatT):
  11. """Return spatial dimension indices for a given tensor format.
  12. Args:
  13. fmt: Tensor format (NCHW, NHWC, NCL, or NLC).
  14. Returns:
  15. Tuple of spatial dimension indices.
  16. """
  17. fmt = Format(fmt)
  18. if fmt is Format.NLC:
  19. dim = (1,)
  20. elif fmt is Format.NCL:
  21. dim = (2,)
  22. elif fmt is Format.NHWC:
  23. dim = (1, 2)
  24. else:
  25. dim = (2, 3)
  26. return dim
  27. def get_channel_dim(fmt: FormatT):
  28. """Return channel dimension index for a given tensor format.
  29. Args:
  30. fmt: Tensor format (NCHW, NHWC, NCL, or NLC).
  31. Returns:
  32. Channel dimension index.
  33. """
  34. fmt = Format(fmt)
  35. if fmt is Format.NHWC:
  36. dim = 3
  37. elif fmt is Format.NLC:
  38. dim = 2
  39. else:
  40. dim = 1
  41. return dim
  42. def nchw_to(x: torch.Tensor, fmt: Format):
  43. """Convert tensor from NCHW format to specified format.
  44. Args:
  45. x: Input tensor in NCHW format.
  46. fmt: Target format.
  47. Returns:
  48. Tensor in target format.
  49. """
  50. if fmt == Format.NHWC:
  51. x = x.permute(0, 2, 3, 1)
  52. elif fmt == Format.NLC:
  53. x = x.flatten(2).transpose(1, 2)
  54. elif fmt == Format.NCL:
  55. x = x.flatten(2)
  56. return x
  57. def nhwc_to(x: torch.Tensor, fmt: Format):
  58. """Convert tensor from NHWC format to specified format.
  59. Args:
  60. x: Input tensor in NHWC format.
  61. fmt: Target format.
  62. Returns:
  63. Tensor in target format.
  64. """
  65. if fmt == Format.NCHW:
  66. x = x.permute(0, 3, 1, 2)
  67. elif fmt == Format.NLC:
  68. x = x.flatten(1, 2)
  69. elif fmt == Format.NCL:
  70. x = x.flatten(1, 2).transpose(1, 2)
  71. return x