_patches.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. import contextlib
  2. from collections.abc import Generator
  3. import torch
  4. from torch._decomp import global_decomposition_table
  5. from torch._decomp.decompositions import _rnn_helper, gather_params, gru_cell, lstm_cell
  6. from torch._higher_order_ops.while_loop import while_loop
  7. def one_layer_while_loop_lstm(inp, hidden, params, has_biases, reverse=False):
  8. """
  9. 1 layer fn for while loop LSTM
  10. Args:
  11. inp: Input tensor of shape (seq_len, batch, input_size)
  12. hidden: Tuple of (hx, cx) hidden states
  13. params: List of weight and bias tensors
  14. has_biases: Whether biases are included
  15. reverse: Whether to process sequence in reverse
  16. Returns:
  17. Tuple of (output, (final_hx, final_cx))
  18. """
  19. ih_weight = params[0]
  20. hh_weight = params[1]
  21. ih_bias = params[2] if has_biases else None
  22. hh_bias = params[3] if has_biases else None
  23. hr_weight = (
  24. params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
  25. )
  26. hx = hidden[0].unsqueeze(0)
  27. cx = hidden[1].unsqueeze(0)
  28. precomputed_input = torch.nn.functional.linear(inp, ih_weight, ih_bias)
  29. precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
  30. # while loop rewrite
  31. step_output = torch.empty(
  32. precomputed_input.size(0),
  33. *tuple(hx.shape[1:]),
  34. dtype=hx.dtype,
  35. device=hx.device,
  36. )
  37. def cond_fn(i, out, hx, cx):
  38. return i < precomputed_input.size(0)
  39. def body_fn(idx, out, hx, cx):
  40. # Extract the integer value from idx and constrain it for data-dependent indexing
  41. i = idx.item()
  42. torch._check_is_size(i)
  43. torch._check_is_size(i, max=precomputed_input.size(0) - 1)
  44. hx, cx = lstm_cell(
  45. precomputed_input[i], hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2
  46. )
  47. out = out.clone()
  48. # Squeeze the first dimension before storing (lstm_cell preserves the unsqueezed dim)
  49. out[i] = hx.squeeze(0)
  50. return idx + 1, out, hx, cx
  51. cnt = torch.tensor(0, dtype=torch.int64)
  52. _, out, final_hx, final_cx = while_loop(
  53. cond_fn, body_fn, [cnt, step_output, hx, cx]
  54. )
  55. if reverse:
  56. out = out.flip(0)
  57. # Use squeeze(1) to match original implementation
  58. return out, (final_hx.squeeze(1), final_cx.squeeze(1))
  59. def lstm_while_loop_impl(
  60. input,
  61. hx,
  62. params,
  63. has_biases,
  64. num_layers,
  65. dropout,
  66. train,
  67. bidirectional,
  68. batch_first,
  69. ):
  70. """
  71. LSTM implementation using while_loop for export compatibility.
  72. This is a drop-in replacement for the default LSTM decomposition that uses
  73. while_loop instead of Python loops, making it more suitable for torch.export.
  74. Args:
  75. input: Input tensor
  76. hx: Tuple of (h0, c0) hidden states
  77. params: List of weight and bias tensors
  78. has_biases: Whether biases are included
  79. num_layers: Number of LSTM layers
  80. dropout: Dropout probability
  81. train: Training mode
  82. bidirectional: Whether to use bidirectional LSTM
  83. batch_first: Whether batch dimension is first
  84. Returns:
  85. Tuple of (output, h_n, c_n)
  86. """
  87. if len(hx) != 2:
  88. raise AssertionError("lstm expects two hidden states")
  89. params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
  90. hidden = list(zip(hx[0], hx[1]))
  91. layer_fn = one_layer_while_loop_lstm
  92. out, final_hiddens = _rnn_helper(
  93. input,
  94. hidden,
  95. params,
  96. has_biases,
  97. num_layers,
  98. dropout,
  99. train,
  100. bidirectional,
  101. batch_first,
  102. layer_fn,
  103. )
  104. final_hiddens = list(zip(*final_hiddens))
  105. return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
  106. def one_layer_while_loop_gru(inp, hidden, params, has_biases, reverse=False):
  107. """
  108. 1 layer fn for while loop GRU
  109. Args:
  110. inp: Input tensor of shape (seq_len, batch, input_size)
  111. hidden: Hidden state tensor
  112. params: List of weight and bias tensors
  113. has_biases: Whether biases are included
  114. reverse: Whether to process sequence in reverse
  115. Returns:
  116. Tuple of (output, final_hidden)
  117. """
  118. ih_weight = params[0]
  119. hh_weight = params[1]
  120. ih_bias = params[2] if has_biases else None
  121. hh_bias = params[3] if has_biases else None
  122. precomputed_input = torch.nn.functional.linear(inp, ih_weight, ih_bias)
  123. precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
  124. cur_hidden = hidden.unsqueeze(0)
  125. # while loop rewrite
  126. step_output = torch.empty(
  127. precomputed_input.size(0),
  128. *tuple(cur_hidden.shape[1:]),
  129. dtype=cur_hidden.dtype,
  130. device=cur_hidden.device,
  131. )
  132. def cond_fn(i, out, cur_hidden):
  133. return i < precomputed_input.size(0)
  134. def body_fn(idx, out, cur_hidden):
  135. # Extract the integer value from idx and constrain it for data-dependent indexing
  136. i = idx.item()
  137. torch._check_is_size(i)
  138. torch._check_is_size(i, max=precomputed_input.size(0) - 1)
  139. cur_hidden = gru_cell(
  140. precomputed_input[i], cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias
  141. )
  142. out = out.clone()
  143. out[i] = cur_hidden.squeeze(0)
  144. return idx + 1, out, cur_hidden
  145. cnt = torch.tensor(0, dtype=torch.int64)
  146. _, out, final_hidden = while_loop(cond_fn, body_fn, [cnt, step_output, cur_hidden])
  147. if reverse:
  148. out = out.flip(0)
  149. return out, final_hidden.squeeze(0)
  150. def gru_while_loop_impl(
  151. input,
  152. hx,
  153. params,
  154. has_biases,
  155. num_layers,
  156. dropout,
  157. train,
  158. bidirectional,
  159. batch_first,
  160. ):
  161. """
  162. GRU implementation using while_loop for export compatibility.
  163. This is a drop-in replacement for the default GRU decomposition that uses
  164. while_loop instead of Python loops, making it more suitable for torch.export.
  165. Args:
  166. input: Input tensor
  167. hx: Hidden state tensor
  168. params: List of weight and bias tensors
  169. has_biases: Whether biases are included
  170. num_layers: Number of GRU layers
  171. dropout: Dropout probability
  172. train: Training mode
  173. bidirectional: Whether to use bidirectional GRU
  174. batch_first: Whether batch dimension is first
  175. Returns:
  176. Tuple of (output, h_n)
  177. """
  178. params = gather_params(params, has_biases, False)
  179. hidden = list(hx.unbind(0))
  180. layer_fn = one_layer_while_loop_gru
  181. out, final_hiddens = _rnn_helper(
  182. input,
  183. hidden,
  184. params,
  185. has_biases,
  186. num_layers,
  187. dropout,
  188. train,
  189. bidirectional,
  190. batch_first,
  191. layer_fn,
  192. )
  193. return out, torch.stack(final_hiddens, 0)
  194. @contextlib.contextmanager
  195. def _register_rnn_while_loop_decomposition(
  196. rnn_op, rnn_impl
  197. ) -> Generator[None, None, None]:
  198. """
  199. Generic context manager for registering while_loop-based RNN decompositions.
  200. Args:
  201. rnn_op: The aten operation to patch (e.g., torch.ops.aten.lstm.input)
  202. rnn_impl: The while_loop-based implementation function
  203. Note:
  204. This is an internal helper. Use register_lstm_while_loop_decomposition()
  205. or register_gru_while_loop_decomposition() instead.
  206. """
  207. registry = global_decomposition_table["post_autograd"]
  208. # Save the original decomposition if it exists
  209. original_decomp = registry.get(rnn_op, None)
  210. # Save the original py_kernel if it exists
  211. original_py_kernel = rnn_op.py_kernels.get(
  212. torch._C.DispatchKey.CompositeImplicitAutograd, None
  213. )
  214. try:
  215. # Register our while_loop-based implementation
  216. registry[rnn_op] = rnn_impl
  217. rnn_op.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] = rnn_impl
  218. yield
  219. finally:
  220. # Restore the original decomposition
  221. if original_decomp is not None:
  222. registry[rnn_op] = original_decomp
  223. else:
  224. # If there was no original, remove our registration
  225. registry.pop(rnn_op, None)
  226. # Restore the original py_kernel
  227. if original_py_kernel is not None:
  228. rnn_op.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] = (
  229. original_py_kernel
  230. )
  231. else:
  232. # If there was no original, remove our registration
  233. rnn_op.py_kernels.pop(torch._C.DispatchKey.CompositeImplicitAutograd, None)
  234. @contextlib.contextmanager
  235. def register_lstm_while_loop_decomposition() -> Generator[None, None, None]:
  236. """
  237. Context manager that temporarily registers the while_loop-based LSTM decomposition.
  238. The while_loop-based decomposition is more suitable for export and graph-based
  239. execution, as it avoids Python control flow that cannot be captured in the graph.
  240. This should support dynamic sequence lengths, however as while_loop does not
  241. support Autograd yet, an ExportedProgram created with this will not be trainable.
  242. Usage::
  243. from torch.export._patches import register_lstm_while_loop_decomposition
  244. from torch.export import export
  245. with register_lstm_while_loop_decomposition():
  246. # Export your model with LSTM
  247. ep = export(model, (x, h0, c0))
  248. Note:
  249. This context manager temporarily modifies the global decomposition table
  250. and py_kernels registration. The original registrations are restored when
  251. exiting the context.
  252. """
  253. with _register_rnn_while_loop_decomposition(
  254. torch.ops.aten.lstm.input, lstm_while_loop_impl
  255. ):
  256. yield
  257. @contextlib.contextmanager
  258. def register_gru_while_loop_decomposition() -> Generator[None, None, None]:
  259. """
  260. Context manager that temporarily registers the while_loop-based GRU decomposition.
  261. The while_loop-based decomposition is more suitable for export and graph-based
  262. execution, as it avoids Python control flow that cannot be captured in the graph.
  263. This should support dynamic sequence lengths, however as while_loop does not
  264. support Autograd yet, an ExportedProgram created with this will not be trainable.
  265. Usage::
  266. from torch.export._patches import register_gru_while_loop_decomposition
  267. from torch.export import export
  268. with register_gru_while_loop_decomposition():
  269. # Export your model with GRU
  270. ep = export(model, (x, h0))
  271. Note:
  272. This context manager temporarily modifies the global decomposition table
  273. and py_kernels registration. The original registrations are restored when
  274. exiting the context.
  275. """
  276. with _register_rnn_while_loop_decomposition(
  277. torch.ops.aten.gru.input, gru_while_loop_impl
  278. ):
  279. yield