test_case.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. """Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
  2. This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
  3. It includes:
  4. - A custom TestCase class that handles Dynamo-specific setup/teardown
  5. - Test running utilities with dependency checking
  6. - Automatic reset of Dynamo state between tests
  7. - Proper handling of gradient mode state
  8. """
  9. import contextlib
  10. import importlib
  11. import inspect
  12. import logging
  13. import os
  14. import re
  15. import sys
  16. import unittest
  17. from collections.abc import Callable
  18. from typing import Any, Union
  19. import torch
  20. import torch.testing
  21. from torch._dynamo import polyfills
  22. from torch._logging._internal import trace_log
  23. from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
  24. IS_WINDOWS,
  25. TEST_WITH_CROSSREF,
  26. TEST_WITH_TORCHDYNAMO,
  27. TestCase as TorchTestCase,
  28. )
  29. from . import config, reset, utils
  30. log = logging.getLogger(__name__)
  31. def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
  32. from torch.testing._internal.common_utils import run_tests
  33. if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF:
  34. return # skip testing
  35. if (
  36. not torch.xpu.is_available()
  37. and IS_WINDOWS
  38. and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0"
  39. ):
  40. return
  41. if isinstance(needs, str):
  42. needs = (needs,)
  43. for need in needs:
  44. if need == "cuda":
  45. if not torch.cuda.is_available():
  46. return
  47. else:
  48. try:
  49. importlib.import_module(need)
  50. except ImportError:
  51. return
  52. run_tests()
  53. class TestCase(TorchTestCase):
  54. _exit_stack: contextlib.ExitStack
  55. @classmethod
  56. def tearDownClass(cls) -> None:
  57. cls._exit_stack.close()
  58. super().tearDownClass()
  59. @classmethod
  60. def setUpClass(cls) -> None:
  61. super().setUpClass()
  62. cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined]
  63. cls._exit_stack.enter_context( # type: ignore[attr-defined]
  64. config.patch(
  65. raise_on_ctx_manager_usage=True,
  66. suppress_errors=False,
  67. log_compilation_metrics=False,
  68. ),
  69. )
  70. def setUp(self) -> None:
  71. self._prior_is_grad_enabled = torch.is_grad_enabled()
  72. super().setUp()
  73. reset()
  74. utils.counters.clear()
  75. self.handler = logging.NullHandler()
  76. trace_log.addHandler(self.handler)
  77. def tearDown(self) -> None:
  78. trace_log.removeHandler(self.handler)
  79. for k, v in utils.counters.items():
  80. print(k, v.most_common())
  81. reset()
  82. utils.counters.clear()
  83. torch._C._autograd._saved_tensors_hooks_enable()
  84. super().tearDown()
  85. if self._prior_is_grad_enabled is not torch.is_grad_enabled():
  86. log.warning("Running test changed grad mode")
  87. torch.set_grad_enabled(self._prior_is_grad_enabled)
  88. def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
  89. if (
  90. config.debug_disable_compile_counter
  91. and isinstance(x, utils.CompileCounterInt)
  92. or isinstance(y, utils.CompileCounterInt)
  93. ):
  94. return
  95. return super().assertEqual(x, y, *args, **kwargs)
  96. # assertExpectedInline might also need to be disabled for wrapped nested
  97. # graph break tests
  98. # NB: multiple inheritance with LoggingTestCase is possible - this should be fine
  99. # since there is no overlap in overridden methods.
  100. class TestCaseWithNestedGraphBreaks(TestCase):
  101. def setUp(self) -> None:
  102. super().setUp()
  103. self.prev_nested_graph_breaks = torch._dynamo.config.nested_graph_breaks
  104. # pyrefly: ignore [bad-assignment]
  105. torch._dynamo.config.nested_graph_breaks = True
  106. def tearDown(self) -> None:
  107. super().tearDown()
  108. # pyrefly: ignore [bad-assignment]
  109. torch._dynamo.config.nested_graph_breaks = self.prev_nested_graph_breaks
  110. class CPythonTestCase(TestCase):
  111. """
  112. Test class for CPython tests located in "test/dynamo/CPython/Py_version/*".
  113. This class enables specific features that are disabled by default, such as
  114. tracing through unittest methods.
  115. """
  116. _stack: contextlib.ExitStack
  117. dynamo_strict_nopython = True
  118. # Restore original unittest methods to simplify tracing CPython test cases.
  119. assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment]
  120. assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment]
  121. assertTrue = unittest.TestCase.assertTrue
  122. assertFalse = unittest.TestCase.assertFalse
  123. assertIs = unittest.TestCase.assertIs
  124. assertIsNot = unittest.TestCase.assertIsNot
  125. assertIsNone = unittest.TestCase.assertIsNone
  126. assertIsNotNone = unittest.TestCase.assertIsNotNone
  127. assertIn = unittest.TestCase.assertIn
  128. assertNotIn = unittest.TestCase.assertNotIn
  129. assertIsInstance = unittest.TestCase.assertIsInstance
  130. assertNotIsInstance = unittest.TestCase.assertNotIsInstance
  131. assertAlmostEqual = unittest.TestCase.assertAlmostEqual
  132. assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual
  133. assertGreater = unittest.TestCase.assertGreater
  134. assertGreaterEqual = unittest.TestCase.assertGreaterEqual
  135. assertLess = unittest.TestCase.assertLess
  136. assertLessEqual = unittest.TestCase.assertLessEqual
  137. assertRegex = unittest.TestCase.assertRegex
  138. assertNotRegex = unittest.TestCase.assertNotRegex
  139. assertCountEqual = unittest.TestCase.assertCountEqual
  140. assertMultiLineEqual = polyfills.assert_multi_line_equal
  141. assertSequenceEqual = polyfills.assert_sequence_equal
  142. assertListEqual = unittest.TestCase.assertListEqual
  143. assertTupleEqual = unittest.TestCase.assertTupleEqual
  144. assertSetEqual = unittest.TestCase.assertSetEqual
  145. # pyrefly: ignore [bad-override]
  146. assertDictEqual = polyfills.assert_dict_equal
  147. # pyrefly: ignore [bad-override]
  148. assertRaises = unittest.TestCase.assertRaises
  149. # pyrefly: ignore [bad-override]
  150. assertRaisesRegex = unittest.TestCase.assertRaisesRegex
  151. assertWarns = unittest.TestCase.assertWarns
  152. assertWarnsRegex = unittest.TestCase.assertWarnsRegex
  153. assertLogs = unittest.TestCase.assertLogs
  154. fail = unittest.TestCase.fail
  155. failureException = unittest.TestCase.failureException
  156. def compile_fn(
  157. self,
  158. fn: Callable[..., Any],
  159. backend: Union[str, Callable[..., Any]],
  160. nopython: bool,
  161. ) -> Callable[..., Any]:
  162. # We want to compile only the test function, excluding any setup code
  163. # from unittest
  164. method = getattr(self, self._testMethodName)
  165. method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
  166. setattr(self, self._testMethodName, method)
  167. return fn
  168. def _dynamo_test_key(self) -> str:
  169. suffix = super()._dynamo_test_key()
  170. test_cls = self.__class__
  171. test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
  172. py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls))
  173. if py_ver:
  174. py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment]
  175. else:
  176. return suffix
  177. return f"CPython{py_ver}-{test_file}-{suffix}"
  178. @classmethod
  179. def tearDownClass(cls) -> None:
  180. cls._stack.close()
  181. super().tearDownClass()
  182. @classmethod
  183. def setUpClass(cls) -> None:
  184. # Skip test if python versions doesn't match
  185. prefix = os.path.join("dynamo", "cpython") + os.path.sep
  186. regex = re.escape(prefix) + r"\d_\d{2}"
  187. search_path = inspect.getfile(cls)
  188. m = re.search(regex, search_path)
  189. if m:
  190. test_py_ver = tuple(map(int, m.group().removeprefix(prefix).split("_")))
  191. py_ver = sys.version_info[:2]
  192. if py_ver != test_py_ver:
  193. expected = ".".join(map(str, test_py_ver))
  194. got = ".".join(map(str, py_ver))
  195. raise unittest.SkipTest(
  196. f"Test requires Python {expected} but got Python {got}"
  197. )
  198. else:
  199. raise unittest.SkipTest(
  200. f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}"
  201. )
  202. super().setUpClass()
  203. cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
  204. cls._stack.enter_context( # type: ignore[attr-defined]
  205. config.patch(
  206. enable_trace_unittest=True,
  207. ),
  208. )
  209. # pyrefly: ignore [implicit-any]
  210. def wrap_with_policy(self, method_name: str, policy: Callable) -> None:
  211. pass