container.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import operator
  4. from collections import abc as container_abcs, OrderedDict
  5. from itertools import chain, islice
  6. from typing import Any, overload, TYPE_CHECKING, TypeVar
  7. from typing_extensions import deprecated, Self
  8. import torch
  9. from torch._jit_internal import _copy_to_script_wrapper
  10. from torch.nn.parameter import Parameter
  11. from .module import Module
  12. if TYPE_CHECKING:
  13. from collections.abc import Iterable, Iterator, Mapping
  14. __all__ = [
  15. "Container",
  16. "Sequential",
  17. "ModuleList",
  18. "ModuleDict",
  19. "ParameterList",
  20. "ParameterDict",
  21. ]
  22. T = TypeVar("T", bound=Module)
  23. _V = TypeVar("_V")
  24. # Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
  25. def _addindent(s_, numSpaces):
  26. s = s_.split("\n")
  27. # don't do anything for single-line stuff
  28. if len(s) == 1:
  29. return s_
  30. first = s.pop(0)
  31. s = [(numSpaces * " ") + line for line in s]
  32. s = "\n".join(s)
  33. s = first + "\n" + s
  34. return s
  35. @deprecated(
  36. "`nn.Container` is deprecated. "
  37. "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.",
  38. category=FutureWarning,
  39. )
  40. class Container(Module):
  41. def __init__(self, **kwargs: Any) -> None:
  42. super().__init__()
  43. for key, value in kwargs.items():
  44. self.add_module(key, value)
  45. class Sequential(Module):
  46. r"""A sequential container.
  47. Modules will be added to it in the order they are passed in the
  48. constructor. Alternatively, an ``OrderedDict`` of modules can be
  49. passed in. The ``forward()`` method of ``Sequential`` accepts any
  50. input and forwards it to the first module it contains. It then
  51. "chains" outputs to inputs sequentially for each subsequent module,
  52. finally returning the output of the last module.
  53. The value a ``Sequential`` provides over manually calling a sequence
  54. of modules is that it allows treating the whole container as a
  55. single module, such that performing a transformation on the
  56. ``Sequential`` applies to each of the modules it stores (which are
  57. each a registered submodule of the ``Sequential``).
  58. What's the difference between a ``Sequential`` and a
  59. :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
  60. sounds like--a list for storing ``Module`` s! On the other hand,
  61. the layers in a ``Sequential`` are connected in a cascading way.
  62. Example::
  63. # Using Sequential to create a small model. When `model` is run,
  64. # input will first be passed to `Conv2d(1,20,5)`. The output of
  65. # `Conv2d(1,20,5)` will be used as the input to the first
  66. # `ReLU`; the output of the first `ReLU` will become the input
  67. # for `Conv2d(20,64,5)`. Finally, the output of
  68. # `Conv2d(20,64,5)` will be used as input to the second `ReLU`
  69. model = nn.Sequential(
  70. nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()
  71. )
  72. # Using Sequential with OrderedDict. This is functionally the
  73. # same as the above code
  74. model = nn.Sequential(
  75. OrderedDict(
  76. [
  77. ("conv1", nn.Conv2d(1, 20, 5)),
  78. ("relu1", nn.ReLU()),
  79. ("conv2", nn.Conv2d(20, 64, 5)),
  80. ("relu2", nn.ReLU()),
  81. ]
  82. )
  83. )
  84. """
  85. _modules: dict[str, Module] # type: ignore[assignment]
  86. @overload
  87. def __init__(self, *args: Module) -> None: ...
  88. @overload
  89. # pyrefly: ignore [inconsistent-overload]
  90. def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
  91. def __init__(self, *args):
  92. super().__init__()
  93. if len(args) == 1 and isinstance(args[0], OrderedDict):
  94. for key, module in args[0].items():
  95. self.add_module(key, module)
  96. else:
  97. for idx, module in enumerate(args):
  98. self.add_module(str(idx), module)
  99. def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V:
  100. """Get the idx-th item of the iterator."""
  101. size = len(self)
  102. idx = operator.index(idx)
  103. if not -size <= idx < size:
  104. raise IndexError(f"index {idx} is out of range")
  105. idx %= size
  106. return next(islice(iterator, idx, None))
  107. @_copy_to_script_wrapper
  108. def __getitem__(self, idx: slice | int) -> Sequential | Module:
  109. if isinstance(idx, slice):
  110. return self.__class__(OrderedDict(list(self._modules.items())[idx]))
  111. else:
  112. return self._get_item_by_idx(self._modules.values(), idx)
  113. def __setitem__(self, idx: int, module: Module) -> None:
  114. key: str = self._get_item_by_idx(self._modules.keys(), idx)
  115. return setattr(self, key, module)
  116. def __delitem__(self, idx: slice | int) -> None:
  117. if isinstance(idx, slice):
  118. for key in list(self._modules.keys())[idx]:
  119. delattr(self, key)
  120. else:
  121. key = self._get_item_by_idx(self._modules.keys(), idx)
  122. delattr(self, key)
  123. # To preserve numbering
  124. str_indices = [str(i) for i in range(len(self._modules))]
  125. self._modules = OrderedDict(
  126. zip(str_indices, self._modules.values(), strict=True)
  127. )
  128. @_copy_to_script_wrapper
  129. def __len__(self) -> int:
  130. return len(self._modules)
  131. def __add__(self, other) -> Sequential:
  132. if isinstance(other, Sequential):
  133. ret = Sequential()
  134. for layer in self:
  135. ret.append(layer)
  136. for layer in other:
  137. ret.append(layer)
  138. return ret
  139. else:
  140. raise ValueError(
  141. "add operator supports only objects "
  142. f"of Sequential class, but {str(type(other))} is given."
  143. )
  144. def pop(self, key: int | slice) -> Module:
  145. """
  146. Pop ``key`` from self.
  147. """
  148. v = self[key]
  149. del self[key]
  150. return v
  151. def __iadd__(self, other) -> Self:
  152. if isinstance(other, Sequential):
  153. offset = len(self)
  154. for i, module in enumerate(other):
  155. self.add_module(str(i + offset), module)
  156. return self
  157. else:
  158. raise ValueError(
  159. "add operator supports only objects "
  160. f"of Sequential class, but {str(type(other))} is given."
  161. )
  162. def __mul__(self, other: int) -> Sequential:
  163. if not isinstance(other, int):
  164. raise TypeError(
  165. f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
  166. )
  167. elif other <= 0:
  168. raise ValueError(
  169. f"Non-positive multiplication factor {other} for {type(self)}"
  170. )
  171. else:
  172. combined = Sequential()
  173. offset = 0
  174. for _ in range(other):
  175. for module in self:
  176. combined.add_module(str(offset), module)
  177. offset += 1
  178. return combined
  179. def __rmul__(self, other: int) -> Sequential:
  180. return self.__mul__(other)
  181. def __imul__(self, other: int) -> Self:
  182. if not isinstance(other, int):
  183. raise TypeError(
  184. f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
  185. )
  186. elif other <= 0:
  187. raise ValueError(
  188. f"Non-positive multiplication factor {other} for {type(self)}"
  189. )
  190. else:
  191. len_original = len(self)
  192. offset = len(self)
  193. for _ in range(other - 1):
  194. for i in range(len_original):
  195. self.add_module(str(i + offset), self._modules[str(i)])
  196. offset += len_original
  197. return self
  198. @_copy_to_script_wrapper
  199. def __dir__(self) -> list[str]:
  200. keys = super().__dir__()
  201. keys = [key for key in keys if not key.isdigit()]
  202. return keys
  203. @_copy_to_script_wrapper
  204. def __iter__(self) -> Iterator[Module]:
  205. return iter(self._modules.values())
  206. # NB: We can't really type check this function as the type of input
  207. # may change dynamically (as is tested in
  208. # TestScript.test_sequential_intermediary_types). Cannot annotate
  209. # with Any as TorchScript expects a more precise type
  210. def forward(self, input):
  211. """
  212. Runs the forward pass.
  213. """
  214. for module in self:
  215. input = module(input)
  216. return input
  217. def append(self, module: Module) -> Self:
  218. r"""Append a given module to the end.
  219. Args:
  220. module (nn.Module): module to append
  221. Example::
  222. >>> import torch.nn as nn
  223. >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
  224. >>> n.append(nn.Linear(3, 4))
  225. Sequential(
  226. (0): Linear(in_features=1, out_features=2, bias=True)
  227. (1): Linear(in_features=2, out_features=3, bias=True)
  228. (2): Linear(in_features=3, out_features=4, bias=True)
  229. )
  230. """
  231. self.add_module(str(len(self)), module)
  232. return self
  233. def insert(self, index: int, module: Module) -> Self:
  234. """
  235. Inserts a module into the Sequential container at the specified index.
  236. Args:
  237. index (int): The index to insert the module.
  238. module (Module): The module to be inserted.
  239. Example::
  240. >>> import torch.nn as nn
  241. >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
  242. >>> n.insert(0, nn.Linear(3, 4))
  243. Sequential(
  244. (0): Linear(in_features=3, out_features=4, bias=True)
  245. (1): Linear(in_features=1, out_features=2, bias=True)
  246. (2): Linear(in_features=2, out_features=3, bias=True)
  247. )
  248. """
  249. if not isinstance(module, Module):
  250. raise AssertionError(f"module should be of type: {Module}")
  251. n = len(self._modules)
  252. if not (-n <= index <= n):
  253. raise IndexError(f"Index out of range: {index}")
  254. if index < 0:
  255. index += n
  256. for i in range(n, index, -1):
  257. self._modules[str(i)] = self._modules[str(i - 1)]
  258. self._modules[str(index)] = module
  259. return self
  260. def extend(self, sequential: Iterable[Module]) -> Self:
  261. """
  262. Extends the current Sequential container with layers from another Sequential container.
  263. Args:
  264. sequential (Sequential): A Sequential container whose layers will be added to the current container.
  265. Example::
  266. >>> import torch.nn as nn
  267. >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
  268. >>> other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5))
  269. >>> n.extend(other) # or `n + other`
  270. Sequential(
  271. (0): Linear(in_features=1, out_features=2, bias=True)
  272. (1): Linear(in_features=2, out_features=3, bias=True)
  273. (2): Linear(in_features=3, out_features=4, bias=True)
  274. (3): Linear(in_features=4, out_features=5, bias=True)
  275. )
  276. """
  277. for layer in sequential:
  278. self.append(layer)
  279. return self
  280. class ModuleList(Module):
  281. r"""Holds submodules in a list.
  282. :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
  283. modules it contains are properly registered, and will be visible by all
  284. :class:`~torch.nn.Module` methods.
  285. Args:
  286. modules (iterable, optional): an iterable of modules to add
  287. Example::
  288. class MyModule(nn.Module):
  289. def __init__(self) -> None:
  290. super().__init__()
  291. self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
  292. def forward(self, x):
  293. # ModuleList can act as an iterable, or be indexed using ints
  294. for i, l in enumerate(self.linears):
  295. x = self.linears[i // 2](x) + l(x)
  296. return x
  297. """
  298. _modules: dict[str, Module] # type: ignore[assignment]
  299. def __init__(self, modules: Iterable[Module] | None = None) -> None:
  300. super().__init__()
  301. if modules is not None:
  302. self += modules
  303. def _get_abs_string_index(self, idx):
  304. """Get the absolute index for the list of modules."""
  305. idx = operator.index(idx)
  306. if not (-len(self) <= idx < len(self)):
  307. raise IndexError(f"index {idx} is out of range")
  308. if idx < 0:
  309. idx += len(self)
  310. return str(idx)
  311. @overload
  312. def __getitem__(self, idx: slice) -> ModuleList: ...
  313. @overload
  314. def __getitem__(self, idx: int) -> Module: ...
  315. @_copy_to_script_wrapper
  316. def __getitem__(self, idx: int | slice) -> Module | ModuleList:
  317. if isinstance(idx, slice):
  318. return self.__class__(list(self._modules.values())[idx])
  319. else:
  320. return self._modules[self._get_abs_string_index(idx)]
  321. def __setitem__(self, idx: int, module: Module) -> None:
  322. idx = self._get_abs_string_index(idx)
  323. return setattr(self, str(idx), module)
  324. def __delitem__(self, idx: int | slice) -> None:
  325. if isinstance(idx, slice):
  326. for k in range(len(self._modules))[idx]:
  327. delattr(self, str(k))
  328. else:
  329. delattr(self, self._get_abs_string_index(idx))
  330. # To preserve numbering, self._modules is being reconstructed with modules after deletion
  331. str_indices = [str(i) for i in range(len(self._modules))]
  332. self._modules = OrderedDict(
  333. zip(str_indices, self._modules.values(), strict=True)
  334. )
  335. @_copy_to_script_wrapper
  336. def __len__(self) -> int:
  337. return len(self._modules)
  338. @_copy_to_script_wrapper
  339. def __iter__(self) -> Iterator[Module]:
  340. return iter(self._modules.values())
  341. def __iadd__(self, modules: Iterable[Module]) -> Self:
  342. return self.extend(modules)
  343. def __add__(self, other: Iterable[Module]) -> ModuleList:
  344. combined = ModuleList()
  345. for i, module in enumerate(chain(self, other)):
  346. combined.add_module(str(i), module)
  347. return combined
  348. def __repr__(self) -> str:
  349. """Return a custom repr for ModuleList that compresses repeated module representations."""
  350. list_of_reprs = [repr(item) for item in self]
  351. if len(list_of_reprs) == 0:
  352. return self._get_name() + "()"
  353. start_end_indices = [[0, 0]]
  354. repeated_blocks = [list_of_reprs[0]]
  355. for i, r in enumerate(list_of_reprs[1:], 1):
  356. if r == repeated_blocks[-1]:
  357. start_end_indices[-1][1] += 1
  358. continue
  359. start_end_indices.append([i, i])
  360. repeated_blocks.append(r)
  361. lines = []
  362. main_str = self._get_name() + "("
  363. for (start_id, end_id), b in zip(
  364. start_end_indices, repeated_blocks, strict=True
  365. ):
  366. local_repr = f"({start_id}): {b}" # default repr
  367. if start_id != end_id:
  368. n = end_id - start_id + 1
  369. local_repr = f"({start_id}-{end_id}): {n} x {b}"
  370. local_repr = _addindent(local_repr, 2)
  371. lines.append(local_repr)
  372. main_str += "\n " + "\n ".join(lines) + "\n"
  373. main_str += ")"
  374. return main_str
  375. @_copy_to_script_wrapper
  376. def __dir__(self) -> list[str]:
  377. keys = super().__dir__()
  378. keys = [key for key in keys if not key.isdigit()]
  379. return keys
  380. def insert(self, index: int, module: Module) -> None:
  381. r"""Insert a given module before a given index in the list.
  382. Args:
  383. index (int): index to insert.
  384. module (nn.Module): module to insert
  385. """
  386. for i in range(len(self._modules), index, -1):
  387. self._modules[str(i)] = self._modules[str(i - 1)]
  388. self._modules[str(index)] = module
  389. def append(self, module: Module) -> Self:
  390. r"""Append a given module to the end of the list.
  391. Args:
  392. module (nn.Module): module to append
  393. """
  394. self.add_module(str(len(self)), module)
  395. return self
  396. def pop(self, key: int | slice) -> Module:
  397. v = self[key]
  398. del self[key]
  399. return v
  400. def extend(self, modules: Iterable[Module]) -> Self:
  401. r"""Append modules from a Python iterable to the end of the list.
  402. Args:
  403. modules (iterable): iterable of modules to append
  404. """
  405. if not isinstance(modules, container_abcs.Iterable):
  406. raise TypeError(
  407. "ModuleList.extend should be called with an "
  408. "iterable, but got " + type(modules).__name__
  409. )
  410. offset = len(self)
  411. for i, module in enumerate(modules):
  412. self.add_module(str(offset + i), module)
  413. return self
  414. # remove forward altogether to fallback on Module's _forward_unimplemented
  415. class ModuleDict(Module):
  416. r"""Holds submodules in a dictionary.
  417. :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
  418. but modules it contains are properly registered, and will be visible by all
  419. :class:`~torch.nn.Module` methods.
  420. :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
  421. * the order of insertion, and
  422. * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
  423. ``OrderedDict``, ``dict`` or another
  424. :class:`~torch.nn.ModuleDict` (the argument to
  425. :meth:`~torch.nn.ModuleDict.update`).
  426. Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
  427. types does not preserve the order of the merged mapping.
  428. Args:
  429. modules (iterable, optional): a mapping (dictionary) of (string: module)
  430. or an iterable of key-value pairs of type (string, module)
  431. Example::
  432. class MyModule(nn.Module):
  433. def __init__(self) -> None:
  434. super().__init__()
  435. self.choices = nn.ModuleDict(
  436. {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)}
  437. )
  438. self.activations = nn.ModuleDict(
  439. [["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]]
  440. )
  441. def forward(self, x, choice, act):
  442. x = self.choices[choice](x)
  443. x = self.activations[act](x)
  444. return x
  445. """
  446. _modules: dict[str, Module] # type: ignore[assignment]
  447. def __init__(self, modules: Mapping[str, Module] | None = None) -> None:
  448. super().__init__()
  449. if modules is not None:
  450. self.update(modules)
  451. @_copy_to_script_wrapper
  452. def __getitem__(self, key: str) -> Module:
  453. return self._modules[key]
  454. def __setitem__(self, key: str, module: Module) -> None:
  455. self.add_module(key, module)
  456. def __delitem__(self, key: str) -> None:
  457. del self._modules[key]
  458. @_copy_to_script_wrapper
  459. def __len__(self) -> int:
  460. return len(self._modules)
  461. @_copy_to_script_wrapper
  462. def __iter__(self) -> Iterator[str]:
  463. return iter(self._modules)
  464. @_copy_to_script_wrapper
  465. def __contains__(self, key: str) -> bool:
  466. return key in self._modules
  467. def clear(self) -> None:
  468. """Remove all items from the ModuleDict."""
  469. self._modules.clear()
  470. def pop(self, key: str) -> Module:
  471. r"""Remove key from the ModuleDict and return its module.
  472. Args:
  473. key (str): key to pop from the ModuleDict
  474. """
  475. v = self[key]
  476. del self[key]
  477. return v
  478. @_copy_to_script_wrapper
  479. def keys(self) -> container_abcs.KeysView[str]:
  480. r"""Return an iterable of the ModuleDict keys."""
  481. return self._modules.keys()
  482. @_copy_to_script_wrapper
  483. def items(self) -> container_abcs.ItemsView[str, Module]:
  484. r"""Return an iterable of the ModuleDict key/value pairs."""
  485. return self._modules.items()
  486. @_copy_to_script_wrapper
  487. def values(self) -> container_abcs.ValuesView[Module]:
  488. r"""Return an iterable of the ModuleDict values."""
  489. return self._modules.values()
  490. def update(self, modules: Mapping[str, Module]) -> None:
  491. r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
  492. .. note::
  493. If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
  494. an iterable of key-value pairs, the order of new elements in it is preserved.
  495. Args:
  496. modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
  497. or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
  498. """
  499. if not isinstance(modules, container_abcs.Iterable):
  500. raise TypeError(
  501. "ModuleDict.update should be called with an "
  502. "iterable of key/value pairs, but got " + type(modules).__name__
  503. )
  504. if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
  505. for key, module in modules.items():
  506. self[key] = module
  507. else:
  508. # modules here can be a list with two items
  509. for j, m in enumerate(modules):
  510. if not isinstance(m, container_abcs.Iterable):
  511. raise TypeError(
  512. "ModuleDict update sequence element "
  513. "#" + str(j) + " should be Iterable; is" + type(m).__name__
  514. )
  515. # pyrefly: ignore [bad-argument-type]
  516. if not len(m) == 2:
  517. raise ValueError(
  518. "ModuleDict update sequence element "
  519. # pyrefly: ignore [bad-argument-type]
  520. "#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
  521. )
  522. # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
  523. # that's too cumbersome to type correctly with overloads, so we add an ignore here
  524. self[m[0]] = m[1] # type: ignore[assignment]
  525. # remove forward altogether to fallback on Module's _forward_unimplemented
  526. class ParameterList(Module):
  527. r"""Holds parameters in a list.
  528. :class:`~torch.nn.ParameterList` can be used like a regular Python
  529. list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
  530. and will be visible by all :class:`~torch.nn.Module` methods.
  531. Note that the constructor, assigning an element of the list, the
  532. :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend`
  533. method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
  534. Args:
  535. parameters (iterable, optional): an iterable of elements to add to the list.
  536. Example::
  537. class MyModule(nn.Module):
  538. def __init__(self) -> None:
  539. super().__init__()
  540. self.params = nn.ParameterList(
  541. [nn.Parameter(torch.randn(10, 10)) for i in range(10)]
  542. )
  543. def forward(self, x):
  544. # ParameterList can act as an iterable, or be indexed using ints
  545. for i, p in enumerate(self.params):
  546. x = self.params[i // 2].mm(x) + p.mm(x)
  547. return x
  548. """
  549. def __init__(self, values: Iterable[Any] | None = None) -> None:
  550. super().__init__()
  551. self._size = 0
  552. if values is not None:
  553. self += values
  554. def _get_abs_string_index(self, idx):
  555. """Get the absolute index for the list of modules."""
  556. idx = operator.index(idx)
  557. if not (-len(self) <= idx < len(self)):
  558. raise IndexError(f"index {idx} is out of range")
  559. if idx < 0:
  560. idx += len(self)
  561. return str(idx)
  562. @overload
  563. def __getitem__(self, idx: int) -> Any: ...
  564. @overload
  565. # pyrefly: ignore [inconsistent-overload]
  566. def __getitem__(self: T, idx: slice) -> T: ...
  567. def __getitem__(self, idx):
  568. if isinstance(idx, slice):
  569. start, stop, step = idx.indices(len(self))
  570. out = self.__class__()
  571. for i in range(start, stop, step):
  572. out.append(self[i])
  573. return out
  574. else:
  575. idx = self._get_abs_string_index(idx)
  576. return getattr(self, str(idx))
  577. def __setitem__(self, idx: int, param: Any) -> None:
  578. # Note that all other function that add an entry to the list part of
  579. # the ParameterList end up here. So this is the only place where we need
  580. # to wrap things into Parameter if needed.
  581. # Objects added via setattr() are not in the list part and thus won't
  582. # call into this function.
  583. idx = self._get_abs_string_index(idx)
  584. if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
  585. param = Parameter(param)
  586. return setattr(self, str(idx), param)
  587. def __len__(self) -> int:
  588. return self._size
  589. def __iter__(self) -> Iterator[Any]:
  590. return iter(self[i] for i in range(len(self)))
  591. def __iadd__(self, parameters: Iterable[Any]) -> Self:
  592. return self.extend(parameters)
  593. def __dir__(self) -> list[str]:
  594. keys = super().__dir__()
  595. keys = [key for key in keys if not key.isdigit()]
  596. return keys
  597. def append(self, value: Any) -> Self:
  598. """Append a given value at the end of the list.
  599. Args:
  600. value (Any): value to append
  601. """
  602. new_idx = len(self)
  603. self._size += 1
  604. self[new_idx] = value
  605. return self
  606. def extend(self, values: Iterable[Any]) -> Self:
  607. """Append values from a Python iterable to the end of the list.
  608. Args:
  609. values (iterable): iterable of values to append
  610. """
  611. # Tensor is an iterable but we never want to unpack it here
  612. if not isinstance(values, container_abcs.Iterable) or isinstance(
  613. values, torch.Tensor
  614. ):
  615. raise TypeError(
  616. "ParameterList.extend should be called with an "
  617. "iterable, but got " + type(values).__name__
  618. )
  619. for value in values:
  620. self.append(value)
  621. return self
  622. def extra_repr(self) -> str:
  623. """
  624. Return the extra representation of the module.
  625. """
  626. child_lines = []
  627. for k, p in enumerate(self):
  628. if isinstance(p, torch.Tensor):
  629. size_str = "x".join(str(size) for size in p.size())
  630. if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
  631. device_str = f" ({p.device})"
  632. else:
  633. device_str = ""
  634. parastr = "{} containing: [{} of size {}{}]".format(
  635. "Parameter" if isinstance(p, Parameter) else "Tensor",
  636. p.dtype,
  637. size_str,
  638. device_str,
  639. )
  640. # pyrefly: ignore [bad-argument-type]
  641. child_lines.append(" (" + str(k) + "): " + parastr)
  642. else:
  643. child_lines.append(
  644. # pyrefly: ignore [bad-argument-type]
  645. " (" + str(k) + "): Object of type: " + type(p).__name__
  646. )
  647. tmpstr = "\n".join(child_lines)
  648. return tmpstr
  649. def __call__(self, *args, **kwargs):
  650. raise RuntimeError("ParameterList should not be called.")
  651. class ParameterDict(Module):
  652. r"""Holds parameters in a dictionary.
  653. ParameterDict can be indexed like a regular Python dictionary, but Parameters it
  654. contains are properly registered, and will be visible by all Module methods.
  655. Other objects are treated as would be done by a regular Python dictionary
  656. :class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
  657. :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
  658. types (e.g., Python's plain ``dict``) does not preserve the order of the
  659. merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
  660. will preserve their ordering.
  661. Note that the constructor, assigning an element of the dictionary and the
  662. :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
  663. :class:`~torch.nn.Parameter`.
  664. Args:
  665. values (iterable, optional): a mapping (dictionary) of
  666. (string : Any) or an iterable of key-value pairs
  667. of type (string, Any)
  668. Example::
  669. class MyModule(nn.Module):
  670. def __init__(self) -> None:
  671. super().__init__()
  672. self.params = nn.ParameterDict(
  673. {
  674. "left": nn.Parameter(torch.randn(5, 10)),
  675. "right": nn.Parameter(torch.randn(5, 10)),
  676. }
  677. )
  678. def forward(self, x, choice):
  679. x = self.params[choice].mm(x)
  680. return x
  681. """
  682. def __init__(self, parameters: Any = None) -> None:
  683. super().__init__()
  684. self._keys: dict[str, None] = {}
  685. if parameters is not None:
  686. self.update(parameters)
  687. def _key_to_attr(self, key: str) -> str:
  688. if not isinstance(key, str):
  689. raise TypeError(
  690. "Index given to ParameterDict cannot be used as a key as it is "
  691. f"not a string (type is '{type(key).__name__}'). Open an issue on "
  692. "github if you need non-string keys."
  693. )
  694. else:
  695. # Use the key as-is so that `.named_parameters()` returns the right thing
  696. return key
  697. def __getitem__(self, key: str) -> Any:
  698. attr = self._key_to_attr(key)
  699. return getattr(self, attr)
  700. def __setitem__(self, key: str, value: Any) -> None:
  701. # Note that all other function that add an entry to the dictionary part of
  702. # the ParameterDict end up here. So this is the only place where we need
  703. # to wrap things into Parameter if needed.
  704. # Objects added via setattr() are not in the dictionary part and thus won't
  705. # call into this function.
  706. self._keys[key] = None
  707. attr = self._key_to_attr(key)
  708. if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
  709. value = Parameter(value)
  710. setattr(self, attr, value)
  711. def __delitem__(self, key: str) -> None:
  712. del self._keys[key]
  713. attr = self._key_to_attr(key)
  714. delattr(self, attr)
  715. def __len__(self) -> int:
  716. return len(self._keys)
  717. def __iter__(self) -> Iterator[str]:
  718. return iter(self._keys)
  719. def __reversed__(self) -> Iterator[str]:
  720. return reversed(self._keys)
  721. def copy(self) -> ParameterDict:
  722. """Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
  723. # We have to use an OrderedDict because the ParameterDict constructor
  724. # behaves differently on plain dict vs OrderedDict
  725. return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
  726. def __contains__(self, key: str) -> bool:
  727. return key in self._keys
  728. def setdefault(self, key: str, default: Any | None = None) -> Any:
  729. """Set the default for a key in the Parameterdict.
  730. If key is in the ParameterDict, return its value.
  731. If not, insert `key` with a parameter `default` and return `default`.
  732. `default` defaults to `None`.
  733. Args:
  734. key (str): key to set default for
  735. default (Any): the parameter set to the key
  736. """
  737. if key not in self:
  738. self[key] = default
  739. return self[key]
  740. def clear(self) -> None:
  741. """Remove all items from the ParameterDict."""
  742. for k in self._keys.copy():
  743. del self[k]
  744. def pop(self, key: str) -> Any:
  745. r"""Remove key from the ParameterDict and return its parameter.
  746. Args:
  747. key (str): key to pop from the ParameterDict
  748. """
  749. v = self[key]
  750. del self[key]
  751. return v
  752. def popitem(self) -> tuple[str, Any]:
  753. """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
  754. k, _ = self._keys.popitem()
  755. # We need the key in the _keys to be able to access/del
  756. self._keys[k] = None
  757. val = self[k]
  758. del self[k]
  759. return k, val
  760. def get(self, key: str, default: Any | None = None) -> Any:
  761. r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
  762. Args:
  763. key (str): key to get from the ParameterDict
  764. default (Parameter, optional): value to return if key not present
  765. """
  766. return self[key] if key in self else default # noqa: SIM401
  767. def fromkeys(
  768. self, keys: Iterable[str], default: Any | None = None
  769. ) -> ParameterDict:
  770. r"""Return a new ParameterDict with the keys provided.
  771. Args:
  772. keys (iterable, string): keys to make the new ParameterDict from
  773. default (Parameter, optional): value to set for all keys
  774. """
  775. return ParameterDict((k, default) for k in keys)
  776. def keys(self) -> container_abcs.KeysView[str]:
  777. r"""Return an iterable of the ParameterDict keys."""
  778. return self._keys.keys()
  779. def items(self) -> Iterable[tuple[str, Any]]:
  780. r"""Return an iterable of the ParameterDict key/value pairs."""
  781. return ((k, self[k]) for k in self._keys)
  782. def values(self) -> Iterable[Any]:
  783. r"""Return an iterable of the ParameterDict values."""
  784. return (self[k] for k in self._keys)
  785. def update(self, parameters: Mapping[str, Any] | ParameterDict) -> None:
  786. r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
  787. .. note::
  788. If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
  789. an iterable of key-value pairs, the order of new elements in it is preserved.
  790. Args:
  791. parameters (iterable): a mapping (dictionary) from string to
  792. :class:`~torch.nn.Parameter`, or an iterable of
  793. key-value pairs of type (string, :class:`~torch.nn.Parameter`)
  794. """
  795. if not isinstance(parameters, container_abcs.Iterable):
  796. raise TypeError(
  797. "ParametersDict.update should be called with an "
  798. "iterable of key/value pairs, but got " + type(parameters).__name__
  799. )
  800. if isinstance(parameters, (OrderedDict, ParameterDict)):
  801. for key, parameter in parameters.items():
  802. self[key] = parameter
  803. elif isinstance(parameters, container_abcs.Mapping):
  804. for key, parameter in sorted(parameters.items()):
  805. self[key] = parameter
  806. else:
  807. for j, p in enumerate(parameters):
  808. if not isinstance(p, container_abcs.Iterable):
  809. raise TypeError(
  810. "ParameterDict update sequence element "
  811. "#" + str(j) + " should be Iterable; is" + type(p).__name__
  812. )
  813. # pyrefly: ignore [bad-argument-type]
  814. if not len(p) == 2:
  815. raise ValueError(
  816. "ParameterDict update sequence element "
  817. # pyrefly: ignore [bad-argument-type]
  818. "#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
  819. )
  820. # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
  821. self[p[0]] = p[1] # type: ignore[assignment]
  822. def extra_repr(self) -> str:
  823. child_lines = []
  824. for k, p in self.items():
  825. if isinstance(p, torch.Tensor):
  826. size_str = "x".join(str(size) for size in p.size())
  827. if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
  828. device_str = f" ({p.device})"
  829. else:
  830. device_str = ""
  831. parastr = "{} containing: [{} of size {}{}]".format(
  832. "Parameter" if isinstance(p, Parameter) else "Tensor",
  833. torch.typename(p),
  834. size_str,
  835. device_str,
  836. )
  837. # pyrefly: ignore [bad-argument-type]
  838. child_lines.append(" (" + str(k) + "): " + parastr)
  839. else:
  840. child_lines.append(
  841. # pyrefly: ignore [bad-argument-type]
  842. " (" + str(k) + "): Object of type: " + type(p).__name__
  843. )
  844. tmpstr = "\n".join(child_lines)
  845. return tmpstr
  846. def __call__(self, input):
  847. raise RuntimeError("ParameterDict should not be called.")
  848. def __or__(self, other: ParameterDict) -> ParameterDict:
  849. copy = self.copy()
  850. copy.update(other)
  851. return copy
  852. def __ror__(self, other: ParameterDict) -> ParameterDict:
  853. copy = other.copy()
  854. copy.update(self)
  855. return copy
  856. def __ior__(self, other: ParameterDict) -> Self:
  857. self.update(other)
  858. return self