| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- """Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
- This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
- It includes:
- - A custom TestCase class that handles Dynamo-specific setup/teardown
- - Test running utilities with dependency checking
- - Automatic reset of Dynamo state between tests
- - Proper handling of gradient mode state
- """
- import contextlib
- import importlib
- import inspect
- import logging
- import os
- import re
- import sys
- import unittest
- from collections.abc import Callable
- from typing import Any, Union
- import torch
- import torch.testing
- from torch._dynamo import polyfills
- from torch._logging._internal import trace_log
- from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
- IS_WINDOWS,
- TEST_WITH_CROSSREF,
- TEST_WITH_TORCHDYNAMO,
- TestCase as TorchTestCase,
- )
- from . import config, reset, utils
- log = logging.getLogger(__name__)
- def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
- from torch.testing._internal.common_utils import run_tests
- if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF:
- return # skip testing
- if (
- not torch.xpu.is_available()
- and IS_WINDOWS
- and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0"
- ):
- return
- if isinstance(needs, str):
- needs = (needs,)
- for need in needs:
- if need == "cuda":
- if not torch.cuda.is_available():
- return
- else:
- try:
- importlib.import_module(need)
- except ImportError:
- return
- run_tests()
- class TestCase(TorchTestCase):
- _exit_stack: contextlib.ExitStack
- @classmethod
- def tearDownClass(cls) -> None:
- cls._exit_stack.close()
- super().tearDownClass()
- @classmethod
- def setUpClass(cls) -> None:
- super().setUpClass()
- cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined]
- cls._exit_stack.enter_context( # type: ignore[attr-defined]
- config.patch(
- raise_on_ctx_manager_usage=True,
- suppress_errors=False,
- log_compilation_metrics=False,
- ),
- )
- def setUp(self) -> None:
- self._prior_is_grad_enabled = torch.is_grad_enabled()
- super().setUp()
- reset()
- utils.counters.clear()
- self.handler = logging.NullHandler()
- trace_log.addHandler(self.handler)
- def tearDown(self) -> None:
- trace_log.removeHandler(self.handler)
- for k, v in utils.counters.items():
- print(k, v.most_common())
- reset()
- utils.counters.clear()
- torch._C._autograd._saved_tensors_hooks_enable()
- super().tearDown()
- if self._prior_is_grad_enabled is not torch.is_grad_enabled():
- log.warning("Running test changed grad mode")
- torch.set_grad_enabled(self._prior_is_grad_enabled)
- def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
- if (
- config.debug_disable_compile_counter
- and isinstance(x, utils.CompileCounterInt)
- or isinstance(y, utils.CompileCounterInt)
- ):
- return
- return super().assertEqual(x, y, *args, **kwargs)
- # assertExpectedInline might also need to be disabled for wrapped nested
- # graph break tests
- # NB: multiple inheritance with LoggingTestCase is possible - this should be fine
- # since there is no overlap in overridden methods.
- class TestCaseWithNestedGraphBreaks(TestCase):
- def setUp(self) -> None:
- super().setUp()
- self.prev_nested_graph_breaks = torch._dynamo.config.nested_graph_breaks
- # pyrefly: ignore [bad-assignment]
- torch._dynamo.config.nested_graph_breaks = True
- def tearDown(self) -> None:
- super().tearDown()
- # pyrefly: ignore [bad-assignment]
- torch._dynamo.config.nested_graph_breaks = self.prev_nested_graph_breaks
- class CPythonTestCase(TestCase):
- """
- Test class for CPython tests located in "test/dynamo/CPython/Py_version/*".
- This class enables specific features that are disabled by default, such as
- tracing through unittest methods.
- """
- _stack: contextlib.ExitStack
- dynamo_strict_nopython = True
- # Restore original unittest methods to simplify tracing CPython test cases.
- assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment]
- assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment]
- assertTrue = unittest.TestCase.assertTrue
- assertFalse = unittest.TestCase.assertFalse
- assertIs = unittest.TestCase.assertIs
- assertIsNot = unittest.TestCase.assertIsNot
- assertIsNone = unittest.TestCase.assertIsNone
- assertIsNotNone = unittest.TestCase.assertIsNotNone
- assertIn = unittest.TestCase.assertIn
- assertNotIn = unittest.TestCase.assertNotIn
- assertIsInstance = unittest.TestCase.assertIsInstance
- assertNotIsInstance = unittest.TestCase.assertNotIsInstance
- assertAlmostEqual = unittest.TestCase.assertAlmostEqual
- assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual
- assertGreater = unittest.TestCase.assertGreater
- assertGreaterEqual = unittest.TestCase.assertGreaterEqual
- assertLess = unittest.TestCase.assertLess
- assertLessEqual = unittest.TestCase.assertLessEqual
- assertRegex = unittest.TestCase.assertRegex
- assertNotRegex = unittest.TestCase.assertNotRegex
- assertCountEqual = unittest.TestCase.assertCountEqual
- assertMultiLineEqual = polyfills.assert_multi_line_equal
- assertSequenceEqual = polyfills.assert_sequence_equal
- assertListEqual = unittest.TestCase.assertListEqual
- assertTupleEqual = unittest.TestCase.assertTupleEqual
- assertSetEqual = unittest.TestCase.assertSetEqual
- # pyrefly: ignore [bad-override]
- assertDictEqual = polyfills.assert_dict_equal
- # pyrefly: ignore [bad-override]
- assertRaises = unittest.TestCase.assertRaises
- # pyrefly: ignore [bad-override]
- assertRaisesRegex = unittest.TestCase.assertRaisesRegex
- assertWarns = unittest.TestCase.assertWarns
- assertWarnsRegex = unittest.TestCase.assertWarnsRegex
- assertLogs = unittest.TestCase.assertLogs
- fail = unittest.TestCase.fail
- failureException = unittest.TestCase.failureException
- def compile_fn(
- self,
- fn: Callable[..., Any],
- backend: Union[str, Callable[..., Any]],
- nopython: bool,
- ) -> Callable[..., Any]:
- # We want to compile only the test function, excluding any setup code
- # from unittest
- method = getattr(self, self._testMethodName)
- method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
- setattr(self, self._testMethodName, method)
- return fn
- def _dynamo_test_key(self) -> str:
- suffix = super()._dynamo_test_key()
- test_cls = self.__class__
- test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
- py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls))
- if py_ver:
- py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment]
- else:
- return suffix
- return f"CPython{py_ver}-{test_file}-{suffix}"
- @classmethod
- def tearDownClass(cls) -> None:
- cls._stack.close()
- super().tearDownClass()
- @classmethod
- def setUpClass(cls) -> None:
- # Skip test if python versions doesn't match
- prefix = os.path.join("dynamo", "cpython") + os.path.sep
- regex = re.escape(prefix) + r"\d_\d{2}"
- search_path = inspect.getfile(cls)
- m = re.search(regex, search_path)
- if m:
- test_py_ver = tuple(map(int, m.group().removeprefix(prefix).split("_")))
- py_ver = sys.version_info[:2]
- if py_ver != test_py_ver:
- expected = ".".join(map(str, test_py_ver))
- got = ".".join(map(str, py_ver))
- raise unittest.SkipTest(
- f"Test requires Python {expected} but got Python {got}"
- )
- else:
- raise unittest.SkipTest(
- f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}"
- )
- super().setUpClass()
- cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
- cls._stack.enter_context( # type: ignore[attr-defined]
- config.patch(
- enable_trace_unittest=True,
- ),
- )
- # pyrefly: ignore [implicit-any]
- def wrap_with_policy(self, method_name: str, policy: Callable) -> None:
- pass
|