operator.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. # This class holds information about a single operator used to determine
  4. # the outcome of a selective/custom PyTorch build that doesn't include
  5. # registration code for all the supported operators. This is done to
  6. # reduce the size of the generated binary so that it can be deployed in
  7. # situations where binary size comes at a premium.
  8. #
  9. @dataclass(frozen=True)
  10. class SelectiveBuildOperator:
  11. # The name of the operator. This includes the aten::, etc... prefix
  12. # The operator name may or may not have the overload name. If this
  13. # operator name does not specify an overload name, the way to determine
  14. # if this entry refers to the family of operators with this base name
  15. # or just the operator with this name is to look at the value of the
  16. # 'include_all_overloads' flag in this class.
  17. name: str
  18. # True if this is a root operator (i.e. called directly from a
  19. # TorchScript model, etc...). An operator is considered to be a
  20. # root operator if it is called directly from any one of the models
  21. # that this instance of the pytorch library was built for. Hence, it
  22. # may not be a root operator in all of the models that are used in
  23. # this instance of the pytorch library.
  24. is_root_operator: bool
  25. # Is this operator used for on-device training? If True, then we need to
  26. # use the information to generate code in VariableType_N.cpp for registration
  27. # of training related operators. Again, this is True if this operator
  28. # is used for training in one or more models used by this instance of the
  29. # pytorch library.
  30. is_used_for_training: bool
  31. # If True, it indicates that this operator instance (object) refers to an
  32. # operator without the overload name and should apply to all overloads
  33. # which have this operator name as the base name. This flag is applicable
  34. # only for objects that have operator names without a DOT (period) character
  35. # in them.
  36. #
  37. # Note: This flag is a temporary workaround to grandfather in the current
  38. # static selective (custom) build mechanism, which largely ignores overload
  39. # names when determining whether to select operators for registration
  40. # purposes.
  41. include_all_overloads: bool
  42. # Debug Information at the operator level
  43. _debug_info: tuple[str, ...] | None
  44. @staticmethod
  45. def from_yaml_dict(
  46. op_name: str, op_info: dict[str, object]
  47. ) -> SelectiveBuildOperator:
  48. allowed_keys = {
  49. "name",
  50. "is_root_operator",
  51. "is_used_for_training",
  52. "include_all_overloads",
  53. "debug_info",
  54. }
  55. if len(set(op_info.keys()) - allowed_keys) > 0:
  56. raise Exception( # noqa: TRY002
  57. "Got unexpected top level keys: {}".format(
  58. ",".join(set(op_info.keys()) - allowed_keys),
  59. )
  60. )
  61. if "name" in op_info:
  62. if op_name != op_info["name"]:
  63. raise AssertionError(
  64. f"op_name mismatch: {op_name} != {op_info['name']}"
  65. )
  66. is_root_operator = op_info.get("is_root_operator", True)
  67. if not isinstance(is_root_operator, bool):
  68. raise AssertionError(
  69. f"Expected 'is_root_operator' to be bool, got {type(is_root_operator)}"
  70. )
  71. is_used_for_training = op_info.get("is_used_for_training", True)
  72. if not isinstance(is_used_for_training, bool):
  73. raise AssertionError(
  74. f"Expected 'is_used_for_training' to be bool, got {type(is_used_for_training)}"
  75. )
  76. include_all_overloads = op_info.get("include_all_overloads", True)
  77. if not isinstance(include_all_overloads, bool):
  78. raise AssertionError(
  79. f"Expected 'include_all_overloads' to be bool, got {type(include_all_overloads)}"
  80. )
  81. debug_info: tuple[str, ...] | None = None
  82. if "debug_info" in op_info:
  83. di_list = op_info["debug_info"]
  84. if not isinstance(di_list, list):
  85. raise AssertionError(
  86. f"Expected 'debug_info' to be list, got {type(di_list)}"
  87. )
  88. debug_info = tuple(str(x) for x in di_list)
  89. return SelectiveBuildOperator(
  90. name=op_name,
  91. is_root_operator=is_root_operator,
  92. is_used_for_training=is_used_for_training,
  93. include_all_overloads=include_all_overloads,
  94. _debug_info=debug_info,
  95. )
  96. @staticmethod
  97. def from_legacy_operator_name_without_overload(
  98. name: str,
  99. ) -> SelectiveBuildOperator:
  100. return SelectiveBuildOperator(
  101. name=name,
  102. is_root_operator=True,
  103. is_used_for_training=True,
  104. include_all_overloads=True,
  105. _debug_info=None,
  106. )
  107. def to_dict(self) -> dict[str, object]:
  108. ret: dict[str, object] = {
  109. "is_root_operator": self.is_root_operator,
  110. "is_used_for_training": self.is_used_for_training,
  111. "include_all_overloads": self.include_all_overloads,
  112. }
  113. if self._debug_info is not None:
  114. ret["debug_info"] = self._debug_info
  115. return ret
  116. def merge_debug_info(
  117. lhs: tuple[str, ...] | None,
  118. rhs: tuple[str, ...] | None,
  119. ) -> tuple[str, ...] | None:
  120. # Ensure that when merging, each entry shows up just once.
  121. if lhs is None and rhs is None:
  122. return None
  123. return tuple(set((lhs or ()) + (rhs or ())))
  124. def combine_operators(
  125. lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator
  126. ) -> SelectiveBuildOperator:
  127. if str(lhs.name) != str(rhs.name):
  128. raise Exception( # noqa: TRY002
  129. f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
  130. )
  131. return SelectiveBuildOperator(
  132. name=lhs.name,
  133. # Consider this operator to be a root operator if it is a
  134. # root operator in any of the models used in this instance of
  135. # the pytorch library.
  136. is_root_operator=lhs.is_root_operator or rhs.is_root_operator,
  137. # Consider this operator to be a training operator if it is
  138. # an operator used for training in any of the models used
  139. # in this instance of the pytorch library.
  140. is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training,
  141. include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads,
  142. _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info),
  143. )
  144. def merge_operator_dicts(
  145. lhs: dict[str, SelectiveBuildOperator],
  146. rhs: dict[str, SelectiveBuildOperator],
  147. ) -> dict[str, SelectiveBuildOperator]:
  148. operators: dict[str, SelectiveBuildOperator] = {}
  149. for op_name, op in list(lhs.items()) + list(rhs.items()):
  150. new_op = op
  151. if op_name in operators:
  152. new_op = combine_operators(operators[op_name], op)
  153. operators[op_name] = new_op
  154. return operators
  155. def strip_operator_overload_name(op_name: str) -> str:
  156. return op_name.split(".", maxsplit=1)[0]