operators.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. """This file provides a location for operators that help exporting models via onnx.
  2. E.g. `shape_as_tensor` and `reshape_from_tensor_shape`
  3. are to make all dynamic sizes operations traceable.
  4. NOTE: at one point these functions were implemented differently.
  5. Since then we have implemented these directly in ATen, so this
  6. file is kept purely for backward-compatibility.
  7. """
  8. from __future__ import annotations
  9. __all__: list[str] = []
  10. import torch
  11. """Get the shape of a tensor as a tensor.
  12. Args:
  13. x (Tensor): The input tensor.
  14. Returns:
  15. Tensor: A tensor of shape [len(x.shape)] containing the size of each dimension of x.
  16. Example:
  17. >>> x = torch.randn(2, 3)
  18. >>> shape_as_tensor(x)
  19. tensor([2, 3])
  20. """
  21. shape_as_tensor = torch._shape_as_tensor
  22. """Reshape a tensor to the given shape.
  23. This function is used to make dynamic size operations traceable when exporting models via ONNX.
  24. This function is kept for backward-compatibility. It is implemented directly in ATen.
  25. Parameters:
  26. x (Tensor): the tensor to be reshaped.
  27. shape (Tensor): the target shape.
  28. Returns:
  29. Tensor: the reshaped tensor.
  30. """
  31. reshape_from_tensor_shape = torch._reshape_from_tensor