custom_ops.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # mypy: allow-untyped-defs
  2. import importlib
  3. import torch
  4. lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901
  5. lib.define(
  6. "access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor"
  7. )
  8. @torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd")
  9. # When running under torch.inference_mode(), we seem to skip AUtograd key
  10. # so we should desugar this op as soon as we start tracing to post-dispatch.
  11. @torch.library.impl(lib, "access_subclass_inner_tensor", "Python")
  12. def _access_subclass_inner_tensor(
  13. src_subclass_tensor: torch.Tensor, attr: str
  14. ) -> torch.Tensor:
  15. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  16. if not is_traceable_wrapper_subclass(src_subclass_tensor):
  17. raise AssertionError(
  18. f"Expected src_subclass_tensor to be a traceable wrapper subclass, "
  19. f"but got {type(src_subclass_tensor)}"
  20. )
  21. val = getattr(src_subclass_tensor, attr, None)
  22. if val is None or not isinstance(val, torch.Tensor):
  23. raise RuntimeError(
  24. f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
  25. )
  26. return val
  27. def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs):
  28. """
  29. Import a custom autograd function by string name and call it. This is pretty bad
  30. because:
  31. 1) There is no schema
  32. Ideally we should automatically wrap custom autograd functions with a custom op, but
  33. that is too much work because we need to schematize custom autograd functions. For now,
  34. we just hackily put it in the IR.
  35. """
  36. # Parse module and class name
  37. module_name, class_name = function_cls_name.rsplit(".", 1)
  38. # Import the module and get the class
  39. module = importlib.import_module(module_name)
  40. function_cls = getattr(module, class_name)
  41. if not hasattr(function_cls, "apply"):
  42. raise AssertionError(
  43. f"Expected function class {function_cls_name} to have 'apply' method"
  44. )
  45. return function_cls.apply(*args, **kwargs)