| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #include <c10/util/Exception.h>
- #include <utility>
- namespace at {
- /*
- [collapse dims] Updates sizes, and strides to reflect a "collapse" of
- the info, possibly excluding the optional excludeDim. A "collapsed" version
- of the info is the fewest dims that order the tensor's elements in the same
- way as the original info. If excludeDim is specified, the collapse is the
- fewest dims that order the tensor's elements as the original and preserve the
- excluded dimension, unless the tensor collapses to a point.
- This function returns a pair of values.
- 1) The (new) index of the preserved dimension if excludeDim is
- specified. 0 if the tensor is collapsed to a point. -1
- otherwise.
- 2) The new number of dimensions.
- */
- template <typename T>
- inline std::pair<int64_t, int64_t> collapse_dims(
- T* sizes,
- T* strides,
- int64_t dims,
- const int excludeDim = -1) {
- TORCH_CHECK(
- excludeDim >= -1 && excludeDim < dims,
- "expected excluded dim between -1 and dims - 1");
- int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
- int64_t newIndex = -1;
- int64_t oldIndex = 0;
- int64_t remappedExcludedDim = -1;
- while (oldIndex < dims) {
- // Finds a dimension to collapse into
- for (; oldIndex < stopDim; ++oldIndex) {
- if (sizes[oldIndex] == 1) {
- continue;
- }
- ++newIndex;
- sizes[newIndex] = sizes[oldIndex];
- strides[newIndex] = strides[oldIndex];
- ++oldIndex;
- break;
- }
- // Collapses dims
- for (; oldIndex < stopDim; ++oldIndex) {
- if (sizes[oldIndex] == 1) {
- continue;
- }
- if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
- sizes[newIndex] *= sizes[oldIndex];
- strides[newIndex] = strides[oldIndex];
- } else {
- ++newIndex;
- sizes[newIndex] = sizes[oldIndex];
- strides[newIndex] = strides[oldIndex];
- }
- }
- // Handles excludeDim being set (oldIndex == excludeDim)
- if (oldIndex != dims) {
- // Preserves excluded dimension
- ++newIndex;
- sizes[newIndex] = sizes[oldIndex];
- strides[newIndex] = strides[oldIndex];
- remappedExcludedDim = newIndex;
- // Restarts iteration after excludeDim
- ++oldIndex;
- stopDim = dims;
- }
- }
- // Handles special case of all dims size 1
- if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
- dims = 1;
- sizes[0] = 1;
- strides[0] = 1;
- return std::pair<int64_t, int64_t>(0, 1);
- }
- dims = newIndex + 1;
- return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
- }
- } // namespace at
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|