| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043 |
- # mypy: allow-untyped-defs
- from __future__ import annotations
- import operator
- from collections import abc as container_abcs, OrderedDict
- from itertools import chain, islice
- from typing import Any, overload, TYPE_CHECKING, TypeVar
- from typing_extensions import deprecated, Self
- import torch
- from torch._jit_internal import _copy_to_script_wrapper
- from torch.nn.parameter import Parameter
- from .module import Module
- if TYPE_CHECKING:
- from collections.abc import Iterable, Iterator, Mapping
- __all__ = [
- "Container",
- "Sequential",
- "ModuleList",
- "ModuleDict",
- "ParameterList",
- "ParameterDict",
- ]
- T = TypeVar("T", bound=Module)
- _V = TypeVar("_V")
- # Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
- def _addindent(s_, numSpaces):
- s = s_.split("\n")
- # don't do anything for single-line stuff
- if len(s) == 1:
- return s_
- first = s.pop(0)
- s = [(numSpaces * " ") + line for line in s]
- s = "\n".join(s)
- s = first + "\n" + s
- return s
- @deprecated(
- "`nn.Container` is deprecated. "
- "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.",
- category=FutureWarning,
- )
- class Container(Module):
- def __init__(self, **kwargs: Any) -> None:
- super().__init__()
- for key, value in kwargs.items():
- self.add_module(key, value)
- class Sequential(Module):
- r"""A sequential container.
- Modules will be added to it in the order they are passed in the
- constructor. Alternatively, an ``OrderedDict`` of modules can be
- passed in. The ``forward()`` method of ``Sequential`` accepts any
- input and forwards it to the first module it contains. It then
- "chains" outputs to inputs sequentially for each subsequent module,
- finally returning the output of the last module.
- The value a ``Sequential`` provides over manually calling a sequence
- of modules is that it allows treating the whole container as a
- single module, such that performing a transformation on the
- ``Sequential`` applies to each of the modules it stores (which are
- each a registered submodule of the ``Sequential``).
- What's the difference between a ``Sequential`` and a
- :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
- sounds like--a list for storing ``Module`` s! On the other hand,
- the layers in a ``Sequential`` are connected in a cascading way.
- Example::
- # Using Sequential to create a small model. When `model` is run,
- # input will first be passed to `Conv2d(1,20,5)`. The output of
- # `Conv2d(1,20,5)` will be used as the input to the first
- # `ReLU`; the output of the first `ReLU` will become the input
- # for `Conv2d(20,64,5)`. Finally, the output of
- # `Conv2d(20,64,5)` will be used as input to the second `ReLU`
- model = nn.Sequential(
- nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()
- )
- # Using Sequential with OrderedDict. This is functionally the
- # same as the above code
- model = nn.Sequential(
- OrderedDict(
- [
- ("conv1", nn.Conv2d(1, 20, 5)),
- ("relu1", nn.ReLU()),
- ("conv2", nn.Conv2d(20, 64, 5)),
- ("relu2", nn.ReLU()),
- ]
- )
- )
- """
- _modules: dict[str, Module] # type: ignore[assignment]
- @overload
- def __init__(self, *args: Module) -> None: ...
- @overload
- # pyrefly: ignore [inconsistent-overload]
- def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
- def __init__(self, *args):
- super().__init__()
- if len(args) == 1 and isinstance(args[0], OrderedDict):
- for key, module in args[0].items():
- self.add_module(key, module)
- else:
- for idx, module in enumerate(args):
- self.add_module(str(idx), module)
- def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V:
- """Get the idx-th item of the iterator."""
- size = len(self)
- idx = operator.index(idx)
- if not -size <= idx < size:
- raise IndexError(f"index {idx} is out of range")
- idx %= size
- return next(islice(iterator, idx, None))
- @_copy_to_script_wrapper
- def __getitem__(self, idx: slice | int) -> Sequential | Module:
- if isinstance(idx, slice):
- return self.__class__(OrderedDict(list(self._modules.items())[idx]))
- else:
- return self._get_item_by_idx(self._modules.values(), idx)
- def __setitem__(self, idx: int, module: Module) -> None:
- key: str = self._get_item_by_idx(self._modules.keys(), idx)
- return setattr(self, key, module)
- def __delitem__(self, idx: slice | int) -> None:
- if isinstance(idx, slice):
- for key in list(self._modules.keys())[idx]:
- delattr(self, key)
- else:
- key = self._get_item_by_idx(self._modules.keys(), idx)
- delattr(self, key)
- # To preserve numbering
- str_indices = [str(i) for i in range(len(self._modules))]
- self._modules = OrderedDict(
- zip(str_indices, self._modules.values(), strict=True)
- )
- @_copy_to_script_wrapper
- def __len__(self) -> int:
- return len(self._modules)
- def __add__(self, other) -> Sequential:
- if isinstance(other, Sequential):
- ret = Sequential()
- for layer in self:
- ret.append(layer)
- for layer in other:
- ret.append(layer)
- return ret
- else:
- raise ValueError(
- "add operator supports only objects "
- f"of Sequential class, but {str(type(other))} is given."
- )
- def pop(self, key: int | slice) -> Module:
- """
- Pop ``key`` from self.
- """
- v = self[key]
- del self[key]
- return v
- def __iadd__(self, other) -> Self:
- if isinstance(other, Sequential):
- offset = len(self)
- for i, module in enumerate(other):
- self.add_module(str(i + offset), module)
- return self
- else:
- raise ValueError(
- "add operator supports only objects "
- f"of Sequential class, but {str(type(other))} is given."
- )
- def __mul__(self, other: int) -> Sequential:
- if not isinstance(other, int):
- raise TypeError(
- f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
- )
- elif other <= 0:
- raise ValueError(
- f"Non-positive multiplication factor {other} for {type(self)}"
- )
- else:
- combined = Sequential()
- offset = 0
- for _ in range(other):
- for module in self:
- combined.add_module(str(offset), module)
- offset += 1
- return combined
- def __rmul__(self, other: int) -> Sequential:
- return self.__mul__(other)
- def __imul__(self, other: int) -> Self:
- if not isinstance(other, int):
- raise TypeError(
- f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
- )
- elif other <= 0:
- raise ValueError(
- f"Non-positive multiplication factor {other} for {type(self)}"
- )
- else:
- len_original = len(self)
- offset = len(self)
- for _ in range(other - 1):
- for i in range(len_original):
- self.add_module(str(i + offset), self._modules[str(i)])
- offset += len_original
- return self
- @_copy_to_script_wrapper
- def __dir__(self) -> list[str]:
- keys = super().__dir__()
- keys = [key for key in keys if not key.isdigit()]
- return keys
- @_copy_to_script_wrapper
- def __iter__(self) -> Iterator[Module]:
- return iter(self._modules.values())
- # NB: We can't really type check this function as the type of input
- # may change dynamically (as is tested in
- # TestScript.test_sequential_intermediary_types). Cannot annotate
- # with Any as TorchScript expects a more precise type
- def forward(self, input):
- """
- Runs the forward pass.
- """
- for module in self:
- input = module(input)
- return input
- def append(self, module: Module) -> Self:
- r"""Append a given module to the end.
- Args:
- module (nn.Module): module to append
- Example::
- >>> import torch.nn as nn
- >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
- >>> n.append(nn.Linear(3, 4))
- Sequential(
- (0): Linear(in_features=1, out_features=2, bias=True)
- (1): Linear(in_features=2, out_features=3, bias=True)
- (2): Linear(in_features=3, out_features=4, bias=True)
- )
- """
- self.add_module(str(len(self)), module)
- return self
- def insert(self, index: int, module: Module) -> Self:
- """
- Inserts a module into the Sequential container at the specified index.
- Args:
- index (int): The index to insert the module.
- module (Module): The module to be inserted.
- Example::
- >>> import torch.nn as nn
- >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
- >>> n.insert(0, nn.Linear(3, 4))
- Sequential(
- (0): Linear(in_features=3, out_features=4, bias=True)
- (1): Linear(in_features=1, out_features=2, bias=True)
- (2): Linear(in_features=2, out_features=3, bias=True)
- )
- """
- if not isinstance(module, Module):
- raise AssertionError(f"module should be of type: {Module}")
- n = len(self._modules)
- if not (-n <= index <= n):
- raise IndexError(f"Index out of range: {index}")
- if index < 0:
- index += n
- for i in range(n, index, -1):
- self._modules[str(i)] = self._modules[str(i - 1)]
- self._modules[str(index)] = module
- return self
- def extend(self, sequential: Iterable[Module]) -> Self:
- """
- Extends the current Sequential container with layers from another Sequential container.
- Args:
- sequential (Sequential): A Sequential container whose layers will be added to the current container.
- Example::
- >>> import torch.nn as nn
- >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
- >>> other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5))
- >>> n.extend(other) # or `n + other`
- Sequential(
- (0): Linear(in_features=1, out_features=2, bias=True)
- (1): Linear(in_features=2, out_features=3, bias=True)
- (2): Linear(in_features=3, out_features=4, bias=True)
- (3): Linear(in_features=4, out_features=5, bias=True)
- )
- """
- for layer in sequential:
- self.append(layer)
- return self
- class ModuleList(Module):
- r"""Holds submodules in a list.
- :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
- modules it contains are properly registered, and will be visible by all
- :class:`~torch.nn.Module` methods.
- Args:
- modules (iterable, optional): an iterable of modules to add
- Example::
- class MyModule(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
- def forward(self, x):
- # ModuleList can act as an iterable, or be indexed using ints
- for i, l in enumerate(self.linears):
- x = self.linears[i // 2](x) + l(x)
- return x
- """
- _modules: dict[str, Module] # type: ignore[assignment]
- def __init__(self, modules: Iterable[Module] | None = None) -> None:
- super().__init__()
- if modules is not None:
- self += modules
- def _get_abs_string_index(self, idx):
- """Get the absolute index for the list of modules."""
- idx = operator.index(idx)
- if not (-len(self) <= idx < len(self)):
- raise IndexError(f"index {idx} is out of range")
- if idx < 0:
- idx += len(self)
- return str(idx)
- @overload
- def __getitem__(self, idx: slice) -> ModuleList: ...
- @overload
- def __getitem__(self, idx: int) -> Module: ...
- @_copy_to_script_wrapper
- def __getitem__(self, idx: int | slice) -> Module | ModuleList:
- if isinstance(idx, slice):
- return self.__class__(list(self._modules.values())[idx])
- else:
- return self._modules[self._get_abs_string_index(idx)]
- def __setitem__(self, idx: int, module: Module) -> None:
- idx = self._get_abs_string_index(idx)
- return setattr(self, str(idx), module)
- def __delitem__(self, idx: int | slice) -> None:
- if isinstance(idx, slice):
- for k in range(len(self._modules))[idx]:
- delattr(self, str(k))
- else:
- delattr(self, self._get_abs_string_index(idx))
- # To preserve numbering, self._modules is being reconstructed with modules after deletion
- str_indices = [str(i) for i in range(len(self._modules))]
- self._modules = OrderedDict(
- zip(str_indices, self._modules.values(), strict=True)
- )
- @_copy_to_script_wrapper
- def __len__(self) -> int:
- return len(self._modules)
- @_copy_to_script_wrapper
- def __iter__(self) -> Iterator[Module]:
- return iter(self._modules.values())
- def __iadd__(self, modules: Iterable[Module]) -> Self:
- return self.extend(modules)
- def __add__(self, other: Iterable[Module]) -> ModuleList:
- combined = ModuleList()
- for i, module in enumerate(chain(self, other)):
- combined.add_module(str(i), module)
- return combined
- def __repr__(self) -> str:
- """Return a custom repr for ModuleList that compresses repeated module representations."""
- list_of_reprs = [repr(item) for item in self]
- if len(list_of_reprs) == 0:
- return self._get_name() + "()"
- start_end_indices = [[0, 0]]
- repeated_blocks = [list_of_reprs[0]]
- for i, r in enumerate(list_of_reprs[1:], 1):
- if r == repeated_blocks[-1]:
- start_end_indices[-1][1] += 1
- continue
- start_end_indices.append([i, i])
- repeated_blocks.append(r)
- lines = []
- main_str = self._get_name() + "("
- for (start_id, end_id), b in zip(
- start_end_indices, repeated_blocks, strict=True
- ):
- local_repr = f"({start_id}): {b}" # default repr
- if start_id != end_id:
- n = end_id - start_id + 1
- local_repr = f"({start_id}-{end_id}): {n} x {b}"
- local_repr = _addindent(local_repr, 2)
- lines.append(local_repr)
- main_str += "\n " + "\n ".join(lines) + "\n"
- main_str += ")"
- return main_str
- @_copy_to_script_wrapper
- def __dir__(self) -> list[str]:
- keys = super().__dir__()
- keys = [key for key in keys if not key.isdigit()]
- return keys
- def insert(self, index: int, module: Module) -> None:
- r"""Insert a given module before a given index in the list.
- Args:
- index (int): index to insert.
- module (nn.Module): module to insert
- """
- for i in range(len(self._modules), index, -1):
- self._modules[str(i)] = self._modules[str(i - 1)]
- self._modules[str(index)] = module
- def append(self, module: Module) -> Self:
- r"""Append a given module to the end of the list.
- Args:
- module (nn.Module): module to append
- """
- self.add_module(str(len(self)), module)
- return self
- def pop(self, key: int | slice) -> Module:
- v = self[key]
- del self[key]
- return v
- def extend(self, modules: Iterable[Module]) -> Self:
- r"""Append modules from a Python iterable to the end of the list.
- Args:
- modules (iterable): iterable of modules to append
- """
- if not isinstance(modules, container_abcs.Iterable):
- raise TypeError(
- "ModuleList.extend should be called with an "
- "iterable, but got " + type(modules).__name__
- )
- offset = len(self)
- for i, module in enumerate(modules):
- self.add_module(str(offset + i), module)
- return self
- # remove forward altogether to fallback on Module's _forward_unimplemented
- class ModuleDict(Module):
- r"""Holds submodules in a dictionary.
- :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
- but modules it contains are properly registered, and will be visible by all
- :class:`~torch.nn.Module` methods.
- :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
- * the order of insertion, and
- * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
- ``OrderedDict``, ``dict`` or another
- :class:`~torch.nn.ModuleDict` (the argument to
- :meth:`~torch.nn.ModuleDict.update`).
- Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
- types does not preserve the order of the merged mapping.
- Args:
- modules (iterable, optional): a mapping (dictionary) of (string: module)
- or an iterable of key-value pairs of type (string, module)
- Example::
- class MyModule(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.choices = nn.ModuleDict(
- {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)}
- )
- self.activations = nn.ModuleDict(
- [["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]]
- )
- def forward(self, x, choice, act):
- x = self.choices[choice](x)
- x = self.activations[act](x)
- return x
- """
- _modules: dict[str, Module] # type: ignore[assignment]
- def __init__(self, modules: Mapping[str, Module] | None = None) -> None:
- super().__init__()
- if modules is not None:
- self.update(modules)
- @_copy_to_script_wrapper
- def __getitem__(self, key: str) -> Module:
- return self._modules[key]
- def __setitem__(self, key: str, module: Module) -> None:
- self.add_module(key, module)
- def __delitem__(self, key: str) -> None:
- del self._modules[key]
- @_copy_to_script_wrapper
- def __len__(self) -> int:
- return len(self._modules)
- @_copy_to_script_wrapper
- def __iter__(self) -> Iterator[str]:
- return iter(self._modules)
- @_copy_to_script_wrapper
- def __contains__(self, key: str) -> bool:
- return key in self._modules
- def clear(self) -> None:
- """Remove all items from the ModuleDict."""
- self._modules.clear()
- def pop(self, key: str) -> Module:
- r"""Remove key from the ModuleDict and return its module.
- Args:
- key (str): key to pop from the ModuleDict
- """
- v = self[key]
- del self[key]
- return v
- @_copy_to_script_wrapper
- def keys(self) -> container_abcs.KeysView[str]:
- r"""Return an iterable of the ModuleDict keys."""
- return self._modules.keys()
- @_copy_to_script_wrapper
- def items(self) -> container_abcs.ItemsView[str, Module]:
- r"""Return an iterable of the ModuleDict key/value pairs."""
- return self._modules.items()
- @_copy_to_script_wrapper
- def values(self) -> container_abcs.ValuesView[Module]:
- r"""Return an iterable of the ModuleDict values."""
- return self._modules.values()
- def update(self, modules: Mapping[str, Module]) -> None:
- r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
- .. note::
- If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
- an iterable of key-value pairs, the order of new elements in it is preserved.
- Args:
- modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
- or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
- """
- if not isinstance(modules, container_abcs.Iterable):
- raise TypeError(
- "ModuleDict.update should be called with an "
- "iterable of key/value pairs, but got " + type(modules).__name__
- )
- if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
- for key, module in modules.items():
- self[key] = module
- else:
- # modules here can be a list with two items
- for j, m in enumerate(modules):
- if not isinstance(m, container_abcs.Iterable):
- raise TypeError(
- "ModuleDict update sequence element "
- "#" + str(j) + " should be Iterable; is" + type(m).__name__
- )
- # pyrefly: ignore [bad-argument-type]
- if not len(m) == 2:
- raise ValueError(
- "ModuleDict update sequence element "
- # pyrefly: ignore [bad-argument-type]
- "#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
- )
- # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
- # that's too cumbersome to type correctly with overloads, so we add an ignore here
- self[m[0]] = m[1] # type: ignore[assignment]
- # remove forward altogether to fallback on Module's _forward_unimplemented
- class ParameterList(Module):
- r"""Holds parameters in a list.
- :class:`~torch.nn.ParameterList` can be used like a regular Python
- list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
- and will be visible by all :class:`~torch.nn.Module` methods.
- Note that the constructor, assigning an element of the list, the
- :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend`
- method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
- Args:
- parameters (iterable, optional): an iterable of elements to add to the list.
- Example::
- class MyModule(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.params = nn.ParameterList(
- [nn.Parameter(torch.randn(10, 10)) for i in range(10)]
- )
- def forward(self, x):
- # ParameterList can act as an iterable, or be indexed using ints
- for i, p in enumerate(self.params):
- x = self.params[i // 2].mm(x) + p.mm(x)
- return x
- """
- def __init__(self, values: Iterable[Any] | None = None) -> None:
- super().__init__()
- self._size = 0
- if values is not None:
- self += values
- def _get_abs_string_index(self, idx):
- """Get the absolute index for the list of modules."""
- idx = operator.index(idx)
- if not (-len(self) <= idx < len(self)):
- raise IndexError(f"index {idx} is out of range")
- if idx < 0:
- idx += len(self)
- return str(idx)
- @overload
- def __getitem__(self, idx: int) -> Any: ...
- @overload
- # pyrefly: ignore [inconsistent-overload]
- def __getitem__(self: T, idx: slice) -> T: ...
- def __getitem__(self, idx):
- if isinstance(idx, slice):
- start, stop, step = idx.indices(len(self))
- out = self.__class__()
- for i in range(start, stop, step):
- out.append(self[i])
- return out
- else:
- idx = self._get_abs_string_index(idx)
- return getattr(self, str(idx))
- def __setitem__(self, idx: int, param: Any) -> None:
- # Note that all other function that add an entry to the list part of
- # the ParameterList end up here. So this is the only place where we need
- # to wrap things into Parameter if needed.
- # Objects added via setattr() are not in the list part and thus won't
- # call into this function.
- idx = self._get_abs_string_index(idx)
- if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
- param = Parameter(param)
- return setattr(self, str(idx), param)
- def __len__(self) -> int:
- return self._size
- def __iter__(self) -> Iterator[Any]:
- return iter(self[i] for i in range(len(self)))
- def __iadd__(self, parameters: Iterable[Any]) -> Self:
- return self.extend(parameters)
- def __dir__(self) -> list[str]:
- keys = super().__dir__()
- keys = [key for key in keys if not key.isdigit()]
- return keys
- def append(self, value: Any) -> Self:
- """Append a given value at the end of the list.
- Args:
- value (Any): value to append
- """
- new_idx = len(self)
- self._size += 1
- self[new_idx] = value
- return self
- def extend(self, values: Iterable[Any]) -> Self:
- """Append values from a Python iterable to the end of the list.
- Args:
- values (iterable): iterable of values to append
- """
- # Tensor is an iterable but we never want to unpack it here
- if not isinstance(values, container_abcs.Iterable) or isinstance(
- values, torch.Tensor
- ):
- raise TypeError(
- "ParameterList.extend should be called with an "
- "iterable, but got " + type(values).__name__
- )
- for value in values:
- self.append(value)
- return self
- def extra_repr(self) -> str:
- """
- Return the extra representation of the module.
- """
- child_lines = []
- for k, p in enumerate(self):
- if isinstance(p, torch.Tensor):
- size_str = "x".join(str(size) for size in p.size())
- if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
- device_str = f" ({p.device})"
- else:
- device_str = ""
- parastr = "{} containing: [{} of size {}{}]".format(
- "Parameter" if isinstance(p, Parameter) else "Tensor",
- p.dtype,
- size_str,
- device_str,
- )
- # pyrefly: ignore [bad-argument-type]
- child_lines.append(" (" + str(k) + "): " + parastr)
- else:
- child_lines.append(
- # pyrefly: ignore [bad-argument-type]
- " (" + str(k) + "): Object of type: " + type(p).__name__
- )
- tmpstr = "\n".join(child_lines)
- return tmpstr
- def __call__(self, *args, **kwargs):
- raise RuntimeError("ParameterList should not be called.")
- class ParameterDict(Module):
- r"""Holds parameters in a dictionary.
- ParameterDict can be indexed like a regular Python dictionary, but Parameters it
- contains are properly registered, and will be visible by all Module methods.
- Other objects are treated as would be done by a regular Python dictionary
- :class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
- :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
- types (e.g., Python's plain ``dict``) does not preserve the order of the
- merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
- will preserve their ordering.
- Note that the constructor, assigning an element of the dictionary and the
- :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
- :class:`~torch.nn.Parameter`.
- Args:
- values (iterable, optional): a mapping (dictionary) of
- (string : Any) or an iterable of key-value pairs
- of type (string, Any)
- Example::
- class MyModule(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.params = nn.ParameterDict(
- {
- "left": nn.Parameter(torch.randn(5, 10)),
- "right": nn.Parameter(torch.randn(5, 10)),
- }
- )
- def forward(self, x, choice):
- x = self.params[choice].mm(x)
- return x
- """
- def __init__(self, parameters: Any = None) -> None:
- super().__init__()
- self._keys: dict[str, None] = {}
- if parameters is not None:
- self.update(parameters)
- def _key_to_attr(self, key: str) -> str:
- if not isinstance(key, str):
- raise TypeError(
- "Index given to ParameterDict cannot be used as a key as it is "
- f"not a string (type is '{type(key).__name__}'). Open an issue on "
- "github if you need non-string keys."
- )
- else:
- # Use the key as-is so that `.named_parameters()` returns the right thing
- return key
- def __getitem__(self, key: str) -> Any:
- attr = self._key_to_attr(key)
- return getattr(self, attr)
- def __setitem__(self, key: str, value: Any) -> None:
- # Note that all other function that add an entry to the dictionary part of
- # the ParameterDict end up here. So this is the only place where we need
- # to wrap things into Parameter if needed.
- # Objects added via setattr() are not in the dictionary part and thus won't
- # call into this function.
- self._keys[key] = None
- attr = self._key_to_attr(key)
- if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
- value = Parameter(value)
- setattr(self, attr, value)
- def __delitem__(self, key: str) -> None:
- del self._keys[key]
- attr = self._key_to_attr(key)
- delattr(self, attr)
- def __len__(self) -> int:
- return len(self._keys)
- def __iter__(self) -> Iterator[str]:
- return iter(self._keys)
- def __reversed__(self) -> Iterator[str]:
- return reversed(self._keys)
- def copy(self) -> ParameterDict:
- """Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
- # We have to use an OrderedDict because the ParameterDict constructor
- # behaves differently on plain dict vs OrderedDict
- return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
- def __contains__(self, key: str) -> bool:
- return key in self._keys
- def setdefault(self, key: str, default: Any | None = None) -> Any:
- """Set the default for a key in the Parameterdict.
- If key is in the ParameterDict, return its value.
- If not, insert `key` with a parameter `default` and return `default`.
- `default` defaults to `None`.
- Args:
- key (str): key to set default for
- default (Any): the parameter set to the key
- """
- if key not in self:
- self[key] = default
- return self[key]
- def clear(self) -> None:
- """Remove all items from the ParameterDict."""
- for k in self._keys.copy():
- del self[k]
- def pop(self, key: str) -> Any:
- r"""Remove key from the ParameterDict and return its parameter.
- Args:
- key (str): key to pop from the ParameterDict
- """
- v = self[key]
- del self[key]
- return v
- def popitem(self) -> tuple[str, Any]:
- """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
- k, _ = self._keys.popitem()
- # We need the key in the _keys to be able to access/del
- self._keys[k] = None
- val = self[k]
- del self[k]
- return k, val
- def get(self, key: str, default: Any | None = None) -> Any:
- r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
- Args:
- key (str): key to get from the ParameterDict
- default (Parameter, optional): value to return if key not present
- """
- return self[key] if key in self else default # noqa: SIM401
- def fromkeys(
- self, keys: Iterable[str], default: Any | None = None
- ) -> ParameterDict:
- r"""Return a new ParameterDict with the keys provided.
- Args:
- keys (iterable, string): keys to make the new ParameterDict from
- default (Parameter, optional): value to set for all keys
- """
- return ParameterDict((k, default) for k in keys)
- def keys(self) -> container_abcs.KeysView[str]:
- r"""Return an iterable of the ParameterDict keys."""
- return self._keys.keys()
- def items(self) -> Iterable[tuple[str, Any]]:
- r"""Return an iterable of the ParameterDict key/value pairs."""
- return ((k, self[k]) for k in self._keys)
- def values(self) -> Iterable[Any]:
- r"""Return an iterable of the ParameterDict values."""
- return (self[k] for k in self._keys)
- def update(self, parameters: Mapping[str, Any] | ParameterDict) -> None:
- r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
- .. note::
- If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
- an iterable of key-value pairs, the order of new elements in it is preserved.
- Args:
- parameters (iterable): a mapping (dictionary) from string to
- :class:`~torch.nn.Parameter`, or an iterable of
- key-value pairs of type (string, :class:`~torch.nn.Parameter`)
- """
- if not isinstance(parameters, container_abcs.Iterable):
- raise TypeError(
- "ParametersDict.update should be called with an "
- "iterable of key/value pairs, but got " + type(parameters).__name__
- )
- if isinstance(parameters, (OrderedDict, ParameterDict)):
- for key, parameter in parameters.items():
- self[key] = parameter
- elif isinstance(parameters, container_abcs.Mapping):
- for key, parameter in sorted(parameters.items()):
- self[key] = parameter
- else:
- for j, p in enumerate(parameters):
- if not isinstance(p, container_abcs.Iterable):
- raise TypeError(
- "ParameterDict update sequence element "
- "#" + str(j) + " should be Iterable; is" + type(p).__name__
- )
- # pyrefly: ignore [bad-argument-type]
- if not len(p) == 2:
- raise ValueError(
- "ParameterDict update sequence element "
- # pyrefly: ignore [bad-argument-type]
- "#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
- )
- # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
- self[p[0]] = p[1] # type: ignore[assignment]
- def extra_repr(self) -> str:
- child_lines = []
- for k, p in self.items():
- if isinstance(p, torch.Tensor):
- size_str = "x".join(str(size) for size in p.size())
- if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
- device_str = f" ({p.device})"
- else:
- device_str = ""
- parastr = "{} containing: [{} of size {}{}]".format(
- "Parameter" if isinstance(p, Parameter) else "Tensor",
- torch.typename(p),
- size_str,
- device_str,
- )
- # pyrefly: ignore [bad-argument-type]
- child_lines.append(" (" + str(k) + "): " + parastr)
- else:
- child_lines.append(
- # pyrefly: ignore [bad-argument-type]
- " (" + str(k) + "): Object of type: " + type(p).__name__
- )
- tmpstr = "\n".join(child_lines)
- return tmpstr
- def __call__(self, input):
- raise RuntimeError("ParameterDict should not be called.")
- def __or__(self, other: ParameterDict) -> ParameterDict:
- copy = self.copy()
- copy.update(other)
- return copy
- def __ror__(self, other: ParameterDict) -> ParameterDict:
- copy = other.copy()
- copy.update(self)
- return copy
- def __ior__(self, other: ParameterDict) -> Self:
- self.update(other)
- return self
|