creation.py 605 B

123456789101112131415161718192021222324
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from .core import MaskedTensor
  3. __all__ = [
  4. "as_masked_tensor",
  5. "masked_tensor",
  6. ]
  7. # These two factory functions are intended to mirror
  8. # torch.tensor - guaranteed to be a leaf node
  9. # torch.as_tensor - differentiable constructor that preserves the autograd history
  10. def masked_tensor(
  11. data: object, mask: object, requires_grad: bool = False
  12. ) -> MaskedTensor:
  13. return MaskedTensor(data, mask, requires_grad)
  14. def as_masked_tensor(data: object, mask: object) -> MaskedTensor:
  15. return MaskedTensor._from_values(data, mask)