tvm.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. """
  2. This module provides TVM backend integration for TorchDynamo.
  3. Apache TVM is a deep learning compiler framework that can optimize and execute
  4. models on various hardware backends. This module enables:
  5. - Compilation of PyTorch models to TVM's computation graphs
  6. - Multiple scheduling options:
  7. - Default scheduler
  8. - Auto-scheduler for automatic optimization
  9. - Meta-schedule for evolutionary search-based tuning
  10. - Hardware-specific optimizations:
  11. - CUDA GPU support
  12. - CPU support with LLVM targeting and architecture-specific tuning
  13. - Automatic detection of CPU capabilities (AVX2, AVX512)
  14. - Tensor conversion utilities between PyTorch and TVM formats
  15. - Configurable optimization levels and tuning trials
  16. The backend can be used with torch.compile():
  17. model = torch.compile(model, backend="tvm")
  18. """
  19. import functools
  20. import importlib
  21. import logging
  22. import os
  23. import sys
  24. import tempfile
  25. from collections.abc import Callable
  26. from pathlib import Path
  27. from types import MappingProxyType
  28. from typing import Any, Optional
  29. import torch
  30. from torch import fx
  31. from .common import device_from_inputs, fake_tensor_unsupported
  32. from .registry import register_backend
  33. log = logging.getLogger(__name__)
  34. @register_backend
  35. @fake_tensor_unsupported # type: ignore[arg-type]
  36. def tvm(
  37. gm: fx.GraphModule,
  38. example_inputs: list[torch.Tensor],
  39. *,
  40. options: Optional[MappingProxyType[str, Any]] = None,
  41. ) -> Callable[..., Any]:
  42. if options is None:
  43. options = MappingProxyType({"scheduler": None, "trials": 20000, "opt_level": 3})
  44. assert options is not None
  45. import tvm # type: ignore[import]
  46. from tvm import relay # type: ignore[import]
  47. from tvm.contrib import graph_executor # type: ignore[import]
  48. jit_mod = torch.jit.trace(gm, example_inputs)
  49. device = device_from_inputs(example_inputs)
  50. shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
  51. example_outputs = gm(*example_inputs)
  52. if len(example_outputs) == 0:
  53. log.warning("Explicitly fall back to eager due to zero output")
  54. return gm.forward
  55. mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
  56. if device.type == "cuda":
  57. dev = tvm.cuda(device.index)
  58. target = tvm.target.cuda()
  59. else:
  60. dev = tvm.cpu(0)
  61. target = tvm.target.Target(llvm_target())
  62. scheduler = options.get("scheduler", None)
  63. if scheduler is None:
  64. scheduler = os.environ.get("TVM_SCHEDULER", None)
  65. trials = options.get("trials", 20000)
  66. opt_level = options.get("opt_level", 3)
  67. if scheduler == "auto_scheduler":
  68. # pyrefly: ignore [import-error, missing-import]
  69. from tvm import auto_scheduler
  70. with (
  71. tempfile.NamedTemporaryFile() as log_file,
  72. auto_scheduler.ApplyHistoryBest(log_file),
  73. tvm.transform.PassContext(
  74. opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True}
  75. ),
  76. ):
  77. lib = relay.build(mod, target=target, params=params)
  78. elif scheduler == "meta_schedule":
  79. # pyrefly: ignore [import-error, missing-import]
  80. from tvm import meta_schedule as ms
  81. with tempfile.TemporaryDirectory() as work_dir:
  82. if device.type != "cuda":
  83. # meta_schedule needs num-cores to be specified
  84. # here we use the maximum core count
  85. target = tvm.target.Target(
  86. f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}"
  87. )
  88. # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch
  89. # once USE_PT_TVMDSOOP is updated and turned on by default in TVM.
  90. assert trials > 0
  91. database = ms.relay_integration.tune_relay(
  92. mod=mod,
  93. target=target,
  94. work_dir=work_dir,
  95. max_trials_global=trials,
  96. num_trials_per_iter=64,
  97. params=params,
  98. strategy="evolutionary",
  99. opt_level=opt_level,
  100. )
  101. lib = ms.relay_integration.compile_relay(
  102. database=database,
  103. mod=mod,
  104. target=target,
  105. params=params,
  106. opt_level=opt_level,
  107. )
  108. elif scheduler == "default" or not scheduler:
  109. # no autotuning
  110. with tvm.transform.PassContext(opt_level=opt_level):
  111. lib = relay.build(mod, target=target, params=params)
  112. else:
  113. raise NotImplementedError(
  114. "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
  115. "There are three available options: default, auto_scheduler and meta_schedule."
  116. )
  117. m = graph_executor.GraphModule(lib["default"](dev))
  118. def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor:
  119. """A helper function to transfer a NDArray to torch.tensor."""
  120. if nd_tensor.dtype == "bool":
  121. # DLPack does not support boolean so it can't be handled by
  122. # torch.utils.dlpack.from_pack. Workaround by going through
  123. # numpy, although this brings additional data copy overhead.
  124. return torch.from_numpy(nd_tensor.numpy())
  125. return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
  126. def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array:
  127. """A helper function to transfer a torch.tensor to NDArray."""
  128. if torch_tensor.dtype == torch.bool:
  129. # same reason as above, fallback to numpy conversion which
  130. # could introduce data copy overhead
  131. return tvm.nd.array(torch_tensor.cpu().numpy())
  132. return tvm.nd.from_dlpack(torch_tensor)
  133. def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]:
  134. args = [a.contiguous() for a in i_args]
  135. shape_info, _ = m.get_input_info()
  136. active_inputs = {name for name, _ in shape_info.items()}
  137. for idx, arg in enumerate(args, 0):
  138. if arg.dim() != 0:
  139. if arg.requires_grad:
  140. arg = arg.detach()
  141. inp_name = f"inp_{idx}"
  142. if inp_name not in active_inputs:
  143. log.warning(
  144. "input %s skipped as not found in tvm's runtime library",
  145. inp_name,
  146. )
  147. continue
  148. m.set_input(
  149. inp_name,
  150. to_tvm_tensor(arg),
  151. )
  152. m.run()
  153. return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())]
  154. return exec_tvm
  155. tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
  156. tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
  157. def has_tvm() -> bool:
  158. try:
  159. importlib.import_module("tvm")
  160. return True
  161. except ImportError:
  162. return False
  163. @functools.cache
  164. def llvm_target() -> str:
  165. if sys.platform == "linux":
  166. cpuinfo = Path("/proc/cpuinfo").read_text()
  167. if "avx512" in cpuinfo:
  168. return "llvm -mcpu=skylake-avx512"
  169. elif "avx2" in cpuinfo:
  170. return "llvm -mcpu=core-avx2"
  171. return "llvm"