bundled_inputs.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. from typing import Any, TypeVar, NamedTuple
  4. from collections.abc import Callable, Sequence
  5. import textwrap
  6. import torch
  7. from torch._C import TupleType, ListType
  8. from torch.jit._recursive import wrap_cpp_module
  9. T = TypeVar("T")
  10. MAX_RAW_TENSOR_SIZE = 16
  11. class InflatableArg(NamedTuple):
  12. """Helper type for bundled inputs.
  13. 'value' is the compressed/deflated input that is stored in the model. Value
  14. must be of the same type as the argument to the function that it is a deflated
  15. input for.
  16. 'fmt' is a formattable code string that is executed to inflate the compressed data into
  17. the appropriate input. It can use 'value' as an input to the format str. It must result
  18. in a value of the same type as 'value'.
  19. 'fmt_fn' is a formattable function code string that is executed to inflate the compressed
  20. data into the appropriate input. It must result in a value of the same type as 'value'.
  21. The function name should be the formattable part of the string.
  22. Note: Only top level InflatableArgs can be inflated. i.e. you cannot place
  23. an inflatable arg inside of some other structure. You should instead create
  24. an inflatable arg such that the fmt code string returns the full structure
  25. of your input.
  26. """
  27. value: Any
  28. fmt: str = "{}"
  29. fmt_fn: str = ""
  30. def bundle_inputs(
  31. model: torch.jit.ScriptModule,
  32. inputs: Sequence[tuple[Any, ...]] | dict[Callable, Sequence[tuple[Any, ...]] | None] | None,
  33. info: list[str] | dict[Callable, list[str]] | None = None,
  34. *,
  35. _receive_inflate_expr: list[str] | None = None,
  36. ) -> torch.jit.ScriptModule:
  37. """Create and return a copy of the specified model with inputs attached.
  38. The original model is not mutated or changed in any way.
  39. Models with bundled inputs can be invoked in a uniform manner by
  40. benchmarking and code coverage tools.
  41. If inputs is passed in as a list then the inputs will be bundled for 'forward'.
  42. If inputs is instead passed in as a map then all the methods specified in the map
  43. will have their corresponding inputs bundled. Info should match watchever type is
  44. chosen for the inputs.
  45. The returned model will support the following methods:
  46. `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
  47. Returns a list of tuples suitable for passing to the model like
  48. `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`
  49. `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
  50. Returns a dictionary mapping function names to a metadata dictionary.
  51. This nested dictionary maps preset strings like:
  52. 'get_inputs_function_name' -> the name of a function attribute in this model that can be
  53. run to get back a list of inputs corresponding to that function.
  54. 'info' -> the user provided extra information about the bundled inputs
  55. If forward has bundled inputs then these following functions will also be defined on the returned module:
  56. `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
  57. Returns a list of tuples suitable for passing to the model like
  58. `for inp in model.get_all_bundled_inputs(): model(*inp)`
  59. `get_num_bundled_inputs() -> int`
  60. Equivalent to `len(model.get_all_bundled_inputs())`,
  61. but slightly easier to call from C++.
  62. Inputs can be specified in one of two ways:
  63. - The model can define `_generate_bundled_inputs_for_<function_name>`.
  64. If the user chooses this method inputs[<function>] should map to None
  65. - The `inputs` argument to this function can be a dictionary mapping functions to a
  66. list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.
  67. Alternatively if only bundling inputs for forward the map can be omitted and a singular list of inputs
  68. can be provided instead.
  69. The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a
  70. list of inputs, the inner tuple is the list of args that together make up one input.
  71. For inputs of functions that take one arg, this will be a tuple of length one. The Any, ...
  72. is the actual data that makes up the args, e.g. a tensor.
  73. Info is an optional parameter that maps functions to a list of strings providing extra information about that
  74. function's bundled inputs. Alternatively if only bundling inputs for forward the map can be omitted and
  75. a singular list of information can be provided instead. This could be descriptions, expected outputs, etc.
  76. - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}
  77. This function will attempt to optimize arguments so that (e.g.)
  78. arguments like `torch.zeros(1000)` will be represented compactly.
  79. Only top-level arguments will be optimized.
  80. Tensors in lists or tuples will not.
  81. """
  82. if not isinstance(model, torch.jit.ScriptModule):
  83. raise Exception("Only ScriptModule is supported.") # noqa: TRY002
  84. ignored_methods, ignored_attrs = _get_bundled_inputs_attributes_and_methods(model)
  85. clone = torch._C._hack_do_not_use_clone_module_with_class( # type: ignore[attr-defined]
  86. model._c,
  87. ignored_methods,
  88. ignored_attrs,
  89. )
  90. # The above cloning function returns a torch._C.scriptmodule and we need a torch.jit.scriptmodule.
  91. # Fortunately there is a function in _recursive that does exactly that conversion.
  92. cloned_module = wrap_cpp_module(clone)
  93. if isinstance(inputs, dict):
  94. if not isinstance(info, dict) and info is not None:
  95. raise AssertionError("If inputs is a dict, info must be a dict or None")
  96. augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
  97. else:
  98. if not isinstance(info, list) and info is not None:
  99. raise AssertionError("If inputs is a list, info must be a list or None")
  100. augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
  101. return cloned_module
  102. def augment_model_with_bundled_inputs(
  103. model: torch.jit.ScriptModule,
  104. inputs: Sequence[tuple[Any, ...]] | None = None,
  105. _receive_inflate_expr: list[str] | None = None, # For debugging.
  106. info: list[str] | None = None, # Optional argument to provide info about forward or its inputs
  107. skip_size_check=False,
  108. ) -> None:
  109. """Add bundled sample inputs to a model for the forward function.
  110. Models with bundled inputs can be invoked in a uniform manner by
  111. benchmarking and code coverage tools.
  112. Augmented models will support the following methods:
  113. `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
  114. Returns a list of tuples suitable for passing to the model like
  115. `for inp in model.get_all_bundled_inputs(): model(*inp)`
  116. `get_num_bundled_inputs() -> int`
  117. Equivalent to `len(model.get_all_bundled_inputs())`,
  118. but slightly easier to call from C++.
  119. `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
  120. Returns a dictionary mapping function names to a metadata dictionary.
  121. This nested dictionary maps preset strings like:
  122. 'get_inputs_function_name' -> the name of a function attribute in this model that can be
  123. run to get back a list of inputs corresponding to that function.
  124. 'info' -> the user provided extra information about the bundled inputs
  125. Inputs can be specified in one of two ways:
  126. - The model can define `_generate_bundled_inputs_for_forward`.
  127. If the user chooses this method inputs should be None
  128. - `inputs` is a list of inputs of form List[Tuple[Any, ...]]. A list of tuples where the elements
  129. of each tuple are the args that make up one input.
  130. """
  131. if not isinstance(model, torch.jit.ScriptModule):
  132. raise Exception("Only ScriptModule is supported.") # noqa: TRY002
  133. forward: Callable = model.forward
  134. # Sometimes forward won't have a name attached so just in case
  135. if not hasattr(forward, "__name__"):
  136. forward.__name__ = 'forward'
  137. augment_many_model_functions_with_bundled_inputs(
  138. model,
  139. inputs={forward : inputs},
  140. _receive_inflate_expr=_receive_inflate_expr,
  141. info={forward : info} if info else None,
  142. skip_size_check=skip_size_check,
  143. )
  144. def augment_many_model_functions_with_bundled_inputs(
  145. model: torch.jit.ScriptModule,
  146. inputs: dict[Callable, Sequence[tuple[Any, ...]] | None],
  147. _receive_inflate_expr: list[str] | None = None, # For debugging.
  148. info: dict[Callable, list[str]] | None = None, # Optional argument to provide info about the function or its inputs
  149. skip_size_check=False,
  150. ) -> None:
  151. """Add bundled sample inputs to a model for an arbitrary list of public functions.
  152. Models with bundled inputs can be invoked in a uniform manner by
  153. benchmarking and code coverage tools.
  154. Augmented models will support the following methods:
  155. `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
  156. Returns a list of tuples suitable for passing to the model like
  157. `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`
  158. `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
  159. Returns a dictionary mapping function names to a metadata dictionary.
  160. This nested dictionary maps preset strings like:
  161. 'get_inputs_function_name' -> the name of a function attribute in this model that can be
  162. run to get back a list of inputs corresponding to that function.
  163. 'info' -> the user provided extra information about the bundled inputs
  164. If forward has bundled inputs then these following functions are also defined:
  165. `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
  166. Returns a list of tuples suitable for passing to the model like
  167. `for inp in model.get_all_bundled_inputs(): model(*inp)`
  168. `get_num_bundled_inputs() -> int`
  169. Equivalent to `len(model.get_all_bundled_inputs())`,
  170. but slightly easier to call from C++.
  171. Inputs can be specified in one of two ways:
  172. - The model can define `_generate_bundled_inputs_for_<function_name>`.
  173. If the user chooses this method inputs[<function>] should map to None
  174. - The `inputs` argument to this function can be a dictionary mapping functions to a
  175. list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.
  176. The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a
  177. list of inputs, the inner tuple is the list of args that together make up one input.
  178. For inputs of functions that take one arg, this will be a tuple of length one. The Any, ...
  179. is the actual data that makes up the args, e.g. a tensor.
  180. Info is an optional parameter that maps functions to a list of strings providing extra information about that
  181. function's bundled inputs. This could be descriptions, expected outputs, etc.
  182. - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}
  183. This function will attempt to optimize arguments so that (e.g.)
  184. arguments like `torch.zeros(1000)` will be represented compactly.
  185. Only top-level arguments will be optimized.
  186. Tensors in lists or tuples will not.
  187. """
  188. if not isinstance(model, torch.jit.ScriptModule):
  189. raise Exception("Only ScriptModule is supported.") # noqa: TRY002
  190. if not inputs:
  191. raise Exception("Please provide inputs for at least 1 function") # noqa: TRY002
  192. if hasattr(model, "get_all_bundled_inputs") or hasattr(model, "get_bundled_inputs_functions_and_info"):
  193. raise Exception( # noqa: TRY002
  194. "Models can only be augmented with bundled inputs once. "
  195. "This Model seems to have already been augmented with "
  196. "bundled inputs. Please start afresh with one that "
  197. "doesn't have bundled inputs.",
  198. )
  199. get_bundled_inputs_functions_and_info_template = ""
  200. for function, input_list in inputs.items():
  201. if hasattr(function, "__name__"):
  202. function_name = function.__name__
  203. else:
  204. if hasattr(function, "name"):
  205. function_name = function.name # type: ignore[attr-defined]
  206. else:
  207. raise Exception( # noqa: TRY002
  208. 'At least one of your functions has no attribute name please ensure all have one. m.foo.name = "foo"')
  209. if input_list is not None and not isinstance(input_list, Sequence):
  210. raise TypeError(f"Error inputs for function {function_name} is not a Sequence")
  211. function_arg_types = [arg.type for arg in function.schema.arguments[1:]] # type: ignore[attr-defined]
  212. deflated_inputs_type: ListType = ListType(TupleType(function_arg_types))
  213. model._c._register_attribute(f"_bundled_inputs_deflated_{function_name}", deflated_inputs_type, [])
  214. if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
  215. if input_list is not None:
  216. raise Exception( # noqa: TRY002
  217. f"inputs[{function_name}] is not None, but _generate_bundled_inputs_for_{function_name} is already defined"
  218. )
  219. # Model author already defined _generate_bundled_inputs_for_<function_name>.
  220. elif input_list is None or len(input_list) == 0:
  221. raise Exception( # noqa: TRY002
  222. f"inputs for {function_name} must be specified if "
  223. f"_generate_bundled_inputs_for_{function_name} is not already defined"
  224. )
  225. else:
  226. # Iterate over the inputs and args in each input.
  227. # Accumulate `deflated_inputs` as (possibly) compressed values
  228. # and `parts` to be joined into the expression that unpacks them.
  229. deflated_inputs = []
  230. parts = []
  231. for inp_idx, args in enumerate(input_list):
  232. if not isinstance(args, tuple) and not isinstance(args, list): # type: ignore[arg-type]
  233. raise TypeError(
  234. f"Error bundled input for function {function_name} idx: {inp_idx} is not a Tuple or a List"
  235. )
  236. deflated_args = []
  237. parts.append("(")
  238. for arg_idx, arg in enumerate(args):
  239. inflate_helper_fn_name = _get_inflate_helper_fn_name(arg_idx, inp_idx, function_name)
  240. deflated, inflater, helper_definition = _inflate_expr(
  241. arg,
  242. f"deflated[{inp_idx}][{arg_idx}]",
  243. inflate_helper_fn_name,
  244. skip_size_check=skip_size_check,
  245. )
  246. deflated_args.append(deflated)
  247. parts.append(f" {inflater},")
  248. if helper_definition:
  249. model.define(textwrap.dedent(helper_definition))
  250. deflated_inputs.append(tuple(deflated_args))
  251. parts.append("),")
  252. parts.append("")
  253. expr = "\n".join(parts)
  254. # Back-channel return this expr for debugging.
  255. if _receive_inflate_expr is not None:
  256. _receive_inflate_expr.append(expr)
  257. setattr(model, f"_bundled_inputs_deflated_{function_name}", deflated_inputs)
  258. definition = textwrap.dedent("""
  259. def _generate_bundled_inputs_for_{name}(self):
  260. deflated = self._bundled_inputs_deflated_{name}
  261. return [
  262. {expr}
  263. ]
  264. """).format(expr=expr, name=function_name)
  265. model.define(definition)
  266. # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs.
  267. model.define(textwrap.dedent("""
  268. def get_all_bundled_inputs_for_{name}(self):
  269. all_inputs = self._generate_bundled_inputs_for_{name}()
  270. assert all_inputs is not None
  271. return all_inputs
  272. """).format(name=function_name))
  273. # Add to the high level helper methods
  274. inputs_info = repr(info[function]) if info and function in info else '[]'
  275. get_bundled_inputs_functions_and_info_template += f"""
  276. temp_dict : Dict[str,List[str]] = {{}}
  277. info: List[str] = {inputs_info}
  278. temp_dict['info'] = info
  279. temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{function_name}']
  280. all_inputs['{function_name}'] = temp_dict
  281. """
  282. # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided
  283. if function_name == 'forward':
  284. model.define(textwrap.dedent("""
  285. def get_all_bundled_inputs(self):
  286. return self.get_all_bundled_inputs_for_forward()
  287. """))
  288. model.define(textwrap.dedent("""
  289. def get_num_bundled_inputs(self):
  290. return len(self.get_all_bundled_inputs_for_forward())
  291. """))
  292. # Define some high level helper methods that act on all bundled inputs
  293. model.define(textwrap.dedent(f"""
  294. def get_bundled_inputs_functions_and_info(self):
  295. all_inputs : Dict[str, Dict[str,List[str]]] = {{}}
  296. {get_bundled_inputs_functions_and_info_template}
  297. return all_inputs
  298. """))
  299. def _inflate_expr(
  300. arg: T, ref: str, inflate_helper_fn_name: str, skip_size_check: bool = False
  301. ) -> tuple[T | torch.Tensor, str, str | None]:
  302. # Allow custom inflation expressions any object.
  303. # For example, calling custom image-decoding ops.
  304. # Or just use "{}" as the format string to ignore size limits.
  305. if isinstance(arg, InflatableArg):
  306. if arg.fmt_fn:
  307. if arg.fmt not in ["{}", ""]:
  308. raise Exception( # noqa: TRY002
  309. f"Bundled input argument at position '{ref}' has "
  310. f"both arg.fmt_fn => \n{arg.fmt_fn} "
  311. f"\n and arg.fmt => {arg.fmt}. "
  312. "Please choose `arg.fmt` if the deflater is straightforward or "
  313. "`arg.fmt_fn` if you need a function."
  314. )
  315. helper_definition = arg.fmt_fn.format(inflate_helper_fn_name)
  316. expr = f"self.{inflate_helper_fn_name}({ref})"
  317. return arg.value, expr, helper_definition
  318. else:
  319. return arg.value, arg.fmt.format(ref), None
  320. if isinstance(arg, torch.Tensor):
  321. # Small-storage tensors can just be saved directly.
  322. if arg._typed_storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check:
  323. return arg, ref, None
  324. # Small contiguous tensors can be cloned to have small storage.
  325. # TODO: Should we do this even for non-contiguous tensors?
  326. if arg.is_contiguous() and arg.numel() <= MAX_RAW_TENSOR_SIZE:
  327. return arg.clone(), ref, None
  328. # Example inputs commonly come from torch.zeros, torch.ones, or torch.full.
  329. # These can be represented compactly.
  330. for fmt in [torch.contiguous_format, torch.channels_last]:
  331. if arg.is_contiguous(memory_format=fmt) and (arg == arg.flatten()[0]).all().item():
  332. return (arg.flatten()[0].clone().expand(*arg.size()),
  333. f"{ref}.contiguous(memory_format={fmt})", None)
  334. # Prevent big tensors from being bundled by default.
  335. # TODO: Provide more useful diagnostics.
  336. raise Exception( # noqa: TRY002
  337. f"Bundled input argument at position '{ref}' is "
  338. f"a tensor with storage size {arg._typed_storage().size()}. "
  339. f"You probably don't want to bundle this as an input. "
  340. )
  341. else:
  342. return arg, ref, None
  343. def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptModule) -> tuple[list[str], list[str]]:
  344. methods: list[str] = []
  345. attributes: list[str] = []
  346. # Has bundled inputs for forward
  347. if hasattr(script_module, 'get_all_bundled_inputs'):
  348. methods.append('get_all_bundled_inputs')
  349. methods.append('get_num_bundled_inputs')
  350. methods.append('run_on_bundled_input')
  351. if hasattr(script_module, 'get_bundled_inputs_functions_and_info'):
  352. methods.append('get_bundled_inputs_functions_and_info')
  353. all_info = script_module.get_bundled_inputs_functions_and_info()
  354. for function_name in all_info:
  355. methods.append("get_all_bundled_inputs_for_" + function_name)
  356. methods.append("_generate_bundled_inputs_for_" + function_name)
  357. attributes.append("_bundled_inputs_deflated_" + function_name)
  358. bundled_inputs_fn = getattr(
  359. script_module,
  360. f"get_all_bundled_inputs_for_{function_name}"
  361. )
  362. num_bundled_inputs: int = len(bundled_inputs_fn())
  363. # Check inflate helper functions for each function, argument and bundled input
  364. func = getattr(script_module, function_name)
  365. for arg_idx in range(len(func.schema.arguments) - 1):
  366. for input_idx in range(num_bundled_inputs):
  367. helper_fn_name = _get_inflate_helper_fn_name(
  368. arg_idx=arg_idx,
  369. input_idx=input_idx,
  370. function_name=function_name
  371. )
  372. # if the arg has an InflatableArg with fmt_fn, add the helper function name
  373. if hasattr(script_module, helper_fn_name):
  374. methods.append(helper_fn_name)
  375. return (methods, attributes)
  376. def _get_inflate_helper_fn_name(
  377. arg_idx: int,
  378. input_idx: int,
  379. function_name: str,
  380. ) -> str:
  381. return f"_inflate_helper_for_{function_name}_input_{input_idx}_arg_{arg_idx}"
  382. def bundle_randn(*size, dtype=None):
  383. """Generate a tensor that will be inflated with torch.randn."""
  384. stub = torch.zeros(1, dtype=dtype).expand(*size)
  385. return InflatableArg(value=stub, fmt="torch.randn_like({})")
  386. def bundle_large_tensor(t):
  387. """Wrap a tensor to allow bundling regardless of size."""
  388. return InflatableArg(value=t, fmt="{}")