effects.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from enum import Enum
  2. from typing import Optional
  3. import torch
  4. class EffectType(Enum):
  5. ORDERED = "Ordered"
  6. from torch._library.utils import RegistrationHandle
  7. # These classes do not have side effects as they just store quantization
  8. # params, so we dont need to mark them as ordered
  9. skip_classes = (
  10. "__torch__.torch.classes.quantized.Conv2dPackedParamsBase",
  11. "__torch__.torch.classes.quantized.Conv3dPackedParamsBase",
  12. "__torch__.torch.classes.quantized.EmbeddingPackedParamsBase",
  13. "__torch__.torch.classes.quantized.LinearPackedParamsBase",
  14. "__torch__.torch.classes.xnnpack.Conv2dOpContext",
  15. "__torch__.torch.classes.xnnpack.LinearOpContext",
  16. "__torch__.torch.classes.xnnpack.TransposeConv2dOpContext",
  17. )
  18. class EffectHolder:
  19. """A holder where one can register an effect impl to."""
  20. def __init__(self, qualname: str):
  21. self.qualname: str = qualname
  22. self._set_default_effect()
  23. def _set_default_effect(self) -> None:
  24. self._effect: Optional[EffectType] = None
  25. # If the op contains a ScriptObject input, we want to mark it as having effects
  26. namespace, opname = torch._library.utils.parse_namespace(self.qualname)
  27. split = opname.split(".")
  28. if len(split) > 1:
  29. if len(split) != 2:
  30. raise AssertionError(
  31. f"Tried to split {opname} based on '.' but found more than 1 '.'"
  32. )
  33. opname, overload = split
  34. else:
  35. overload = ""
  36. if namespace == "higher_order":
  37. return
  38. opname = f"{namespace}::{opname}"
  39. if torch._C._get_operation_overload(opname, overload) is not None:
  40. # Since we call this when destroying the library, sometimes the
  41. # schema will be gone already at that time.
  42. schema = torch._C._get_schema(opname, overload)
  43. for arg in schema.arguments:
  44. if isinstance(arg.type, torch.ClassType):
  45. type_str = arg.type.str() # pyrefly: ignore[missing-attribute]
  46. if type_str in skip_classes:
  47. continue
  48. self._effect = EffectType.ORDERED
  49. return
  50. @property
  51. def effect(self) -> Optional[EffectType]:
  52. return self._effect
  53. @effect.setter
  54. def effect(self, _):
  55. raise RuntimeError("Unable to directly set kernel.")
  56. def register(self, effect: Optional[EffectType]) -> RegistrationHandle:
  57. """Register an effect
  58. Returns a RegistrationHandle that one can use to de-register this
  59. effect.
  60. """
  61. self._effect = effect
  62. def deregister_effect():
  63. self._set_default_effect()
  64. handle = RegistrationHandle(deregister_effect)
  65. return handle