context.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """Context Keeper."""
  2. from __future__ import annotations
  3. import logging
  4. import threading
  5. from wandb.proto.wandb_internal_pb2 import Record, Result
  6. logger = logging.getLogger(__name__)
  7. class Context:
  8. _cancel_event: threading.Event
  9. # TODO(debug_context) add debug setting to enable this
  10. # _debug_record: Optional[Record]
  11. def __init__(self) -> None:
  12. self._cancel_event = threading.Event()
  13. # TODO(debug_context) see above
  14. # self._debug_record = None
  15. def cancel(self) -> None:
  16. self._cancel_event.set()
  17. @property
  18. def cancel_event(self) -> threading.Event:
  19. return self._cancel_event
  20. def context_id_from_record(record: Record) -> str:
  21. context_id = record.control.mailbox_slot
  22. return context_id
  23. def context_id_from_result(result: Result) -> str:
  24. context_id = result.control.mailbox_slot
  25. return context_id
  26. class ContextKeeper:
  27. _active_items: dict[str, Context]
  28. def __init__(self) -> None:
  29. self._active_items = {}
  30. def add_from_record(self, record: Record) -> Context | None:
  31. context_id = context_id_from_record(record)
  32. if not context_id:
  33. return None
  34. context_obj = self.add(context_id)
  35. # TODO(debug_context) see above
  36. # context_obj._debug_record = record
  37. return context_obj
  38. def add(self, context_id: str) -> Context:
  39. assert context_id
  40. context_obj = Context()
  41. self._active_items[context_id] = context_obj
  42. return context_obj
  43. def get(self, context_id: str) -> Context | None:
  44. item = self._active_items.get(context_id)
  45. return item
  46. def release(self, context_id: str) -> None:
  47. if not context_id:
  48. return
  49. _ = self._active_items.pop(context_id, None)
  50. def cancel(self, context_id: str) -> bool:
  51. item = self.get(context_id)
  52. if item:
  53. item.cancel()
  54. return True
  55. return False
  56. # TODO(debug_context) see above
  57. # def _debug_print_orphans(self, print_to_stdout: bool) -> None:
  58. # for context_id, context in self._active_items.items():
  59. # record = context._debug_record
  60. # record_type = record.WhichOneof("record_type") if record else "unknown"
  61. # message = (
  62. # f"Context: {context_id} {context.cancel_event.is_set()} {record_type}"
  63. # )
  64. # logger.warning(message)
  65. # if print_to_stdout:
  66. # print(message)