subclasses.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # mypy: ignore-errors
  2. from typing import Any, Optional
  3. import torch
  4. import torch.utils._pytree as pytree
  5. from torch._subclasses.fake_tensor import is_fake
  6. from torch.testing._internal.two_tensor import TwoTensor
  7. from torch.utils._python_dispatch import return_and_correct_aliasing
  8. class WrapperSubclass(torch.Tensor):
  9. @staticmethod
  10. def __new__(cls, a, outer_size=None, outer_stride=None):
  11. if outer_size is None:
  12. outer_size = a.size()
  13. if outer_stride is None:
  14. outer_stride = a.stride()
  15. kwargs = {}
  16. kwargs["strides"] = outer_stride
  17. kwargs["storage_offset"] = a.storage_offset()
  18. kwargs["device"] = a.device
  19. kwargs["layout"] = a.layout
  20. kwargs["requires_grad"] = a.requires_grad
  21. kwargs["dtype"] = a.dtype
  22. out = torch.Tensor._make_wrapper_subclass(cls, outer_size, **kwargs)
  23. return out
  24. def __init__(self, a, outer_size=None, outer_stride=None):
  25. self.a = a
  26. def __repr__(self):
  27. return f"WrapperSubclass({repr(self.a)})"
  28. def __tensor_flatten__(self):
  29. return ["a"], None
  30. @staticmethod
  31. def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
  32. if meta is not None:
  33. raise AssertionError("Expected meta to be None")
  34. a = inner_tensors["a"]
  35. if is_fake(a):
  36. if outer_size is None:
  37. raise AssertionError("Expected outer_size to not be None")
  38. if outer_stride is None:
  39. raise AssertionError("Expected outer_stride to not be None")
  40. return WrapperSubclass(a, outer_size, outer_stride)
  41. @classmethod
  42. def __torch_dispatch__(cls, func, types, args, kwargs):
  43. if kwargs is None:
  44. kwargs = {}
  45. args_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, args)
  46. kwargs_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, kwargs)
  47. out_a = func(*args_a, **kwargs_a)
  48. out_a_flat, spec = pytree.tree_flatten(out_a)
  49. out_flat = [
  50. WrapperSubclass(o_a) if isinstance(o_a, torch.Tensor) else o_a
  51. for o_a in out_a_flat
  52. ]
  53. out = pytree.tree_unflatten(out_flat, spec)
  54. from torch._higher_order_ops.cond import cond_op
  55. if func is cond_op:
  56. return out
  57. else:
  58. return return_and_correct_aliasing(func, args, kwargs, out)
  59. def __coerce_same_metadata_as_tangent__(
  60. self, expected_metadata: Any, expected_type: Optional[type] = None
  61. ):
  62. if expected_type is type(self.a):
  63. return self.a
  64. elif expected_type is TwoTensor:
  65. return TwoTensor(self.a, self.a.clone())
  66. return None