_meta_registrations.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import random
  2. import torch
  3. from torch._C._distributed_c10d import FakeWork
  4. used_ids: set[int] = set()
  5. def generate_unique_id() -> int:
  6. while True:
  7. new_id = random.randint(1, 10**9)
  8. if new_id not in used_ids:
  9. used_ids.add(new_id)
  10. return new_id
  11. # Function to create and return FakeWork object
  12. def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def]
  13. work = FakeWork()
  14. work.seq_id = generate_unique_id()
  15. fakework_script_obj = work.boxed()
  16. return (args[0], fakework_script_obj) if return_first_arg else fakework_script_obj
  17. # Dictionary mapping collective operations to their meta functions
  18. # All 20 ops from torch.csrc.distributed.c10d.Ops.cpp are included
  19. # _DEPRECATED_META_FUNCTIONS = {
  20. # "allreduce_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
  21. # "allgather_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
  22. # "allgather_into_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
  23. # "reduce_scatter_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
  24. # }
  25. _META_FUNCTIONS = {
  26. "broadcast_": lambda *args: create_fakework(args),
  27. "allreduce_": lambda *args: create_fakework(args),
  28. "allgather_": lambda *args: create_fakework(args),
  29. "_allgather_base_": lambda *args: create_fakework(args),
  30. "reduce_scatter_": lambda *args: create_fakework(args),
  31. "_reduce_scatter_base_": lambda *args: create_fakework(args),
  32. "reduce_": lambda *args: create_fakework(args, return_first_arg=False),
  33. "gather_": lambda *args: create_fakework(args, return_first_arg=False),
  34. "scatter_": lambda *args: create_fakework(args),
  35. "alltoall_": lambda *args: create_fakework(args),
  36. "alltoall_base_": lambda *args: create_fakework(args, return_first_arg=False),
  37. "barrier": lambda *args: create_fakework(args, return_first_arg=False),
  38. "monitored_barrier_": lambda *args: None,
  39. "send": lambda *args: create_fakework(args, return_first_arg=False),
  40. "recv_": lambda *args: create_fakework(args, return_first_arg=False),
  41. "recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False),
  42. }
  43. lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901
  44. for op, meta_func in _META_FUNCTIONS.items():
  45. lib_impl.impl(op, meta_func, "Meta")