rnn.py 75 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import math
  4. import numbers
  5. import warnings
  6. import weakref
  7. from typing import overload
  8. from typing_extensions import deprecated
  9. import torch
  10. from torch import _VF, Tensor
  11. from torch.nn import init
  12. from torch.nn.parameter import Parameter
  13. from torch.nn.utils.rnn import PackedSequence
  14. from .module import Module
  15. __all__ = [
  16. "RNNBase",
  17. "RNN",
  18. "LSTM",
  19. "GRU",
  20. "RNNCellBase",
  21. "RNNCell",
  22. "LSTMCell",
  23. "GRUCell",
  24. ]
  25. _rnn_impls = {
  26. "RNN_TANH": _VF.rnn_tanh,
  27. "RNN_RELU": _VF.rnn_relu,
  28. }
  29. def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  30. return tensor.index_select(dim, permutation)
  31. @deprecated(
  32. "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead",
  33. category=FutureWarning,
  34. )
  35. def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  36. return _apply_permutation(tensor, permutation, dim)
  37. class RNNBase(Module):
  38. r"""Base class for RNN modules (RNN, LSTM, GRU).
  39. Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization
  40. and utility methods for parameter storage management.
  41. .. note::
  42. The forward method is not implemented by the RNNBase class.
  43. .. note::
  44. LSTM and GRU classes override some methods implemented by RNNBase.
  45. """
  46. __constants__ = [
  47. "mode",
  48. "input_size",
  49. "hidden_size",
  50. "num_layers",
  51. "bias",
  52. "batch_first",
  53. "dropout",
  54. "bidirectional",
  55. "proj_size",
  56. ]
  57. __jit_unused_properties__ = ["all_weights"]
  58. mode: str
  59. input_size: int
  60. hidden_size: int
  61. num_layers: int
  62. bias: bool
  63. batch_first: bool
  64. dropout: float
  65. bidirectional: bool
  66. proj_size: int
  67. def __init__(
  68. self,
  69. mode: str,
  70. input_size: int,
  71. hidden_size: int,
  72. num_layers: int = 1,
  73. bias: bool = True,
  74. batch_first: bool = False,
  75. dropout: float = 0.0,
  76. bidirectional: bool = False,
  77. proj_size: int = 0,
  78. device=None,
  79. dtype=None,
  80. ) -> None:
  81. factory_kwargs = {"device": device, "dtype": dtype}
  82. super().__init__()
  83. self.mode = mode
  84. self.input_size = input_size
  85. self.hidden_size = hidden_size
  86. self.num_layers = num_layers
  87. self.bias = bias
  88. self.batch_first = batch_first
  89. self.dropout = float(dropout)
  90. self.bidirectional = bidirectional
  91. self.proj_size = proj_size
  92. self._flat_weight_refs: list[weakref.ReferenceType[Parameter] | None] = []
  93. num_directions = 2 if bidirectional else 1
  94. if (
  95. not isinstance(dropout, numbers.Number)
  96. or not 0 <= dropout <= 1
  97. or isinstance(dropout, bool)
  98. ):
  99. raise ValueError(
  100. "dropout should be a number in range [0, 1] "
  101. "representing the probability of an element being "
  102. "zeroed"
  103. )
  104. if dropout > 0 and num_layers == 1:
  105. warnings.warn(
  106. "dropout option adds dropout after all but last "
  107. "recurrent layer, so non-zero dropout expects "
  108. f"num_layers greater than 1, but got dropout={dropout} and "
  109. f"num_layers={num_layers}",
  110. stacklevel=2,
  111. )
  112. if not isinstance(bias, bool):
  113. raise TypeError(f"bias should be of type bool, got: {type(bias).__name__}")
  114. if not isinstance(batch_first, bool):
  115. raise TypeError(
  116. f"batch_first should be of type bool, got: {type(batch_first).__name__}"
  117. )
  118. if not isinstance(input_size, int):
  119. raise TypeError(
  120. f"input_size should be of type int, got: {type(input_size).__name__}"
  121. )
  122. if input_size <= 0:
  123. raise ValueError("input_size must be greater than zero")
  124. if not isinstance(hidden_size, int):
  125. raise TypeError(
  126. f"hidden_size should be of type int, got: {type(hidden_size).__name__}"
  127. )
  128. if hidden_size <= 0:
  129. raise ValueError("hidden_size must be greater than zero")
  130. if num_layers <= 0:
  131. raise ValueError("num_layers must be greater than zero")
  132. if proj_size < 0:
  133. raise ValueError(
  134. "proj_size should be a positive integer or zero to disable projections"
  135. )
  136. if proj_size >= hidden_size:
  137. raise ValueError("proj_size has to be smaller than hidden_size")
  138. if mode == "LSTM":
  139. gate_size = 4 * hidden_size
  140. elif mode == "GRU":
  141. gate_size = 3 * hidden_size
  142. elif mode == "RNN_TANH":
  143. gate_size = hidden_size
  144. elif mode == "RNN_RELU":
  145. gate_size = hidden_size
  146. else:
  147. raise ValueError("Unrecognized RNN mode: " + mode)
  148. self._flat_weights_names = []
  149. self._all_weights = []
  150. for layer in range(num_layers):
  151. for direction in range(num_directions):
  152. real_hidden_size = proj_size if proj_size > 0 else hidden_size
  153. layer_input_size = (
  154. input_size if layer == 0 else real_hidden_size * num_directions
  155. )
  156. w_ih = Parameter(
  157. torch.empty((gate_size, layer_input_size), **factory_kwargs)
  158. )
  159. w_hh = Parameter(
  160. torch.empty((gate_size, real_hidden_size), **factory_kwargs)
  161. )
  162. b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
  163. # Second bias vector included for CuDNN compatibility. Only one
  164. # bias vector is needed in standard definition.
  165. b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
  166. layer_params: tuple[Tensor, ...] = ()
  167. if self.proj_size == 0:
  168. if bias:
  169. layer_params = (w_ih, w_hh, b_ih, b_hh)
  170. else:
  171. layer_params = (w_ih, w_hh)
  172. else:
  173. w_hr = Parameter(
  174. torch.empty((proj_size, hidden_size), **factory_kwargs)
  175. )
  176. if bias:
  177. layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
  178. else:
  179. layer_params = (w_ih, w_hh, w_hr)
  180. suffix = "_reverse" if direction == 1 else ""
  181. param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"]
  182. if bias:
  183. param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"]
  184. if self.proj_size > 0:
  185. param_names += ["weight_hr_l{}{}"]
  186. param_names = [x.format(layer, suffix) for x in param_names]
  187. for name, param in zip(param_names, layer_params, strict=True):
  188. setattr(self, name, param)
  189. self._flat_weights_names.extend(param_names)
  190. self._all_weights.append(param_names)
  191. self._init_flat_weights()
  192. self.reset_parameters()
  193. def _init_flat_weights(self) -> None:
  194. self._flat_weights = [
  195. getattr(self, wn) if hasattr(self, wn) else None
  196. for wn in self._flat_weights_names
  197. ]
  198. self._flat_weight_refs = [
  199. weakref.ref(w) if w is not None else None for w in self._flat_weights
  200. ]
  201. self.flatten_parameters()
  202. def __setattr__(self, attr, value) -> None:
  203. if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names:
  204. # keep self._flat_weights up to date if you do self.weight = ...
  205. idx = self._flat_weights_names.index(attr)
  206. self._flat_weights[idx] = value
  207. super().__setattr__(attr, value)
  208. def flatten_parameters(self) -> None:
  209. """Reset parameter data pointer so that they can use faster code paths.
  210. Right now, this works only if the module is on the GPU and cuDNN is enabled.
  211. Otherwise, it's a no-op.
  212. """
  213. # Short-circuits if _flat_weights is only partially instantiated
  214. if len(self._flat_weights) != len(self._flat_weights_names):
  215. return
  216. for w in self._flat_weights:
  217. if not isinstance(w, Tensor):
  218. return
  219. # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
  220. # or the tensors in _flat_weights are of different dtypes
  221. first_fw = self._flat_weights[0] # type: ignore[union-attr]
  222. dtype = first_fw.dtype # type: ignore[union-attr]
  223. for fw in self._flat_weights:
  224. if (
  225. not isinstance(fw, Tensor)
  226. or fw.dtype != dtype
  227. or not fw.is_cuda
  228. or not torch.backends.cudnn.is_acceptable(fw)
  229. ):
  230. return
  231. # If any parameters alias, we fall back to the slower, copying code path. This is
  232. # a sufficient check, because overlapping parameter buffers that don't completely
  233. # alias would break the assumptions of the uniqueness check in
  234. # Module.named_parameters().
  235. unique_data_ptrs = {
  236. p.data_ptr() # type: ignore[union-attr]
  237. for p in self._flat_weights
  238. }
  239. if len(unique_data_ptrs) != len(self._flat_weights):
  240. return
  241. with torch.cuda.device_of(first_fw):
  242. import torch.backends.cudnn.rnn as rnn
  243. # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
  244. # an inplace operation on self._flat_weights
  245. with torch.no_grad():
  246. if torch._use_cudnn_rnn_flatten_weight():
  247. num_weights = 4 if self.bias else 2
  248. if self.proj_size > 0:
  249. num_weights += 1
  250. torch._cudnn_rnn_flatten_weight(
  251. self._flat_weights, # type: ignore[arg-type]
  252. num_weights,
  253. self.input_size,
  254. rnn.get_cudnn_mode(self.mode),
  255. self.hidden_size,
  256. self.proj_size,
  257. self.num_layers,
  258. self.batch_first,
  259. bool(self.bidirectional),
  260. )
  261. def _apply(self, fn, recurse=True):
  262. self._flat_weight_refs = []
  263. ret = super()._apply(fn, recurse)
  264. # Resets _flat_weights
  265. # Note: be v. careful before removing this, as 3rd party device types
  266. # likely rely on this behavior to properly .to() modules like LSTM.
  267. self._init_flat_weights()
  268. return ret
  269. def reset_parameters(self) -> None:
  270. stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
  271. for weight in self.parameters():
  272. init.uniform_(weight, -stdv, stdv)
  273. def check_input(self, input: Tensor, batch_sizes: Tensor | None) -> None:
  274. if not torch.jit.is_scripting():
  275. if (
  276. input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr]
  277. and not torch._C._is_any_autocast_enabled()
  278. ):
  279. raise ValueError(
  280. f"RNN input dtype ({input.dtype}) does not match weight dtype ({self._flat_weights[0].dtype}). " # type: ignore[union-attr]
  281. f"Convert input: input.to({self._flat_weights[0].dtype}), or convert model: model.to({input.dtype})" # type: ignore[union-attr]
  282. )
  283. expected_input_dim = 2 if batch_sizes is not None else 3
  284. if input.dim() != expected_input_dim:
  285. raise RuntimeError(
  286. f"input must have {expected_input_dim} dimensions, got {input.dim()}"
  287. )
  288. if self.input_size != input.size(-1):
  289. raise RuntimeError(
  290. f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}"
  291. )
  292. def get_expected_hidden_size(
  293. self, input: Tensor, batch_sizes: Tensor | None
  294. ) -> tuple[int, int, int]:
  295. if batch_sizes is not None:
  296. mini_batch = int(batch_sizes[0])
  297. else:
  298. mini_batch = input.size(0) if self.batch_first else input.size(1)
  299. num_directions = 2 if self.bidirectional else 1
  300. if self.proj_size > 0:
  301. expected_hidden_size = (
  302. self.num_layers * num_directions,
  303. mini_batch,
  304. self.proj_size,
  305. )
  306. else:
  307. expected_hidden_size = (
  308. self.num_layers * num_directions,
  309. mini_batch,
  310. self.hidden_size,
  311. )
  312. return expected_hidden_size
  313. def check_hidden_size(
  314. self,
  315. hx: Tensor,
  316. expected_hidden_size: tuple[int, int, int],
  317. msg: str = "Expected hidden size {}, got {}",
  318. ) -> None:
  319. if hx.size() != expected_hidden_size:
  320. raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
  321. def _weights_have_changed(self):
  322. # Returns True if the weight tensors have changed since the last forward pass.
  323. # This is the case when used with torch.func.functional_call(), for example.
  324. weights_changed = False
  325. for ref, name in zip(
  326. self._flat_weight_refs, self._flat_weights_names, strict=True
  327. ):
  328. weight = getattr(self, name) if hasattr(self, name) else None
  329. if weight is not None and ref is not None and ref() is not weight:
  330. weights_changed = True
  331. break
  332. return weights_changed
  333. def check_forward_args(
  334. self, input: Tensor, hidden: Tensor, batch_sizes: Tensor | None
  335. ) -> None:
  336. self.check_input(input, batch_sizes)
  337. expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
  338. self.check_hidden_size(hidden, expected_hidden_size)
  339. def permute_hidden(self, hx: Tensor, permutation: Tensor | None):
  340. if permutation is None:
  341. return hx
  342. return _apply_permutation(hx, permutation)
  343. def extra_repr(self) -> str:
  344. s = "{input_size}, {hidden_size}"
  345. if self.proj_size != 0:
  346. s += ", proj_size={proj_size}"
  347. if self.num_layers != 1:
  348. s += ", num_layers={num_layers}"
  349. if self.bias is not True:
  350. s += ", bias={bias}"
  351. if self.batch_first is not False:
  352. s += ", batch_first={batch_first}"
  353. if self.dropout != 0:
  354. s += ", dropout={dropout}"
  355. if self.bidirectional is not False:
  356. s += ", bidirectional={bidirectional}"
  357. return s.format(**self.__dict__)
  358. def _update_flat_weights(self) -> None:
  359. if not torch.jit.is_scripting():
  360. if self._weights_have_changed():
  361. self._init_flat_weights()
  362. def __getstate__(self):
  363. # If weights have been changed, update the _flat_weights in __getstate__ here.
  364. self._update_flat_weights()
  365. # Don't serialize the weight references.
  366. state = self.__dict__.copy()
  367. del state["_flat_weight_refs"]
  368. return state
  369. def __setstate__(self, d):
  370. super().__setstate__(d)
  371. if "all_weights" in d:
  372. self._all_weights = d["all_weights"]
  373. # In PyTorch 1.8 we added a proj_size member variable to LSTM.
  374. # LSTMs that were serialized via torch.save(module) before PyTorch 1.8
  375. # don't have it, so to preserve compatibility we set proj_size here.
  376. if "proj_size" not in d:
  377. self.proj_size = 0
  378. if not isinstance(self._all_weights[0][0], str):
  379. num_layers = self.num_layers
  380. num_directions = 2 if self.bidirectional else 1
  381. self._flat_weights_names = []
  382. self._all_weights = []
  383. for layer in range(num_layers):
  384. for direction in range(num_directions):
  385. suffix = "_reverse" if direction == 1 else ""
  386. weights = [
  387. "weight_ih_l{}{}",
  388. "weight_hh_l{}{}",
  389. "bias_ih_l{}{}",
  390. "bias_hh_l{}{}",
  391. "weight_hr_l{}{}",
  392. ]
  393. weights = [x.format(layer, suffix) for x in weights]
  394. if self.bias:
  395. if self.proj_size > 0:
  396. self._all_weights += [weights]
  397. self._flat_weights_names.extend(weights)
  398. else:
  399. self._all_weights += [weights[:4]]
  400. self._flat_weights_names.extend(weights[:4])
  401. else:
  402. if self.proj_size > 0:
  403. self._all_weights += [weights[:2]] + [weights[-1:]]
  404. self._flat_weights_names.extend(
  405. weights[:2] + [weights[-1:]]
  406. )
  407. else:
  408. self._all_weights += [weights[:2]]
  409. self._flat_weights_names.extend(weights[:2])
  410. self._flat_weights = [
  411. getattr(self, wn) if hasattr(self, wn) else None
  412. for wn in self._flat_weights_names
  413. ]
  414. self._flat_weight_refs = [
  415. weakref.ref(w) if w is not None else None for w in self._flat_weights
  416. ]
  417. @property
  418. def all_weights(self) -> list[list[Parameter]]:
  419. return [
  420. [getattr(self, weight) for weight in weights]
  421. for weights in self._all_weights
  422. ]
  423. def _replicate_for_data_parallel(self):
  424. replica = super()._replicate_for_data_parallel()
  425. # Need to copy these caches, otherwise the replica will share the same
  426. # flat weights list.
  427. replica._flat_weights = replica._flat_weights[:]
  428. replica._flat_weights_names = replica._flat_weights_names[:]
  429. return replica
  430. class RNN(RNNBase):
  431. r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None)
  432. Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}`
  433. non-linearity to an input sequence. For each element in the input sequence,
  434. each layer computes the following function:
  435. .. math::
  436. h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
  437. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
  438. the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
  439. previous layer at time `t-1` or the initial hidden state at time `0`.
  440. If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
  441. .. code-block:: python
  442. # Efficient implementation equivalent to the following with bidirectional=False
  443. rnn = nn.RNN(input_size, hidden_size, num_layers)
  444. params = dict(rnn.named_parameters())
  445. def forward(x, hx=None, batch_first=False):
  446. if batch_first:
  447. x = x.transpose(0, 1)
  448. seq_len, batch_size, _ = x.size()
  449. if hx is None:
  450. hx = torch.zeros(rnn.num_layers, batch_size, rnn.hidden_size)
  451. h_t_minus_1 = hx.clone()
  452. h_t = hx.clone()
  453. output = []
  454. for t in range(seq_len):
  455. for layer in range(rnn.num_layers):
  456. input_t = x[t] if layer == 0 else h_t[layer - 1]
  457. h_t[layer] = torch.tanh(
  458. input_t @ params[f"weight_ih_l{layer}"].T
  459. + h_t_minus_1[layer] @ params[f"weight_hh_l{layer}"].T
  460. + params[f"bias_hh_l{layer}"]
  461. + params[f"bias_ih_l{layer}"]
  462. )
  463. output.append(h_t[-1].clone())
  464. h_t_minus_1 = h_t.clone()
  465. output = torch.stack(output)
  466. if batch_first:
  467. output = output.transpose(0, 1)
  468. return output, h_t
  469. Args:
  470. input_size: The number of expected features in the input `x`
  471. hidden_size: The number of features in the hidden state `h`
  472. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  473. would mean stacking two RNNs together to form a `stacked RNN`,
  474. with the second RNN taking in outputs of the first RNN and
  475. computing the final results. Default: 1
  476. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  477. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  478. Default: ``True``
  479. batch_first: If ``True``, then the input and output tensors are provided
  480. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  481. Note that this does not apply to hidden or cell states. See the
  482. Inputs/Outputs sections below for details. Default: ``False``
  483. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  484. RNN layer except the last layer, with dropout probability equal to
  485. :attr:`dropout`. Default: 0
  486. bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
  487. Inputs: input, hx
  488. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  489. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  490. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  491. the input sequence. The input can also be a packed variable length sequence.
  492. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  493. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  494. * **hx**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  495. :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
  496. state for the input sequence batch. Defaults to zeros if not provided.
  497. where:
  498. .. math::
  499. \begin{aligned}
  500. N ={} & \text{batch size} \\
  501. L ={} & \text{sequence length} \\
  502. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  503. H_{in} ={} & \text{input\_size} \\
  504. H_{out} ={} & \text{hidden\_size}
  505. \end{aligned}
  506. Outputs: output, h_n
  507. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  508. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  509. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  510. `(h_t)` from the last layer of the RNN, for each `t`. If a
  511. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  512. will also be a packed sequence.
  513. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  514. :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  515. for each element in the batch.
  516. Attributes:
  517. weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
  518. of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
  519. `(hidden_size, num_directions * hidden_size)`
  520. weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
  521. of shape `(hidden_size, hidden_size)`
  522. bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
  523. of shape `(hidden_size)`
  524. bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
  525. of shape `(hidden_size)`
  526. .. note::
  527. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  528. where :math:`k = \frac{1}{\text{hidden\_size}}`
  529. .. note::
  530. For bidirectional RNNs, forward and backward are directions 0 and 1 respectively.
  531. Example of splitting the output layers when ``batch_first=False``:
  532. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  533. .. note::
  534. ``batch_first`` argument is ignored for unbatched inputs.
  535. .. include:: ../cudnn_rnn_determinism.rst
  536. .. include:: ../cudnn_persistent_rnn.rst
  537. Examples::
  538. >>> rnn = nn.RNN(10, 20, 2)
  539. >>> input = torch.randn(5, 3, 10)
  540. >>> h0 = torch.randn(2, 3, 20)
  541. >>> output, hn = rnn(input, h0)
  542. """
  543. @overload
  544. def __init__(
  545. self,
  546. input_size: int,
  547. hidden_size: int,
  548. num_layers: int = 1,
  549. nonlinearity: str = "tanh",
  550. bias: bool = True,
  551. batch_first: bool = False,
  552. dropout: float = 0.0,
  553. bidirectional: bool = False,
  554. device=None,
  555. dtype=None,
  556. ) -> None: ...
  557. @overload
  558. def __init__(self, *args, **kwargs) -> None: ...
  559. def __init__(self, *args, **kwargs):
  560. if "proj_size" in kwargs:
  561. raise ValueError(
  562. "proj_size argument is only supported for LSTM, not RNN or GRU"
  563. )
  564. if len(args) > 3:
  565. self.nonlinearity = args[3]
  566. args = args[:3] + args[4:]
  567. else:
  568. self.nonlinearity = kwargs.pop("nonlinearity", "tanh")
  569. if self.nonlinearity == "tanh":
  570. mode = "RNN_TANH"
  571. elif self.nonlinearity == "relu":
  572. mode = "RNN_RELU"
  573. else:
  574. raise ValueError(
  575. f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'."
  576. )
  577. super().__init__(mode, *args, **kwargs)
  578. @overload
  579. @torch._jit_internal._overload_method # noqa: F811
  580. # pyrefly: ignore [bad-override]
  581. def forward(
  582. self,
  583. input: Tensor,
  584. hx: Tensor | None = None,
  585. # pyrefly: ignore [bad-return]
  586. ) -> tuple[Tensor, Tensor]:
  587. pass
  588. @overload
  589. @torch._jit_internal._overload_method # noqa: F811
  590. def forward(
  591. self,
  592. input: PackedSequence,
  593. hx: Tensor | None = None,
  594. # pyrefly: ignore [bad-return]
  595. ) -> tuple[PackedSequence, Tensor]:
  596. pass
  597. def forward(self, input, hx=None): # noqa: F811
  598. """
  599. Runs the forward pass.
  600. """
  601. self._update_flat_weights()
  602. num_directions = 2 if self.bidirectional else 1
  603. orig_input = input
  604. if isinstance(orig_input, PackedSequence):
  605. input, batch_sizes, sorted_indices, unsorted_indices = input
  606. max_batch_size = batch_sizes[0]
  607. # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate
  608. if hx is None:
  609. hx = torch.zeros(
  610. self.num_layers * num_directions,
  611. max_batch_size,
  612. self.hidden_size,
  613. dtype=input.dtype,
  614. device=input.device,
  615. )
  616. else:
  617. # Each batch of the hidden state should match the input sequence that
  618. # the user believes he/she is passing in.
  619. hx = self.permute_hidden(hx, sorted_indices)
  620. else:
  621. batch_sizes = None
  622. if input.dim() not in (2, 3):
  623. raise ValueError(
  624. f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead"
  625. )
  626. is_batched = input.dim() == 3
  627. batch_dim = 0 if self.batch_first else 1
  628. if not is_batched:
  629. input = input.unsqueeze(batch_dim)
  630. if hx is not None:
  631. if hx.dim() != 2:
  632. raise RuntimeError(
  633. f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
  634. )
  635. hx = hx.unsqueeze(1)
  636. else:
  637. if hx is not None and hx.dim() != 3:
  638. raise RuntimeError(
  639. f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
  640. )
  641. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  642. sorted_indices = None
  643. unsorted_indices = None
  644. if hx is None:
  645. hx = torch.zeros(
  646. self.num_layers * num_directions,
  647. max_batch_size,
  648. self.hidden_size,
  649. dtype=input.dtype,
  650. device=input.device,
  651. )
  652. else:
  653. # Each batch of the hidden state should match the input sequence that
  654. # the user believes he/she is passing in.
  655. hx = self.permute_hidden(hx, sorted_indices)
  656. if hx is None:
  657. raise AssertionError("hx must not be None")
  658. self.check_forward_args(input, hx, batch_sizes)
  659. if self.mode != "RNN_TANH" and self.mode != "RNN_RELU":
  660. raise AssertionError(f"mode must be RNN_TANH or RNN_RELU, got {self.mode}")
  661. if batch_sizes is None:
  662. if self.mode == "RNN_TANH":
  663. # pyrefly: ignore [no-matching-overload]
  664. result = _VF.rnn_tanh(
  665. input,
  666. hx,
  667. self._flat_weights, # type: ignore[arg-type]
  668. self.bias,
  669. self.num_layers,
  670. self.dropout,
  671. self.training,
  672. self.bidirectional,
  673. self.batch_first,
  674. )
  675. else:
  676. # pyrefly: ignore [no-matching-overload]
  677. result = _VF.rnn_relu(
  678. input,
  679. hx,
  680. self._flat_weights, # type: ignore[arg-type]
  681. self.bias,
  682. self.num_layers,
  683. self.dropout,
  684. self.training,
  685. self.bidirectional,
  686. self.batch_first,
  687. )
  688. else:
  689. if self.mode == "RNN_TANH":
  690. # pyrefly: ignore [no-matching-overload]
  691. result = _VF.rnn_tanh(
  692. input,
  693. batch_sizes,
  694. hx,
  695. self._flat_weights, # type: ignore[arg-type]
  696. self.bias,
  697. self.num_layers,
  698. self.dropout,
  699. self.training,
  700. self.bidirectional,
  701. )
  702. else:
  703. # pyrefly: ignore [no-matching-overload]
  704. result = _VF.rnn_relu(
  705. input,
  706. batch_sizes,
  707. hx,
  708. self._flat_weights, # type: ignore[arg-type]
  709. self.bias,
  710. self.num_layers,
  711. self.dropout,
  712. self.training,
  713. self.bidirectional,
  714. )
  715. output = result[0]
  716. hidden = result[1]
  717. if isinstance(orig_input, PackedSequence):
  718. output_packed = PackedSequence(
  719. output,
  720. # pyrefly: ignore [bad-argument-type]
  721. batch_sizes,
  722. sorted_indices,
  723. unsorted_indices,
  724. )
  725. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  726. if not is_batched: # type: ignore[possibly-undefined]
  727. output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
  728. hidden = hidden.squeeze(1)
  729. return output, self.permute_hidden(hidden, unsorted_indices)
  730. # XXX: LSTM and GRU implementation is different from RNNBase, this is because:
  731. # 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
  732. # its current state could not support the python Union Type or Any Type
  733. # 2. TorchScript static typing does not allow a Function or Callable type in
  734. # Dict values, so we have to separately call _VF instead of using _rnn_impls
  735. # 3. This is temporary only and in the transition state that we want to make it
  736. # on time for the release
  737. #
  738. # More discussion details in https://github.com/pytorch/pytorch/pull/23266
  739. #
  740. # TODO: remove the overriding implementations for LSTM and GRU when TorchScript
  741. # support expressing these two modules generally.
  742. class LSTM(RNNBase):
  743. r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None)
  744. Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence.
  745. For each element in the input sequence, each layer computes the following
  746. function:
  747. .. math::
  748. \begin{array}{ll} \\
  749. i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
  750. f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
  751. g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
  752. o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
  753. c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
  754. h_t = o_t \odot \tanh(c_t) \\
  755. \end{array}
  756. where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
  757. state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
  758. is the hidden state of the layer at time `t-1` or the initial hidden
  759. state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
  760. :math:`o_t` are the input, forget, cell, and output gates, respectively.
  761. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  762. In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  763. (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  764. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  765. variable which is :math:`0` with probability :attr:`dropout`.
  766. If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes
  767. the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from
  768. ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).
  769. Second, the output hidden state of each layer will be multiplied by a learnable projection
  770. matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output
  771. of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact
  772. dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.
  773. Args:
  774. input_size: The number of expected features in the input `x`
  775. hidden_size: The number of features in the hidden state `h`
  776. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  777. would mean stacking two LSTMs together to form a `stacked LSTM`,
  778. with the second LSTM taking in outputs of the first LSTM and
  779. computing the final results. Default: 1
  780. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  781. Default: ``True``
  782. batch_first: If ``True``, then the input and output tensors are provided
  783. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  784. Note that this does not apply to hidden or cell states. See the
  785. Inputs/Outputs sections below for details. Default: ``False``
  786. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  787. LSTM layer except the last layer, with dropout probability equal to
  788. :attr:`dropout`. Default: 0
  789. bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
  790. proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
  791. Inputs: input, (h_0, c_0)
  792. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  793. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  794. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  795. the input sequence. The input can also be a packed variable length sequence.
  796. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  797. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  798. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  799. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  800. initial hidden state for each element in the input sequence.
  801. Defaults to zeros if (h_0, c_0) is not provided.
  802. * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
  803. :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  804. initial cell state for each element in the input sequence.
  805. Defaults to zeros if (h_0, c_0) is not provided.
  806. where:
  807. .. math::
  808. \begin{aligned}
  809. N ={} & \text{batch size} \\
  810. L ={} & \text{sequence length} \\
  811. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  812. H_{in} ={} & \text{input\_size} \\
  813. H_{cell} ={} & \text{hidden\_size} \\
  814. H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
  815. \end{aligned}
  816. Outputs: output, (h_n, c_n)
  817. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  818. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  819. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  820. `(h_t)` from the last layer of the LSTM, for each `t`. If a
  821. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  822. will also be a packed sequence. When ``bidirectional=True``, `output` will contain
  823. a concatenation of the forward and reverse hidden states at each time step in the sequence.
  824. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  825. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  826. final hidden state for each element in the sequence. When ``bidirectional=True``,
  827. `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively.
  828. * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
  829. :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  830. final cell state for each element in the sequence. When ``bidirectional=True``,
  831. `c_n` will contain a concatenation of the final forward and reverse cell states, respectively.
  832. Attributes:
  833. weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
  834. `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
  835. Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If
  836. ``proj_size > 0`` was specified, the shape will be
  837. `(4*hidden_size, num_directions * proj_size)` for `k > 0`
  838. weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
  839. `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0``
  840. was specified, the shape will be `(4*hidden_size, proj_size)`.
  841. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
  842. `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
  843. bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
  844. `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
  845. weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer
  846. of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was
  847. specified.
  848. weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction.
  849. Only present when ``bidirectional=True``.
  850. weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction.
  851. Only present when ``bidirectional=True``.
  852. bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction.
  853. Only present when ``bidirectional=True``.
  854. bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction.
  855. Only present when ``bidirectional=True``.
  856. weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction.
  857. Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified.
  858. .. note::
  859. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  860. where :math:`k = \frac{1}{\text{hidden\_size}}`
  861. .. note::
  862. For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively.
  863. Example of splitting the output layers when ``batch_first=False``:
  864. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  865. .. note::
  866. For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the
  867. former contains the final forward and reverse hidden states, while the latter contains the
  868. final forward hidden state and the initial reverse hidden state.
  869. .. note::
  870. ``batch_first`` argument is ignored for unbatched inputs.
  871. .. note::
  872. ``proj_size`` should be smaller than ``hidden_size``.
  873. .. include:: ../cudnn_rnn_determinism.rst
  874. .. include:: ../cudnn_persistent_rnn.rst
  875. Examples::
  876. >>> rnn = nn.LSTM(10, 20, 2)
  877. >>> input = torch.randn(5, 3, 10)
  878. >>> h0 = torch.randn(2, 3, 20)
  879. >>> c0 = torch.randn(2, 3, 20)
  880. >>> output, (hn, cn) = rnn(input, (h0, c0))
  881. """
  882. @overload
  883. def __init__(
  884. self,
  885. input_size: int,
  886. hidden_size: int,
  887. num_layers: int = 1,
  888. bias: bool = True,
  889. batch_first: bool = False,
  890. dropout: float = 0.0,
  891. bidirectional: bool = False,
  892. proj_size: int = 0,
  893. device=None,
  894. dtype=None,
  895. ) -> None: ...
  896. @overload
  897. def __init__(self, *args, **kwargs) -> None: ...
  898. def __init__(self, *args, **kwargs):
  899. super().__init__("LSTM", *args, **kwargs)
  900. def get_expected_cell_size(
  901. self, input: Tensor, batch_sizes: Tensor | None
  902. ) -> tuple[int, int, int]:
  903. if batch_sizes is not None:
  904. mini_batch = int(batch_sizes[0])
  905. else:
  906. mini_batch = input.size(0) if self.batch_first else input.size(1)
  907. num_directions = 2 if self.bidirectional else 1
  908. expected_hidden_size = (
  909. self.num_layers * num_directions,
  910. mini_batch,
  911. self.hidden_size,
  912. )
  913. return expected_hidden_size
  914. # In the future, we should prevent mypy from applying contravariance rules here.
  915. # See torch/nn/modules/module.py::_forward_unimplemented
  916. # pyrefly: ignore [bad-override]
  917. def check_forward_args(
  918. self,
  919. input: Tensor,
  920. hidden: tuple[Tensor, Tensor], # type: ignore[override]
  921. batch_sizes: Tensor | None,
  922. ) -> None:
  923. self.check_input(input, batch_sizes)
  924. self.check_hidden_size(
  925. hidden[0],
  926. self.get_expected_hidden_size(input, batch_sizes),
  927. "Expected hidden[0] size {}, got {}",
  928. )
  929. self.check_hidden_size(
  930. hidden[1],
  931. self.get_expected_cell_size(input, batch_sizes),
  932. "Expected hidden[1] size {}, got {}",
  933. )
  934. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  935. def permute_hidden( # type: ignore[override]
  936. self,
  937. hx: tuple[Tensor, Tensor],
  938. permutation: Tensor | None,
  939. ) -> tuple[Tensor, Tensor]:
  940. if permutation is None:
  941. return hx
  942. return _apply_permutation(hx[0], permutation), _apply_permutation(
  943. hx[1], permutation
  944. )
  945. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  946. @overload # type: ignore[override]
  947. @torch._jit_internal._overload_method # noqa: F811
  948. # pyrefly: ignore [bad-override]
  949. def forward(
  950. self,
  951. input: Tensor,
  952. hx: tuple[Tensor, Tensor] | None = None,
  953. # pyrefly: ignore [bad-return]
  954. ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
  955. pass
  956. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  957. @overload
  958. @torch._jit_internal._overload_method # noqa: F811
  959. def forward(
  960. self,
  961. input: PackedSequence,
  962. hx: tuple[Tensor, Tensor] | None = None,
  963. # pyrefly: ignore [bad-return]
  964. ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
  965. pass
  966. def forward(self, input, hx=None): # noqa: F811
  967. self._update_flat_weights()
  968. orig_input = input
  969. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  970. batch_sizes = None
  971. num_directions = 2 if self.bidirectional else 1
  972. real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
  973. if isinstance(orig_input, PackedSequence):
  974. input, batch_sizes, sorted_indices, unsorted_indices = input
  975. max_batch_size = batch_sizes[0]
  976. if hx is None:
  977. h_zeros = torch.zeros(
  978. self.num_layers * num_directions,
  979. max_batch_size,
  980. real_hidden_size,
  981. dtype=input.dtype,
  982. device=input.device,
  983. )
  984. c_zeros = torch.zeros(
  985. self.num_layers * num_directions,
  986. max_batch_size,
  987. self.hidden_size,
  988. dtype=input.dtype,
  989. device=input.device,
  990. )
  991. hx = (h_zeros, c_zeros)
  992. else:
  993. # Each batch of the hidden state should match the input sequence that
  994. # the user believes he/she is passing in.
  995. hx = self.permute_hidden(hx, sorted_indices)
  996. else:
  997. if input.dim() not in (2, 3):
  998. raise ValueError(
  999. f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead"
  1000. )
  1001. is_batched = input.dim() == 3
  1002. batch_dim = 0 if self.batch_first else 1
  1003. if not is_batched:
  1004. input = input.unsqueeze(batch_dim)
  1005. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  1006. sorted_indices = None
  1007. unsorted_indices = None
  1008. if hx is None:
  1009. h_zeros = torch.zeros(
  1010. self.num_layers * num_directions,
  1011. max_batch_size,
  1012. real_hidden_size,
  1013. dtype=input.dtype,
  1014. device=input.device,
  1015. )
  1016. c_zeros = torch.zeros(
  1017. self.num_layers * num_directions,
  1018. max_batch_size,
  1019. self.hidden_size,
  1020. dtype=input.dtype,
  1021. device=input.device,
  1022. )
  1023. hx = (h_zeros, c_zeros)
  1024. self.check_forward_args(input, hx, batch_sizes)
  1025. else:
  1026. if is_batched:
  1027. if hx[0].dim() != 3 or hx[1].dim() != 3:
  1028. msg = (
  1029. "For batched 3-D input, hx and cx should "
  1030. f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
  1031. )
  1032. raise RuntimeError(msg)
  1033. else:
  1034. if hx[0].dim() != 2 or hx[1].dim() != 2:
  1035. msg = (
  1036. "For unbatched 2-D input, hx and cx should "
  1037. f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
  1038. )
  1039. raise RuntimeError(msg)
  1040. hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
  1041. # Each batch of the hidden state should match the input sequence that
  1042. # the user believes he/she is passing in.
  1043. self.check_forward_args(input, hx, batch_sizes)
  1044. hx = self.permute_hidden(hx, sorted_indices)
  1045. if batch_sizes is None:
  1046. # pyrefly: ignore [no-matching-overload]
  1047. result = _VF.lstm(
  1048. input,
  1049. hx,
  1050. self._flat_weights, # type: ignore[arg-type]
  1051. self.bias,
  1052. self.num_layers,
  1053. self.dropout,
  1054. self.training,
  1055. self.bidirectional,
  1056. self.batch_first,
  1057. )
  1058. else:
  1059. # pyrefly: ignore [no-matching-overload]
  1060. result = _VF.lstm(
  1061. input,
  1062. batch_sizes,
  1063. hx,
  1064. self._flat_weights, # type: ignore[arg-type]
  1065. self.bias,
  1066. self.num_layers,
  1067. self.dropout,
  1068. self.training,
  1069. self.bidirectional,
  1070. )
  1071. output = result[0]
  1072. hidden = result[1:]
  1073. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  1074. if isinstance(orig_input, PackedSequence):
  1075. output_packed = PackedSequence(
  1076. output,
  1077. # pyrefly: ignore [bad-argument-type]
  1078. batch_sizes,
  1079. sorted_indices,
  1080. unsorted_indices,
  1081. )
  1082. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  1083. else:
  1084. if not is_batched: # type: ignore[possibly-undefined]
  1085. output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
  1086. hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
  1087. return output, self.permute_hidden(hidden, unsorted_indices)
  1088. class GRU(RNNBase):
  1089. r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None)
  1090. Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
  1091. For each element in the input sequence, each layer computes the following
  1092. function:
  1093. .. math::
  1094. \begin{array}{ll}
  1095. r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
  1096. z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
  1097. n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\
  1098. h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
  1099. \end{array}
  1100. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
  1101. at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
  1102. at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
  1103. :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
  1104. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  1105. In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  1106. (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  1107. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  1108. variable which is :math:`0` with probability :attr:`dropout`.
  1109. Args:
  1110. input_size: The number of expected features in the input `x`
  1111. hidden_size: The number of features in the hidden state `h`
  1112. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  1113. would mean stacking two GRUs together to form a `stacked GRU`,
  1114. with the second GRU taking in outputs of the first GRU and
  1115. computing the final results. Default: 1
  1116. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  1117. Default: ``True``
  1118. batch_first: If ``True``, then the input and output tensors are provided
  1119. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  1120. Note that this does not apply to hidden or cell states. See the
  1121. Inputs/Outputs sections below for details. Default: ``False``
  1122. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  1123. GRU layer except the last layer, with dropout probability equal to
  1124. :attr:`dropout`. Default: 0
  1125. bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
  1126. Inputs: input, h_0
  1127. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  1128. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  1129. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  1130. the input sequence. The input can also be a packed variable length sequence.
  1131. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  1132. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  1133. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
  1134. :math:`(D * \text{num\_layers}, N, H_{out})`
  1135. containing the initial hidden state for the input sequence. Defaults to zeros if not provided.
  1136. where:
  1137. .. math::
  1138. \begin{aligned}
  1139. N ={} & \text{batch size} \\
  1140. L ={} & \text{sequence length} \\
  1141. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  1142. H_{in} ={} & \text{input\_size} \\
  1143. H_{out} ={} & \text{hidden\_size}
  1144. \end{aligned}
  1145. Outputs: output, h_n
  1146. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  1147. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  1148. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  1149. `(h_t)` from the last layer of the GRU, for each `t`. If a
  1150. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  1151. will also be a packed sequence.
  1152. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
  1153. :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  1154. for the input sequence.
  1155. Attributes:
  1156. weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
  1157. (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
  1158. Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
  1159. weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
  1160. (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
  1161. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
  1162. (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
  1163. bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
  1164. (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
  1165. .. note::
  1166. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1167. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1168. .. note::
  1169. For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.
  1170. Example of splitting the output layers when ``batch_first=False``:
  1171. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  1172. .. note::
  1173. ``batch_first`` argument is ignored for unbatched inputs.
  1174. .. note::
  1175. The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks.
  1176. In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the
  1177. previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix
  1178. `W` and addition of bias:
  1179. .. math::
  1180. \begin{aligned}
  1181. n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn})
  1182. \end{aligned}
  1183. This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}`
  1184. .. math::
  1185. \begin{aligned}
  1186. n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn}))
  1187. \end{aligned}
  1188. This implementation differs on purpose for efficiency.
  1189. .. include:: ../cudnn_persistent_rnn.rst
  1190. Examples::
  1191. >>> rnn = nn.GRU(10, 20, 2)
  1192. >>> input = torch.randn(5, 3, 10)
  1193. >>> h0 = torch.randn(2, 3, 20)
  1194. >>> output, hn = rnn(input, h0)
  1195. """
  1196. @overload
  1197. def __init__(
  1198. self,
  1199. input_size: int,
  1200. hidden_size: int,
  1201. num_layers: int = 1,
  1202. bias: bool = True,
  1203. batch_first: bool = False,
  1204. dropout: float = 0.0,
  1205. bidirectional: bool = False,
  1206. device=None,
  1207. dtype=None,
  1208. ) -> None: ...
  1209. @overload
  1210. def __init__(self, *args, **kwargs) -> None: ...
  1211. def __init__(self, *args, **kwargs):
  1212. if "proj_size" in kwargs:
  1213. raise ValueError(
  1214. "proj_size argument is only supported for LSTM, not RNN or GRU"
  1215. )
  1216. super().__init__("GRU", *args, **kwargs)
  1217. @overload # type: ignore[override]
  1218. @torch._jit_internal._overload_method # noqa: F811
  1219. # pyrefly: ignore [bad-override]
  1220. def forward(
  1221. self,
  1222. input: Tensor,
  1223. hx: Tensor | None = None,
  1224. # pyrefly: ignore [bad-return]
  1225. ) -> tuple[Tensor, Tensor]: # noqa: F811
  1226. pass
  1227. @overload
  1228. @torch._jit_internal._overload_method # noqa: F811
  1229. def forward(
  1230. self,
  1231. input: PackedSequence,
  1232. hx: Tensor | None = None,
  1233. # pyrefly: ignore [bad-return]
  1234. ) -> tuple[PackedSequence, Tensor]: # noqa: F811
  1235. pass
  1236. def forward(self, input, hx=None): # noqa: F811
  1237. self._update_flat_weights()
  1238. orig_input = input
  1239. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  1240. if isinstance(orig_input, PackedSequence):
  1241. input, batch_sizes, sorted_indices, unsorted_indices = input
  1242. max_batch_size = batch_sizes[0]
  1243. if hx is None:
  1244. num_directions = 2 if self.bidirectional else 1
  1245. hx = torch.zeros(
  1246. self.num_layers * num_directions,
  1247. max_batch_size,
  1248. self.hidden_size,
  1249. dtype=input.dtype,
  1250. device=input.device,
  1251. )
  1252. else:
  1253. # Each batch of the hidden state should match the input sequence that
  1254. # the user believes he/she is passing in.
  1255. hx = self.permute_hidden(hx, sorted_indices)
  1256. else:
  1257. batch_sizes = None
  1258. if input.dim() not in (2, 3):
  1259. raise ValueError(
  1260. f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead"
  1261. )
  1262. is_batched = input.dim() == 3
  1263. batch_dim = 0 if self.batch_first else 1
  1264. if not is_batched:
  1265. input = input.unsqueeze(batch_dim)
  1266. if hx is not None:
  1267. if hx.dim() != 2:
  1268. raise RuntimeError(
  1269. f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
  1270. )
  1271. hx = hx.unsqueeze(1)
  1272. else:
  1273. if hx is not None and hx.dim() != 3:
  1274. raise RuntimeError(
  1275. f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
  1276. )
  1277. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  1278. sorted_indices = None
  1279. unsorted_indices = None
  1280. if hx is None:
  1281. num_directions = 2 if self.bidirectional else 1
  1282. hx = torch.zeros(
  1283. self.num_layers * num_directions,
  1284. max_batch_size,
  1285. self.hidden_size,
  1286. dtype=input.dtype,
  1287. device=input.device,
  1288. )
  1289. else:
  1290. # Each batch of the hidden state should match the input sequence that
  1291. # the user believes he/she is passing in.
  1292. hx = self.permute_hidden(hx, sorted_indices)
  1293. self.check_forward_args(input, hx, batch_sizes)
  1294. if batch_sizes is None:
  1295. # pyrefly: ignore [no-matching-overload]
  1296. result = _VF.gru(
  1297. input,
  1298. hx,
  1299. self._flat_weights, # type: ignore[arg-type]
  1300. self.bias,
  1301. self.num_layers,
  1302. self.dropout,
  1303. self.training,
  1304. self.bidirectional,
  1305. self.batch_first,
  1306. )
  1307. else:
  1308. # pyrefly: ignore [no-matching-overload]
  1309. result = _VF.gru(
  1310. input,
  1311. batch_sizes,
  1312. hx,
  1313. self._flat_weights, # type: ignore[arg-type]
  1314. self.bias,
  1315. self.num_layers,
  1316. self.dropout,
  1317. self.training,
  1318. self.bidirectional,
  1319. )
  1320. output = result[0]
  1321. hidden = result[1]
  1322. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  1323. if isinstance(orig_input, PackedSequence):
  1324. output_packed = PackedSequence(
  1325. output,
  1326. # pyrefly: ignore [bad-argument-type]
  1327. batch_sizes,
  1328. sorted_indices,
  1329. unsorted_indices,
  1330. )
  1331. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  1332. else:
  1333. if not is_batched: # type: ignore[possibly-undefined]
  1334. output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
  1335. hidden = hidden.squeeze(1)
  1336. return output, self.permute_hidden(hidden, unsorted_indices)
  1337. class RNNCellBase(Module):
  1338. __constants__ = ["input_size", "hidden_size", "bias"]
  1339. input_size: int
  1340. hidden_size: int
  1341. bias: bool
  1342. weight_ih: Tensor
  1343. weight_hh: Tensor
  1344. # WARNING: bias_ih and bias_hh purposely not defined here.
  1345. # See https://github.com/pytorch/pytorch/issues/39670
  1346. def __init__(
  1347. self,
  1348. input_size: int,
  1349. hidden_size: int,
  1350. bias: bool,
  1351. num_chunks: int,
  1352. device=None,
  1353. dtype=None,
  1354. ) -> None:
  1355. factory_kwargs = {"device": device, "dtype": dtype}
  1356. super().__init__()
  1357. self.input_size = input_size
  1358. self.hidden_size = hidden_size
  1359. self.bias = bias
  1360. self.weight_ih = Parameter(
  1361. torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs)
  1362. )
  1363. self.weight_hh = Parameter(
  1364. torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs)
  1365. )
  1366. if bias:
  1367. self.bias_ih = Parameter(
  1368. torch.empty(num_chunks * hidden_size, **factory_kwargs)
  1369. )
  1370. self.bias_hh = Parameter(
  1371. torch.empty(num_chunks * hidden_size, **factory_kwargs)
  1372. )
  1373. else:
  1374. self.register_parameter("bias_ih", None)
  1375. self.register_parameter("bias_hh", None)
  1376. self.reset_parameters()
  1377. def extra_repr(self) -> str:
  1378. s = "{input_size}, {hidden_size}"
  1379. if "bias" in self.__dict__ and self.bias is not True:
  1380. s += ", bias={bias}"
  1381. if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh":
  1382. s += ", nonlinearity={nonlinearity}"
  1383. return s.format(**self.__dict__)
  1384. def reset_parameters(self) -> None:
  1385. stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
  1386. for weight in self.parameters():
  1387. init.uniform_(weight, -stdv, stdv)
  1388. class RNNCell(RNNCellBase):
  1389. r"""An Elman RNN cell with tanh or ReLU non-linearity.
  1390. .. math::
  1391. h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
  1392. If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
  1393. Args:
  1394. input_size: The number of expected features in the input `x`
  1395. hidden_size: The number of features in the hidden state `h`
  1396. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  1397. Default: ``True``
  1398. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  1399. Inputs: input, hidden
  1400. - **input**: tensor containing input features
  1401. - **hidden**: tensor containing the initial hidden state
  1402. Defaults to zero if not provided.
  1403. Outputs: h'
  1404. - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
  1405. for each element in the batch
  1406. Shape:
  1407. - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
  1408. :math:`H_{in}` = `input_size`.
  1409. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
  1410. state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
  1411. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
  1412. Attributes:
  1413. weight_ih: the learnable input-hidden weights, of shape
  1414. `(hidden_size, input_size)`
  1415. weight_hh: the learnable hidden-hidden weights, of shape
  1416. `(hidden_size, hidden_size)`
  1417. bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
  1418. bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
  1419. .. note::
  1420. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1421. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1422. Examples::
  1423. >>> rnn = nn.RNNCell(10, 20)
  1424. >>> input = torch.randn(6, 3, 10)
  1425. >>> hx = torch.randn(3, 20)
  1426. >>> output = []
  1427. >>> for i in range(6):
  1428. ... hx = rnn(input[i], hx)
  1429. ... output.append(hx)
  1430. """
  1431. __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"]
  1432. nonlinearity: str
  1433. def __init__(
  1434. self,
  1435. input_size: int,
  1436. hidden_size: int,
  1437. bias: bool = True,
  1438. nonlinearity: str = "tanh",
  1439. device=None,
  1440. dtype=None,
  1441. ) -> None:
  1442. factory_kwargs = {"device": device, "dtype": dtype}
  1443. super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
  1444. self.nonlinearity = nonlinearity
  1445. def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor:
  1446. if input.dim() not in (1, 2):
  1447. raise ValueError(
  1448. f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
  1449. )
  1450. if hx is not None and hx.dim() not in (1, 2):
  1451. raise ValueError(
  1452. f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
  1453. )
  1454. is_batched = input.dim() == 2
  1455. if not is_batched:
  1456. input = input.unsqueeze(0)
  1457. if hx is None:
  1458. hx = torch.zeros(
  1459. input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
  1460. )
  1461. else:
  1462. hx = hx.unsqueeze(0) if not is_batched else hx
  1463. if self.nonlinearity == "tanh":
  1464. ret = _VF.rnn_tanh_cell(
  1465. input,
  1466. hx,
  1467. self.weight_ih,
  1468. self.weight_hh,
  1469. self.bias_ih,
  1470. self.bias_hh,
  1471. )
  1472. elif self.nonlinearity == "relu":
  1473. ret = _VF.rnn_relu_cell(
  1474. input,
  1475. hx,
  1476. self.weight_ih,
  1477. self.weight_hh,
  1478. self.bias_ih,
  1479. self.bias_hh,
  1480. )
  1481. else:
  1482. ret = input # TODO: remove when jit supports exception flow
  1483. raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
  1484. if not is_batched:
  1485. ret = ret.squeeze(0)
  1486. return ret
  1487. class LSTMCell(RNNCellBase):
  1488. r"""A long short-term memory (LSTM) cell.
  1489. .. math::
  1490. \begin{array}{ll}
  1491. i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
  1492. f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
  1493. g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
  1494. o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
  1495. c' = f \odot c + i \odot g \\
  1496. h' = o \odot \tanh(c') \\
  1497. \end{array}
  1498. where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  1499. Args:
  1500. input_size: The number of expected features in the input `x`
  1501. hidden_size: The number of features in the hidden state `h`
  1502. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  1503. `b_hh`. Default: ``True``
  1504. Inputs: input, (h_0, c_0)
  1505. - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features
  1506. - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state
  1507. - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state
  1508. If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
  1509. Outputs: (h_1, c_1)
  1510. - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state
  1511. - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state
  1512. Attributes:
  1513. weight_ih: the learnable input-hidden weights, of shape
  1514. `(4*hidden_size, input_size)`
  1515. weight_hh: the learnable hidden-hidden weights, of shape
  1516. `(4*hidden_size, hidden_size)`
  1517. bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
  1518. bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
  1519. .. note::
  1520. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1521. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1522. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  1523. Examples::
  1524. >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
  1525. >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
  1526. >>> hx = torch.randn(3, 20) # (batch, hidden_size)
  1527. >>> cx = torch.randn(3, 20)
  1528. >>> output = []
  1529. >>> for i in range(input.size()[0]):
  1530. ... hx, cx = rnn(input[i], (hx, cx))
  1531. ... output.append(hx)
  1532. >>> output = torch.stack(output, dim=0)
  1533. """
  1534. def __init__(
  1535. self,
  1536. input_size: int,
  1537. hidden_size: int,
  1538. bias: bool = True,
  1539. device=None,
  1540. dtype=None,
  1541. ) -> None:
  1542. factory_kwargs = {"device": device, "dtype": dtype}
  1543. super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
  1544. def forward(
  1545. self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None
  1546. ) -> tuple[Tensor, Tensor]:
  1547. if input.dim() not in (1, 2):
  1548. raise ValueError(
  1549. f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
  1550. )
  1551. if hx is not None:
  1552. for idx, value in enumerate(hx):
  1553. if value.dim() not in (1, 2):
  1554. raise ValueError(
  1555. f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead"
  1556. )
  1557. is_batched = input.dim() == 2
  1558. if not is_batched:
  1559. input = input.unsqueeze(0)
  1560. if hx is None:
  1561. zeros = torch.zeros(
  1562. input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
  1563. )
  1564. hx = (zeros, zeros)
  1565. else:
  1566. hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
  1567. ret = _VF.lstm_cell(
  1568. input,
  1569. hx,
  1570. self.weight_ih,
  1571. self.weight_hh,
  1572. self.bias_ih,
  1573. self.bias_hh,
  1574. )
  1575. if not is_batched:
  1576. ret = (ret[0].squeeze(0), ret[1].squeeze(0))
  1577. return ret
  1578. class GRUCell(RNNCellBase):
  1579. r"""A gated recurrent unit (GRU) cell.
  1580. .. math::
  1581. \begin{array}{ll}
  1582. r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
  1583. z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
  1584. n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\
  1585. h' = (1 - z) \odot n + z \odot h
  1586. \end{array}
  1587. where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  1588. Args:
  1589. input_size: The number of expected features in the input `x`
  1590. hidden_size: The number of features in the hidden state `h`
  1591. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  1592. `b_hh`. Default: ``True``
  1593. Inputs: input, hidden
  1594. - **input** : tensor containing input features
  1595. - **hidden** : tensor containing the initial hidden
  1596. state for each element in the batch.
  1597. Defaults to zero if not provided.
  1598. Outputs: h'
  1599. - **h'** : tensor containing the next hidden state
  1600. for each element in the batch
  1601. Shape:
  1602. - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
  1603. :math:`H_{in}` = `input_size`.
  1604. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
  1605. state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
  1606. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
  1607. Attributes:
  1608. weight_ih: the learnable input-hidden weights, of shape
  1609. `(3*hidden_size, input_size)`
  1610. weight_hh: the learnable hidden-hidden weights, of shape
  1611. `(3*hidden_size, hidden_size)`
  1612. bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
  1613. bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
  1614. .. note::
  1615. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1616. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1617. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  1618. Examples::
  1619. >>> rnn = nn.GRUCell(10, 20)
  1620. >>> input = torch.randn(6, 3, 10)
  1621. >>> hx = torch.randn(3, 20)
  1622. >>> output = []
  1623. >>> for i in range(6):
  1624. ... hx = rnn(input[i], hx)
  1625. ... output.append(hx)
  1626. """
  1627. def __init__(
  1628. self,
  1629. input_size: int,
  1630. hidden_size: int,
  1631. bias: bool = True,
  1632. device=None,
  1633. dtype=None,
  1634. ) -> None:
  1635. factory_kwargs = {"device": device, "dtype": dtype}
  1636. super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
  1637. def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor:
  1638. if input.dim() not in (1, 2):
  1639. raise ValueError(
  1640. f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
  1641. )
  1642. if hx is not None and hx.dim() not in (1, 2):
  1643. raise ValueError(
  1644. f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
  1645. )
  1646. is_batched = input.dim() == 2
  1647. if not is_batched:
  1648. input = input.unsqueeze(0)
  1649. if hx is None:
  1650. hx = torch.zeros(
  1651. input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
  1652. )
  1653. else:
  1654. hx = hx.unsqueeze(0) if not is_batched else hx
  1655. ret = _VF.gru_cell(
  1656. input,
  1657. hx,
  1658. self.weight_ih,
  1659. self.weight_hh,
  1660. self.bias_ih,
  1661. self.bias_hh,
  1662. )
  1663. if not is_batched:
  1664. ret = ret.squeeze(0)
  1665. return ret