supported_ops.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import textwrap
  4. import torch.jit
  5. from torch.jit._builtins import _find_builtin
  6. # this file is for generating documentation using sphinx autodoc
  7. # > help(torch.jit.supported_ops) will also give a nice listed of the
  8. # supported ops programmatically
  9. def _hidden(name):
  10. return name.startswith("_") and not name.startswith("__")
  11. def _emit_type(type):
  12. return str(type)
  13. def _emit_arg(indent, i, arg):
  14. v = f"{arg.name} : {_emit_type(arg.type)}"
  15. default = arg.default_value
  16. if default is not None:
  17. v = f"{v}={str(default)}"
  18. if i > 0:
  19. v = f"\n{' ' * indent}{v}"
  20. return v
  21. def _emit_args(indent, arguments):
  22. return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments))
  23. def _emit_ret(ret):
  24. return _emit_type(ret.type)
  25. def _emit_rets(returns):
  26. if len(returns) == 1:
  27. return _emit_ret(returns[0])
  28. return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]"
  29. def _emit_schema(mod, name, schema, arg_start=0, padding=4):
  30. if mod is None:
  31. qualified_name = name
  32. else:
  33. qualified_name = f"{mod}.{name}"
  34. schema_str = (
  35. f"{qualified_name}"
  36. f"({_emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:])}) "
  37. f"-> {_emit_rets(schema.returns)}"
  38. )
  39. return schema_str
  40. def _get_tensor_ops():
  41. def is_tensor_method(schema) -> bool:
  42. if len(schema.arguments) == 0:
  43. return False
  44. self = schema.arguments[0]
  45. if self.name != "self":
  46. return False
  47. if not self.type.isSubtypeOf(torch._C.TensorType.get()):
  48. return False
  49. return True
  50. methods = []
  51. # discover methods
  52. for elem in dir(torch.Tensor):
  53. if not _hidden(elem):
  54. schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
  55. for schema in schemas:
  56. if is_tensor_method(schema):
  57. methods.append(_emit_schema("Tensor", elem, schema, arg_start=1))
  58. return "Supported Tensor Methods", methods
  59. def _get_nn_functional_ops():
  60. functions = []
  61. # Iterate over torch.nn.functional
  62. mod = torch.nn.functional
  63. name = mod.__name__
  64. for elem in dir(torch.nn.functional):
  65. attr = getattr(mod, elem)
  66. if not inspect.isfunction(attr) or _hidden(elem[0]):
  67. # Ignore non-functions and internal methods
  68. continue
  69. attr_module = inspect.getmodule(attr)
  70. if not attr_module:
  71. raise RuntimeError(f"Module for {attr} not found")
  72. if "torch.nn.functional" not in attr_module.__name__:
  73. # Ignore functions from outside torch.nn.functional
  74. continue
  75. try:
  76. # compile fn, get schema
  77. scripted = torch.jit.script(attr)
  78. scripted_schema = scripted.schema
  79. functions.append(_emit_schema(name, elem, scripted_schema))
  80. except: # noqa: B001,E722
  81. # Skip interpolate / boolean dispatched things
  82. pass
  83. # Iterate over modules that we know contain a lot of builtins
  84. for mod in torch.jit._builtins._modules_containing_builtins:
  85. name = mod.__name__
  86. for elem in dir(mod):
  87. builtin = _find_builtin(getattr(mod, elem))
  88. if builtin is not None:
  89. schemas = torch._C._jit_get_schemas_for_operator(builtin)
  90. for schema in schemas:
  91. # remove _tan but not __and__
  92. if not _hidden(elem):
  93. functions.append(_emit_schema(name, elem, schema))
  94. return "Supported PyTorch Functions", functions
  95. def _get_builtins_helper():
  96. builtins = []
  97. for fn, _builtin_name in torch.jit._builtins._builtin_ops:
  98. mod = inspect.getmodule(fn)
  99. if not hasattr(fn, "__name__"):
  100. # typing classes
  101. continue
  102. if not mod:
  103. continue
  104. if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__):
  105. # skip internal-only methods
  106. continue
  107. if "torch._C" in mod.__name__:
  108. continue
  109. builtins.append((fn, _builtin_name))
  110. return builtins
  111. def _is_math_fn(fn):
  112. mod = inspect.getmodule(fn)
  113. if not mod:
  114. raise RuntimeError(f"Module for {fn} not found")
  115. return mod.__name__ == "math"
  116. def _get_torchscript_builtins():
  117. functions = []
  118. builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper())
  119. builtins_list = list(builtins)
  120. # Iterate over the specially added builtins
  121. for fn, _builtin_name in builtins_list:
  122. mod = inspect.getmodule(fn)
  123. if not mod:
  124. raise RuntimeError(f"Module for {fn} not found")
  125. builtin = _find_builtin(fn)
  126. if builtin is not None:
  127. schemas = torch._C._jit_get_schemas_for_operator(builtin)
  128. for schema in schemas:
  129. functions.append(_emit_schema(mod.__name__, fn.__name__, schema))
  130. return "TorchScript Builtin Functions", functions
  131. def _get_math_builtins():
  132. functions = []
  133. builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper())
  134. builtins_list = list(builtins)
  135. # Iterate over the specially added builtins
  136. for fn, _builtin_name in builtins_list:
  137. mod = inspect.getmodule(fn)
  138. if not mod:
  139. raise RuntimeError(f"Module for {fn} not found")
  140. builtin = _find_builtin(fn)
  141. if builtin is not None:
  142. schemas = torch._C._jit_get_schemas_for_operator(builtin)
  143. for schema in schemas:
  144. schema_str = _emit_schema(mod.__name__, fn.__name__, schema)
  145. if "Tensor" in schema_str:
  146. # Skip Tensor ops that have the same name as math functions
  147. # (they will show up in the tensor methods section)
  148. continue
  149. functions.append(schema)
  150. return "``math`` Module", functions
  151. def _get_global_builtins():
  152. # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp
  153. supported_builtins = [
  154. "print",
  155. "tuple",
  156. "float",
  157. "complex",
  158. "int",
  159. "bool",
  160. "str",
  161. "getattr",
  162. "hasattr",
  163. "isinstance",
  164. "len",
  165. "hex",
  166. "oct",
  167. "round",
  168. "hash",
  169. "min",
  170. "max",
  171. "abs",
  172. "all",
  173. "divmod",
  174. "list",
  175. "ord",
  176. "chr",
  177. "bin",
  178. "range",
  179. "zip",
  180. "enumerate",
  181. "sorted",
  182. ]
  183. op_renames = {
  184. "bool": "aten::Bool",
  185. "int": "aten::Int",
  186. "float": "aten::Float",
  187. "complex": "aten::Complex",
  188. "abs": "prim::abs",
  189. "max": "prim::max",
  190. "min": "prim::min",
  191. "range": "fake::does_not_exist",
  192. }
  193. schemaless_op_explanations = {
  194. "print": "Print any value",
  195. "tuple": "Lists cannot be converted to tuples with this method since their size is not statically known",
  196. "getattr": "Attribute name must be a literal string",
  197. "hasattr": "Attribute name must be a literal string",
  198. "isinstance": "Result is static",
  199. "zip": "Arguments must be iterable.",
  200. "enumerate": "Arguments must be iterable.",
  201. "range": "Can only be used as an iterator in a for loop",
  202. }
  203. magic_methods = [
  204. ("complex", "__complex__"),
  205. ("float", "__float__"),
  206. ("int", "__int__"),
  207. ("bool", "__bool__"),
  208. ("str", "__str__"),
  209. ("len", "__len__"),
  210. ("hex", "__hex__"),
  211. ("oct", "__oct__"),
  212. ]
  213. magic_methods_rows = []
  214. for fn, magic_method in magic_methods:
  215. # pyrefly: ignore [bad-argument-type]
  216. magic_methods_rows.append(f'"{fn}", "``{magic_method}``"')
  217. schematized_ops = []
  218. schemaless_ops = []
  219. for fn in supported_builtins:
  220. op_name = f"aten::{fn}"
  221. if fn in op_renames:
  222. op_name = op_renames[fn]
  223. schemas = torch._C._jit_get_schemas_for_operator(op_name)
  224. for s in schemas:
  225. schematized_ops.append(_emit_schema(None, fn, s, padding=0))
  226. if len(schemas) > 0:
  227. schematized_ops.append("")
  228. else:
  229. table_row = (
  230. f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"'
  231. )
  232. # pyrefly: ignore [bad-argument-type]
  233. schemaless_ops.append(table_row)
  234. schematized_ops_str = "\n".join(schematized_ops)
  235. schemaless_ops_str = "\n".join(schemaless_ops)
  236. magic_methods_rows_str = "\n".join(magic_methods_rows)
  237. schematized_ops_str = textwrap.indent(schematized_ops_str, "\t")
  238. schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t")
  239. magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t")
  240. section = f"""
  241. The functions in the following table are supported but do not have a static schema
  242. .. csv-table::
  243. :header: "Function", "Note"
  244. {schemaless_ops_str}
  245. The following functions will use the corresponding magic method on TorchScript classes
  246. .. csv-table::
  247. :header: "Function", "Magic Method"
  248. {magic_methods_rows_str}
  249. These built-in functions use the schema
  250. .. rst-class:: codeblock-height-limiter
  251. ::
  252. {schematized_ops_str}
  253. """
  254. return "Python Built-in Functions", section
  255. def _list_supported_ops():
  256. def emit_block(decls):
  257. return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format(
  258. "".join(f" {d}\n\n" for d in decls)
  259. )
  260. body = ""
  261. op_gathering_fns = (
  262. _get_tensor_ops,
  263. _get_nn_functional_ops,
  264. _get_torchscript_builtins,
  265. _get_global_builtins,
  266. _get_math_builtins,
  267. )
  268. for fn in op_gathering_fns:
  269. header, items = fn()
  270. link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-")
  271. if isinstance(items, str):
  272. section = f"{header}\n{'~' * len(header)}\n{items}\n"
  273. else:
  274. section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}"
  275. section = f".. _{link_target}:" + "\n\n" + section
  276. body += section
  277. return body
  278. __doc__ = _list_supported_ops()