json.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from typing import TYPE_CHECKING, Dict, TextIO
  5. import numpy as np
  6. import ray.cloudpickle as cloudpickle
  7. from ray.air.constants import EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE, EXPR_RESULT_FILE
  8. from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback
  9. from ray.tune.utils.util import SafeFallbackEncoder
  10. from ray.util.annotations import Deprecated, PublicAPI
  11. if TYPE_CHECKING:
  12. from ray.tune.experiment.trial import Trial # noqa: F401
  13. logger = logging.getLogger(__name__)
  14. tf = None
  15. VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
  16. @Deprecated(
  17. message=_LOGGER_DEPRECATION_WARNING.format(
  18. old="JsonLogger", new="ray.tune.json.JsonLoggerCallback"
  19. ),
  20. warning=True,
  21. )
  22. @PublicAPI
  23. class JsonLogger(Logger):
  24. """Logs trial results in json format.
  25. Also writes to a results file and param.json file when results or
  26. configurations are updated. Experiments must be executed with the
  27. JsonLogger to be compatible with the ExperimentAnalysis tool.
  28. """
  29. def _init(self):
  30. self.update_config(self.config)
  31. local_file = Path(self.logdir, EXPR_RESULT_FILE)
  32. self.local_out = local_file.open("a")
  33. def on_result(self, result: Dict):
  34. json.dump(result, self, cls=SafeFallbackEncoder)
  35. self.write("\n")
  36. self.local_out.flush()
  37. def write(self, b):
  38. self.local_out.write(b)
  39. def flush(self):
  40. if not self.local_out.closed:
  41. self.local_out.flush()
  42. def close(self):
  43. self.local_out.close()
  44. def update_config(self, config: Dict):
  45. self.config = config
  46. config_out = Path(self.logdir, EXPR_PARAM_FILE)
  47. with open(config_out, "w") as f:
  48. json.dump(self.config, f, indent=2, sort_keys=True, cls=SafeFallbackEncoder)
  49. config_pkl = Path(self.logdir, EXPR_PARAM_PICKLE_FILE)
  50. with config_pkl.open("wb") as f:
  51. cloudpickle.dump(self.config, f)
  52. @PublicAPI
  53. class JsonLoggerCallback(LoggerCallback):
  54. """Logs trial results in json format.
  55. Also writes to a results file and param.json file when results or
  56. configurations are updated. Experiments must be executed with the
  57. JsonLoggerCallback to be compatible with the ExperimentAnalysis tool.
  58. """
  59. _SAVED_FILE_TEMPLATES = [EXPR_RESULT_FILE, EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE]
  60. def __init__(self):
  61. self._trial_configs: Dict["Trial", Dict] = {}
  62. self._trial_files: Dict["Trial", TextIO] = {}
  63. def log_trial_start(self, trial: "Trial"):
  64. if trial in self._trial_files:
  65. self._trial_files[trial].close()
  66. # Update config
  67. self.update_config(trial, trial.config)
  68. # Make sure logdir exists
  69. trial.init_local_path()
  70. local_file = Path(trial.local_path, EXPR_RESULT_FILE)
  71. # Resume the file from remote storage.
  72. self._restore_from_remote(EXPR_RESULT_FILE, trial)
  73. self._trial_files[trial] = local_file.open("at")
  74. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  75. if trial not in self._trial_files:
  76. self.log_trial_start(trial)
  77. json.dump(result, self._trial_files[trial], cls=SafeFallbackEncoder)
  78. self._trial_files[trial].write("\n")
  79. self._trial_files[trial].flush()
  80. def log_trial_end(self, trial: "Trial", failed: bool = False):
  81. if trial not in self._trial_files:
  82. return
  83. self._trial_files[trial].close()
  84. del self._trial_files[trial]
  85. def update_config(self, trial: "Trial", config: Dict):
  86. self._trial_configs[trial] = config
  87. config_out = Path(trial.local_path, EXPR_PARAM_FILE)
  88. with config_out.open("w") as f:
  89. json.dump(
  90. self._trial_configs[trial],
  91. f,
  92. indent=2,
  93. sort_keys=True,
  94. cls=SafeFallbackEncoder,
  95. )
  96. config_pkl = Path(trial.local_path, EXPR_PARAM_PICKLE_FILE)
  97. with config_pkl.open("wb") as f:
  98. cloudpickle.dump(self._trial_configs[trial], f)