metadata.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import json
  2. from collections import deque
  3. from numbers import Number
  4. from typing import Optional, Tuple
  5. from ray.train._internal.checkpoint_manager import _CheckpointManager
  6. from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
  7. class _TrainingRunMetadata:
  8. """Serializable struct for holding runtime trial metadata.
  9. Runtime metadata is data that changes and is updated on runtime. This includes
  10. e.g. the last result, the currently available checkpoints, and the number
  11. of errors encountered for a trial.
  12. """
  13. def __init__(self, n_steps: Tuple[int] = (5, 10)):
  14. # General metadata
  15. self.start_time = None
  16. # Errors
  17. self.num_failures = 0
  18. self.num_failures_after_restore = 0
  19. self.error_filename = None
  20. self.pickled_error_filename = None
  21. # Results and metrics
  22. self.last_result = {}
  23. self.last_result_time = -float("inf")
  24. # stores in memory max/min/avg/last-n-avg/last result for each
  25. # metric by trial
  26. self.metric_analysis = {}
  27. self._n_steps = n_steps
  28. self.metric_n_steps = {}
  29. # Checkpoints
  30. self.checkpoint_manager: Optional[_CheckpointManager] = None
  31. self._cached_json = None
  32. def invalidate_cache(self):
  33. self._cached_json = None
  34. def update_metric(self, metric: str, value: Number, step: Optional[int] = 1):
  35. if metric not in self.metric_analysis:
  36. self.metric_analysis[metric] = {
  37. "max": value,
  38. "min": value,
  39. "avg": value,
  40. "last": value,
  41. }
  42. self.metric_n_steps[metric] = {}
  43. for n in self._n_steps:
  44. key = "last-{:d}-avg".format(n)
  45. self.metric_analysis[metric][key] = value
  46. # Store n as string for correct restore.
  47. self.metric_n_steps[metric][str(n)] = deque([value], maxlen=n)
  48. else:
  49. step = step or 1
  50. self.metric_analysis[metric]["max"] = max(
  51. value, self.metric_analysis[metric]["max"]
  52. )
  53. self.metric_analysis[metric]["min"] = min(
  54. value, self.metric_analysis[metric]["min"]
  55. )
  56. self.metric_analysis[metric]["avg"] = (
  57. 1 / step * (value + (step - 1) * self.metric_analysis[metric]["avg"])
  58. )
  59. self.metric_analysis[metric]["last"] = value
  60. for n in self._n_steps:
  61. key = "last-{:d}-avg".format(n)
  62. self.metric_n_steps[metric][str(n)].append(value)
  63. self.metric_analysis[metric][key] = sum(
  64. self.metric_n_steps[metric][str(n)]
  65. ) / len(self.metric_n_steps[metric][str(n)])
  66. self.invalidate_cache()
  67. def __setattr__(self, key, value):
  68. super().__setattr__(key, value)
  69. if key not in {"_cached_json"}:
  70. self.invalidate_cache()
  71. def get_json_state(self) -> str:
  72. if self._cached_json is None:
  73. data = self.__dict__
  74. data.pop("_cached_json", None)
  75. self._cached_json = json.dumps(data, indent=2, cls=TuneFunctionEncoder)
  76. return self._cached_json
  77. @classmethod
  78. def from_json_state(cls, json_state: str) -> "_TrainingRunMetadata":
  79. state = json.loads(json_state, cls=TuneFunctionDecoder)
  80. run_metadata = cls()
  81. run_metadata.__dict__.update(state)
  82. return run_metadata