| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- """
- This module contains pre-dispatch wrappers for functorch operations
- that enable proper tracing in PT2 non-strict export/compile fx graph.
- """
- from __future__ import annotations
- from typing import TYPE_CHECKING
- import torch
- from torch._C._functorch import (
- _add_batch_dim as _add_batch_dim_impl,
- _remove_batch_dim as _remove_batch_dim_impl,
- _vmap_decrement_nesting as _vmap_decrement_nesting_impl,
- _vmap_increment_nesting as _vmap_increment_nesting_impl,
- )
- if TYPE_CHECKING:
- import threading
- def _add_batch_dim(self: torch.Tensor, batch_dim: int, level: int) -> torch.Tensor:
- """
- Thin wrapper around torch._C._add_batch_dim that is used to proxy in
- PT2 export/compile fx graph
- """
- from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
- mode = _maybe_find_pre_dispatch_tf_mode_for_export()
- batch_dim = self.ndim + batch_dim if batch_dim < 0 else batch_dim
- if mode:
- return torch.overrides.handle_torch_function(
- _add_batch_dim, (self,), self, batch_dim, level
- )
- res = _add_batch_dim_impl(self, batch_dim, level)
- return res
- def _remove_batch_dim(
- self: torch.Tensor, level: int, batch_size: int, out_dim: int
- ) -> torch.Tensor:
- """
- Thin wrapper around torch._C._remove_batch_dim that is used to proxy in
- PT2 export/compile fx graph
- """
- from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
- mode = _maybe_find_pre_dispatch_tf_mode_for_export()
- if mode:
- return torch.overrides.handle_torch_function(
- _remove_batch_dim, (self,), self, level, batch_size, out_dim
- )
- res = _remove_batch_dim_impl(self, level, batch_size, out_dim)
- return res
- def _vmap_increment_nesting(batch_size: int, randomness: str) -> int:
- """
- Thin wrapper around torch._C._vmap_increment_nesting that is used
- to proxy in export/compile graph
- """
- from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
- mode = _maybe_find_pre_dispatch_tf_mode_for_export()
- if mode:
- return torch.overrides.handle_torch_function(
- _vmap_increment_nesting, (batch_size,), batch_size, randomness
- )
- res = _vmap_increment_nesting_impl(batch_size, randomness)
- return res
- def _vmap_decrement_nesting() -> int:
- """
- Thin wrapper around torch._C._vmap_increment_nesting that is used
- to proxy in export/compile graph
- """
- from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
- mode = _maybe_find_pre_dispatch_tf_mode_for_export()
- if mode:
- return torch.overrides.handle_torch_function(
- _vmap_decrement_nesting,
- (),
- )
- return _vmap_decrement_nesting_impl()
- # Global variables for lazy_load_decompositions
- DECOMPOSITIONS_LOADED: bool = False
- DECOMPOSITIONS_LOCK: threading.Lock | None = None
- VMAP_DECOMPOSITIONS_LIB: torch.library.Library | None = None
- def lazy_load_decompositions() -> None:
- """
- Lazy loading of vmap decompositions with pre-dispatch support.
- """
- from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
- mode = _maybe_find_pre_dispatch_tf_mode_for_export()
- if mode:
- return torch.overrides.handle_torch_function(lazy_load_decompositions, ())
- global DECOMPOSITIONS_LOADED, DECOMPOSITIONS_LOCK, VMAP_DECOMPOSITIONS_LIB
- if DECOMPOSITIONS_LOADED:
- return
- # Initialize lock if needed
- if DECOMPOSITIONS_LOCK is None:
- import threading
- DECOMPOSITIONS_LOCK = threading.Lock()
- with DECOMPOSITIONS_LOCK:
- if DECOMPOSITIONS_LOADED:
- return
- import os
- if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
- DECOMPOSITIONS_LOADED = True
- return
- # use an alternate way to register an operator into the decomposition table
- # _register_jit_decomposition doesn't work for some operators, e.g. addr,
- # because the Tensor types generated cannot be unioned by torchscript
- # decomp should be type OpOverload
- VMAP_DECOMPOSITIONS_LIB = torch.library.Library(
- "aten", "IMPL", "FuncTorchBatched"
- )
- from torch._decomp import decomposition_table
- def _register_python_decomposition_vmap(decomp: torch._ops.OpOverload) -> None:
- if VMAP_DECOMPOSITIONS_LIB is None:
- raise AssertionError("VMAP_DECOMPOSITIONS_LIB must not be None")
- if decomp in decomposition_table:
- VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
- else:
- raise RuntimeError(f"could not find decomposition for {decomp}")
- _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
- _register_python_decomposition_vmap(
- torch.ops.aten.smooth_l1_loss_backward.default
- )
- _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)
- _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)
- _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)
- _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)
- _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)
- _register_python_decomposition_vmap(torch.ops.aten.addr.default)
- DECOMPOSITIONS_LOADED = True
|