linear.py 743 B

12345678910111213141516171819
  1. """ Linear layer (alternate definition)
  2. """
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn as nn
  6. class Linear(nn.Linear):
  7. r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
  8. Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
  9. weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
  10. """
  11. def forward(self, input: torch.Tensor) -> torch.Tensor:
  12. if torch.jit.is_scripting():
  13. bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
  14. return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
  15. else:
  16. return F.linear(input, self.weight, self.bias)