predispatch.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. """
  7. This module contains pre-dispatch wrappers for functorch operations
  8. that enable proper tracing in PT2 non-strict export/compile fx graph.
  9. """
  10. from __future__ import annotations
  11. from typing import TYPE_CHECKING
  12. import torch
  13. from torch._C._functorch import (
  14. _add_batch_dim as _add_batch_dim_impl,
  15. _remove_batch_dim as _remove_batch_dim_impl,
  16. _vmap_decrement_nesting as _vmap_decrement_nesting_impl,
  17. _vmap_increment_nesting as _vmap_increment_nesting_impl,
  18. )
  19. if TYPE_CHECKING:
  20. import threading
  21. def _add_batch_dim(self: torch.Tensor, batch_dim: int, level: int) -> torch.Tensor:
  22. """
  23. Thin wrapper around torch._C._add_batch_dim that is used to proxy in
  24. PT2 export/compile fx graph
  25. """
  26. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  27. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  28. batch_dim = self.ndim + batch_dim if batch_dim < 0 else batch_dim
  29. if mode:
  30. return torch.overrides.handle_torch_function(
  31. _add_batch_dim, (self,), self, batch_dim, level
  32. )
  33. res = _add_batch_dim_impl(self, batch_dim, level)
  34. return res
  35. def _remove_batch_dim(
  36. self: torch.Tensor, level: int, batch_size: int, out_dim: int
  37. ) -> torch.Tensor:
  38. """
  39. Thin wrapper around torch._C._remove_batch_dim that is used to proxy in
  40. PT2 export/compile fx graph
  41. """
  42. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  43. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  44. if mode:
  45. return torch.overrides.handle_torch_function(
  46. _remove_batch_dim, (self,), self, level, batch_size, out_dim
  47. )
  48. res = _remove_batch_dim_impl(self, level, batch_size, out_dim)
  49. return res
  50. def _vmap_increment_nesting(batch_size: int, randomness: str) -> int:
  51. """
  52. Thin wrapper around torch._C._vmap_increment_nesting that is used
  53. to proxy in export/compile graph
  54. """
  55. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  56. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  57. if mode:
  58. return torch.overrides.handle_torch_function(
  59. _vmap_increment_nesting, (batch_size,), batch_size, randomness
  60. )
  61. res = _vmap_increment_nesting_impl(batch_size, randomness)
  62. return res
  63. def _vmap_decrement_nesting() -> int:
  64. """
  65. Thin wrapper around torch._C._vmap_increment_nesting that is used
  66. to proxy in export/compile graph
  67. """
  68. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  69. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  70. if mode:
  71. return torch.overrides.handle_torch_function(
  72. _vmap_decrement_nesting,
  73. (),
  74. )
  75. return _vmap_decrement_nesting_impl()
  76. # Global variables for lazy_load_decompositions
  77. DECOMPOSITIONS_LOADED: bool = False
  78. DECOMPOSITIONS_LOCK: threading.Lock | None = None
  79. VMAP_DECOMPOSITIONS_LIB: torch.library.Library | None = None
  80. def lazy_load_decompositions() -> None:
  81. """
  82. Lazy loading of vmap decompositions with pre-dispatch support.
  83. """
  84. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  85. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  86. if mode:
  87. return torch.overrides.handle_torch_function(lazy_load_decompositions, ())
  88. global DECOMPOSITIONS_LOADED, DECOMPOSITIONS_LOCK, VMAP_DECOMPOSITIONS_LIB
  89. if DECOMPOSITIONS_LOADED:
  90. return
  91. # Initialize lock if needed
  92. if DECOMPOSITIONS_LOCK is None:
  93. import threading
  94. DECOMPOSITIONS_LOCK = threading.Lock()
  95. with DECOMPOSITIONS_LOCK:
  96. if DECOMPOSITIONS_LOADED:
  97. return
  98. import os
  99. if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
  100. DECOMPOSITIONS_LOADED = True
  101. return
  102. # use an alternate way to register an operator into the decomposition table
  103. # _register_jit_decomposition doesn't work for some operators, e.g. addr,
  104. # because the Tensor types generated cannot be unioned by torchscript
  105. # decomp should be type OpOverload
  106. VMAP_DECOMPOSITIONS_LIB = torch.library.Library(
  107. "aten", "IMPL", "FuncTorchBatched"
  108. )
  109. from torch._decomp import decomposition_table
  110. def _register_python_decomposition_vmap(decomp: torch._ops.OpOverload) -> None:
  111. if VMAP_DECOMPOSITIONS_LIB is None:
  112. raise AssertionError("VMAP_DECOMPOSITIONS_LIB must not be None")
  113. if decomp in decomposition_table:
  114. VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
  115. else:
  116. raise RuntimeError(f"could not find decomposition for {decomp}")
  117. _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
  118. _register_python_decomposition_vmap(
  119. torch.ops.aten.smooth_l1_loss_backward.default
  120. )
  121. _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)
  122. _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)
  123. _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)
  124. _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)
  125. _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)
  126. _register_python_decomposition_vmap(torch.ops.aten.addr.default)
  127. DECOMPOSITIONS_LOADED = True