pgo.py 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003
  1. """
  2. Profile Guided Optimization (PGO) implementation for Dynamo.
  3. This module provides functionality for caching and managing code state profiles
  4. that guide optimization decisions in Dynamo. It implements both local and remote
  5. caching mechanisms for storing profile information across runs, handles profile
  6. merging across distributed ranks, and manages the lifecycle of profile data
  7. during compilation. The profiles track dynamic vs static properties of tensors
  8. and help Dynamo make better specialization decisions.
  9. """
  10. from __future__ import annotations
  11. import base64
  12. import copy
  13. import dataclasses
  14. import enum
  15. import functools
  16. import logging
  17. import os
  18. import pickle
  19. import re
  20. import zlib
  21. from collections import defaultdict
  22. from typing import Optional, TYPE_CHECKING, TypeVar, Union
  23. from typing_extensions import override, Self
  24. import torch._dynamo.config
  25. import torch._utils_internal
  26. import torch.compiler.config
  27. import torch.distributed as dist
  28. from torch._dynamo.utils import (
  29. CompileEventLogger,
  30. dynamo_timed,
  31. set_feature_use,
  32. warn_once,
  33. )
  34. from torch._environment import is_fbcode
  35. from torch._logging._internal import trace_structured_artifact
  36. from torch.compiler._cache import (
  37. CacheArtifact,
  38. CacheArtifactFactory,
  39. CacheArtifactManager,
  40. )
  41. from torch.utils._ordered_set import OrderedSet
  42. if TYPE_CHECKING:
  43. import types
  44. from torch._dynamo.symbolic_convert import InstructionTranslator
  45. from torch._inductor.remote_cache import JsonDataTy, RemoteCache
  46. class ReservedWorkflowIdUserError(ValueError):
  47. pass
  48. log = logging.getLogger(__name__)
  49. LOCK_TIMEOUT = 10
  50. # How does in memory representation work? Concretely, this module is
  51. # responsible for holding GLOBAL state representing the state it holds, no
  52. # other copies permitted. So we retire frame_state entirely and store it
  53. # here. This should be reset when Dynamo is reset. We never GC information
  54. # (similar to how the filesystem doesn't get cleaned up except by tmp
  55. # cleaner), so the expectation is the information is relatively cheap and we
  56. # don't mind leaking it.
  57. # How exactly did we design the cache key? Here are some of the questions:
  58. #
  59. # - JOB_ID: Do we have a unique identifier for the "training run" (such that
  60. # it stays the same if we're running the same code, and changes if we're
  61. # running something different).
  62. #
  63. # - RANK: Are we sharing the cache across ranks, or does each rank get
  64. # an individual cache?
  65. #
  66. # We choose to require job_id for PGO cache. This is to prevent
  67. # situations where unrelated invocations of PyTorch unpredictably cause
  68. # changes to each other's behavior. With a job_id, at least you know there
  69. # is some "state" associated with it. (State dict might be another way to
  70. # tell if a run is related or not.) You can opt-in to YOLO everything
  71. # aliases everything by passing a shared job_id for all your invocations.
  72. #
  73. # We choose to NOT share PGO cache across ranks. With no RANK_SHARING, there
  74. # is never contention between runs, so we can leisurely update a bundle with
  75. # information we need. Because we are grouped by job_id, we can have a single
  76. # consolidated bundle for everything (or not; maybe worry about O(n^2) IO if
  77. # we updated every compile--let's just instrument this.) Can even take a
  78. # filelock for extra safety (expect no contention); expect 50ns overhead from
  79. # uncontended filelock.
  80. #
  81. # If we did share ranks, everyone is storming to modify the same cache files.
  82. # We can do this by having folks atomic write to a CAS-store and then having
  83. # readers do on-the-fly merging (this can be implemented in remote using
  84. # prefix iteration). As an optional optimization, one rank can be elected to
  85. # handling bundling post facto (ideally, this is done async, after quiescence,
  86. # without compiler collective need to wait for everyone to finish writing
  87. # their bits.) Not sure how you can avoid a listdir because if some rank shows
  88. # up with some new entries we need to pull them in ASAP (unless you want to
  89. # delay bundling).
  90. #
  91. # But compiler collectives fill a similar niche: compilers chat with each
  92. # other so rank 0 has collected everything. So elect rank 0 only to write the
  93. # bundle. Don't even need CAS-store atomic write; just one rank writing an
  94. # updating bundles. The point is that use compiler collectives to share
  95. # profiles across ranks, but use the PGO cache to persist profiles per rank
  96. # across attempts. No need to have one mechanism to do everything.
  97. @functools.cache
  98. def _hash_containing_file(filepath: str) -> str:
  99. # if the file does not exists we consider filepath to be the hash.
  100. if not os.path.exists(filepath):
  101. return filepath
  102. with open(filepath, "rb") as file:
  103. content = file.read()
  104. crc32_value = zlib.crc32(content)
  105. hash = format(crc32_value & 0xFFFFFFFF, "08x")
  106. return hash
  107. @dataclasses.dataclass(frozen=True)
  108. class CodeId:
  109. filename: str
  110. firstlineno: int
  111. name: str
  112. # When a job restart, the code can be copied to a different path than the previous attempt. In that case
  113. # self.filename will have a different value, we do not want to consider those differences. Instead we
  114. # hash the content of the file and use it as an identifier of the file.
  115. #
  116. # self.filename is kept in the object to give readable information/pointer to the actual file, in a local
  117. # code state it will refer to the first seen file path.
  118. file_hash: str
  119. # Exclude file name.
  120. def __eq__(self, other: object) -> bool:
  121. if not isinstance(other, CodeId):
  122. return False
  123. return (
  124. self.file_hash == other.file_hash
  125. and self.firstlineno == other.firstlineno
  126. and self.name == other.name
  127. )
  128. # Ensure if two CodeIds are the same, then they have the same hash by excluding filename.
  129. def __hash__(self) -> int:
  130. return hash((self.file_hash, self.name, self.firstlineno))
  131. def __str__(self) -> str:
  132. return f"hash({self.file_hash}){self.filename}:{self.firstlineno}:{self.name}"
  133. @staticmethod
  134. def make(code: types.CodeType) -> CodeId:
  135. return CodeId(
  136. code.co_filename,
  137. code.co_firstlineno,
  138. code.co_name,
  139. _hash_containing_file(code.co_filename),
  140. )
  141. @dataclasses.dataclass
  142. class CodeState:
  143. automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field(
  144. default_factory=lambda: defaultdict(FrameStateSizeEntry)
  145. )
  146. _INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
  147. _CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
  148. _LOGGED_DYNAMIC_ALLOWLIST: bool = False
  149. _KNOWN_DYNAMIC_SOURCES: set[str] = set()
  150. @dataclasses.dataclass(frozen=True)
  151. class InferStride:
  152. """
  153. Denotes the quantity stride[dim] * size[dim], which is what the stride would
  154. be for the next physical dimension that results in a contiguous layout.
  155. For example, given size = [2, 3], stride = [3, 1], we can replace this with
  156. stride = [InferStride(1), 1], because InferStride(1) = stride[1] * size[1] = 1 * 3 = 3
  157. Indirecting the representation in this way is important for the join operation
  158. on strides as if we join [2, 3][3, 1] and [2, 4][4, 1],
  159. we don't want [2, None][None, 1] which would get eventually symbolized into
  160. [2, s0][s1, 1] (notice that the relationship between s0 and s1 is broken).
  161. If we instead rewrite the expressions as InferStride so we have [2, 3][InferStride(1), 1]
  162. and [2, 4][InferStride(1), 1] we now join to [2, None][InferStride(1), 1] will
  163. result in [2, s0][s0, 1], as desired.
  164. """
  165. dim: int
  166. _T = TypeVar("_T")
  167. class AutoUnset(enum.Enum):
  168. """
  169. The identity element of our semilattice, a generic "don't know" element that
  170. is always subsumed when we get more information.
  171. """
  172. token = 0
  173. auto_unset = AutoUnset.token
  174. class AutoDynamic(enum.Enum):
  175. """
  176. The top element of our (bounded) semilattice, whenever you merge this with
  177. any other element you always get it again
  178. """
  179. token = 0
  180. auto_dynamic = AutoDynamic.token
  181. @dataclasses.dataclass
  182. class FrameStateSizeEntry:
  183. scalar: Union[int, AutoDynamic, AutoUnset] = dataclasses.field(default=auto_unset)
  184. # NB: We don't have cases where we have a known dimensionality but
  185. # we know NOTHING about the individual sizes
  186. size: Union[AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic], ...]] = (
  187. dataclasses.field(default=auto_unset)
  188. )
  189. stride: Union[
  190. AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic, InferStride], ...]
  191. ] = dataclasses.field(default=auto_unset)
  192. def render(self) -> str:
  193. # Special cases
  194. def render_single(s: Union[int, AutoDynamic, AutoUnset, InferStride]) -> str:
  195. if s is auto_dynamic:
  196. return "?"
  197. elif s is auto_unset:
  198. # This basically shouldn't happen, this is for debugging
  199. return "auto unset"
  200. elif isinstance(s, InferStride):
  201. return f"S({s.dim})"
  202. else:
  203. return str(s)
  204. def render_tuple(ss: tuple[Union[int, AutoDynamic, InferStride], ...]) -> str:
  205. return "[" + ", ".join(render_single(s) for s in ss) + "]"
  206. # Common cases
  207. if self.size is auto_dynamic and self.stride is auto_dynamic:
  208. if self.scalar is auto_dynamic:
  209. return "fully dynamic scalar or tensor"
  210. else:
  211. return f"scalar {self.scalar}"
  212. elif self.scalar is auto_dynamic:
  213. if isinstance(self.size, tuple) and isinstance(self.stride, tuple):
  214. return f"tensor size={render_tuple(self.size)} stride={render_tuple(self.stride)}"
  215. # Fallback
  216. return f"unusual {repr(self)}"
  217. def __post_init__(self) -> None:
  218. assert not isinstance(self.scalar, torch.SymInt), self.scalar
  219. if isinstance(self.size, tuple):
  220. for s in self.size:
  221. assert not isinstance(s, torch.SymInt), s
  222. if isinstance(self.stride, tuple):
  223. for s1 in self.stride:
  224. assert not isinstance(s1, torch.SymInt), s1
  225. def is_size_dynamic(self, dim: int) -> bool:
  226. if self.size is auto_dynamic:
  227. return True
  228. if self.size is auto_unset:
  229. return False
  230. return self.size[dim] is auto_dynamic
  231. def is_stride_dynamic(self, dim: int) -> bool:
  232. # At the moment, dynamic strides is a bit buggy. Good test case
  233. # here is `PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py
  234. # TestAutograd.test_gradcheck_jacobian_mismatch`
  235. #
  236. # This if statement preserves historical behavior, which is that we
  237. # ONLY make strides dynamic if the size is exactly static everywhere.
  238. # We could potentially relax this but in general we should be very
  239. # careful about when to infer dynamic strides.
  240. #
  241. # Actually, the existing algorithm is already somewhat problematic.
  242. # Suppose a tensor that is sometimes:
  243. # f32[2, 3, 5][15, 5, 1] and other times
  244. # f32[2, 3, 5][5, 10, 1] (specifically, dim 0 and 1 are physically transposed).
  245. # If we infer strides should be (DYNAMIC, DYNAMIC, 1). But this is
  246. # silly: we really should have just guarded on dim order.
  247. if not (
  248. isinstance(self.size, tuple) and all(type(s) is int for s in self.size)
  249. ):
  250. return False
  251. if self.stride is auto_dynamic:
  252. return True
  253. if self.stride is auto_unset:
  254. return False
  255. return self.stride[dim] is auto_dynamic
  256. @staticmethod
  257. def _munge_symint(xs: tuple[int, ...]) -> tuple[Union[AutoDynamic, int], ...]:
  258. return tuple(auto_dynamic if isinstance(x, torch.SymInt) else x for x in xs)
  259. @classmethod
  260. def make_scalar(cls, x: int) -> FrameStateSizeEntry:
  261. return FrameStateSizeEntry(scalar=x, size=auto_dynamic, stride=auto_dynamic)
  262. @classmethod
  263. def make_tensor(
  264. cls, size: tuple[int, ...], stride: tuple[int, ...]
  265. ) -> FrameStateSizeEntry:
  266. return FrameStateSizeEntry(
  267. scalar=auto_dynamic,
  268. size=cls._munge_symint(size),
  269. stride=cls._munge_symint(stride),
  270. )
  271. @classmethod
  272. def make_size(cls, size: tuple[int, ...]) -> FrameStateSizeEntry:
  273. return FrameStateSizeEntry(
  274. scalar=auto_unset,
  275. size=cls._munge_symint(size),
  276. stride=auto_unset,
  277. )
  278. @staticmethod
  279. def _merge_atom(x: _T, y: _T) -> Union[AutoDynamic, _T]:
  280. if x is auto_unset:
  281. return y
  282. if y is auto_unset:
  283. return x
  284. if x is auto_dynamic or y is auto_dynamic or x != y:
  285. return auto_dynamic
  286. return x
  287. @classmethod
  288. def _merge_atom_tup(
  289. cls,
  290. xs: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
  291. ys: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
  292. ) -> Union[AutoDynamic, AutoUnset, tuple[Union[AutoDynamic, _T], ...]]:
  293. if xs is auto_unset:
  294. return ys
  295. if ys is auto_unset:
  296. return xs
  297. if xs is auto_dynamic or ys is auto_dynamic:
  298. return auto_dynamic
  299. if len(xs) != len(ys):
  300. return auto_dynamic
  301. return tuple(cls._merge_atom(x, y) for x, y in zip(xs, ys))
  302. def __ior__(self, other: Self) -> Self:
  303. self.scalar = self._merge_atom(self.scalar, other.scalar)
  304. self.size = self._merge_atom_tup(self.size, other.size)
  305. self.stride = self._merge_atom_tup(self.stride, other.stride)
  306. return self
  307. def update_automatic_dynamic(
  308. tx: InstructionTranslator,
  309. name: str,
  310. entry: FrameStateSizeEntry,
  311. *,
  312. is_unspecialized_nn_module: bool = False,
  313. ) -> FrameStateSizeEntry:
  314. code_id = CodeId.make(tx.f_code)
  315. frame_state = get_code_state()[code_id]
  316. if torch._dynamo.config.automatic_dynamic_shapes:
  317. is_update = name in frame_state.automatic_dynamic
  318. mut_entry = frame_state.automatic_dynamic[name]
  319. old_entry = copy.copy(mut_entry)
  320. mut_entry |= entry
  321. # Do some logs (damn, I spend more code logging than I do actually doing
  322. # the updates lol)
  323. if is_update and old_entry.scalar != mut_entry.scalar:
  324. log.debug(
  325. "automatic dynamic int %s val %s != %s",
  326. name,
  327. entry.scalar,
  328. old_entry.scalar,
  329. )
  330. CompileEventLogger.instant(
  331. "automatic_dynamic",
  332. {
  333. "name": name,
  334. "dim_changed": "scalar",
  335. "reason": "scalar change",
  336. "cached": str(old_entry.scalar),
  337. "new": str(entry.scalar),
  338. },
  339. )
  340. if is_unspecialized_nn_module:
  341. log.info(
  342. "%s is converted to a symbolic integer. It is an attribute of a "
  343. "user defined nn module class. If you wish to keep it static, you can "
  344. "mark the nn module class as `torch._dynamo.mark_static`.",
  345. name,
  346. )
  347. def log_tup(
  348. tup_name: str, short_reason: str, long_reason: str, i: Optional[int] = None
  349. ) -> None:
  350. entry_tup = (
  351. getattr(entry, tup_name) if i is None else getattr(entry, tup_name)[i]
  352. )
  353. old_entry_tup = (
  354. getattr(old_entry, tup_name)
  355. if i is None
  356. else getattr(old_entry, tup_name)[i]
  357. )
  358. log.debug(
  359. "automatic dynamic %s %s %s %s != %s",
  360. tup_name,
  361. name,
  362. short_reason,
  363. # NB: We used to only report len(...) here for dim mismatch
  364. entry_tup,
  365. old_entry_tup,
  366. )
  367. CompileEventLogger.instant(
  368. "automatic_dynamic",
  369. {
  370. "name": name,
  371. "dim_changed": "all" if i is None else i,
  372. "reason": long_reason,
  373. "cached": str(old_entry_tup),
  374. "new": str(entry_tup),
  375. },
  376. )
  377. if is_update and old_entry.size != mut_entry.size:
  378. if isinstance(old_entry.size, tuple) and isinstance(entry.size, tuple):
  379. if len(old_entry.size) != len(entry.size):
  380. log_tup("size", "dim", "dimensionality change")
  381. else:
  382. for i in range(len(entry.size)):
  383. if old_entry.size[i] != entry.size[i]:
  384. log_tup("size", f"size({i})", "size change", i)
  385. else:
  386. log_tup("size", "other", "other")
  387. if is_update and old_entry.stride != mut_entry.stride:
  388. if isinstance(old_entry.stride, tuple) and isinstance(entry.stride, tuple):
  389. if len(old_entry.stride) != len(entry.stride):
  390. log_tup("stride", "dim", "dimensionality change")
  391. else:
  392. for i in range(len(entry.stride)):
  393. if old_entry.stride[i] != entry.stride[i]:
  394. log_tup("stride", f"stride({i})", "stride change", i)
  395. else:
  396. log_tup("stride", "other", "other")
  397. else:
  398. old_entry = frame_state.automatic_dynamic[name]
  399. log.debug(
  400. "automatic dynamic is off, overwriting int %s val %s -> %s",
  401. name,
  402. old_entry.scalar,
  403. entry.scalar,
  404. )
  405. frame_state.automatic_dynamic[name] = entry
  406. mut_entry = entry
  407. return mut_entry
  408. def process_automatic_dynamic(
  409. tx: InstructionTranslator,
  410. name: str,
  411. entry: FrameStateSizeEntry,
  412. *,
  413. is_unspecialized_nn_module: bool = False,
  414. ) -> FrameStateSizeEntry:
  415. if (st := tx.distributed_state) is None:
  416. return update_automatic_dynamic(
  417. tx,
  418. name,
  419. entry,
  420. is_unspecialized_nn_module=is_unspecialized_nn_module,
  421. )
  422. elif st.all_states is None:
  423. # Preflight, always pretend as if it's static. The point here
  424. # is we want to get through the preflight quickly, and static
  425. # will run faster. The preexisting frame state will get
  426. # applied anyway after we do compiler collectives.
  427. # TODO: I'm not sure if we should just bong the entire pgo
  428. # state here, it kind of depends if we're going to have other
  429. # things that talk in compiler collective. Also, the PGO
  430. # state, if we've already inferred something is automatic
  431. # dynamic, will have lost the actual input sizes, which might
  432. # be useful for debugging purposes (e.g., observing 0/1
  433. # specialization). Bonging the entire PGO state here would
  434. # let us delete this logic here; the compiler collective
  435. # would just directly update_automatic_dynamic
  436. st.local_state.automatic_dynamic[name] = entry
  437. return entry
  438. else:
  439. # Apply the updates. NB: all_states includes the local state
  440. # too.
  441. res = None
  442. for sub_state in st.all_states:
  443. if name in sub_state.automatic_dynamic:
  444. res = update_automatic_dynamic(
  445. tx,
  446. name,
  447. sub_state.automatic_dynamic[name],
  448. is_unspecialized_nn_module=is_unspecialized_nn_module,
  449. )
  450. assert res is not None
  451. return res
  452. def format_cache_key(key: str) -> str:
  453. # NB: We always use global rank for keys, even though they are overkill
  454. # for local only cache
  455. rank = None
  456. if dist.is_available() and dist.is_initialized():
  457. rank = dist.get_rank()
  458. tag = torch.compiler.config.cache_key_tag
  459. return f"{key}:{rank}:{tag}"
  460. def get_cache_key() -> Optional[str]:
  461. # TODO: info versions of these logs that log only once
  462. if torch.compiler.config.force_disable_caches:
  463. warn_once(
  464. "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches"
  465. )
  466. return None
  467. # NB: We namespace the cache keys so that only user-specified job id
  468. # can alias with each other.
  469. if (r := torch.compiler.config.job_id) is not None:
  470. if r.startswith("mast:"):
  471. raise ReservedWorkflowIdUserError(
  472. "torch.compiler.config.job_id with prefix 'mast:' is reserved for "
  473. "automatically generated job id associated with a specific MAST job "
  474. "name and version."
  475. )
  476. return format_cache_key(r)
  477. if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None:
  478. mast_job_name, mast_job_version = name_version
  479. return format_cache_key(f"mast:{mast_job_name}:{mast_job_version}")
  480. return None
  481. def get_extra_cache_key(sticky_key: str) -> Optional[str]:
  482. if torch.compiler.config.force_disable_caches:
  483. warn_once(
  484. "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches"
  485. )
  486. return None
  487. return format_cache_key(sticky_key)
  488. # This solely controls local PGO
  489. def code_state_path(cache_key: str) -> Optional[str]:
  490. if not torch._dynamo.config.automatic_dynamic_local_pgo:
  491. log.debug("automatic_dynamic_local_pgo not enabled")
  492. return None
  493. from torch._inductor.runtime.runtime_utils import cache_dir
  494. code_state_key = re.sub(r'[<>:"/\\|?*]', "_", f"code_state_{cache_key}.pkl")
  495. return os.path.join(cache_dir(), "dynamo", code_state_key)
  496. def should_use_remote_dynamo_pgo_cache() -> bool:
  497. if torch.compiler.config.force_disable_caches:
  498. return False
  499. if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None:
  500. return r
  501. if not is_fbcode():
  502. return False
  503. if torch._utils_internal.is_fb_unit_test():
  504. return False
  505. try:
  506. from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
  507. except ModuleNotFoundError:
  508. return False
  509. return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
  510. "pytorch/remote_cache:dynamo_pgo_version"
  511. )
  512. def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
  513. from torch._inductor.remote_cache import create_cache
  514. if not should_use_remote_dynamo_pgo_cache():
  515. return None
  516. return create_cache(
  517. "dynamo-pgo",
  518. is_fbcode(),
  519. "FbRemoteDynamoPGOCache",
  520. "RemoteDynamoPGOCache",
  521. )
  522. def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]:
  523. dynamic_sources: OrderedSet[str] = OrderedSet()
  524. for src, fs in code_state.automatic_dynamic.items():
  525. dynamic = False
  526. if isinstance(fs.size, tuple):
  527. dynamic = auto_dynamic in fs.size # type: ignore[operator]
  528. elif fs.scalar == auto_dynamic:
  529. dynamic = True
  530. if dynamic:
  531. dynamic_sources.add(src)
  532. return dynamic_sources
  533. def _collect_missing_sources(all_sources: OrderedSet[str]) -> OrderedSet[str]:
  534. from torch._dynamo.variables.builder import is_dynamic_source
  535. global _KNOWN_DYNAMIC_SOURCES
  536. missing_sources: OrderedSet[str] = OrderedSet()
  537. for src in all_sources:
  538. if src in _KNOWN_DYNAMIC_SOURCES:
  539. continue
  540. elif is_dynamic_source(src):
  541. _KNOWN_DYNAMIC_SOURCES.add(src)
  542. continue
  543. missing_sources.add(src)
  544. return missing_sources
  545. def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None:
  546. global _KNOWN_DYNAMIC_SOURCES
  547. code_id = CodeId.make(f_code)
  548. frame_state = get_code_state()[code_id]
  549. all_dynamic_sources = _collect_dynamic_sources(frame_state)
  550. frame_whitelist = ",".join(all_dynamic_sources)
  551. missing_whitelist = ",".join(_collect_missing_sources(all_dynamic_sources))
  552. if frame_whitelist:
  553. with dynamo_timed(name := "pgo.dynamic_whitelist", log_pt2_compile_event=True):
  554. CompileEventLogger.pt2_compile(
  555. name,
  556. recompile_dynamic_whitelist=frame_whitelist,
  557. missing_dynamic_whitelist=missing_whitelist,
  558. )
  559. def _log_size_mismatch_recompile() -> None:
  560. global _LOGGED_DYNAMIC_ALLOWLIST
  561. if not _LOGGED_DYNAMIC_ALLOWLIST:
  562. torch._utils_internal.add_mlhub_insight(
  563. category="dynamic_shapes_analysis",
  564. insight="Dynamic shape recompilation detected",
  565. insight_description="PGO detected a recompilation due to dynamic shapes. \
  566. Please follow the instruction from the action link to reduce \
  567. recompilation overhead.",
  568. )
  569. # add mlhub insight only once per rank
  570. _LOGGED_DYNAMIC_ALLOWLIST = True
  571. def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str:
  572. code_state_str = "\n".join(
  573. f"{k}:\n"
  574. + "\n".join(
  575. f" {src}: {fs.render()}" for src, fs in v.automatic_dynamic.items()
  576. )
  577. for k, v in cs.items()
  578. )
  579. dynamic_sources: OrderedSet[str] = OrderedSet()
  580. for state in cs.values():
  581. dynamic_sources.update(_collect_dynamic_sources(state))
  582. if dynamic_sources:
  583. code_state_str += (
  584. "\n\nPGO detected a recompilation due to dynamic shapes. "
  585. "To reduce shape recompilations by compiling dynamically to start, "
  586. f'set environment variable TORCH_COMPILE_DYNAMIC_SOURCES="{",".join(dynamic_sources)}"'
  587. )
  588. return code_state_str
  589. @CacheArtifactFactory.register
  590. class PGOCacheArtifact(CacheArtifact):
  591. @override
  592. def populate_cache(self) -> None:
  593. meta = write_local_impl(
  594. self._rewrite_cache_key_for_mega_cache(self.key), self.content
  595. )
  596. assert meta is not None
  597. @override
  598. @staticmethod
  599. def type() -> str:
  600. return "pgo"
  601. @staticmethod
  602. def _rewrite_cache_key_for_mega_cache(original_key: str) -> str:
  603. """
  604. The PGO cache artifact key for a MAST job contains the job name and the version.
  605. When we want to use the cache artifact on a different MAST job, we need to
  606. update the key to use the new MAST job's name and version.
  607. """
  608. if not original_key.startswith("mast:"):
  609. # if original_key is overridden, then dont change it
  610. return original_key
  611. if (new_key := get_cache_key()) is not None:
  612. return new_key
  613. return original_key
  614. def hit(key: str, ty: str) -> defaultdict[CodeId, CodeState]:
  615. global _INIT_CODE_STATE
  616. assert isinstance(_CODE_STATE, defaultdict)
  617. log.info("get_code_state %s hit %s, %d entries", key, ty, len(_CODE_STATE))
  618. trace_structured_artifact(
  619. f"get_{ty}_code_state",
  620. "string",
  621. lambda: render_code_state(_CODE_STATE), # type: ignore[arg-type]
  622. )
  623. set_feature_use("pgo", True)
  624. _INIT_CODE_STATE = copy.deepcopy(_CODE_STATE)
  625. return _CODE_STATE
  626. def get_local_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]:
  627. global _CODE_STATE
  628. path = code_state_path(cache_key)
  629. if path is not None and os.path.exists(path):
  630. with dynamo_timed(
  631. name := "pgo.get_local_code_state", log_pt2_compile_event=True
  632. ):
  633. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  634. # Read lock not necessary as we always write atomically write to
  635. # the actual location
  636. with open(path, "rb") as f:
  637. try:
  638. content = f.read()
  639. _CODE_STATE = pickle.loads(content)
  640. CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell())
  641. except Exception:
  642. log.warning(
  643. "get_code_state failed while reading %s", path, exc_info=True
  644. )
  645. else:
  646. CacheArtifactManager.record_artifact(
  647. PGOCacheArtifact.type(), cache_key, content
  648. )
  649. return hit(path, "local")
  650. return None
  651. def lookup_remote_cache_entry(
  652. remote_cache: RemoteCache[JsonDataTy],
  653. cache_key: str,
  654. event_name: Optional[str] = None,
  655. ) -> Optional[defaultdict[CodeId, CodeState]]:
  656. code_state = None
  657. try:
  658. cache_data = remote_cache.get(cache_key)
  659. except Exception:
  660. log.warning("get_code_state failed remote read on %s", cache_key, exc_info=True)
  661. else:
  662. if cache_data is not None:
  663. try:
  664. assert isinstance(cache_data, dict)
  665. data = cache_data["data"]
  666. assert isinstance(data, str)
  667. payload = base64.b64decode(data)
  668. if event_name is not None:
  669. CompileEventLogger.pt2_compile(
  670. event_name, cache_size_bytes=len(payload)
  671. )
  672. code_state = pickle.loads(payload)
  673. except Exception:
  674. log.warning(
  675. "get_code_state failed parsing remote result on %s",
  676. cache_key,
  677. exc_info=True,
  678. )
  679. else:
  680. CacheArtifactManager.record_artifact(
  681. PGOCacheArtifact.type(), cache_key, payload
  682. )
  683. else:
  684. log.info("get_code_state remote miss on %s", cache_key)
  685. return code_state
  686. def get_remote_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]:
  687. global _CODE_STATE
  688. remote_cache = get_remote_cache()
  689. if remote_cache is not None:
  690. with dynamo_timed(
  691. name := "pgo.get_remote_code_state",
  692. log_pt2_compile_event=True,
  693. dynamo_compile_column_us="pgo_get_remote_code_state_time_us",
  694. ):
  695. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  696. code_state = lookup_remote_cache_entry(remote_cache, cache_key, name)
  697. if code_state is not None:
  698. _CODE_STATE = code_state
  699. return hit(cache_key, "remote")
  700. return None
  701. def get_extra_remote_code_state(cache_key: str) -> None:
  702. """
  703. Reads an additional PGO profile from the given cache key, and merges it with the default PGO profile.
  704. """
  705. global _CODE_STATE
  706. assert _CODE_STATE is not None
  707. remote_cache = get_remote_cache()
  708. if remote_cache is not None:
  709. with dynamo_timed(
  710. name := "pgo.get_extra_remote_code_state",
  711. log_pt2_compile_event=True,
  712. dynamo_compile_column_us="pgo_get_remote_code_state_time_us",
  713. ):
  714. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  715. code_state = lookup_remote_cache_entry(remote_cache, cache_key)
  716. log.info(
  717. "get_extra_code_state %s hit, %d entries",
  718. cache_key,
  719. len(code_state) if code_state is not None else 0,
  720. )
  721. if code_state is not None:
  722. assert not _CODE_STATE
  723. _CODE_STATE = code_state
  724. # log to tlparse
  725. trace_structured_artifact(
  726. "get_extra_remote_code_state",
  727. "string",
  728. lambda: render_code_state(code_state),
  729. )
  730. def get_code_state() -> defaultdict[CodeId, CodeState]:
  731. global _CODE_STATE, _INIT_CODE_STATE
  732. if _CODE_STATE is not None:
  733. return _CODE_STATE
  734. # Initialize it (even if we don't look up profile)
  735. _CODE_STATE = defaultdict(CodeState)
  736. cache_key = get_cache_key()
  737. if cache_key is None:
  738. return _CODE_STATE
  739. # Attempt local
  740. local_code_state = get_local_code_state(cache_key)
  741. # Attempt remote
  742. if local_code_state is None:
  743. get_remote_code_state(cache_key)
  744. # Attempt additional remote if neither local/default remote succeeded
  745. if (
  746. not _CODE_STATE
  747. and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
  748. ):
  749. extra_read_key = get_extra_cache_key(sticky_read)
  750. if extra_read_key is not None:
  751. get_extra_remote_code_state(extra_read_key)
  752. log.info("get_code_state using default")
  753. assert _CODE_STATE is not None
  754. return _CODE_STATE
  755. def put_code_state() -> None:
  756. if _CODE_STATE is None:
  757. log.info("put_code_state: never initialized, will not write")
  758. return
  759. if _CODE_STATE == _INIT_CODE_STATE:
  760. log.info("put_code_state: no change, skipping")
  761. return
  762. cache_key = get_cache_key()
  763. if cache_key is None:
  764. log.info("put_code_state: no cache key, skipping")
  765. return
  766. put_local_code_state(cache_key)
  767. put_remote_code_state(cache_key)
  768. if (sticky_write := torch.compiler.config.pgo_extra_write_key) is not None:
  769. extra_write_key = get_extra_cache_key(sticky_write)
  770. if extra_write_key is not None:
  771. put_remote_code_state(extra_write_key)
  772. def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[tuple[str, int]]:
  773. path = code_state_path(cache_key)
  774. if path is None:
  775. return None
  776. # If the user isn't misusing our API, we should have exclusive access to
  777. # this directory. But it's not too hard
  778. tmp_path = path + ".tmp"
  779. lock_path = path + ".lock"
  780. # We /mostly/ don't need the lock but the tmp file could be clobbered
  781. # TODO: use a safe tempfile create to eliminate lock
  782. from torch.utils._filelock import FileLock
  783. os.makedirs(os.path.dirname(path), exist_ok=True)
  784. with FileLock(lock_path, timeout=LOCK_TIMEOUT):
  785. with open(tmp_path, "wb") as f:
  786. f.write(pickled_code)
  787. size = f.tell()
  788. os.replace(tmp_path, path)
  789. return path, size
  790. def put_local_code_state(cache_key: str) -> None:
  791. with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True):
  792. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  793. assert _CODE_STATE is not None
  794. pickled_code = pickle.dumps(_CODE_STATE)
  795. CacheArtifactManager.record_artifact(
  796. PGOCacheArtifact.type(), cache_key, pickled_code
  797. )
  798. meta = write_local_impl(cache_key, pickled_code)
  799. if meta is None:
  800. log.info("put_code_state: local cache disabled")
  801. return
  802. path, size = meta
  803. CompileEventLogger.pt2_compile(name, cache_size_bytes=size)
  804. log.info("put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE))
  805. trace_structured_artifact(
  806. "put_local_code_state",
  807. "string",
  808. lambda: render_code_state(_CODE_STATE),
  809. )
  810. def put_remote_code_state(cache_key: str, extra_code_state: bool = False) -> None:
  811. event_name = (
  812. "put_remote_code_state"
  813. if not extra_code_state
  814. else "put_extra_remote_code_state"
  815. )
  816. with dynamo_timed(
  817. name := f"pgo.{event_name}",
  818. log_pt2_compile_event=True,
  819. dynamo_compile_column_us="pgo_put_remote_code_state_time_us",
  820. ):
  821. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  822. assert _CODE_STATE is not None
  823. remote_cache = get_remote_cache()
  824. if remote_cache is None:
  825. log.info("%s: remote cache disabled", event_name)
  826. return
  827. content = pickle.dumps(_CODE_STATE)
  828. CompileEventLogger.pt2_compile(name, cache_size_bytes=len(content))
  829. cache_data: JsonDataTy = {
  830. "data": base64.b64encode(content).decode("ascii"),
  831. }
  832. remote_cache.put(cache_key, cache_data)
  833. log.info(
  834. "%s: wrote remote %s, %d entries", event_name, cache_key, len(_CODE_STATE)
  835. )
  836. # TODO: don't log this multiple times
  837. trace_structured_artifact(
  838. event_name,
  839. "string",
  840. lambda: render_code_state(_CODE_STATE),
  841. )
  842. # NB: this does NOT reset the cached code state on disk
  843. def reset_code_state() -> None:
  844. global _CODE_STATE, _INIT_CODE_STATE, _LOGGED_DYNAMIC_ALLOWLIST
  845. _CODE_STATE = None
  846. _INIT_CODE_STATE = None
  847. _LOGGED_DYNAMIC_ALLOWLIST = False