TensorNames.h 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/WrapDimUtils.h>
  4. namespace at::namedinference {
  5. // TensorName and TensorNames are wrappers around Dimname and DimnameList
  6. // that contain helper functions to make writing name inference rules easier.
  7. //
  8. // A TensorName represents a Dimname associated with some DimnameList (from a
  9. // Tensor). This encapsulates all the information that is needed to check if
  10. // names *match* and to *unify* names.
  11. //
  12. // Definition: Two names in two tensors *match* if they are equal, or if at
  13. // least one of them is a wildcard that can be *refined* to the other name.
  14. //
  15. // Definition: unify(name, other) fails if the names do not match. Otherwise,
  16. // it returns the most refined of name and other.
  17. //
  18. // Here is an example of checking if two names match.
  19. // tensor: Tensor[A, None]
  20. // other: Tensor[A]
  21. //
  22. // Let's say we wish to check if tensor.names[-1] matches other.names[-1].
  23. // None (in tensor) cannot match A (in other) because if the None were refined
  24. // to A, `tensor` would have duplicate names [A, A]. Therefore we need to check
  25. // tensor.names [A, None] for the existence of A.
  26. struct TORCH_API TensorName {
  27. explicit TensorName(ArrayRef<Dimname> origin, int origin_idx)
  28. : origin_(origin),
  29. name_(origin[maybe_wrap_dim(
  30. origin_idx,
  31. static_cast<int64_t>(origin.size()))]),
  32. origin_idx_(origin_idx) {}
  33. // op_name is only used for error reporting.
  34. const TensorName& unify(const TensorName& other, const char* op_name) const;
  35. Dimname toDimname() const;
  36. private:
  37. ArrayRef<Dimname> origin_;
  38. Dimname name_;
  39. int origin_idx_; // A named tensor can have at most 64 dims.
  40. TORCH_API friend std::ostream& operator<<(
  41. std::ostream& out,
  42. const TensorName& tensorname);
  43. };
  44. using TensorNameVec = SmallVector<TensorName, 10>;
  45. struct TORCH_API TensorNames {
  46. explicit TensorNames(ArrayRef<Dimname> names);
  47. // Create TensorNames from names[start:end]. Each individual TensorName stores
  48. // `names`, NOT names[start:end], because the original tensor's names are
  49. // `names`.
  50. explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end);
  51. // op_name is only used for error reporting.
  52. TensorNames& unifyFromRightInplace(
  53. const TensorNames& other,
  54. const char* op_name = "unify");
  55. void checkUnique(const char* op_name) const;
  56. void append(TensorName name);
  57. std::vector<Dimname> toDimnameVec() const;
  58. private:
  59. explicit TensorNames(TensorNameVec&& names) : names_(std::move(names)) {}
  60. TensorNameVec names_;
  61. };
  62. } // namespace at::namedinference
  63. #else
  64. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  65. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)