_dim_entry.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING, Union
  3. if TYPE_CHECKING:
  4. from collections.abc import Sequence
  5. from . import Dim
  6. import torch # noqa: TC002
  7. # NB: The old code represented dimension was from as negative number, so we
  8. # follow this convention even though it shouldn't be necessary now
  9. class DimEntry:
  10. # The dimension this is from the rhs, or a FCD
  11. data: Union[Dim, int]
  12. def __init__(self, data: Union[Dim, int, None] = None) -> None:
  13. from . import Dim
  14. if type(data) is int:
  15. if data >= 0:
  16. raise AssertionError(f"Expected negative int, got {data}")
  17. elif data is None:
  18. data = 0
  19. else:
  20. if not isinstance(data, Dim):
  21. raise AssertionError(f"Expected Dim, got {type(data)}")
  22. self.data = data
  23. def __eq__(self, other: object) -> bool:
  24. if not isinstance(other, DimEntry):
  25. return False
  26. # Use 'is' for Dim objects to avoid triggering __torch_function__
  27. # Use '==' only for positional (int) comparisons
  28. if self.is_positional() and other.is_positional():
  29. # Both are positional (ints)
  30. return self.data == other.data
  31. elif not self.is_positional() and not other.is_positional():
  32. # Both are Dim objects - use 'is' to avoid __eq__
  33. return self.data is other.data
  34. else:
  35. # One is positional, one is Dim - they can't be equal
  36. return False
  37. def is_positional(self) -> bool:
  38. return type(self.data) is int and self.data < 0
  39. def is_none(self) -> bool:
  40. # Use isinstance to check for Dim objects, avoid triggering __torch_function__
  41. from . import Dim
  42. if isinstance(self.data, Dim):
  43. # This is a Dim object, it can't be "none" (which is represented by 0)
  44. return False
  45. else:
  46. # This is an int or other type
  47. return self.data == 0
  48. def position(self) -> int:
  49. if not isinstance(self.data, int):
  50. raise AssertionError(f"Expected int, got {type(self.data)}")
  51. return self.data
  52. def dim(self) -> Dim:
  53. if isinstance(self.data, int):
  54. raise AssertionError("Expected Dim, got int")
  55. return self.data
  56. def __repr__(self) -> str:
  57. return repr(self.data)
  58. def ndim_of_levels(levels: Sequence[DimEntry]) -> int:
  59. r = 0
  60. for l in levels:
  61. if l.is_positional():
  62. r += 1
  63. return r
  64. def _match_levels(
  65. tensor: torch.Tensor,
  66. from_levels: list[DimEntry],
  67. to_levels: list[DimEntry],
  68. drop_levels: bool = False,
  69. ) -> torch.Tensor:
  70. """
  71. Reshape a tensor to match target levels using as_strided.
  72. Args:
  73. tensor: Input tensor to reshape
  74. from_levels: Current levels of the tensor
  75. to_levels: Target levels to match
  76. drop_levels: If True, missing dimensions are assumed to have stride 0
  77. Returns:
  78. Reshaped tensor
  79. """
  80. if from_levels == to_levels:
  81. return tensor
  82. sizes = tensor.size()
  83. strides = tensor.stride()
  84. if not drop_levels:
  85. if len(from_levels) > len(to_levels):
  86. raise AssertionError("Cannot expand dimensions without drop_levels")
  87. new_sizes = []
  88. new_strides = []
  89. for level in to_levels:
  90. # Find index of this level in from_levels
  91. try:
  92. idx = from_levels.index(level)
  93. except ValueError:
  94. # Level not found in from_levels
  95. if level.is_positional():
  96. new_sizes.append(1)
  97. else:
  98. new_sizes.append(level.dim().size)
  99. new_strides.append(0)
  100. else:
  101. new_sizes.append(sizes[idx])
  102. new_strides.append(strides[idx])
  103. return tensor.as_strided(new_sizes, new_strides, tensor.storage_offset())