csv.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import csv
  2. import logging
  3. from pathlib import Path
  4. from typing import TYPE_CHECKING, Dict, TextIO
  5. from ray.air.constants import EXPR_PROGRESS_FILE
  6. from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback
  7. from ray.tune.utils import flatten_dict
  8. from ray.util.annotations import Deprecated, PublicAPI
  9. if TYPE_CHECKING:
  10. from ray.tune.experiment.trial import Trial # noqa: F401
  11. logger = logging.getLogger(__name__)
  12. @Deprecated(
  13. message=_LOGGER_DEPRECATION_WARNING.format(
  14. old="CSVLogger", new="ray.tune.csv.CSVLoggerCallback"
  15. ),
  16. warning=True,
  17. )
  18. @PublicAPI
  19. class CSVLogger(Logger):
  20. """Logs results to progress.csv under the trial directory.
  21. Automatically flattens nested dicts in the result dict before writing
  22. to csv:
  23. {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
  24. """
  25. def _init(self):
  26. self._initialized = False
  27. def _maybe_init(self):
  28. """CSV outputted with Headers as first set of results."""
  29. if not self._initialized:
  30. progress_file = Path(self.logdir, EXPR_PROGRESS_FILE)
  31. self._continuing = (
  32. progress_file.exists() and progress_file.stat().st_size > 0
  33. )
  34. self._file = progress_file.open("a")
  35. self._csv_out = None
  36. self._initialized = True
  37. def on_result(self, result: Dict):
  38. self._maybe_init()
  39. tmp = result.copy()
  40. if "config" in tmp:
  41. del tmp["config"]
  42. result = flatten_dict(tmp, delimiter="/")
  43. if self._csv_out is None:
  44. self._csv_out = csv.DictWriter(self._file, result.keys())
  45. if not self._continuing:
  46. self._csv_out.writeheader()
  47. self._csv_out.writerow(
  48. {k: v for k, v in result.items() if k in self._csv_out.fieldnames}
  49. )
  50. self._file.flush()
  51. def flush(self):
  52. if self._initialized and not self._file.closed:
  53. self._file.flush()
  54. def close(self):
  55. if self._initialized:
  56. self._file.close()
  57. @PublicAPI
  58. class CSVLoggerCallback(LoggerCallback):
  59. """Logs results to progress.csv under the trial directory.
  60. Automatically flattens nested dicts in the result dict before writing
  61. to csv:
  62. {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
  63. """
  64. _SAVED_FILE_TEMPLATES = [EXPR_PROGRESS_FILE]
  65. def __init__(self):
  66. self._trial_continue: Dict["Trial", bool] = {}
  67. self._trial_files: Dict["Trial", TextIO] = {}
  68. self._trial_csv: Dict["Trial", csv.DictWriter] = {}
  69. def _setup_trial(self, trial: "Trial"):
  70. if trial in self._trial_files:
  71. self._trial_files[trial].close()
  72. # Make sure logdir exists
  73. trial.init_local_path()
  74. local_file_path = Path(trial.local_path, EXPR_PROGRESS_FILE)
  75. # Resume the file from remote storage.
  76. self._restore_from_remote(EXPR_PROGRESS_FILE, trial)
  77. self._trial_continue[trial] = (
  78. local_file_path.exists() and local_file_path.stat().st_size > 0
  79. )
  80. self._trial_files[trial] = local_file_path.open("at")
  81. self._trial_csv[trial] = None
  82. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  83. if trial not in self._trial_files:
  84. self._setup_trial(trial)
  85. tmp = result.copy()
  86. tmp.pop("config", None)
  87. result = flatten_dict(tmp, delimiter="/")
  88. if not self._trial_csv[trial]:
  89. self._trial_csv[trial] = csv.DictWriter(
  90. self._trial_files[trial], result.keys()
  91. )
  92. if not self._trial_continue[trial]:
  93. self._trial_csv[trial].writeheader()
  94. self._trial_csv[trial].writerow(
  95. {k: v for k, v in result.items() if k in self._trial_csv[trial].fieldnames}
  96. )
  97. self._trial_files[trial].flush()
  98. def log_trial_end(self, trial: "Trial", failed: bool = False):
  99. if trial not in self._trial_files:
  100. return
  101. del self._trial_csv[trial]
  102. self._trial_files[trial].close()
  103. del self._trial_files[trial]