| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- # mypy: allow-untyped-defs
- import copyreg
- import os.path as _osp
- import weakref
- import torch
- from torch.utils import (
- backcompat as backcompat,
- collect_env as collect_env,
- data as data,
- deterministic as deterministic,
- hooks as hooks,
- )
- from torch.utils.backend_registration import (
- generate_methods_for_privateuse1_backend,
- rename_privateuse1_backend,
- )
- from torch.utils.cpp_backtrace import get_cpp_backtrace
- from torch.utils.throughput_benchmark import ThroughputBenchmark
- def set_module(obj, mod):
- """
- Set the module attribute on a python object for a given object for nicer printing
- """
- if not isinstance(mod, str):
- raise TypeError("The mod argument should be a string")
- obj.__module__ = mod
- cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), "share", "cmake")
- def swap_tensors(t1, t2):
- """
- This function swaps the content of the two Tensor objects.
- At a high level, this will make t1 have the content of t2 while preserving
- its identity.
- This will not work if t1 and t2 have different slots.
- """
- # Ensure there are no weakrefs
- if weakref.getweakrefs(t1):
- raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
- if weakref.getweakrefs(t2):
- raise RuntimeError("Cannot swap t2 because it has weakref associated with it")
- t1_slots = set(copyreg._slotnames(t1.__class__)) # type: ignore[attr-defined]
- t2_slots = set(copyreg._slotnames(t2.__class__)) # type: ignore[attr-defined]
- if t1_slots != t2_slots:
- raise RuntimeError("Cannot swap t1 and t2 if they have different slots")
- def swap_attr(name):
- tmp = getattr(t1, name)
- setattr(t1, name, (getattr(t2, name)))
- setattr(t2, name, tmp)
- def error_pre_hook(grad_outputs):
- raise RuntimeError(
- "Trying to execute AccumulateGrad node that was poisoned by swap_tensors "
- "this can happen when you try to run backward on a tensor that was swapped. "
- "For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` "
- "you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) "
- "between running forward and backward. To resolve this, please only change the "
- "device/dtype before running forward (or after both forward and backward)."
- )
- def check_use_count(t, name="t1"):
- use_count = t._use_count()
- error_str = (
- f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} "
- f"make sure you are not holding references to the tensor in other places."
- )
- if use_count > 1:
- if use_count == 2 and t.is_leaf:
- accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node
- # Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge
- if t._use_count() == 2:
- accum_grad_node.register_prehook(error_pre_hook)
- else:
- raise RuntimeError(error_str)
- else:
- raise RuntimeError(error_str)
- check_use_count(t1, "t1")
- check_use_count(t2, "t2")
- # Swap the types
- # Note that this will fail if there are mismatched slots
- swap_attr("__class__")
- # Swap the dynamic attributes
- swap_attr("__dict__")
- # Swap the slots
- for slot in t1_slots:
- if hasattr(t1, slot) and hasattr(t2, slot):
- swap_attr(slot)
- elif hasattr(t1, slot):
- setattr(t2, slot, (getattr(t1, slot)))
- delattr(t1, slot)
- elif hasattr(t2, slot):
- setattr(t1, slot, (getattr(t2, slot)))
- delattr(t2, slot)
- # Swap the at::Tensor they point to
- torch._C._swap_tensor_impl(t1, t2)
|