opaque_object.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. """
  2. Note [Opaque Objects]
  3. Opaque objects are the way we allow custom operators to accept a user-defined
  4. "black box" object as an input.
  5. There are two kinds of opaque types: VALUE type and REFERENCE type.
  6. The distinction determines how torch.compile handles the object.
  7. REFERENCE TYPES (default):
  8. Reference-typed opaque objects represent mutable stateful objects and are
  9. treated as black boxes. In torch.compile, since torch.compile cannot optimize
  10. the anything (including tensors) within the object, the object must be an
  11. input to the graph.
  12. You can register a custom class as being a reference-based opaque object class
  13. through `register_opaque_type(MyClass, typ="reference")`.
  14. VALUE TYPES:
  15. Value-typed opaque objects represent constant values.
  16. In torch.compile, the graph specializes on the object like how other constants
  17. are. Therefore there are a couple of methods on the class that must be
  18. implemented before registering it as a value-typed opaque object class:
  19. - __eq__: torch.compile will create guards based on the equality of this
  20. object, meaning that a recompilation will happen if __eq__ returns False.
  21. - __hash__: This must be implemented for Fake Tensor caching
  22. - __fx_repr__: This must be implemented to provide an evaluable representation
  23. for FX graph codegen. It should return a tuple of (repr_string, dict[str, type])
  24. where repr_string can reconstruct the object and the dict maps names used in
  25. repr_string to their corresponding types.
  26. You can register a custom class as being a reference-based opaque object class
  27. through `register_opaque_type(MyClass, typ="value")`.
  28. """
  29. from collections.abc import Callable
  30. from dataclasses import dataclass
  31. from enum import Enum
  32. from typing import Any, Literal, NewType, Optional
  33. from typing_extensions import TypeIs
  34. from weakref import WeakKeyDictionary
  35. import torch
  36. from torch._opaque_base import OpaqueBase, OpaqueBaseMeta # noqa: F401
  37. from .fake_class_registry import register_fake_class
  38. class MemberType(Enum):
  39. """
  40. Defines how a member (attribute/property/method) of an opaque object is handled
  41. during torch.compile tracing.
  42. """
  43. # Reads/calls the member at trace time with the real object and bakes the result as a constant
  44. USE_REAL = "use_real"
  45. # Inlines/traces the member
  46. INLINED = "inlined"
  47. @register_fake_class("aten::OpaqueObject")
  48. class FakeOpaqueObject:
  49. def __init__(self) -> None:
  50. pass
  51. @classmethod
  52. def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None:
  53. raise RuntimeError(
  54. "FakeOpaqueObject should not be created through __obj_unflatten__ "
  55. "and should be special handled. Please file an issue to Github."
  56. )
  57. OpaqueTypeStr = "__torch__.torch.classes.aten.OpaqueObject"
  58. OpaqueType = NewType("OpaqueType", torch._C.ScriptObject)
  59. @dataclass
  60. class _OpaqueTypeInfo:
  61. class_name: str
  62. opaque_typ: Literal["reference", "value"]
  63. guard_fn: Callable[
  64. [Any], list[Any]
  65. ] # Callable that takes the object and returns list of values to guard on
  66. members: dict[str, MemberType] # Maps member name to how it should be handled
  67. hoist: bool
  68. # Mapping of type -> (string name, reference/value type)
  69. _OPAQUE_TYPES: WeakKeyDictionary[Any, _OpaqueTypeInfo] = WeakKeyDictionary()
  70. # Mapping of class_name -> (type, reference/value type)
  71. _OPAQUE_TYPES_BY_NAME: dict[str, _OpaqueTypeInfo] = {}
  72. def get_opaque_type_name(cls: Any) -> str:
  73. """
  74. Gets the registered opaque type name for a given class.
  75. Args:
  76. cls (type): The class to get the type name for.
  77. Returns:
  78. str: The registered type name for the class.
  79. Raises:
  80. ValueError: If the class is not registered as an opaque type.
  81. """
  82. if cls not in _OPAQUE_TYPES:
  83. raise ValueError(
  84. f"Class {cls} is not registered as an opaque type. "
  85. f"Call register_opaque_type({cls.__name__}) first."
  86. )
  87. return _OPAQUE_TYPES[cls].class_name
  88. def register_opaque_type(
  89. cls: Any,
  90. *,
  91. typ: str,
  92. hoist=False,
  93. guard_fn: Any = None,
  94. members: dict[str, MemberType] | None = None,
  95. ) -> None:
  96. """
  97. Registers the given type as an opaque type which allows this to be consumed
  98. by a custom operator.
  99. The type name will be automatically generated from the class's fully
  100. qualified name (ex. my_module.MyClass).
  101. Args:
  102. cls (type): The class to register as an opaque type.
  103. typ (str): Either "reference" or "value". See Note [Opaque Objects] for
  104. more details.
  105. hoist (bool): Only applies to value types. A hoist=True value type
  106. object is lifted as an input to the torch.compile'd graph, instead
  107. of being a constant baked into the graph. This is useful to
  108. improve compilation times in hierarchical compilation
  109. (e.g., change your custom ops to use hoisted strings to avoid
  110. baking the string into the Dynamo/AOTAutograd/FX graphs).
  111. This flag does nothing for reference types.
  112. guard_fn (callable | None): A function that takes an instance of the opaque
  113. object and returns a list of values to guard on. These values will be compared
  114. for equality on each function call, triggering recompilation if they change.
  115. Only applicable for reference types.
  116. Example: lambda obj: [obj.x, obj.y]
  117. members (dict[str, MemberType] | None): Dictionary mapping member names
  118. (attributes, properties, or methods) to their MemberType, which controls
  119. how they are handled during torch.compile tracing:
  120. - MemberType.USE_REAL: Evaluates with the real object at compile time and
  121. bakes the result as a constant
  122. - MemberType.INLINED: Inlines the method call into the trace
  123. """
  124. import torch.utils._pytree as pytree
  125. # Prevent registration of built-in types (int, str, list, dict, etc.) and torch.Tensor
  126. if cls.__module__ == "builtins" or cls is torch.Tensor:
  127. raise ValueError(
  128. f"Unable to register built-in type {cls} as an opaque type. "
  129. "Please wrap it in a custom class and register the custom class as opaque."
  130. )
  131. if cls in pytree.SUPPORTED_NODES:
  132. raise ValueError(
  133. f"{cls} cannot be registered as an opaque object as it has been "
  134. "registered as a pytree. Opaque objects must be pytree leaves."
  135. )
  136. if not isinstance(cls, OpaqueBaseMeta):
  137. raise TypeError(
  138. f"Opaque type {cls} must subclass torch._opaque_base.OpaqueBase "
  139. "or 'metaclass=torch._opaque_base.OpaqueBaseMeta'. "
  140. "This is required so that FakeScriptObject can be registered "
  141. "as a virtual subclass, allowing isinstance() checks to work "
  142. "during torch.compile tracing. "
  143. )
  144. if typ not in ["reference", "value"]:
  145. raise AssertionError(
  146. f"Opaque type must be either 'reference' or 'value', got {typ!r}"
  147. )
  148. if typ == "value":
  149. if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap]
  150. raise TypeError(
  151. f"Value-type opaque object of type {cls} is "
  152. "expected to have a non-default `__eq__` "
  153. "implementation as we will use this in torch.compile "
  154. "to guard on the equality of objects."
  155. )
  156. # Class with a custom `__eq__` without `__hash__` won't inherit the default
  157. # `__hash__` from object; see https://stackoverflow.com/a/1608907.
  158. if cls.__hash__ is None: # type: ignore[comparison-overlap]
  159. raise TypeError(
  160. f"Value-type opaque object of type {cls} is "
  161. "expected to have a non-default `__hash__` "
  162. "implementation as we will use this in torch.compile "
  163. "for FakeTensor caching."
  164. )
  165. if not hasattr(cls, "__fx_repr__"):
  166. raise TypeError(
  167. f"Value-type opaque object of type {cls} is "
  168. "expected to have a `__fx_repr__` method "
  169. "implementation as we will use this to reconstruct "
  170. "the object in the FX codegen. __fx_repr__ should return "
  171. "a tuple of (repr_string, set_of_types)."
  172. )
  173. if guard_fn is not None:
  174. raise TypeError(
  175. "No need to specify `guard_fn` for "
  176. f"value-type opaque class {cls} as it will be guarded based "
  177. "on `__eq__`."
  178. )
  179. # Generate a fully qualified name by combining module and qualname
  180. name = f"{cls.__module__}.{cls.__qualname__}"
  181. type_info = _OpaqueTypeInfo(name, typ, guard_fn, members or {}, hoist)
  182. _OPAQUE_TYPES[cls] = type_info
  183. _OPAQUE_TYPES_BY_NAME[name] = type_info
  184. torch._C._register_opaque_type(name)
  185. def is_opaque_value(value: object) -> TypeIs[OpaqueType]:
  186. return is_opaque_type(type(value))
  187. def should_hoist(cls: Any) -> bool:
  188. if cls not in _OPAQUE_TYPES:
  189. return False
  190. return _OPAQUE_TYPES[cls].hoist
  191. def has_members(cls: Any) -> bool:
  192. if cls not in _OPAQUE_TYPES:
  193. return False
  194. return len(_OPAQUE_TYPES[cls].members) > 0
  195. def is_opaque_type(cls: Any) -> bool:
  196. """
  197. Checks if the given type is an opaque type.
  198. """
  199. if isinstance(cls, str):
  200. return torch._C._is_opaque_type_registered(cls)
  201. if cls not in _OPAQUE_TYPES:
  202. return False
  203. return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls].class_name)
  204. def is_opaque_value_type(cls: Any) -> bool:
  205. """
  206. Checks if the given type is an opaque **value** type.
  207. See Note [Opaque Objects] for more information.
  208. """
  209. if not is_opaque_type(cls):
  210. return False
  211. if isinstance(cls, str):
  212. return _OPAQUE_TYPES_BY_NAME[cls].opaque_typ == "value"
  213. return _OPAQUE_TYPES[cls].opaque_typ == "value"
  214. def is_opaque_reference_type(cls: Any) -> bool:
  215. """
  216. Checks if the given type is an opaque **reference** type.
  217. See Note [Opaque Objects] for more information.
  218. """
  219. if not is_opaque_type(cls):
  220. return False
  221. if isinstance(cls, str):
  222. return _OPAQUE_TYPES_BY_NAME[cls].opaque_typ == "reference"
  223. return _OPAQUE_TYPES[cls].opaque_typ == "reference"
  224. def get_opaque_obj_repr(obj: Any) -> tuple[str, dict[str, type]]:
  225. """
  226. Get the FX-evaluable repr for an opaque object and collect required globals.
  227. Objects must implement __fx_repr__() which should return:
  228. (repr_string, dict_mapping_name_to_type)
  229. where repr_string is an evaluable string representation and
  230. dict_mapping_name_to_type maps the names used in repr_string to their types.
  231. For example, if repr_string is "Foo(bar=Bar(1))", the dict should be:
  232. {"Foo": Foo, "Bar": Bar}
  233. """
  234. if not hasattr(obj, "__fx_repr__"):
  235. raise TypeError(
  236. f"Value-type opaque object of type {obj} is "
  237. "expected to have a `__fx_repr__` method "
  238. "implementation as we will use this to reconstruct "
  239. "the object in the FX codegen. __fx_repr__ should return "
  240. "a tuple of (repr_string, dict[str, type])."
  241. )
  242. repr_str, globals_dict = obj.__fx_repr__()
  243. if not isinstance(repr_str, str):
  244. raise TypeError(
  245. f"__fx_repr__ for {type(obj).__name__} must return a string as the "
  246. f"first element, got {type(repr_str).__name__}"
  247. )
  248. if not isinstance(globals_dict, dict):
  249. raise TypeError(
  250. f"__fx_repr__ for {type(obj).__name__} must return a dict as the "
  251. f"second element, got {type(globals_dict).__name__}"
  252. )
  253. return repr_str, globals_dict
  254. def get_opaque_obj_info(cls: Any) -> Optional[_OpaqueTypeInfo]:
  255. if not is_opaque_type(cls):
  256. return None
  257. if isinstance(cls, str):
  258. return _OPAQUE_TYPES_BY_NAME[cls]
  259. return _OPAQUE_TYPES[cls]
  260. def get_member_type(cls: Any, member_name: str) -> Optional[MemberType]:
  261. """
  262. Get the MemberType for a specific member of an opaque object class.
  263. Args:
  264. cls: The opaque object class (or its string name)
  265. member_name: The name of the member to query
  266. Returns:
  267. MemberType if the member is registered, None otherwise
  268. """
  269. info = get_opaque_obj_info(cls)
  270. if info is None:
  271. return None
  272. return info.members.get(member_name)