CollapseDims.h 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #include <c10/util/Exception.h>
  3. #include <utility>
  4. namespace at {
  5. /*
  6. [collapse dims] Updates sizes, and strides to reflect a "collapse" of
  7. the info, possibly excluding the optional excludeDim. A "collapsed" version
  8. of the info is the fewest dims that order the tensor's elements in the same
  9. way as the original info. If excludeDim is specified, the collapse is the
  10. fewest dims that order the tensor's elements as the original and preserve the
  11. excluded dimension, unless the tensor collapses to a point.
  12. This function returns a pair of values.
  13. 1) The (new) index of the preserved dimension if excludeDim is
  14. specified. 0 if the tensor is collapsed to a point. -1
  15. otherwise.
  16. 2) The new number of dimensions.
  17. */
  18. template <typename T>
  19. inline std::pair<int64_t, int64_t> collapse_dims(
  20. T* sizes,
  21. T* strides,
  22. int64_t dims,
  23. const int excludeDim = -1) {
  24. TORCH_CHECK(
  25. excludeDim >= -1 && excludeDim < dims,
  26. "expected excluded dim between -1 and dims - 1");
  27. int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
  28. int64_t newIndex = -1;
  29. int64_t oldIndex = 0;
  30. int64_t remappedExcludedDim = -1;
  31. while (oldIndex < dims) {
  32. // Finds a dimension to collapse into
  33. for (; oldIndex < stopDim; ++oldIndex) {
  34. if (sizes[oldIndex] == 1) {
  35. continue;
  36. }
  37. ++newIndex;
  38. sizes[newIndex] = sizes[oldIndex];
  39. strides[newIndex] = strides[oldIndex];
  40. ++oldIndex;
  41. break;
  42. }
  43. // Collapses dims
  44. for (; oldIndex < stopDim; ++oldIndex) {
  45. if (sizes[oldIndex] == 1) {
  46. continue;
  47. }
  48. if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
  49. sizes[newIndex] *= sizes[oldIndex];
  50. strides[newIndex] = strides[oldIndex];
  51. } else {
  52. ++newIndex;
  53. sizes[newIndex] = sizes[oldIndex];
  54. strides[newIndex] = strides[oldIndex];
  55. }
  56. }
  57. // Handles excludeDim being set (oldIndex == excludeDim)
  58. if (oldIndex != dims) {
  59. // Preserves excluded dimension
  60. ++newIndex;
  61. sizes[newIndex] = sizes[oldIndex];
  62. strides[newIndex] = strides[oldIndex];
  63. remappedExcludedDim = newIndex;
  64. // Restarts iteration after excludeDim
  65. ++oldIndex;
  66. stopDim = dims;
  67. }
  68. }
  69. // Handles special case of all dims size 1
  70. if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
  71. dims = 1;
  72. sizes[0] = 1;
  73. strides[0] = 1;
  74. return std::pair<int64_t, int64_t>(0, 1);
  75. }
  76. dims = newIndex + 1;
  77. return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
  78. }
  79. } // namespace at
  80. #else
  81. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  82. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)