NonEmptyUtils.h 853 B

1234567891011121314151617181920212223242526272829303132
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #include <ATen/core/TensorBase.h>
  3. #include <algorithm>
  4. #include <vector>
  5. namespace at::native {
  6. inline int64_t ensure_nonempty_dim(int64_t dim) {
  7. return std::max<int64_t>(dim, 1);
  8. }
  9. inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
  10. return t.dim() == 0 ? 1 : t.size(dim);
  11. }
  12. inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
  13. return t.dim() == 0 ? 1 : t.stride(dim);
  14. }
  15. using IdxVec = std::vector<int64_t>;
  16. inline IdxVec ensure_nonempty_vec(IdxVec vec) {
  17. if (vec.empty()) {
  18. vec.push_back(1);
  19. }
  20. return vec;
  21. }
  22. } // namespace at::native
  23. #else
  24. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  25. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)