api.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. import json
  9. from dataclasses import asdict, dataclass, field
  10. from enum import Enum
  11. from typing import Union
  12. __all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"]
  13. EventMetadataValue = Union[str, int, float, bool, None]
  14. class EventSource(str, Enum):
  15. """Known identifiers of the event producers."""
  16. AGENT = "AGENT"
  17. WORKER = "WORKER"
  18. @dataclass
  19. class Event:
  20. """
  21. The class represents the generic event that occurs during the torchelastic job execution.
  22. The event can be any kind of meaningful action.
  23. Args:
  24. name: event name.
  25. source: the event producer, e.g. agent or worker
  26. timestamp: timestamp in milliseconds when event occurred.
  27. metadata: additional data that is associated with the event.
  28. """
  29. name: str
  30. source: EventSource
  31. timestamp: int = 0
  32. metadata: dict[str, EventMetadataValue] = field(default_factory=dict)
  33. def __str__(self):
  34. return self.serialize()
  35. @staticmethod
  36. def deserialize(data: Union[str, "Event"]) -> "Event":
  37. if isinstance(data, Event):
  38. return data
  39. if isinstance(data, str):
  40. data_dict = json.loads(data)
  41. data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
  42. # pyrefly: ignore [unbound-name]
  43. return Event(**data_dict)
  44. def serialize(self) -> str:
  45. return json.dumps(asdict(self))
  46. class NodeState(str, Enum):
  47. """The states that a node can be in rendezvous."""
  48. INIT = "INIT"
  49. RUNNING = "RUNNING"
  50. SUCCEEDED = "SUCCEEDED"
  51. FAILED = "FAILED"
  52. @dataclass
  53. class RdzvEvent:
  54. """
  55. Dataclass to represent any rendezvous event.
  56. Args:
  57. name: Event name. (E.g. Current action being performed)
  58. run_id: The run id of the rendezvous
  59. message: The message describing the event
  60. hostname: Hostname of the node
  61. pid: The process id of the node
  62. node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED)
  63. master_endpoint: The master endpoint for the rendezvous store, if known
  64. rank: The rank of the node, if known
  65. local_id: The local_id of the node, if defined in dynamic_rendezvous.py
  66. error_trace: Error stack trace, if this is an error event.
  67. """
  68. name: str
  69. run_id: str
  70. message: str
  71. hostname: str
  72. pid: int
  73. node_state: NodeState
  74. master_endpoint: str = ""
  75. rank: int | None = None
  76. local_id: int | None = None
  77. error_trace: str = ""
  78. def __str__(self):
  79. return self.serialize()
  80. @staticmethod
  81. def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent":
  82. if isinstance(data, RdzvEvent):
  83. return data
  84. if isinstance(data, str):
  85. data_dict = json.loads(data)
  86. data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
  87. # pyrefly: ignore [unbound-name]
  88. return RdzvEvent(**data_dict)
  89. def serialize(self) -> str:
  90. return json.dumps(asdict(self))