| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- import contextlib
- from collections.abc import Generator
- import torch
- from torch._decomp import global_decomposition_table
- from torch._decomp.decompositions import _rnn_helper, gather_params, gru_cell, lstm_cell
- from torch._higher_order_ops.while_loop import while_loop
- def one_layer_while_loop_lstm(inp, hidden, params, has_biases, reverse=False):
- """
- 1 layer fn for while loop LSTM
- Args:
- inp: Input tensor of shape (seq_len, batch, input_size)
- hidden: Tuple of (hx, cx) hidden states
- params: List of weight and bias tensors
- has_biases: Whether biases are included
- reverse: Whether to process sequence in reverse
- Returns:
- Tuple of (output, (final_hx, final_cx))
- """
- ih_weight = params[0]
- hh_weight = params[1]
- ih_bias = params[2] if has_biases else None
- hh_bias = params[3] if has_biases else None
- hr_weight = (
- params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
- )
- hx = hidden[0].unsqueeze(0)
- cx = hidden[1].unsqueeze(0)
- precomputed_input = torch.nn.functional.linear(inp, ih_weight, ih_bias)
- precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
- # while loop rewrite
- step_output = torch.empty(
- precomputed_input.size(0),
- *tuple(hx.shape[1:]),
- dtype=hx.dtype,
- device=hx.device,
- )
- def cond_fn(i, out, hx, cx):
- return i < precomputed_input.size(0)
- def body_fn(idx, out, hx, cx):
- # Extract the integer value from idx and constrain it for data-dependent indexing
- i = idx.item()
- torch._check_is_size(i)
- torch._check_is_size(i, max=precomputed_input.size(0) - 1)
- hx, cx = lstm_cell(
- precomputed_input[i], hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2
- )
- out = out.clone()
- # Squeeze the first dimension before storing (lstm_cell preserves the unsqueezed dim)
- out[i] = hx.squeeze(0)
- return idx + 1, out, hx, cx
- cnt = torch.tensor(0, dtype=torch.int64)
- _, out, final_hx, final_cx = while_loop(
- cond_fn, body_fn, [cnt, step_output, hx, cx]
- )
- if reverse:
- out = out.flip(0)
- # Use squeeze(1) to match original implementation
- return out, (final_hx.squeeze(1), final_cx.squeeze(1))
- def lstm_while_loop_impl(
- input,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- ):
- """
- LSTM implementation using while_loop for export compatibility.
- This is a drop-in replacement for the default LSTM decomposition that uses
- while_loop instead of Python loops, making it more suitable for torch.export.
- Args:
- input: Input tensor
- hx: Tuple of (h0, c0) hidden states
- params: List of weight and bias tensors
- has_biases: Whether biases are included
- num_layers: Number of LSTM layers
- dropout: Dropout probability
- train: Training mode
- bidirectional: Whether to use bidirectional LSTM
- batch_first: Whether batch dimension is first
- Returns:
- Tuple of (output, h_n, c_n)
- """
- if len(hx) != 2:
- raise AssertionError("lstm expects two hidden states")
- params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
- hidden = list(zip(hx[0], hx[1]))
- layer_fn = one_layer_while_loop_lstm
- out, final_hiddens = _rnn_helper(
- input,
- hidden,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- layer_fn,
- )
- final_hiddens = list(zip(*final_hiddens))
- return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
- def one_layer_while_loop_gru(inp, hidden, params, has_biases, reverse=False):
- """
- 1 layer fn for while loop GRU
- Args:
- inp: Input tensor of shape (seq_len, batch, input_size)
- hidden: Hidden state tensor
- params: List of weight and bias tensors
- has_biases: Whether biases are included
- reverse: Whether to process sequence in reverse
- Returns:
- Tuple of (output, final_hidden)
- """
- ih_weight = params[0]
- hh_weight = params[1]
- ih_bias = params[2] if has_biases else None
- hh_bias = params[3] if has_biases else None
- precomputed_input = torch.nn.functional.linear(inp, ih_weight, ih_bias)
- precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
- cur_hidden = hidden.unsqueeze(0)
- # while loop rewrite
- step_output = torch.empty(
- precomputed_input.size(0),
- *tuple(cur_hidden.shape[1:]),
- dtype=cur_hidden.dtype,
- device=cur_hidden.device,
- )
- def cond_fn(i, out, cur_hidden):
- return i < precomputed_input.size(0)
- def body_fn(idx, out, cur_hidden):
- # Extract the integer value from idx and constrain it for data-dependent indexing
- i = idx.item()
- torch._check_is_size(i)
- torch._check_is_size(i, max=precomputed_input.size(0) - 1)
- cur_hidden = gru_cell(
- precomputed_input[i], cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias
- )
- out = out.clone()
- out[i] = cur_hidden.squeeze(0)
- return idx + 1, out, cur_hidden
- cnt = torch.tensor(0, dtype=torch.int64)
- _, out, final_hidden = while_loop(cond_fn, body_fn, [cnt, step_output, cur_hidden])
- if reverse:
- out = out.flip(0)
- return out, final_hidden.squeeze(0)
- def gru_while_loop_impl(
- input,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- ):
- """
- GRU implementation using while_loop for export compatibility.
- This is a drop-in replacement for the default GRU decomposition that uses
- while_loop instead of Python loops, making it more suitable for torch.export.
- Args:
- input: Input tensor
- hx: Hidden state tensor
- params: List of weight and bias tensors
- has_biases: Whether biases are included
- num_layers: Number of GRU layers
- dropout: Dropout probability
- train: Training mode
- bidirectional: Whether to use bidirectional GRU
- batch_first: Whether batch dimension is first
- Returns:
- Tuple of (output, h_n)
- """
- params = gather_params(params, has_biases, False)
- hidden = list(hx.unbind(0))
- layer_fn = one_layer_while_loop_gru
- out, final_hiddens = _rnn_helper(
- input,
- hidden,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- layer_fn,
- )
- return out, torch.stack(final_hiddens, 0)
- @contextlib.contextmanager
- def _register_rnn_while_loop_decomposition(
- rnn_op, rnn_impl
- ) -> Generator[None, None, None]:
- """
- Generic context manager for registering while_loop-based RNN decompositions.
- Args:
- rnn_op: The aten operation to patch (e.g., torch.ops.aten.lstm.input)
- rnn_impl: The while_loop-based implementation function
- Note:
- This is an internal helper. Use register_lstm_while_loop_decomposition()
- or register_gru_while_loop_decomposition() instead.
- """
- registry = global_decomposition_table["post_autograd"]
- # Save the original decomposition if it exists
- original_decomp = registry.get(rnn_op, None)
- # Save the original py_kernel if it exists
- original_py_kernel = rnn_op.py_kernels.get(
- torch._C.DispatchKey.CompositeImplicitAutograd, None
- )
- try:
- # Register our while_loop-based implementation
- registry[rnn_op] = rnn_impl
- rnn_op.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] = rnn_impl
- yield
- finally:
- # Restore the original decomposition
- if original_decomp is not None:
- registry[rnn_op] = original_decomp
- else:
- # If there was no original, remove our registration
- registry.pop(rnn_op, None)
- # Restore the original py_kernel
- if original_py_kernel is not None:
- rnn_op.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] = (
- original_py_kernel
- )
- else:
- # If there was no original, remove our registration
- rnn_op.py_kernels.pop(torch._C.DispatchKey.CompositeImplicitAutograd, None)
- @contextlib.contextmanager
- def register_lstm_while_loop_decomposition() -> Generator[None, None, None]:
- """
- Context manager that temporarily registers the while_loop-based LSTM decomposition.
- The while_loop-based decomposition is more suitable for export and graph-based
- execution, as it avoids Python control flow that cannot be captured in the graph.
- This should support dynamic sequence lengths, however as while_loop does not
- support Autograd yet, an ExportedProgram created with this will not be trainable.
- Usage::
- from torch.export._patches import register_lstm_while_loop_decomposition
- from torch.export import export
- with register_lstm_while_loop_decomposition():
- # Export your model with LSTM
- ep = export(model, (x, h0, c0))
- Note:
- This context manager temporarily modifies the global decomposition table
- and py_kernels registration. The original registrations are restored when
- exiting the context.
- """
- with _register_rnn_while_loop_decomposition(
- torch.ops.aten.lstm.input, lstm_while_loop_impl
- ):
- yield
- @contextlib.contextmanager
- def register_gru_while_loop_decomposition() -> Generator[None, None, None]:
- """
- Context manager that temporarily registers the while_loop-based GRU decomposition.
- The while_loop-based decomposition is more suitable for export and graph-based
- execution, as it avoids Python control flow that cannot be captured in the graph.
- This should support dynamic sequence lengths, however as while_loop does not
- support Autograd yet, an ExportedProgram created with this will not be trainable.
- Usage::
- from torch.export._patches import register_gru_while_loop_decomposition
- from torch.export import export
- with register_gru_while_loop_decomposition():
- # Export your model with GRU
- ep = export(model, (x, h0))
- Note:
- This context manager temporarily modifies the global decomposition table
- and py_kernels registration. The original registrations are restored when
- exiting the context.
- """
- with _register_rnn_while_loop_decomposition(
- torch.ops.aten.gru.input, gru_while_loop_impl
- ):
- yield
|