prepare.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. from typing import Optional
  4. import torch
  5. from torch.backends._nnapi.serializer import _NnapiSerializer
  6. ANEURALNETWORKS_PREFER_LOW_POWER = 0
  7. ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
  8. ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
  9. class NnapiModule(torch.nn.Module):
  10. """Torch Module that wraps an NNAPI Compilation.
  11. This module handles preparing the weights, initializing the
  12. NNAPI TorchBind object, and adjusting the memory formats
  13. of all inputs and outputs.
  14. """
  15. # _nnapi.Compilation is defined
  16. comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
  17. weights: list[torch.Tensor]
  18. out_templates: list[torch.Tensor]
  19. def __init__(
  20. self,
  21. shape_compute_module: torch.nn.Module,
  22. ser_model: torch.Tensor,
  23. weights: list[torch.Tensor],
  24. inp_mem_fmts: list[int],
  25. out_mem_fmts: list[int],
  26. compilation_preference: int,
  27. relax_f32_to_f16: bool,
  28. ):
  29. super().__init__()
  30. self.shape_compute_module = shape_compute_module
  31. self.ser_model = ser_model
  32. self.weights = weights
  33. self.inp_mem_fmts = inp_mem_fmts
  34. self.out_mem_fmts = out_mem_fmts
  35. self.out_templates = []
  36. self.comp = None
  37. self.compilation_preference = compilation_preference
  38. self.relax_f32_to_f16 = relax_f32_to_f16
  39. @torch.jit.export
  40. def init(self, args: list[torch.Tensor]):
  41. if self.comp is not None:
  42. raise AssertionError("comp must be None before initialization")
  43. self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
  44. self.weights = [w.contiguous() for w in self.weights]
  45. comp = torch.classes._nnapi.Compilation()
  46. comp.init2(
  47. self.ser_model,
  48. self.weights,
  49. self.compilation_preference,
  50. self.relax_f32_to_f16,
  51. )
  52. self.comp = comp
  53. def forward(self, args: list[torch.Tensor]) -> list[torch.Tensor]:
  54. if self.comp is None:
  55. self.init(args)
  56. comp = self.comp
  57. if comp is None:
  58. raise AssertionError("comp must not be None")
  59. outs = [torch.empty_like(out) for out in self.out_templates]
  60. if len(args) != len(self.inp_mem_fmts):
  61. raise AssertionError(
  62. f"args length {len(args)} != inp_mem_fmts length {len(self.inp_mem_fmts)}"
  63. )
  64. fixed_args = []
  65. for idx in range(len(args)):
  66. fmt = self.inp_mem_fmts[idx]
  67. # These constants match the values in DimOrder in serializer.py
  68. # TODO: See if it's possible to use those directly.
  69. if fmt == 0:
  70. fixed_args.append(args[idx].contiguous())
  71. elif fmt == 1:
  72. fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
  73. else:
  74. raise ValueError("Invalid mem_fmt")
  75. comp.run(fixed_args, outs)
  76. if len(outs) != len(self.out_mem_fmts):
  77. raise AssertionError(
  78. f"outs length {len(outs)} != out_mem_fmts length {len(self.out_mem_fmts)}"
  79. )
  80. for idx in range(len(self.out_templates)):
  81. fmt = self.out_mem_fmts[idx]
  82. # These constants match the values in DimOrder in serializer.py
  83. # TODO: See if it's possible to use those directly.
  84. if fmt in (0, 2):
  85. pass
  86. elif fmt == 1:
  87. outs[idx] = outs[idx].permute(0, 3, 1, 2)
  88. else:
  89. raise ValueError("Invalid mem_fmt")
  90. return outs
  91. def convert_model_to_nnapi(
  92. model,
  93. inputs,
  94. serializer=None,
  95. return_shapes=None,
  96. use_int16_for_qint16=False,
  97. compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
  98. relax_f32_to_f16=False,
  99. ):
  100. (
  101. shape_compute_module,
  102. ser_model_tensor,
  103. used_weights,
  104. inp_mem_fmts,
  105. out_mem_fmts,
  106. retval_count,
  107. ) = process_for_nnapi(
  108. model, inputs, serializer, return_shapes, use_int16_for_qint16
  109. )
  110. nnapi_model = NnapiModule(
  111. shape_compute_module,
  112. ser_model_tensor,
  113. used_weights,
  114. inp_mem_fmts,
  115. out_mem_fmts,
  116. compilation_preference,
  117. relax_f32_to_f16,
  118. )
  119. class NnapiInterfaceWrapper(torch.nn.Module):
  120. """NNAPI list-ifying and de-list-ifying wrapper.
  121. NNAPI always expects a list of inputs and provides a list of outputs.
  122. This module allows us to accept inputs as separate arguments.
  123. It returns results as either a single tensor or tuple,
  124. matching the original module.
  125. """
  126. def __init__(self, mod):
  127. super().__init__()
  128. self.mod = mod
  129. wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
  130. wrapper_model = torch.jit.script(wrapper_model_py)
  131. # TODO: Maybe make these names match the original.
  132. arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
  133. if retval_count < 0:
  134. ret_expr = "retvals[0]"
  135. else:
  136. ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
  137. wrapper_model.define(
  138. f"def forward(self, {arg_list}):\n"
  139. f" retvals = self.mod([{arg_list}])\n"
  140. f" return {ret_expr}\n"
  141. )
  142. return wrapper_model
  143. def process_for_nnapi(
  144. model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False
  145. ):
  146. model = torch.jit.freeze(model)
  147. if isinstance(inputs, torch.Tensor):
  148. inputs = [inputs]
  149. serializer = serializer or _NnapiSerializer(
  150. config=None, use_int16_for_qint16=use_int16_for_qint16
  151. )
  152. (
  153. ser_model,
  154. used_weights,
  155. inp_mem_fmts,
  156. out_mem_fmts,
  157. shape_compute_lines,
  158. retval_count,
  159. ) = serializer.serialize_model(model, inputs, return_shapes)
  160. ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
  161. # We have to create a new class here every time this function is called
  162. # because module.define adds a method to the *class*, not the instance.
  163. class ShapeComputeModule(torch.nn.Module):
  164. """Code-gen-ed module for tensor shape computation.
  165. module.prepare will mutate ser_model according to the computed operand
  166. shapes, based on the shapes of args. Returns a list of output templates.
  167. """
  168. shape_compute_module = torch.jit.script(ShapeComputeModule())
  169. real_shape_compute_lines = [
  170. "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
  171. ] + [f" {line}\n" for line in shape_compute_lines]
  172. shape_compute_module.define("".join(real_shape_compute_lines))
  173. return (
  174. shape_compute_module,
  175. ser_model_tensor,
  176. used_weights,
  177. inp_mem_fmts,
  178. out_mem_fmts,
  179. retval_count,
  180. )