debug_prims.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import contextlib
  2. from collections.abc import Generator, Sequence
  3. import torch
  4. from torch.utils._content_store import ContentStoreReader
  5. LOAD_TENSOR_READER: ContentStoreReader | None = None
  6. @contextlib.contextmanager
  7. def load_tensor_reader(loc: str) -> Generator[None, None, None]:
  8. global LOAD_TENSOR_READER
  9. if LOAD_TENSOR_READER is not None:
  10. raise AssertionError("LOAD_TENSOR_READER is already set")
  11. # load_tensor is an "op", and we will play merry hell on
  12. # Inductor's memory planning if we return a tensor that
  13. # aliases another tensor that we previously returned from
  14. # an operator. So unlike standard ContentStoreReader use,
  15. # we disable the cache so that you always get fresh storages
  16. # (no aliasing for you!)
  17. LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False)
  18. try:
  19. yield
  20. finally:
  21. LOAD_TENSOR_READER = None
  22. def register_debug_prims() -> None:
  23. torch.library.define(
  24. "debugprims::load_tensor",
  25. "(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor",
  26. )
  27. @torch.library.impl("debugprims::load_tensor", "BackendSelect")
  28. def load_tensor_factory(
  29. name: str,
  30. size: Sequence[int],
  31. stride: Sequence[int],
  32. dtype: torch.dtype,
  33. device: torch.device,
  34. ) -> torch.Tensor:
  35. if LOAD_TENSOR_READER is None:
  36. from torch._dynamo.testing import rand_strided
  37. return rand_strided(size, stride, dtype, device)
  38. else:
  39. from torch._dynamo.utils import clone_input
  40. # device argument here takes care of coercion
  41. r = LOAD_TENSOR_READER.read_tensor(name, device=device)
  42. if list(r.size()) != size:
  43. raise AssertionError(f"{r.size()} != {size}")
  44. if list(r.stride()) != stride:
  45. raise AssertionError(f"{r.stride()} != {stride}")
  46. if r.device != device:
  47. raise AssertionError(f"{r.device} != {device}")
  48. # Unlike the other properties, we will do coercions for dtype
  49. # mismatch
  50. if r.dtype != dtype:
  51. r = clone_input(r, dtype=dtype) # type: ignore[no-untyped-call]
  52. return r