| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886 |
- # mypy: allow-untyped-decorators
- # mypy: allow-untyped-defs
- import math
- import numbers
- import warnings
- import weakref
- from typing import overload
- from typing_extensions import deprecated
- import torch
- from torch import _VF, Tensor
- from torch.nn import init
- from torch.nn.parameter import Parameter
- from torch.nn.utils.rnn import PackedSequence
- from .module import Module
- __all__ = [
- "RNNBase",
- "RNN",
- "LSTM",
- "GRU",
- "RNNCellBase",
- "RNNCell",
- "LSTMCell",
- "GRUCell",
- ]
- _rnn_impls = {
- "RNN_TANH": _VF.rnn_tanh,
- "RNN_RELU": _VF.rnn_relu,
- }
- def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
- return tensor.index_select(dim, permutation)
- @deprecated(
- "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead",
- category=FutureWarning,
- )
- def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
- return _apply_permutation(tensor, permutation, dim)
- class RNNBase(Module):
- r"""Base class for RNN modules (RNN, LSTM, GRU).
- Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization
- and utility methods for parameter storage management.
- .. note::
- The forward method is not implemented by the RNNBase class.
- .. note::
- LSTM and GRU classes override some methods implemented by RNNBase.
- """
- __constants__ = [
- "mode",
- "input_size",
- "hidden_size",
- "num_layers",
- "bias",
- "batch_first",
- "dropout",
- "bidirectional",
- "proj_size",
- ]
- __jit_unused_properties__ = ["all_weights"]
- mode: str
- input_size: int
- hidden_size: int
- num_layers: int
- bias: bool
- batch_first: bool
- dropout: float
- bidirectional: bool
- proj_size: int
- def __init__(
- self,
- mode: str,
- input_size: int,
- hidden_size: int,
- num_layers: int = 1,
- bias: bool = True,
- batch_first: bool = False,
- dropout: float = 0.0,
- bidirectional: bool = False,
- proj_size: int = 0,
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.mode = mode
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.num_layers = num_layers
- self.bias = bias
- self.batch_first = batch_first
- self.dropout = float(dropout)
- self.bidirectional = bidirectional
- self.proj_size = proj_size
- self._flat_weight_refs: list[weakref.ReferenceType[Parameter] | None] = []
- num_directions = 2 if bidirectional else 1
- if (
- not isinstance(dropout, numbers.Number)
- or not 0 <= dropout <= 1
- or isinstance(dropout, bool)
- ):
- raise ValueError(
- "dropout should be a number in range [0, 1] "
- "representing the probability of an element being "
- "zeroed"
- )
- if dropout > 0 and num_layers == 1:
- warnings.warn(
- "dropout option adds dropout after all but last "
- "recurrent layer, so non-zero dropout expects "
- f"num_layers greater than 1, but got dropout={dropout} and "
- f"num_layers={num_layers}",
- stacklevel=2,
- )
- if not isinstance(bias, bool):
- raise TypeError(f"bias should be of type bool, got: {type(bias).__name__}")
- if not isinstance(batch_first, bool):
- raise TypeError(
- f"batch_first should be of type bool, got: {type(batch_first).__name__}"
- )
- if not isinstance(input_size, int):
- raise TypeError(
- f"input_size should be of type int, got: {type(input_size).__name__}"
- )
- if input_size <= 0:
- raise ValueError("input_size must be greater than zero")
- if not isinstance(hidden_size, int):
- raise TypeError(
- f"hidden_size should be of type int, got: {type(hidden_size).__name__}"
- )
- if hidden_size <= 0:
- raise ValueError("hidden_size must be greater than zero")
- if num_layers <= 0:
- raise ValueError("num_layers must be greater than zero")
- if proj_size < 0:
- raise ValueError(
- "proj_size should be a positive integer or zero to disable projections"
- )
- if proj_size >= hidden_size:
- raise ValueError("proj_size has to be smaller than hidden_size")
- if mode == "LSTM":
- gate_size = 4 * hidden_size
- elif mode == "GRU":
- gate_size = 3 * hidden_size
- elif mode == "RNN_TANH":
- gate_size = hidden_size
- elif mode == "RNN_RELU":
- gate_size = hidden_size
- else:
- raise ValueError("Unrecognized RNN mode: " + mode)
- self._flat_weights_names = []
- self._all_weights = []
- for layer in range(num_layers):
- for direction in range(num_directions):
- real_hidden_size = proj_size if proj_size > 0 else hidden_size
- layer_input_size = (
- input_size if layer == 0 else real_hidden_size * num_directions
- )
- w_ih = Parameter(
- torch.empty((gate_size, layer_input_size), **factory_kwargs)
- )
- w_hh = Parameter(
- torch.empty((gate_size, real_hidden_size), **factory_kwargs)
- )
- b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
- # Second bias vector included for CuDNN compatibility. Only one
- # bias vector is needed in standard definition.
- b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
- layer_params: tuple[Tensor, ...] = ()
- if self.proj_size == 0:
- if bias:
- layer_params = (w_ih, w_hh, b_ih, b_hh)
- else:
- layer_params = (w_ih, w_hh)
- else:
- w_hr = Parameter(
- torch.empty((proj_size, hidden_size), **factory_kwargs)
- )
- if bias:
- layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
- else:
- layer_params = (w_ih, w_hh, w_hr)
- suffix = "_reverse" if direction == 1 else ""
- param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"]
- if bias:
- param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"]
- if self.proj_size > 0:
- param_names += ["weight_hr_l{}{}"]
- param_names = [x.format(layer, suffix) for x in param_names]
- for name, param in zip(param_names, layer_params, strict=True):
- setattr(self, name, param)
- self._flat_weights_names.extend(param_names)
- self._all_weights.append(param_names)
- self._init_flat_weights()
- self.reset_parameters()
- def _init_flat_weights(self) -> None:
- self._flat_weights = [
- getattr(self, wn) if hasattr(self, wn) else None
- for wn in self._flat_weights_names
- ]
- self._flat_weight_refs = [
- weakref.ref(w) if w is not None else None for w in self._flat_weights
- ]
- self.flatten_parameters()
- def __setattr__(self, attr, value) -> None:
- if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names:
- # keep self._flat_weights up to date if you do self.weight = ...
- idx = self._flat_weights_names.index(attr)
- self._flat_weights[idx] = value
- super().__setattr__(attr, value)
- def flatten_parameters(self) -> None:
- """Reset parameter data pointer so that they can use faster code paths.
- Right now, this works only if the module is on the GPU and cuDNN is enabled.
- Otherwise, it's a no-op.
- """
- # Short-circuits if _flat_weights is only partially instantiated
- if len(self._flat_weights) != len(self._flat_weights_names):
- return
- for w in self._flat_weights:
- if not isinstance(w, Tensor):
- return
- # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
- # or the tensors in _flat_weights are of different dtypes
- first_fw = self._flat_weights[0] # type: ignore[union-attr]
- dtype = first_fw.dtype # type: ignore[union-attr]
- for fw in self._flat_weights:
- if (
- not isinstance(fw, Tensor)
- or fw.dtype != dtype
- or not fw.is_cuda
- or not torch.backends.cudnn.is_acceptable(fw)
- ):
- return
- # If any parameters alias, we fall back to the slower, copying code path. This is
- # a sufficient check, because overlapping parameter buffers that don't completely
- # alias would break the assumptions of the uniqueness check in
- # Module.named_parameters().
- unique_data_ptrs = {
- p.data_ptr() # type: ignore[union-attr]
- for p in self._flat_weights
- }
- if len(unique_data_ptrs) != len(self._flat_weights):
- return
- with torch.cuda.device_of(first_fw):
- import torch.backends.cudnn.rnn as rnn
- # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
- # an inplace operation on self._flat_weights
- with torch.no_grad():
- if torch._use_cudnn_rnn_flatten_weight():
- num_weights = 4 if self.bias else 2
- if self.proj_size > 0:
- num_weights += 1
- torch._cudnn_rnn_flatten_weight(
- self._flat_weights, # type: ignore[arg-type]
- num_weights,
- self.input_size,
- rnn.get_cudnn_mode(self.mode),
- self.hidden_size,
- self.proj_size,
- self.num_layers,
- self.batch_first,
- bool(self.bidirectional),
- )
- def _apply(self, fn, recurse=True):
- self._flat_weight_refs = []
- ret = super()._apply(fn, recurse)
- # Resets _flat_weights
- # Note: be v. careful before removing this, as 3rd party device types
- # likely rely on this behavior to properly .to() modules like LSTM.
- self._init_flat_weights()
- return ret
- def reset_parameters(self) -> None:
- stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
- for weight in self.parameters():
- init.uniform_(weight, -stdv, stdv)
- def check_input(self, input: Tensor, batch_sizes: Tensor | None) -> None:
- if not torch.jit.is_scripting():
- if (
- input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr]
- and not torch._C._is_any_autocast_enabled()
- ):
- raise ValueError(
- f"RNN input dtype ({input.dtype}) does not match weight dtype ({self._flat_weights[0].dtype}). " # type: ignore[union-attr]
- f"Convert input: input.to({self._flat_weights[0].dtype}), or convert model: model.to({input.dtype})" # type: ignore[union-attr]
- )
- expected_input_dim = 2 if batch_sizes is not None else 3
- if input.dim() != expected_input_dim:
- raise RuntimeError(
- f"input must have {expected_input_dim} dimensions, got {input.dim()}"
- )
- if self.input_size != input.size(-1):
- raise RuntimeError(
- f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}"
- )
- def get_expected_hidden_size(
- self, input: Tensor, batch_sizes: Tensor | None
- ) -> tuple[int, int, int]:
- if batch_sizes is not None:
- mini_batch = int(batch_sizes[0])
- else:
- mini_batch = input.size(0) if self.batch_first else input.size(1)
- num_directions = 2 if self.bidirectional else 1
- if self.proj_size > 0:
- expected_hidden_size = (
- self.num_layers * num_directions,
- mini_batch,
- self.proj_size,
- )
- else:
- expected_hidden_size = (
- self.num_layers * num_directions,
- mini_batch,
- self.hidden_size,
- )
- return expected_hidden_size
- def check_hidden_size(
- self,
- hx: Tensor,
- expected_hidden_size: tuple[int, int, int],
- msg: str = "Expected hidden size {}, got {}",
- ) -> None:
- if hx.size() != expected_hidden_size:
- raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
- def _weights_have_changed(self):
- # Returns True if the weight tensors have changed since the last forward pass.
- # This is the case when used with torch.func.functional_call(), for example.
- weights_changed = False
- for ref, name in zip(
- self._flat_weight_refs, self._flat_weights_names, strict=True
- ):
- weight = getattr(self, name) if hasattr(self, name) else None
- if weight is not None and ref is not None and ref() is not weight:
- weights_changed = True
- break
- return weights_changed
- def check_forward_args(
- self, input: Tensor, hidden: Tensor, batch_sizes: Tensor | None
- ) -> None:
- self.check_input(input, batch_sizes)
- expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
- self.check_hidden_size(hidden, expected_hidden_size)
- def permute_hidden(self, hx: Tensor, permutation: Tensor | None):
- if permutation is None:
- return hx
- return _apply_permutation(hx, permutation)
- def extra_repr(self) -> str:
- s = "{input_size}, {hidden_size}"
- if self.proj_size != 0:
- s += ", proj_size={proj_size}"
- if self.num_layers != 1:
- s += ", num_layers={num_layers}"
- if self.bias is not True:
- s += ", bias={bias}"
- if self.batch_first is not False:
- s += ", batch_first={batch_first}"
- if self.dropout != 0:
- s += ", dropout={dropout}"
- if self.bidirectional is not False:
- s += ", bidirectional={bidirectional}"
- return s.format(**self.__dict__)
- def _update_flat_weights(self) -> None:
- if not torch.jit.is_scripting():
- if self._weights_have_changed():
- self._init_flat_weights()
- def __getstate__(self):
- # If weights have been changed, update the _flat_weights in __getstate__ here.
- self._update_flat_weights()
- # Don't serialize the weight references.
- state = self.__dict__.copy()
- del state["_flat_weight_refs"]
- return state
- def __setstate__(self, d):
- super().__setstate__(d)
- if "all_weights" in d:
- self._all_weights = d["all_weights"]
- # In PyTorch 1.8 we added a proj_size member variable to LSTM.
- # LSTMs that were serialized via torch.save(module) before PyTorch 1.8
- # don't have it, so to preserve compatibility we set proj_size here.
- if "proj_size" not in d:
- self.proj_size = 0
- if not isinstance(self._all_weights[0][0], str):
- num_layers = self.num_layers
- num_directions = 2 if self.bidirectional else 1
- self._flat_weights_names = []
- self._all_weights = []
- for layer in range(num_layers):
- for direction in range(num_directions):
- suffix = "_reverse" if direction == 1 else ""
- weights = [
- "weight_ih_l{}{}",
- "weight_hh_l{}{}",
- "bias_ih_l{}{}",
- "bias_hh_l{}{}",
- "weight_hr_l{}{}",
- ]
- weights = [x.format(layer, suffix) for x in weights]
- if self.bias:
- if self.proj_size > 0:
- self._all_weights += [weights]
- self._flat_weights_names.extend(weights)
- else:
- self._all_weights += [weights[:4]]
- self._flat_weights_names.extend(weights[:4])
- else:
- if self.proj_size > 0:
- self._all_weights += [weights[:2]] + [weights[-1:]]
- self._flat_weights_names.extend(
- weights[:2] + [weights[-1:]]
- )
- else:
- self._all_weights += [weights[:2]]
- self._flat_weights_names.extend(weights[:2])
- self._flat_weights = [
- getattr(self, wn) if hasattr(self, wn) else None
- for wn in self._flat_weights_names
- ]
- self._flat_weight_refs = [
- weakref.ref(w) if w is not None else None for w in self._flat_weights
- ]
- @property
- def all_weights(self) -> list[list[Parameter]]:
- return [
- [getattr(self, weight) for weight in weights]
- for weights in self._all_weights
- ]
- def _replicate_for_data_parallel(self):
- replica = super()._replicate_for_data_parallel()
- # Need to copy these caches, otherwise the replica will share the same
- # flat weights list.
- replica._flat_weights = replica._flat_weights[:]
- replica._flat_weights_names = replica._flat_weights_names[:]
- return replica
- class RNN(RNNBase):
- r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None)
- Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}`
- non-linearity to an input sequence. For each element in the input sequence,
- each layer computes the following function:
- .. math::
- h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
- where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
- the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
- previous layer at time `t-1` or the initial hidden state at time `0`.
- If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
- .. code-block:: python
- # Efficient implementation equivalent to the following with bidirectional=False
- rnn = nn.RNN(input_size, hidden_size, num_layers)
- params = dict(rnn.named_parameters())
- def forward(x, hx=None, batch_first=False):
- if batch_first:
- x = x.transpose(0, 1)
- seq_len, batch_size, _ = x.size()
- if hx is None:
- hx = torch.zeros(rnn.num_layers, batch_size, rnn.hidden_size)
- h_t_minus_1 = hx.clone()
- h_t = hx.clone()
- output = []
- for t in range(seq_len):
- for layer in range(rnn.num_layers):
- input_t = x[t] if layer == 0 else h_t[layer - 1]
- h_t[layer] = torch.tanh(
- input_t @ params[f"weight_ih_l{layer}"].T
- + h_t_minus_1[layer] @ params[f"weight_hh_l{layer}"].T
- + params[f"bias_hh_l{layer}"]
- + params[f"bias_ih_l{layer}"]
- )
- output.append(h_t[-1].clone())
- h_t_minus_1 = h_t.clone()
- output = torch.stack(output)
- if batch_first:
- output = output.transpose(0, 1)
- return output, h_t
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
- would mean stacking two RNNs together to form a `stacked RNN`,
- with the second RNN taking in outputs of the first RNN and
- computing the final results. Default: 1
- nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
- bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
- Default: ``True``
- batch_first: If ``True``, then the input and output tensors are provided
- as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
- Note that this does not apply to hidden or cell states. See the
- Inputs/Outputs sections below for details. Default: ``False``
- dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
- RNN layer except the last layer, with dropout probability equal to
- :attr:`dropout`. Default: 0
- bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
- Inputs: input, hx
- * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
- :math:`(L, N, H_{in})` when ``batch_first=False`` or
- :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
- the input sequence. The input can also be a packed variable length sequence.
- See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
- :func:`torch.nn.utils.rnn.pack_sequence` for details.
- * **hx**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
- state for the input sequence batch. Defaults to zeros if not provided.
- where:
- .. math::
- \begin{aligned}
- N ={} & \text{batch size} \\
- L ={} & \text{sequence length} \\
- D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
- H_{in} ={} & \text{input\_size} \\
- H_{out} ={} & \text{hidden\_size}
- \end{aligned}
- Outputs: output, h_n
- * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
- :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
- :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
- `(h_t)` from the last layer of the RNN, for each `t`. If a
- :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
- will also be a packed sequence.
- * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
- for each element in the batch.
- Attributes:
- weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
- of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
- `(hidden_size, num_directions * hidden_size)`
- weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
- of shape `(hidden_size, hidden_size)`
- bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
- of shape `(hidden_size)`
- bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
- of shape `(hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- .. note::
- For bidirectional RNNs, forward and backward are directions 0 and 1 respectively.
- Example of splitting the output layers when ``batch_first=False``:
- ``output.view(seq_len, batch, num_directions, hidden_size)``.
- .. note::
- ``batch_first`` argument is ignored for unbatched inputs.
- .. include:: ../cudnn_rnn_determinism.rst
- .. include:: ../cudnn_persistent_rnn.rst
- Examples::
- >>> rnn = nn.RNN(10, 20, 2)
- >>> input = torch.randn(5, 3, 10)
- >>> h0 = torch.randn(2, 3, 20)
- >>> output, hn = rnn(input, h0)
- """
- @overload
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- num_layers: int = 1,
- nonlinearity: str = "tanh",
- bias: bool = True,
- batch_first: bool = False,
- dropout: float = 0.0,
- bidirectional: bool = False,
- device=None,
- dtype=None,
- ) -> None: ...
- @overload
- def __init__(self, *args, **kwargs) -> None: ...
- def __init__(self, *args, **kwargs):
- if "proj_size" in kwargs:
- raise ValueError(
- "proj_size argument is only supported for LSTM, not RNN or GRU"
- )
- if len(args) > 3:
- self.nonlinearity = args[3]
- args = args[:3] + args[4:]
- else:
- self.nonlinearity = kwargs.pop("nonlinearity", "tanh")
- if self.nonlinearity == "tanh":
- mode = "RNN_TANH"
- elif self.nonlinearity == "relu":
- mode = "RNN_RELU"
- else:
- raise ValueError(
- f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'."
- )
- super().__init__(mode, *args, **kwargs)
- @overload
- @torch._jit_internal._overload_method # noqa: F811
- # pyrefly: ignore [bad-override]
- def forward(
- self,
- input: Tensor,
- hx: Tensor | None = None,
- # pyrefly: ignore [bad-return]
- ) -> tuple[Tensor, Tensor]:
- pass
- @overload
- @torch._jit_internal._overload_method # noqa: F811
- def forward(
- self,
- input: PackedSequence,
- hx: Tensor | None = None,
- # pyrefly: ignore [bad-return]
- ) -> tuple[PackedSequence, Tensor]:
- pass
- def forward(self, input, hx=None): # noqa: F811
- """
- Runs the forward pass.
- """
- self._update_flat_weights()
- num_directions = 2 if self.bidirectional else 1
- orig_input = input
- if isinstance(orig_input, PackedSequence):
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = batch_sizes[0]
- # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate
- if hx is None:
- hx = torch.zeros(
- self.num_layers * num_directions,
- max_batch_size,
- self.hidden_size,
- dtype=input.dtype,
- device=input.device,
- )
- else:
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- else:
- batch_sizes = None
- if input.dim() not in (2, 3):
- raise ValueError(
- f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead"
- )
- is_batched = input.dim() == 3
- batch_dim = 0 if self.batch_first else 1
- if not is_batched:
- input = input.unsqueeze(batch_dim)
- if hx is not None:
- if hx.dim() != 2:
- raise RuntimeError(
- f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
- )
- hx = hx.unsqueeze(1)
- else:
- if hx is not None and hx.dim() != 3:
- raise RuntimeError(
- f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
- )
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
- if hx is None:
- hx = torch.zeros(
- self.num_layers * num_directions,
- max_batch_size,
- self.hidden_size,
- dtype=input.dtype,
- device=input.device,
- )
- else:
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- if hx is None:
- raise AssertionError("hx must not be None")
- self.check_forward_args(input, hx, batch_sizes)
- if self.mode != "RNN_TANH" and self.mode != "RNN_RELU":
- raise AssertionError(f"mode must be RNN_TANH or RNN_RELU, got {self.mode}")
- if batch_sizes is None:
- if self.mode == "RNN_TANH":
- # pyrefly: ignore [no-matching-overload]
- result = _VF.rnn_tanh(
- input,
- hx,
- self._flat_weights, # type: ignore[arg-type]
- self.bias,
- self.num_layers,
- self.dropout,
- self.training,
- self.bidirectional,
- self.batch_first,
- )
- else:
- # pyrefly: ignore [no-matching-overload]
- result = _VF.rnn_relu(
- input,
- hx,
- self._flat_weights, # type: ignore[arg-type]
- self.bias,
- self.num_layers,
- self.dropout,
- self.training,
- self.bidirectional,
- self.batch_first,
- )
- else:
- if self.mode == "RNN_TANH":
- # pyrefly: ignore [no-matching-overload]
- result = _VF.rnn_tanh(
- input,
- batch_sizes,
- hx,
- self._flat_weights, # type: ignore[arg-type]
- self.bias,
- self.num_layers,
- self.dropout,
- self.training,
- self.bidirectional,
- )
- else:
- # pyrefly: ignore [no-matching-overload]
- result = _VF.rnn_relu(
- input,
- batch_sizes,
- hx,
- self._flat_weights, # type: ignore[arg-type]
- self.bias,
- self.num_layers,
- self.dropout,
- self.training,
- self.bidirectional,
- )
- output = result[0]
- hidden = result[1]
- if isinstance(orig_input, PackedSequence):
- output_packed = PackedSequence(
- output,
- # pyrefly: ignore [bad-argument-type]
- batch_sizes,
- sorted_indices,
- unsorted_indices,
- )
- return output_packed, self.permute_hidden(hidden, unsorted_indices)
- if not is_batched: # type: ignore[possibly-undefined]
- output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
- hidden = hidden.squeeze(1)
- return output, self.permute_hidden(hidden, unsorted_indices)
- # XXX: LSTM and GRU implementation is different from RNNBase, this is because:
- # 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
- # its current state could not support the python Union Type or Any Type
- # 2. TorchScript static typing does not allow a Function or Callable type in
- # Dict values, so we have to separately call _VF instead of using _rnn_impls
- # 3. This is temporary only and in the transition state that we want to make it
- # on time for the release
- #
- # More discussion details in https://github.com/pytorch/pytorch/pull/23266
- #
- # TODO: remove the overriding implementations for LSTM and GRU when TorchScript
- # support expressing these two modules generally.
- class LSTM(RNNBase):
- r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None)
- Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence.
- For each element in the input sequence, each layer computes the following
- function:
- .. math::
- \begin{array}{ll} \\
- i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
- f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
- g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
- o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
- c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
- h_t = o_t \odot \tanh(c_t) \\
- \end{array}
- where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
- state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
- is the hidden state of the layer at time `t-1` or the initial hidden
- state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
- :math:`o_t` are the input, forget, cell, and output gates, respectively.
- :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
- In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
- (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
- dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
- variable which is :math:`0` with probability :attr:`dropout`.
- If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes
- the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from
- ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).
- Second, the output hidden state of each layer will be multiplied by a learnable projection
- matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output
- of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact
- dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
- would mean stacking two LSTMs together to form a `stacked LSTM`,
- with the second LSTM taking in outputs of the first LSTM and
- computing the final results. Default: 1
- bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
- Default: ``True``
- batch_first: If ``True``, then the input and output tensors are provided
- as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
- Note that this does not apply to hidden or cell states. See the
- Inputs/Outputs sections below for details. Default: ``False``
- dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
- LSTM layer except the last layer, with dropout probability equal to
- :attr:`dropout`. Default: 0
- bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
- proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
- Inputs: input, (h_0, c_0)
- * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
- :math:`(L, N, H_{in})` when ``batch_first=False`` or
- :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
- the input sequence. The input can also be a packed variable length sequence.
- See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
- :func:`torch.nn.utils.rnn.pack_sequence` for details.
- * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the
- initial hidden state for each element in the input sequence.
- Defaults to zeros if (h_0, c_0) is not provided.
- * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
- initial cell state for each element in the input sequence.
- Defaults to zeros if (h_0, c_0) is not provided.
- where:
- .. math::
- \begin{aligned}
- N ={} & \text{batch size} \\
- L ={} & \text{sequence length} \\
- D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
- H_{in} ={} & \text{input\_size} \\
- H_{cell} ={} & \text{hidden\_size} \\
- H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
- \end{aligned}
- Outputs: output, (h_n, c_n)
- * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
- :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
- :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
- `(h_t)` from the last layer of the LSTM, for each `t`. If a
- :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
- will also be a packed sequence. When ``bidirectional=True``, `output` will contain
- a concatenation of the forward and reverse hidden states at each time step in the sequence.
- * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the
- final hidden state for each element in the sequence. When ``bidirectional=True``,
- `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively.
- * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
- final cell state for each element in the sequence. When ``bidirectional=True``,
- `c_n` will contain a concatenation of the final forward and reverse cell states, respectively.
- Attributes:
- weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
- `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
- Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If
- ``proj_size > 0`` was specified, the shape will be
- `(4*hidden_size, num_directions * proj_size)` for `k > 0`
- weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
- `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0``
- was specified, the shape will be `(4*hidden_size, proj_size)`.
- bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
- `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
- bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
- `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
- weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer
- of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was
- specified.
- weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction.
- Only present when ``bidirectional=True``.
- weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction.
- Only present when ``bidirectional=True``.
- bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction.
- Only present when ``bidirectional=True``.
- bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction.
- Only present when ``bidirectional=True``.
- weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction.
- Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified.
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- .. note::
- For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively.
- Example of splitting the output layers when ``batch_first=False``:
- ``output.view(seq_len, batch, num_directions, hidden_size)``.
- .. note::
- For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the
- former contains the final forward and reverse hidden states, while the latter contains the
- final forward hidden state and the initial reverse hidden state.
- .. note::
- ``batch_first`` argument is ignored for unbatched inputs.
- .. note::
- ``proj_size`` should be smaller than ``hidden_size``.
- .. include:: ../cudnn_rnn_determinism.rst
- .. include:: ../cudnn_persistent_rnn.rst
- Examples::
- >>> rnn = nn.LSTM(10, 20, 2)
- >>> input = torch.randn(5, 3, 10)
- >>> h0 = torch.randn(2, 3, 20)
- >>> c0 = torch.randn(2, 3, 20)
- >>> output, (hn, cn) = rnn(input, (h0, c0))
- """
- @overload
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- num_layers: int = 1,
- bias: bool = True,
- batch_first: bool = False,
- dropout: float = 0.0,
- bidirectional: bool = False,
- proj_size: int = 0,
- device=None,
- dtype=None,
- ) -> None: ...
- @overload
- def __init__(self, *args, **kwargs) -> None: ...
- def __init__(self, *args, **kwargs):
- super().__init__("LSTM", *args, **kwargs)
- def get_expected_cell_size(
- self, input: Tensor, batch_sizes: Tensor | None
- ) -> tuple[int, int, int]:
- if batch_sizes is not None:
- mini_batch = int(batch_sizes[0])
- else:
- mini_batch = input.size(0) if self.batch_first else input.size(1)
- num_directions = 2 if self.bidirectional else 1
- expected_hidden_size = (
- self.num_layers * num_directions,
- mini_batch,
- self.hidden_size,
- )
- return expected_hidden_size
- # In the future, we should prevent mypy from applying contravariance rules here.
- # See torch/nn/modules/module.py::_forward_unimplemented
- # pyrefly: ignore [bad-override]
- def check_forward_args(
- self,
- input: Tensor,
- hidden: tuple[Tensor, Tensor], # type: ignore[override]
- batch_sizes: Tensor | None,
- ) -> None:
- self.check_input(input, batch_sizes)
- self.check_hidden_size(
- hidden[0],
- self.get_expected_hidden_size(input, batch_sizes),
- "Expected hidden[0] size {}, got {}",
- )
- self.check_hidden_size(
- hidden[1],
- self.get_expected_cell_size(input, batch_sizes),
- "Expected hidden[1] size {}, got {}",
- )
- # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
- def permute_hidden( # type: ignore[override]
- self,
- hx: tuple[Tensor, Tensor],
- permutation: Tensor | None,
- ) -> tuple[Tensor, Tensor]:
- if permutation is None:
- return hx
- return _apply_permutation(hx[0], permutation), _apply_permutation(
- hx[1], permutation
- )
- # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
- @overload # type: ignore[override]
- @torch._jit_internal._overload_method # noqa: F811
- # pyrefly: ignore [bad-override]
- def forward(
- self,
- input: Tensor,
- hx: tuple[Tensor, Tensor] | None = None,
- # pyrefly: ignore [bad-return]
- ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
- pass
- # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
- @overload
- @torch._jit_internal._overload_method # noqa: F811
- def forward(
- self,
- input: PackedSequence,
- hx: tuple[Tensor, Tensor] | None = None,
- # pyrefly: ignore [bad-return]
- ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
- pass
- def forward(self, input, hx=None): # noqa: F811
- self._update_flat_weights()
- orig_input = input
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- batch_sizes = None
- num_directions = 2 if self.bidirectional else 1
- real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
- if isinstance(orig_input, PackedSequence):
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = batch_sizes[0]
- if hx is None:
- h_zeros = torch.zeros(
- self.num_layers * num_directions,
- max_batch_size,
- real_hidden_size,
- dtype=input.dtype,
- device=input.device,
- )
- c_zeros = torch.zeros(
- self.num_layers * num_directions,
- max_batch_size,
- self.hidden_size,
- dtype=input.dtype,
- device=input.device,
- )
- hx = (h_zeros, c_zeros)
- else:
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- else:
- if input.dim() not in (2, 3):
- raise ValueError(
- f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead"
- )
- is_batched = input.dim() == 3
- batch_dim = 0 if self.batch_first else 1
- if not is_batched:
- input = input.unsqueeze(batch_dim)
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
- if hx is None:
- h_zeros = torch.zeros(
- self.num_layers * num_directions,
- max_batch_size,
- real_hidden_size,
- dtype=input.dtype,
- device=input.device,
- )
- c_zeros = torch.zeros(
- self.num_layers * num_directions,
- max_batch_size,
- self.hidden_size,
- dtype=input.dtype,
- device=input.device,
- )
- hx = (h_zeros, c_zeros)
- self.check_forward_args(input, hx, batch_sizes)
- else:
- if is_batched:
- if hx[0].dim() != 3 or hx[1].dim() != 3:
- msg = (
- "For batched 3-D input, hx and cx should "
- f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
- )
- raise RuntimeError(msg)
- else:
- if hx[0].dim() != 2 or hx[1].dim() != 2:
- msg = (
- "For unbatched 2-D input, hx and cx should "
- f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
- )
- raise RuntimeError(msg)
- hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- self.check_forward_args(input, hx, batch_sizes)
- hx = self.permute_hidden(hx, sorted_indices)
- if batch_sizes is None:
- # pyrefly: ignore [no-matching-overload]
- result = _VF.lstm(
- input,
- hx,
- self._flat_weights, # type: ignore[arg-type]
- self.bias,
- self.num_layers,
- self.dropout,
- self.training,
- self.bidirectional,
- self.batch_first,
- )
- else:
- # pyrefly: ignore [no-matching-overload]
- result = _VF.lstm(
- input,
- batch_sizes,
- hx,
- self._flat_weights, # type: ignore[arg-type]
- self.bias,
- self.num_layers,
- self.dropout,
- self.training,
- self.bidirectional,
- )
- output = result[0]
- hidden = result[1:]
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- if isinstance(orig_input, PackedSequence):
- output_packed = PackedSequence(
- output,
- # pyrefly: ignore [bad-argument-type]
- batch_sizes,
- sorted_indices,
- unsorted_indices,
- )
- return output_packed, self.permute_hidden(hidden, unsorted_indices)
- else:
- if not is_batched: # type: ignore[possibly-undefined]
- output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
- hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
- return output, self.permute_hidden(hidden, unsorted_indices)
- class GRU(RNNBase):
- r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None)
- Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
- For each element in the input sequence, each layer computes the following
- function:
- .. math::
- \begin{array}{ll}
- r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
- z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
- n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\
- h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
- \end{array}
- where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
- at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
- at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
- :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
- :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
- In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
- (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
- dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
- variable which is :math:`0` with probability :attr:`dropout`.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
- would mean stacking two GRUs together to form a `stacked GRU`,
- with the second GRU taking in outputs of the first GRU and
- computing the final results. Default: 1
- bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
- Default: ``True``
- batch_first: If ``True``, then the input and output tensors are provided
- as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
- Note that this does not apply to hidden or cell states. See the
- Inputs/Outputs sections below for details. Default: ``False``
- dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
- GRU layer except the last layer, with dropout probability equal to
- :attr:`dropout`. Default: 0
- bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
- Inputs: input, h_0
- * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
- :math:`(L, N, H_{in})` when ``batch_first=False`` or
- :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
- the input sequence. The input can also be a packed variable length sequence.
- See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
- :func:`torch.nn.utils.rnn.pack_sequence` for details.
- * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
- :math:`(D * \text{num\_layers}, N, H_{out})`
- containing the initial hidden state for the input sequence. Defaults to zeros if not provided.
- where:
- .. math::
- \begin{aligned}
- N ={} & \text{batch size} \\
- L ={} & \text{sequence length} \\
- D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
- H_{in} ={} & \text{input\_size} \\
- H_{out} ={} & \text{hidden\_size}
- \end{aligned}
- Outputs: output, h_n
- * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
- :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
- :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
- `(h_t)` from the last layer of the GRU, for each `t`. If a
- :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
- will also be a packed sequence.
- * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
- for the input sequence.
- Attributes:
- weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
- (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
- Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
- weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
- (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
- bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
- (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
- bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
- (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- .. note::
- For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.
- Example of splitting the output layers when ``batch_first=False``:
- ``output.view(seq_len, batch, num_directions, hidden_size)``.
- .. note::
- ``batch_first`` argument is ignored for unbatched inputs.
- .. note::
- The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks.
- In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the
- previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix
- `W` and addition of bias:
- .. math::
- \begin{aligned}
- n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn})
- \end{aligned}
- This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}`
- .. math::
- \begin{aligned}
- n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn}))
- \end{aligned}
- This implementation differs on purpose for efficiency.
- .. include:: ../cudnn_persistent_rnn.rst
- Examples::
- >>> rnn = nn.GRU(10, 20, 2)
- >>> input = torch.randn(5, 3, 10)
- >>> h0 = torch.randn(2, 3, 20)
- >>> output, hn = rnn(input, h0)
- """
- @overload
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- num_layers: int = 1,
- bias: bool = True,
- batch_first: bool = False,
- dropout: float = 0.0,
- bidirectional: bool = False,
- device=None,
- dtype=None,
- ) -> None: ...
- @overload
- def __init__(self, *args, **kwargs) -> None: ...
- def __init__(self, *args, **kwargs):
- if "proj_size" in kwargs:
- raise ValueError(
- "proj_size argument is only supported for LSTM, not RNN or GRU"
- )
- super().__init__("GRU", *args, **kwargs)
- @overload # type: ignore[override]
- @torch._jit_internal._overload_method # noqa: F811
- # pyrefly: ignore [bad-override]
- def forward(
- self,
- input: Tensor,
- hx: Tensor | None = None,
- # pyrefly: ignore [bad-return]
- ) -> tuple[Tensor, Tensor]: # noqa: F811
- pass
- @overload
- @torch._jit_internal._overload_method # noqa: F811
- def forward(
- self,
- input: PackedSequence,
- hx: Tensor | None = None,
- # pyrefly: ignore [bad-return]
- ) -> tuple[PackedSequence, Tensor]: # noqa: F811
- pass
- def forward(self, input, hx=None): # noqa: F811
- self._update_flat_weights()
- orig_input = input
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- if isinstance(orig_input, PackedSequence):
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = batch_sizes[0]
- if hx is None:
- num_directions = 2 if self.bidirectional else 1
- hx = torch.zeros(
- self.num_layers * num_directions,
- max_batch_size,
- self.hidden_size,
- dtype=input.dtype,
- device=input.device,
- )
- else:
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- else:
- batch_sizes = None
- if input.dim() not in (2, 3):
- raise ValueError(
- f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead"
- )
- is_batched = input.dim() == 3
- batch_dim = 0 if self.batch_first else 1
- if not is_batched:
- input = input.unsqueeze(batch_dim)
- if hx is not None:
- if hx.dim() != 2:
- raise RuntimeError(
- f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
- )
- hx = hx.unsqueeze(1)
- else:
- if hx is not None and hx.dim() != 3:
- raise RuntimeError(
- f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
- )
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
- if hx is None:
- num_directions = 2 if self.bidirectional else 1
- hx = torch.zeros(
- self.num_layers * num_directions,
- max_batch_size,
- self.hidden_size,
- dtype=input.dtype,
- device=input.device,
- )
- else:
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- self.check_forward_args(input, hx, batch_sizes)
- if batch_sizes is None:
- # pyrefly: ignore [no-matching-overload]
- result = _VF.gru(
- input,
- hx,
- self._flat_weights, # type: ignore[arg-type]
- self.bias,
- self.num_layers,
- self.dropout,
- self.training,
- self.bidirectional,
- self.batch_first,
- )
- else:
- # pyrefly: ignore [no-matching-overload]
- result = _VF.gru(
- input,
- batch_sizes,
- hx,
- self._flat_weights, # type: ignore[arg-type]
- self.bias,
- self.num_layers,
- self.dropout,
- self.training,
- self.bidirectional,
- )
- output = result[0]
- hidden = result[1]
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- if isinstance(orig_input, PackedSequence):
- output_packed = PackedSequence(
- output,
- # pyrefly: ignore [bad-argument-type]
- batch_sizes,
- sorted_indices,
- unsorted_indices,
- )
- return output_packed, self.permute_hidden(hidden, unsorted_indices)
- else:
- if not is_batched: # type: ignore[possibly-undefined]
- output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
- hidden = hidden.squeeze(1)
- return output, self.permute_hidden(hidden, unsorted_indices)
- class RNNCellBase(Module):
- __constants__ = ["input_size", "hidden_size", "bias"]
- input_size: int
- hidden_size: int
- bias: bool
- weight_ih: Tensor
- weight_hh: Tensor
- # WARNING: bias_ih and bias_hh purposely not defined here.
- # See https://github.com/pytorch/pytorch/issues/39670
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- bias: bool,
- num_chunks: int,
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.bias = bias
- self.weight_ih = Parameter(
- torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs)
- )
- self.weight_hh = Parameter(
- torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs)
- )
- if bias:
- self.bias_ih = Parameter(
- torch.empty(num_chunks * hidden_size, **factory_kwargs)
- )
- self.bias_hh = Parameter(
- torch.empty(num_chunks * hidden_size, **factory_kwargs)
- )
- else:
- self.register_parameter("bias_ih", None)
- self.register_parameter("bias_hh", None)
- self.reset_parameters()
- def extra_repr(self) -> str:
- s = "{input_size}, {hidden_size}"
- if "bias" in self.__dict__ and self.bias is not True:
- s += ", bias={bias}"
- if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh":
- s += ", nonlinearity={nonlinearity}"
- return s.format(**self.__dict__)
- def reset_parameters(self) -> None:
- stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
- for weight in self.parameters():
- init.uniform_(weight, -stdv, stdv)
- class RNNCell(RNNCellBase):
- r"""An Elman RNN cell with tanh or ReLU non-linearity.
- .. math::
- h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
- If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
- Default: ``True``
- nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
- Inputs: input, hidden
- - **input**: tensor containing input features
- - **hidden**: tensor containing the initial hidden state
- Defaults to zero if not provided.
- Outputs: h'
- - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
- for each element in the batch
- Shape:
- - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
- :math:`H_{in}` = `input_size`.
- - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
- state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
- - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
- Attributes:
- weight_ih: the learnable input-hidden weights, of shape
- `(hidden_size, input_size)`
- weight_hh: the learnable hidden-hidden weights, of shape
- `(hidden_size, hidden_size)`
- bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
- bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- Examples::
- >>> rnn = nn.RNNCell(10, 20)
- >>> input = torch.randn(6, 3, 10)
- >>> hx = torch.randn(3, 20)
- >>> output = []
- >>> for i in range(6):
- ... hx = rnn(input[i], hx)
- ... output.append(hx)
- """
- __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"]
- nonlinearity: str
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- bias: bool = True,
- nonlinearity: str = "tanh",
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
- self.nonlinearity = nonlinearity
- def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor:
- if input.dim() not in (1, 2):
- raise ValueError(
- f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
- )
- if hx is not None and hx.dim() not in (1, 2):
- raise ValueError(
- f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
- )
- is_batched = input.dim() == 2
- if not is_batched:
- input = input.unsqueeze(0)
- if hx is None:
- hx = torch.zeros(
- input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
- )
- else:
- hx = hx.unsqueeze(0) if not is_batched else hx
- if self.nonlinearity == "tanh":
- ret = _VF.rnn_tanh_cell(
- input,
- hx,
- self.weight_ih,
- self.weight_hh,
- self.bias_ih,
- self.bias_hh,
- )
- elif self.nonlinearity == "relu":
- ret = _VF.rnn_relu_cell(
- input,
- hx,
- self.weight_ih,
- self.weight_hh,
- self.bias_ih,
- self.bias_hh,
- )
- else:
- ret = input # TODO: remove when jit supports exception flow
- raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
- if not is_batched:
- ret = ret.squeeze(0)
- return ret
- class LSTMCell(RNNCellBase):
- r"""A long short-term memory (LSTM) cell.
- .. math::
- \begin{array}{ll}
- i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
- f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
- g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
- o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
- c' = f \odot c + i \odot g \\
- h' = o \odot \tanh(c') \\
- \end{array}
- where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- bias: If ``False``, then the layer does not use bias weights `b_ih` and
- `b_hh`. Default: ``True``
- Inputs: input, (h_0, c_0)
- - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features
- - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state
- - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state
- If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
- Outputs: (h_1, c_1)
- - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state
- - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state
- Attributes:
- weight_ih: the learnable input-hidden weights, of shape
- `(4*hidden_size, input_size)`
- weight_hh: the learnable hidden-hidden weights, of shape
- `(4*hidden_size, hidden_size)`
- bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
- bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
- Examples::
- >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
- >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
- >>> hx = torch.randn(3, 20) # (batch, hidden_size)
- >>> cx = torch.randn(3, 20)
- >>> output = []
- >>> for i in range(input.size()[0]):
- ... hx, cx = rnn(input[i], (hx, cx))
- ... output.append(hx)
- >>> output = torch.stack(output, dim=0)
- """
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- bias: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
- def forward(
- self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None
- ) -> tuple[Tensor, Tensor]:
- if input.dim() not in (1, 2):
- raise ValueError(
- f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
- )
- if hx is not None:
- for idx, value in enumerate(hx):
- if value.dim() not in (1, 2):
- raise ValueError(
- f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead"
- )
- is_batched = input.dim() == 2
- if not is_batched:
- input = input.unsqueeze(0)
- if hx is None:
- zeros = torch.zeros(
- input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
- )
- hx = (zeros, zeros)
- else:
- hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
- ret = _VF.lstm_cell(
- input,
- hx,
- self.weight_ih,
- self.weight_hh,
- self.bias_ih,
- self.bias_hh,
- )
- if not is_batched:
- ret = (ret[0].squeeze(0), ret[1].squeeze(0))
- return ret
- class GRUCell(RNNCellBase):
- r"""A gated recurrent unit (GRU) cell.
- .. math::
- \begin{array}{ll}
- r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
- z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
- n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\
- h' = (1 - z) \odot n + z \odot h
- \end{array}
- where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- bias: If ``False``, then the layer does not use bias weights `b_ih` and
- `b_hh`. Default: ``True``
- Inputs: input, hidden
- - **input** : tensor containing input features
- - **hidden** : tensor containing the initial hidden
- state for each element in the batch.
- Defaults to zero if not provided.
- Outputs: h'
- - **h'** : tensor containing the next hidden state
- for each element in the batch
- Shape:
- - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
- :math:`H_{in}` = `input_size`.
- - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
- state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
- - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
- Attributes:
- weight_ih: the learnable input-hidden weights, of shape
- `(3*hidden_size, input_size)`
- weight_hh: the learnable hidden-hidden weights, of shape
- `(3*hidden_size, hidden_size)`
- bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
- bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
- Examples::
- >>> rnn = nn.GRUCell(10, 20)
- >>> input = torch.randn(6, 3, 10)
- >>> hx = torch.randn(3, 20)
- >>> output = []
- >>> for i in range(6):
- ... hx = rnn(input[i], hx)
- ... output.append(hx)
- """
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- bias: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
- def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor:
- if input.dim() not in (1, 2):
- raise ValueError(
- f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
- )
- if hx is not None and hx.dim() not in (1, 2):
- raise ValueError(
- f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
- )
- is_batched = input.dim() == 2
- if not is_batched:
- input = input.unsqueeze(0)
- if hx is None:
- hx = torch.zeros(
- input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
- )
- else:
- hx = hx.unsqueeze(0) if not is_batched else hx
- ret = _VF.gru_cell(
- input,
- hx,
- self.weight_ih,
- self.weight_hh,
- self.bias_ih,
- self.bias_hh,
- )
- if not is_batched:
- ret = ret.squeeze(0)
- return ret
|