testing.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. """Testing utilities and infrastructure for Dynamo.
  2. This module provides a comprehensive set of testing utilities including:
  3. - Test result collection and validation
  4. - Graph manipulation and comparison tools
  5. - Test case management and execution helpers
  6. - Specialized test decorators for different Python versions and features
  7. - RNG state management
  8. - Compilation counting and monitoring
  9. - Debug utilities for bytecode transformation
  10. The utilities in this module are used across Dynamo's test suite to ensure
  11. consistent testing patterns and proper test isolation.
  12. """
  13. import contextlib
  14. import dis
  15. import functools
  16. import logging
  17. import os.path
  18. import random
  19. import re
  20. import sys
  21. import types
  22. import unittest
  23. from collections.abc import Callable, Generator, Sequence
  24. from typing import Any, Optional, overload, TypeVar, Union
  25. from typing_extensions import ParamSpec
  26. from unittest.mock import patch
  27. import torch
  28. from torch import fx
  29. from torch._dynamo.backends.debugging import aot_eager
  30. from torch._dynamo.output_graph import OutputGraph
  31. from . import config, eval_frame, optimize_assert, reset
  32. from .bytecode_transformation import (
  33. create_instruction,
  34. debug_checks,
  35. is_generator,
  36. transform_code_object,
  37. )
  38. from .guards import CheckFunctionManager, CompileId, GuardedCode
  39. from .types import ConvertFrameReturn, DynamoFrameType, wrap_guarded_code
  40. from .utils import CompileCounterInt, same
  41. np: Optional[types.ModuleType] = None
  42. try:
  43. import numpy as np
  44. except ModuleNotFoundError:
  45. np = None
  46. unsupported = eval_frame.unsupported
  47. three = 3
  48. log = logging.getLogger(__name__)
  49. _P = ParamSpec("_P")
  50. def clone_me(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
  51. if x is None:
  52. return None
  53. return x.detach().clone().requires_grad_(x.requires_grad)
  54. def remove_optimized_module_prefix(name: str) -> str:
  55. return re.sub(r"^_orig_mod[.]", "", name)
  56. def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-def]
  57. from torch._dynamo.symbolic_convert import InstructionTranslator
  58. gm = None
  59. region_tracker = None
  60. def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def]
  61. nonlocal gm
  62. nonlocal region_tracker
  63. gm = _gm
  64. region_tracker = InstructionTranslator.current_tx().output.region_tracker
  65. return _gm
  66. torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs)
  67. return gm.graph, region_tracker # type: ignore[union-attr]
  68. def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
  69. backend = AotEagerAndRecordGraphs()
  70. result = torch.compile(backend=backend)(fn)(*args, **kwargs)
  71. return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
  72. def collect_results(
  73. model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
  74. ) -> list[Any]:
  75. results = []
  76. results.append(prediction)
  77. results.append(loss)
  78. # if isinstance(loss, torch.Tensor) and loss.item() > 1:
  79. # log.warning(
  80. # f"High loss value alert - {loss:.2f}. Can result in unstable gradients."
  81. # )
  82. grads = {}
  83. params = {}
  84. for name, param in model.named_parameters():
  85. if isinstance(model, eval_frame.OptimizedModule):
  86. name = remove_optimized_module_prefix(name)
  87. param_copy = param
  88. grad = param.grad
  89. # Treat None and zero grad as same
  90. if param.grad is None:
  91. grad = torch.zeros_like(param)
  92. grads[name + ".grad"] = grad
  93. params[name] = param_copy
  94. results.append(grads)
  95. results.append(params)
  96. buffers = {}
  97. for name, buffer in model.named_buffers():
  98. if isinstance(model, eval_frame.OptimizedModule):
  99. name = remove_optimized_module_prefix(name)
  100. buffers[name] = buffer
  101. results.append(buffers)
  102. for example in example_inputs:
  103. if isinstance(example, (tuple, list)):
  104. results.extend(inp.grad for inp in example if isinstance(inp, torch.Tensor))
  105. else:
  106. if isinstance(example, torch.Tensor):
  107. results.append(example.grad)
  108. return results
  109. def requires_bwd_pass(out: Any) -> bool:
  110. if isinstance(out, torch.Tensor):
  111. return out.requires_grad
  112. elif isinstance(out, (list, tuple)):
  113. return any(requires_bwd_pass(x) for x in out)
  114. elif out is None:
  115. return False
  116. elif isinstance(out, int):
  117. return False
  118. raise NotImplementedError("Don't know how to reduce", type(out))
  119. @overload
  120. def reduce_to_scalar_loss(out: torch.Tensor) -> torch.Tensor: ...
  121. @overload
  122. def reduce_to_scalar_loss(
  123. out: Union[list[Any], tuple[Any, ...], dict[Any, Any]],
  124. ) -> float: ...
  125. def reduce_to_scalar_loss(out: Any) -> Union[torch.Tensor, float]:
  126. """Reduce the output of a model to get scalar loss"""
  127. if isinstance(out, torch.Tensor):
  128. # Mean does not work on integer tensors
  129. return out.sum() / out.numel()
  130. elif isinstance(out, (list, tuple)):
  131. return sum(reduce_to_scalar_loss(x) for x in out) / len(out)
  132. elif type(out).__name__ in (
  133. "MaskedLMOutput",
  134. "Seq2SeqLMOutput",
  135. "CausalLMOutputWithCrossAttentions",
  136. ):
  137. return reduce_to_scalar_loss(out.logits)
  138. elif type(out).__name__ == "SquashedNormal":
  139. return out.mean.sum()
  140. elif isinstance(out, dict):
  141. return sum(reduce_to_scalar_loss(value) for value in out.values()) / len(
  142. out.keys()
  143. )
  144. raise NotImplementedError("Don't know how to reduce", type(out))
  145. def debug_dir() -> str:
  146. path = os.path.join(os.path.dirname(__file__), "../debug")
  147. if not os.path.exists(path):
  148. os.mkdir(path)
  149. return path
  150. def debug_dump(name: str, code: types.CodeType, extra: str = "") -> None:
  151. with open(os.path.join(debug_dir(), name), "w") as fd:
  152. fd.write(
  153. f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
  154. )
  155. def debug_insert_nops(
  156. frame: DynamoFrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0
  157. ) -> ConvertFrameReturn:
  158. """used to debug jump updates"""
  159. def insert_nops(instructions: list[Any], code_options: Any) -> None:
  160. instructions.insert(0, create_instruction("NOP"))
  161. instructions.insert(0, create_instruction("NOP"))
  162. metrics_context = torch._dynamo.utils.get_metrics_context()
  163. with torch._dynamo.utils.dynamo_timed("debug_insert_nops"), metrics_context:
  164. if is_generator(frame.f_code):
  165. return ConvertFrameReturn()
  166. debug_checks(frame.f_code)
  167. code, _ = transform_code_object(frame.f_code, insert_nops)
  168. graph = OutputGraph(
  169. code_options={},
  170. compiler_fn=None,
  171. root_tx=None, # type: ignore[arg-type]
  172. export=False,
  173. export_constraints=[],
  174. frame_state={"_id": 0},
  175. # TODO: shouldn't this be f_locals/f_globals from frame?
  176. local_scope=locals(),
  177. global_scope=globals(),
  178. f_code=frame.f_code,
  179. torch_function_mode_stack=[],
  180. package=None,
  181. )
  182. return wrap_guarded_code(
  183. GuardedCode(
  184. code,
  185. CheckFunctionManager(frame.f_code, graph).guard_manager, # type: ignore[arg-type]
  186. CompileId(frame_id=0, frame_compile_id=0),
  187. )
  188. )
  189. class CompileCounter:
  190. def __init__(self) -> None:
  191. self.frame_count: Union[int, CompileCounterInt] = 0
  192. self.clear()
  193. def __call__(
  194. self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  195. ) -> Callable[..., Any]:
  196. self.frame_count += 1
  197. for node in gm.graph.nodes:
  198. if "call" in node.op:
  199. self.op_count += 1
  200. return gm.forward
  201. def clear(self) -> None:
  202. if config.debug_disable_compile_counter:
  203. self.frame_count = CompileCounterInt(0)
  204. else:
  205. self.frame_count = 0
  206. self.op_count = 0
  207. class CompileCounterWithBackend:
  208. def __init__(self, backend: str) -> None:
  209. self.frame_count: Union[int, CompileCounterInt] = 0
  210. self.backend = backend
  211. self.graphs: list[torch.fx.GraphModule] = []
  212. self.clear()
  213. def __call__(
  214. self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  215. ) -> Callable[..., Any]:
  216. from .backends.registry import lookup_backend
  217. self.frame_count += 1
  218. for node in gm.graph.nodes:
  219. if "call" in node.op:
  220. self.op_count += 1
  221. self.graphs.append(gm)
  222. return lookup_backend(self.backend)(gm, example_inputs)
  223. def clear(self) -> None:
  224. if config.debug_disable_compile_counter:
  225. self.frame_count = CompileCounterInt(0)
  226. else:
  227. self.frame_count = 0
  228. self.op_count = 0
  229. self.graphs = []
  230. # Equivalent to backend="eager", but also records graphs that
  231. # we can assert on
  232. class EagerAndRecordGraphs:
  233. def __init__(self) -> None:
  234. self.graphs: list[torch.fx.GraphModule] = []
  235. def __call__(
  236. self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  237. ) -> Callable[..., Any]:
  238. self.graphs.append(gm)
  239. return gm.forward
  240. class AotEagerAndRecordGraphs:
  241. def __init__(self) -> None:
  242. self.graphs: list[torch.fx.GraphModule] = []
  243. self.fw_graphs: list[torch.fx.GraphModule] = []
  244. self.bw_graphs: list[torch.fx.GraphModule] = []
  245. def __call__(
  246. self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  247. ) -> Callable[..., Any]:
  248. self.graphs.append(gm)
  249. def fw_compiler(
  250. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  251. ) -> Callable[..., Any]:
  252. self.fw_graphs.append(gm)
  253. return gm.forward
  254. def bw_compiler(
  255. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  256. ) -> Callable[..., Any]:
  257. self.bw_graphs.append(gm)
  258. return gm.forward
  259. return aot_eager(
  260. gm,
  261. example_inputs,
  262. fw_compiler=fw_compiler,
  263. bw_compiler=bw_compiler,
  264. )
  265. class InductorAndRecordGraphs:
  266. def __init__(self) -> None:
  267. self.graphs: list[torch.fx.GraphModule] = []
  268. self.inductor_graphs: list[torch.fx.GraphModule] = []
  269. def __call__(self, gm, example_inputs): # type: ignore[no-untyped-def]
  270. import torch._inductor.compile_fx as compile_fx_mod
  271. self.graphs.append(gm)
  272. old_compile_fx_inner = compile_fx_mod._compile_fx_inner
  273. def patched(*args, **kwargs): # type: ignore[no-untyped-def]
  274. self.inductor_graphs.append(args[0])
  275. return old_compile_fx_inner(*args, **kwargs)
  276. with patch.object(compile_fx_mod, "_compile_fx_inner", new=patched):
  277. return compile_fx_mod.compile_fx(gm, example_inputs)
  278. def strip_comment(code: str) -> str:
  279. return re.sub(r"(?m)^ *#.*\n?", "", code)
  280. def remove_trailing_space(code: str) -> str:
  281. return "\n".join([line.rstrip() for line in code.split("\n")])
  282. def _squash_blank_lines(code: str) -> str:
  283. lines = code.split("\n")
  284. result: list[str] = []
  285. saw_blank = False
  286. for line in lines:
  287. if line.strip() == "":
  288. if saw_blank:
  289. continue
  290. saw_blank = True
  291. else:
  292. saw_blank = False
  293. result.append(line)
  294. return "\n".join(result)
  295. def normalize_gm(gm_str: str) -> str:
  296. # strip comments as comments have path to files which may differ from
  297. # system to system.
  298. stripped = strip_comment(gm_str)
  299. no_trailing = remove_trailing_space(stripped)
  300. return _squash_blank_lines(no_trailing)
  301. def empty_line_normalizer(code: str) -> str:
  302. """
  303. Normalize code: remove empty lines.
  304. """
  305. normal_code = re.sub(r"[\r\n]+", "\n", code)
  306. return normal_code
  307. def standard_test(
  308. self: Any,
  309. fn: Callable[..., Any],
  310. nargs: int,
  311. expected_ops: Optional[int] = None,
  312. expected_ops_dynamic: Optional[int] = None,
  313. expected_frame_count: int = 1,
  314. ) -> None:
  315. if not config.assume_static_by_default and expected_ops_dynamic is not None:
  316. expected_ops = expected_ops_dynamic
  317. actual = CompileCounter()
  318. args1 = [torch.randn(10, 10) for _ in range(nargs)]
  319. args2 = [torch.randn(10, 10) for _ in range(nargs)]
  320. correct1 = fn(*args1)
  321. correct2 = fn(*args2)
  322. reset()
  323. opt_fn = optimize_assert(actual)(fn)
  324. val1a = opt_fn(*args1)
  325. val2a = opt_fn(*args2)
  326. val1b = opt_fn(*args1)
  327. val2b = opt_fn(*args2)
  328. reset()
  329. self.assertTrue(same(val1a, correct1))
  330. self.assertTrue(same(val1b, correct1))
  331. self.assertTrue(same(val2a, correct2))
  332. self.assertTrue(same(val2b, correct2))
  333. self.assertEqual(actual.frame_count, expected_frame_count)
  334. if expected_ops is not None:
  335. self.assertEqual(actual.op_count, expected_ops)
  336. def dummy_fx_compile(
  337. gm: fx.GraphModule, example_inputs: list[torch.Tensor]
  338. ) -> Callable[..., Any]:
  339. return gm.forward
  340. def format_speedup(
  341. speedup: float,
  342. pvalue: float,
  343. is_correct: bool = True,
  344. pvalue_threshold: float = 0.1,
  345. ) -> str:
  346. if not is_correct:
  347. return "ERROR"
  348. if pvalue > pvalue_threshold:
  349. return f"{speedup:.3f}x SAME"
  350. return f"{speedup:.3f}x p={pvalue:.2f}"
  351. def rand_strided(
  352. size: Sequence[int],
  353. stride: Sequence[int],
  354. dtype: torch.dtype = torch.float32,
  355. device: Union[str, torch.device] = "cpu",
  356. extra_size: int = 0,
  357. ) -> torch.Tensor:
  358. needed_size = extra_size
  359. if all(s > 0 for s in size):
  360. # only need to allocate if all sizes are non-zero
  361. needed_size += (
  362. sum((shape - 1) * stride for shape, stride in zip(size, stride)) + 1
  363. )
  364. if dtype.is_floating_point:
  365. if dtype.itemsize == 1:
  366. """
  367. normal distribution kernel is not implemented for fp8..
  368. Workaround that by creating a fp16 tensor and then cast.
  369. """
  370. buffer = torch.randn(needed_size, dtype=torch.float16, device=device).to(
  371. dtype=dtype
  372. )
  373. else:
  374. buffer = torch.randn(needed_size, dtype=dtype, device=device)
  375. else:
  376. buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device)
  377. return torch.as_strided(buffer, size, stride)
  378. _T = TypeVar("_T")
  379. def check_dynamic_shape_capture() -> bool:
  380. # This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
  381. return not config.assume_static_by_default
  382. def _make_fn_with_patches(fn: Callable[_P, _T], *patches: Any) -> Callable[_P, _T]:
  383. @functools.wraps(fn)
  384. def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  385. with contextlib.ExitStack() as stack:
  386. for module, attr, val in patches:
  387. stack.enter_context(patch.object(module, attr, val))
  388. return fn(*args, **kwargs)
  389. return _fn
  390. def make_test_cls_with_patches(
  391. cls: type,
  392. cls_prefix: str,
  393. fn_suffix: str,
  394. *patches: Any,
  395. xfail_prop: Optional[str] = None,
  396. decorator: Callable[[Callable[..., Any]], Callable[..., Any]] = lambda x: x,
  397. ) -> type:
  398. DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
  399. DummyTestClass.__qualname__ = DummyTestClass.__name__
  400. for name in dir(cls):
  401. if name.startswith("test_"):
  402. fn = getattr(cls, name)
  403. if not callable(fn):
  404. setattr(DummyTestClass, name, getattr(cls, name))
  405. continue
  406. new_name = f"{name}{fn_suffix}"
  407. new_fn = _make_fn_with_patches(fn, *patches)
  408. new_fn.__name__ = new_name
  409. if xfail_prop is not None and hasattr(fn, xfail_prop):
  410. new_fn = unittest.expectedFailure(new_fn)
  411. setattr(DummyTestClass, new_name, decorator(new_fn))
  412. # NB: Doesn't handle slots correctly, but whatever
  413. elif not hasattr(DummyTestClass, name):
  414. setattr(DummyTestClass, name, getattr(cls, name))
  415. return DummyTestClass
  416. # test Python 3.11+ specific features
  417. def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  418. if sys.version_info >= (3, 11):
  419. return fn
  420. # pyrefly: ignore [bad-return, bad-argument-type]
  421. return unittest.skip(fn)
  422. def skipIfNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  423. if sys.version_info >= (3, 12):
  424. return fn
  425. return unittest.skip("Requires Python 3.12+")(fn)
  426. def skipIfOnlyNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  427. if sys.version_info >= (3, 13) or sys.version_info < (3, 12):
  428. return unittest.skip("Requires Python 3.12")(fn)
  429. return fn
  430. def xfailIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  431. if sys.version_info >= (3, 12):
  432. return unittest.expectedFailure(fn)
  433. return fn
  434. def skipIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  435. if sys.version_info >= (3, 12):
  436. return unittest.skip("Not supported in Python 3.12+")(fn)
  437. return fn
  438. # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
  439. # and test/dynamo/test_dynamic_shapes.py
  440. def expectedFailureDynamic(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  441. fn._expected_failure_dynamic = True # type: ignore[attr-defined]
  442. return fn
  443. # Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py
  444. def expectedFailureCodegenDynamic(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  445. fn._expected_failure_codegen_dynamic = True # type: ignore[attr-defined]
  446. return fn
  447. # Controls test generated in test/inductor/test_cpp_wrapper.py
  448. def expectedFailureDynamicWrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  449. fn._expected_failure_dynamic_wrapper = True # type: ignore[attr-defined]
  450. return fn
  451. def reset_rng_state(use_xla: bool = False) -> None:
  452. torch.manual_seed(1337)
  453. random.seed(1337)
  454. if np:
  455. np.random.seed(1337)
  456. if use_xla:
  457. import torch_xla.core.xla_model as xm
  458. xm.set_rng_state(1337, str(xm.xla_device()))
  459. def _skipped_function_for_test_reconstruct(
  460. f: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
  461. ) -> _T:
  462. return f(*args, **kwargs)
  463. _testing_invoke_subgraph_inductor_compile_captured_gms = None
  464. @contextlib.contextmanager
  465. def _testing_capture_invoke_subgraph_inductor_compile_gms() -> Generator[
  466. list[torch.fx.GraphModule]
  467. ]:
  468. """
  469. Context manager to capture graph modules compiled by invoke_subgraph_inductor_compile.
  470. Usage:
  471. with _testing_capture_invoke_subgraph_inductor_compile_gms() as captured_gms:
  472. # code that triggers invoke_subgraph_inductor_compile
  473. pass
  474. # captured_gms will contain the list of captured graph modules
  475. """
  476. global _testing_invoke_subgraph_inductor_compile_captured_gms
  477. _testing_invoke_subgraph_inductor_compile_captured_gms = []
  478. try:
  479. yield _testing_invoke_subgraph_inductor_compile_captured_gms
  480. finally:
  481. _testing_invoke_subgraph_inductor_compile_captured_gms = None