sparse.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from torch.nn import functional as F, init
  5. from torch.nn.parameter import Parameter
  6. from .module import Module
  7. __all__ = ["Embedding", "EmbeddingBag"]
  8. class Embedding(Module):
  9. r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
  10. This module is often used to store word embeddings and retrieve them using indices.
  11. The input to the module is a list of indices, and the output is the corresponding
  12. word embeddings.
  13. Args:
  14. num_embeddings (int): size of the dictionary of embeddings
  15. embedding_dim (int): the size of each embedding vector
  16. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
  17. therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
  18. i.e. it remains as a fixed "pad". For a newly constructed Embedding,
  19. the embedding vector at :attr:`padding_idx` will default to all zeros,
  20. but can be updated to another value to be used as the padding vector.
  21. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
  22. is renormalized to have norm :attr:`max_norm`.
  23. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
  24. scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
  25. the words in the mini-batch. Default ``False``.
  26. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
  27. See Notes for more details regarding sparse gradients.
  28. Attributes:
  29. weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
  30. initialized from :math:`\mathcal{N}(0, 1)`
  31. Shape:
  32. - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
  33. - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
  34. .. note::
  35. Keep in mind that only a limited number of optimizers support
  36. sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
  37. :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
  38. .. note::
  39. When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the
  40. :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be
  41. modified in-place, performing a differentiable operation on ``Embedding.weight`` before
  42. calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when
  43. :attr:`max_norm` is not ``None``. For example::
  44. n, d, m = 3, 5, 7
  45. embedding = nn.Embedding(n, d, max_norm=1.0)
  46. W = torch.randn((m, d), requires_grad=True)
  47. idx = torch.tensor([1, 2])
  48. a = (
  49. embedding.weight.clone() @ W.t()
  50. ) # weight must be cloned for this to be differentiable
  51. b = embedding(idx) @ W.t() # modifies weight in-place
  52. out = a.unsqueeze(0) + b.unsqueeze(1)
  53. loss = out.sigmoid().prod()
  54. loss.backward()
  55. Examples::
  56. >>> # an Embedding module containing 10 tensors of size 3
  57. >>> embedding = nn.Embedding(10, 3)
  58. >>> # a batch of 2 samples of 4 indices each
  59. >>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
  60. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  61. >>> embedding(input)
  62. tensor([[[-0.0251, -1.6902, 0.7172],
  63. [-0.6431, 0.0748, 0.6969],
  64. [ 1.4970, 1.3448, -0.9685],
  65. [-0.3677, -2.7265, -0.1685]],
  66. [[ 1.4970, 1.3448, -0.9685],
  67. [ 0.4362, -0.4004, 0.9400],
  68. [-0.6431, 0.0748, 0.6969],
  69. [ 0.9124, -2.3616, 1.1151]]])
  70. >>> # example with padding_idx
  71. >>> embedding = nn.Embedding(10, 3, padding_idx=0)
  72. >>> input = torch.LongTensor([[0, 2, 0, 5]])
  73. >>> embedding(input)
  74. tensor([[[ 0.0000, 0.0000, 0.0000],
  75. [ 0.1535, -2.0309, 0.9315],
  76. [ 0.0000, 0.0000, 0.0000],
  77. [-0.1655, 0.9897, 0.0635]]])
  78. >>> # example of changing `pad` vector
  79. >>> padding_idx = 0
  80. >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
  81. >>> embedding.weight
  82. Parameter containing:
  83. tensor([[ 0.0000, 0.0000, 0.0000],
  84. [-0.7895, -0.7089, -0.0364],
  85. [ 0.6778, 0.5803, 0.2678]], requires_grad=True)
  86. >>> with torch.no_grad():
  87. ... embedding.weight[padding_idx] = torch.ones(3)
  88. >>> embedding.weight
  89. Parameter containing:
  90. tensor([[ 1.0000, 1.0000, 1.0000],
  91. [-0.7895, -0.7089, -0.0364],
  92. [ 0.6778, 0.5803, 0.2678]], requires_grad=True)
  93. """
  94. __constants__ = [
  95. "num_embeddings",
  96. "embedding_dim",
  97. "padding_idx",
  98. "max_norm",
  99. "norm_type",
  100. "scale_grad_by_freq",
  101. "sparse",
  102. ]
  103. num_embeddings: int
  104. embedding_dim: int
  105. padding_idx: int | None
  106. max_norm: float | None
  107. norm_type: float
  108. scale_grad_by_freq: bool
  109. weight: Tensor
  110. freeze: bool
  111. sparse: bool
  112. def __init__(
  113. self,
  114. num_embeddings: int,
  115. embedding_dim: int,
  116. padding_idx: int | None = None,
  117. max_norm: float | None = None,
  118. norm_type: float = 2.0,
  119. scale_grad_by_freq: bool = False,
  120. sparse: bool = False,
  121. _weight: Tensor | None = None,
  122. _freeze: bool = False,
  123. device=None,
  124. dtype=None,
  125. ) -> None:
  126. factory_kwargs = {"device": device, "dtype": dtype}
  127. super().__init__()
  128. self.num_embeddings = num_embeddings
  129. self.embedding_dim = embedding_dim
  130. if padding_idx is not None:
  131. if padding_idx > 0:
  132. if padding_idx >= self.num_embeddings:
  133. raise AssertionError("Padding_idx must be within num_embeddings")
  134. elif padding_idx < 0:
  135. if padding_idx < -self.num_embeddings:
  136. raise AssertionError("Padding_idx must be within num_embeddings")
  137. padding_idx = self.num_embeddings + padding_idx
  138. self.padding_idx = padding_idx
  139. self.max_norm = max_norm
  140. self.norm_type = norm_type
  141. self.scale_grad_by_freq = scale_grad_by_freq
  142. if _weight is None:
  143. self.weight = Parameter(
  144. torch.empty((num_embeddings, embedding_dim), **factory_kwargs),
  145. requires_grad=not _freeze,
  146. )
  147. self.reset_parameters()
  148. else:
  149. if list(_weight.shape) != [num_embeddings, embedding_dim]:
  150. raise AssertionError(
  151. "Shape of weight does not match num_embeddings and embedding_dim"
  152. )
  153. self.weight = Parameter(_weight, requires_grad=not _freeze)
  154. self.sparse = sparse
  155. def reset_parameters(self) -> None:
  156. init.normal_(self.weight)
  157. self._fill_padding_idx_with_zero()
  158. def _fill_padding_idx_with_zero(self) -> None:
  159. if self.padding_idx is not None:
  160. with torch.no_grad():
  161. self.weight[self.padding_idx].fill_(0)
  162. def forward(self, input: Tensor) -> Tensor:
  163. return F.embedding(
  164. input,
  165. self.weight,
  166. self.padding_idx,
  167. self.max_norm,
  168. self.norm_type,
  169. self.scale_grad_by_freq,
  170. self.sparse,
  171. )
  172. def extra_repr(self) -> str:
  173. s = "{num_embeddings}, {embedding_dim}"
  174. if self.padding_idx is not None:
  175. s += ", padding_idx={padding_idx}"
  176. if self.max_norm is not None:
  177. s += ", max_norm={max_norm}"
  178. if self.norm_type != 2:
  179. s += ", norm_type={norm_type}"
  180. if self.scale_grad_by_freq is not False:
  181. s += ", scale_grad_by_freq={scale_grad_by_freq}"
  182. if self.sparse is not False:
  183. s += ", sparse=True"
  184. return s.format(**self.__dict__)
  185. @classmethod
  186. def from_pretrained(
  187. cls,
  188. embeddings,
  189. freeze=True,
  190. padding_idx=None,
  191. max_norm=None,
  192. norm_type=2.0,
  193. scale_grad_by_freq=False,
  194. sparse=False,
  195. ):
  196. r"""Create Embedding instance from given 2-dimensional FloatTensor.
  197. Args:
  198. embeddings (Tensor): FloatTensor containing weights for the Embedding.
  199. First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
  200. freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process.
  201. Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
  202. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
  203. therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
  204. i.e. it remains as a fixed "pad".
  205. max_norm (float, optional): See module initialization documentation.
  206. norm_type (float, optional): See module initialization documentation. Default ``2``.
  207. scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``.
  208. sparse (bool, optional): See module initialization documentation.
  209. Examples::
  210. >>> # FloatTensor containing pretrained weights
  211. >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
  212. >>> embedding = nn.Embedding.from_pretrained(weight)
  213. >>> # Get embeddings for index 1
  214. >>> input = torch.LongTensor([1])
  215. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  216. >>> embedding(input)
  217. tensor([[ 4.0000, 5.1000, 6.3000]])
  218. """
  219. if embeddings.dim() != 2:
  220. raise AssertionError("Embeddings parameter is expected to be 2-dimensional")
  221. rows, cols = embeddings.shape
  222. embedding = cls(
  223. num_embeddings=rows,
  224. embedding_dim=cols,
  225. _weight=embeddings,
  226. _freeze=freeze,
  227. padding_idx=padding_idx,
  228. max_norm=max_norm,
  229. norm_type=norm_type,
  230. scale_grad_by_freq=scale_grad_by_freq,
  231. sparse=sparse,
  232. )
  233. return embedding
  234. class EmbeddingBag(Module):
  235. r"""Compute sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings.
  236. For bags of constant length, no :attr:`per_sample_weights`, no indices equal to :attr:`padding_idx`,
  237. and with 2D inputs, this class
  238. * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``,
  239. * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``,
  240. * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``.
  241. However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
  242. operations.
  243. EmbeddingBag also supports per-sample weights as an argument to the forward
  244. pass. This scales the output of the Embedding before performing a weighted
  245. reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the
  246. only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
  247. :attr:`per_sample_weights`.
  248. Args:
  249. num_embeddings (int): size of the dictionary of embeddings
  250. embedding_dim (int): the size of each embedding vector
  251. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
  252. is renormalized to have norm :attr:`max_norm`.
  253. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
  254. scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of
  255. the words in the mini-batch. Default ``False``.
  256. Note: this option is not supported when ``mode="max"``.
  257. mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
  258. ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
  259. into consideration. ``"mean"`` computes the average of the values
  260. in the bag, ``"max"`` computes the max value over each bag.
  261. Default: ``"mean"``
  262. sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
  263. Notes for more details regarding sparse gradients. Note: this option is not
  264. supported when ``mode="max"``.
  265. include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1.
  266. The last element is the size of the input, or the ending index position
  267. of the last bag (sequence). This matches the CSR format. Ignored when
  268. input is 2D. Default ``False``.
  269. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the
  270. gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated
  271. during training, i.e. it remains as a fixed "pad". For a newly constructed
  272. EmbeddingBag, the embedding vector at :attr:`padding_idx` will default to all
  273. zeros, but can be updated to another value to be used as the padding vector.
  274. Note that the embedding vector at :attr:`padding_idx` is excluded from the
  275. reduction.
  276. Attributes:
  277. weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
  278. initialized from :math:`\mathcal{N}(0, 1)`.
  279. Examples::
  280. >>> # an EmbeddingBag module containing 10 tensors of size 3
  281. >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
  282. >>> # a batch of 2 samples of 4 indices each
  283. >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
  284. >>> offsets = torch.tensor([0, 4], dtype=torch.long)
  285. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  286. >>> embedding_sum(input, offsets)
  287. tensor([[-0.8861, -5.4350, -0.0523],
  288. [ 1.1306, -2.5798, -1.0044]])
  289. >>> # Example with padding_idx
  290. >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2)
  291. >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long)
  292. >>> offsets = torch.tensor([0, 4], dtype=torch.long)
  293. >>> embedding_sum(input, offsets)
  294. tensor([[ 0.0000, 0.0000, 0.0000],
  295. [-0.7082, 3.2145, -2.6251]])
  296. >>> # An EmbeddingBag can be loaded from an Embedding like so
  297. >>> embedding = nn.Embedding(10, 3, padding_idx=2)
  298. >>> embedding_sum = nn.EmbeddingBag.from_pretrained(
  299. embedding.weight,
  300. padding_idx=embedding.padding_idx,
  301. mode='sum')
  302. """
  303. __constants__ = [
  304. "num_embeddings",
  305. "embedding_dim",
  306. "max_norm",
  307. "norm_type",
  308. "scale_grad_by_freq",
  309. "mode",
  310. "sparse",
  311. "include_last_offset",
  312. "padding_idx",
  313. ]
  314. num_embeddings: int
  315. embedding_dim: int
  316. max_norm: float | None
  317. norm_type: float
  318. scale_grad_by_freq: bool
  319. weight: Tensor
  320. mode: str
  321. sparse: bool
  322. include_last_offset: bool
  323. padding_idx: int | None
  324. def __init__(
  325. self,
  326. num_embeddings: int,
  327. embedding_dim: int,
  328. max_norm: float | None = None,
  329. norm_type: float = 2.0,
  330. scale_grad_by_freq: bool = False,
  331. mode: str = "mean",
  332. sparse: bool = False,
  333. _weight: Tensor | None = None,
  334. include_last_offset: bool = False,
  335. padding_idx: int | None = None,
  336. device=None,
  337. dtype=None,
  338. ) -> None:
  339. factory_kwargs = {"device": device, "dtype": dtype}
  340. super().__init__()
  341. self.num_embeddings = num_embeddings
  342. self.embedding_dim = embedding_dim
  343. self.max_norm = max_norm
  344. self.norm_type = norm_type
  345. self.scale_grad_by_freq = scale_grad_by_freq
  346. if padding_idx is not None:
  347. if padding_idx > 0:
  348. if padding_idx >= self.num_embeddings:
  349. raise AssertionError("padding_idx must be within num_embeddings")
  350. elif padding_idx < 0:
  351. if padding_idx < -self.num_embeddings:
  352. raise AssertionError("padding_idx must be within num_embeddings")
  353. padding_idx = self.num_embeddings + padding_idx
  354. self.padding_idx = padding_idx
  355. if _weight is None:
  356. self.weight = Parameter(
  357. torch.empty((num_embeddings, embedding_dim), **factory_kwargs)
  358. )
  359. self.reset_parameters()
  360. else:
  361. if list(_weight.shape) != [num_embeddings, embedding_dim]:
  362. raise AssertionError(
  363. "Shape of weight does not match num_embeddings and embedding_dim"
  364. )
  365. self.weight = Parameter(_weight)
  366. self.mode = mode
  367. self.sparse = sparse
  368. self.include_last_offset = include_last_offset
  369. def reset_parameters(self) -> None:
  370. init.normal_(self.weight)
  371. self._fill_padding_idx_with_zero()
  372. def _fill_padding_idx_with_zero(self) -> None:
  373. if self.padding_idx is not None:
  374. with torch.no_grad():
  375. self.weight[self.padding_idx].fill_(0)
  376. def forward(
  377. self,
  378. input: Tensor,
  379. offsets: Tensor | None = None,
  380. per_sample_weights: Tensor | None = None,
  381. ) -> Tensor:
  382. """Forward pass of EmbeddingBag.
  383. Args:
  384. input (Tensor): Tensor containing bags of indices into the embedding matrix.
  385. offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
  386. the starting index position of each bag (sequence) in :attr:`input`.
  387. per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
  388. to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
  389. must have exactly the same shape as input and is treated as having the same
  390. :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
  391. Returns:
  392. Tensor output shape of `(B, embedding_dim)`.
  393. .. note::
  394. A few notes about ``input`` and ``offsets``:
  395. - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
  396. - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences)
  397. each of fixed length ``N``, and this will return ``B`` values aggregated in a way
  398. depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case.
  399. - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of
  400. multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing the
  401. starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`,
  402. :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have
  403. returned vectors filled by zeros.
  404. """
  405. return F.embedding_bag(
  406. input,
  407. self.weight,
  408. offsets,
  409. self.max_norm,
  410. self.norm_type,
  411. self.scale_grad_by_freq,
  412. self.mode,
  413. self.sparse,
  414. per_sample_weights,
  415. self.include_last_offset,
  416. self.padding_idx,
  417. )
  418. def extra_repr(self) -> str:
  419. s = "{num_embeddings}, {embedding_dim}"
  420. if self.max_norm is not None:
  421. s += ", max_norm={max_norm}"
  422. if self.norm_type != 2:
  423. s += ", norm_type={norm_type}"
  424. if self.scale_grad_by_freq is not False:
  425. s += ", scale_grad_by_freq={scale_grad_by_freq}"
  426. s += ", mode={mode}"
  427. if self.padding_idx is not None:
  428. s += ", padding_idx={padding_idx}"
  429. return s.format(**{k: repr(v) for k, v in self.__dict__.items()})
  430. @classmethod
  431. def from_pretrained(
  432. cls,
  433. embeddings: Tensor,
  434. freeze: bool = True,
  435. max_norm: float | None = None,
  436. norm_type: float = 2.0,
  437. scale_grad_by_freq: bool = False,
  438. mode: str = "mean",
  439. sparse: bool = False,
  440. include_last_offset: bool = False,
  441. padding_idx: int | None = None,
  442. ) -> "EmbeddingBag":
  443. r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor.
  444. Args:
  445. embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
  446. First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
  447. freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process.
  448. Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
  449. max_norm (float, optional): See module initialization documentation. Default: ``None``
  450. norm_type (float, optional): See module initialization documentation. Default ``2``.
  451. scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``.
  452. mode (str, optional): See module initialization documentation. Default: ``"mean"``
  453. sparse (bool, optional): See module initialization documentation. Default: ``False``.
  454. include_last_offset (bool, optional): See module initialization documentation. Default: ``False``.
  455. padding_idx (int, optional): See module initialization documentation. Default: ``None``.
  456. Examples::
  457. >>> # FloatTensor containing pretrained weights
  458. >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
  459. >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
  460. >>> # Get embeddings for index 1
  461. >>> input = torch.LongTensor([[1, 0]])
  462. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  463. >>> embeddingbag(input)
  464. tensor([[ 2.5000, 3.7000, 4.6500]])
  465. """
  466. if embeddings.dim() != 2:
  467. raise AssertionError("Embeddings parameter is expected to be 2-dimensional")
  468. rows, cols = embeddings.shape
  469. embeddingbag = cls(
  470. num_embeddings=rows,
  471. embedding_dim=cols,
  472. _weight=embeddings,
  473. max_norm=max_norm,
  474. norm_type=norm_type,
  475. scale_grad_by_freq=scale_grad_by_freq,
  476. mode=mode,
  477. sparse=sparse,
  478. include_last_offset=include_last_offset,
  479. padding_idx=padding_idx,
  480. )
  481. embeddingbag.weight.requires_grad = not freeze
  482. return embeddingbag