_mode_utils.py 255 B

123456789101112131415
  1. # mypy: allow-untyped-defs
  2. from typing import TypeVar
  3. import torch
  4. T = TypeVar("T")
  5. # returns if all are the same mode
  6. def all_same_mode(modes):
  7. return all(tuple(mode == modes[0] for mode in modes))
  8. no_dispatch = torch._C._DisableTorchDispatch