tensorboardx.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import logging
  2. from typing import TYPE_CHECKING, Dict
  3. import numpy as np
  4. from ray.air.constants import TRAINING_ITERATION
  5. from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback
  6. from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL
  7. from ray.tune.utils import flatten_dict
  8. from ray.util.annotations import Deprecated, PublicAPI
  9. from ray.util.debug import log_once
  10. if TYPE_CHECKING:
  11. from ray.tune.experiment.trial import Trial # noqa: F401
  12. logger = logging.getLogger(__name__)
  13. VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
  14. @Deprecated(
  15. message=_LOGGER_DEPRECATION_WARNING.format(
  16. old="TBXLogger", new="ray.tune.tensorboardx.TBXLoggerCallback"
  17. ),
  18. warning=True,
  19. )
  20. @PublicAPI
  21. class TBXLogger(Logger):
  22. """TensorBoardX Logger.
  23. Note that hparams will be written only after a trial has terminated.
  24. This logger automatically flattens nested dicts to show on TensorBoard:
  25. {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
  26. """
  27. VALID_HPARAMS = (str, bool, int, float, list, type(None))
  28. VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
  29. def _init(self):
  30. try:
  31. from tensorboardX import SummaryWriter
  32. except ImportError:
  33. if log_once("tbx-install"):
  34. logger.info('pip install "ray[tune]" to see TensorBoard files.')
  35. raise
  36. self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
  37. self.last_result = None
  38. def on_result(self, result: Dict):
  39. step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
  40. tmp = result.copy()
  41. for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
  42. if k in tmp:
  43. del tmp[k] # not useful to log these
  44. flat_result = flatten_dict(tmp, delimiter="/")
  45. path = ["ray", "tune"]
  46. valid_result = {}
  47. for attr, value in flat_result.items():
  48. full_attr = "/".join(path + [attr])
  49. if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not np.isnan(value):
  50. valid_result[full_attr] = value
  51. self._file_writer.add_scalar(full_attr, value, global_step=step)
  52. elif (isinstance(value, list) and len(value) > 0) or (
  53. isinstance(value, np.ndarray) and value.size > 0
  54. ):
  55. valid_result[full_attr] = value
  56. # Must be a single image.
  57. if isinstance(value, np.ndarray) and value.ndim == 3:
  58. self._file_writer.add_image(
  59. full_attr,
  60. value,
  61. global_step=step,
  62. )
  63. continue
  64. # Must be a batch of images.
  65. if isinstance(value, np.ndarray) and value.ndim == 4:
  66. self._file_writer.add_images(
  67. full_attr,
  68. value,
  69. global_step=step,
  70. )
  71. continue
  72. # Must be video
  73. if isinstance(value, np.ndarray) and value.ndim == 5:
  74. self._file_writer.add_video(
  75. full_attr, value, global_step=step, fps=20
  76. )
  77. continue
  78. try:
  79. self._file_writer.add_histogram(full_attr, value, global_step=step)
  80. # In case TensorboardX still doesn't think it's a valid value
  81. # (e.g. `[[]]`), warn and move on.
  82. except (ValueError, TypeError):
  83. if log_once("invalid_tbx_value"):
  84. logger.warning(
  85. "You are trying to log an invalid value ({}={}) "
  86. "via {}!".format(full_attr, value, type(self).__name__)
  87. )
  88. self.last_result = valid_result
  89. self._file_writer.flush()
  90. def flush(self):
  91. if self._file_writer is not None:
  92. self._file_writer.flush()
  93. def close(self):
  94. if self._file_writer is not None:
  95. if self.trial and self.trial.evaluated_params and self.last_result:
  96. flat_result = flatten_dict(self.last_result, delimiter="/")
  97. scrubbed_result = {
  98. k: value
  99. for k, value in flat_result.items()
  100. if isinstance(value, tuple(VALID_SUMMARY_TYPES))
  101. }
  102. self._try_log_hparams(scrubbed_result)
  103. self._file_writer.close()
  104. def _try_log_hparams(self, result):
  105. # TBX currently errors if the hparams value is None.
  106. flat_params = flatten_dict(self.trial.evaluated_params)
  107. scrubbed_params = {
  108. k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
  109. }
  110. np_params = {
  111. k: v.tolist()
  112. for k, v in flat_params.items()
  113. if isinstance(v, self.VALID_NP_HPARAMS)
  114. }
  115. scrubbed_params.update(np_params)
  116. removed = {
  117. k: v
  118. for k, v in flat_params.items()
  119. if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
  120. }
  121. if removed:
  122. logger.info(
  123. "Removed the following hyperparameter values when "
  124. "logging to tensorboard: %s",
  125. str(removed),
  126. )
  127. from tensorboardX.summary import hparams
  128. try:
  129. experiment_tag, session_start_tag, session_end_tag = hparams(
  130. hparam_dict=scrubbed_params, metric_dict=result
  131. )
  132. self._file_writer.file_writer.add_summary(experiment_tag)
  133. self._file_writer.file_writer.add_summary(session_start_tag)
  134. self._file_writer.file_writer.add_summary(session_end_tag)
  135. except Exception:
  136. logger.exception(
  137. "TensorboardX failed to log hparams. "
  138. "This may be due to an unsupported type "
  139. "in the hyperparameter values."
  140. )
  141. @PublicAPI
  142. class TBXLoggerCallback(LoggerCallback):
  143. """TensorBoardX Logger.
  144. Note that hparams will be written only after a trial has terminated.
  145. This logger automatically flattens nested dicts to show on TensorBoard:
  146. {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
  147. """
  148. _SAVED_FILE_TEMPLATES = ["events.out.tfevents.*"]
  149. VALID_HPARAMS = (str, bool, int, float, list, type(None))
  150. VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
  151. def __init__(self):
  152. try:
  153. from tensorboardX import SummaryWriter
  154. self._summary_writer_cls = SummaryWriter
  155. except ImportError:
  156. if log_once("tbx-install"):
  157. logger.info('pip install "ray[tune]" to see TensorBoard files.')
  158. raise
  159. self._trial_writer: Dict["Trial", SummaryWriter] = {}
  160. self._trial_result: Dict["Trial", Dict] = {}
  161. def log_trial_start(self, trial: "Trial"):
  162. if trial in self._trial_writer:
  163. self._trial_writer[trial].close()
  164. trial.init_local_path()
  165. self._trial_writer[trial] = self._summary_writer_cls(
  166. trial.local_path, flush_secs=30
  167. )
  168. self._trial_result[trial] = {}
  169. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  170. if trial not in self._trial_writer:
  171. self.log_trial_start(trial)
  172. step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
  173. tmp = result.copy()
  174. for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
  175. if k in tmp:
  176. del tmp[k] # not useful to log these
  177. flat_result = flatten_dict(tmp, delimiter="/")
  178. path = ["ray", "tune"]
  179. valid_result = {}
  180. for attr, value in flat_result.items():
  181. full_attr = "/".join(path + [attr])
  182. if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not np.isnan(value):
  183. valid_result[full_attr] = value
  184. self._trial_writer[trial].add_scalar(full_attr, value, global_step=step)
  185. elif (isinstance(value, list) and len(value) > 0) or (
  186. isinstance(value, np.ndarray) and value.size > 0
  187. ):
  188. valid_result[full_attr] = value
  189. # Must be a single image.
  190. if isinstance(value, np.ndarray) and value.ndim == 3:
  191. self._trial_writer[trial].add_image(
  192. full_attr,
  193. value,
  194. global_step=step,
  195. )
  196. continue
  197. # Must be a batch of images.
  198. if isinstance(value, np.ndarray) and value.ndim == 4:
  199. self._trial_writer[trial].add_images(
  200. full_attr,
  201. value,
  202. global_step=step,
  203. )
  204. continue
  205. # Must be video
  206. if isinstance(value, np.ndarray) and value.ndim == 5:
  207. self._trial_writer[trial].add_video(
  208. full_attr, value, global_step=step, fps=20
  209. )
  210. continue
  211. try:
  212. self._trial_writer[trial].add_histogram(
  213. full_attr, value, global_step=step
  214. )
  215. # In case TensorboardX still doesn't think it's a valid value
  216. # (e.g. `[[]]`), warn and move on.
  217. except (ValueError, TypeError):
  218. if log_once("invalid_tbx_value"):
  219. logger.warning(
  220. "You are trying to log an invalid value ({}={}) "
  221. "via {}!".format(full_attr, value, type(self).__name__)
  222. )
  223. self._trial_result[trial] = valid_result
  224. self._trial_writer[trial].flush()
  225. def log_trial_end(self, trial: "Trial", failed: bool = False):
  226. if trial in self._trial_writer:
  227. if trial and trial.evaluated_params and self._trial_result[trial]:
  228. flat_result = flatten_dict(self._trial_result[trial], delimiter="/")
  229. scrubbed_result = {
  230. k: value
  231. for k, value in flat_result.items()
  232. if isinstance(value, tuple(VALID_SUMMARY_TYPES))
  233. }
  234. self._try_log_hparams(trial, scrubbed_result)
  235. self._trial_writer[trial].close()
  236. del self._trial_writer[trial]
  237. del self._trial_result[trial]
  238. def _try_log_hparams(self, trial: "Trial", result: Dict):
  239. # TBX currently errors if the hparams value is None.
  240. flat_params = flatten_dict(trial.evaluated_params)
  241. scrubbed_params = {
  242. k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
  243. }
  244. np_params = {
  245. k: v.tolist()
  246. for k, v in flat_params.items()
  247. if isinstance(v, self.VALID_NP_HPARAMS)
  248. }
  249. scrubbed_params.update(np_params)
  250. removed = {
  251. k: v
  252. for k, v in flat_params.items()
  253. if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
  254. }
  255. if removed:
  256. logger.info(
  257. "Removed the following hyperparameter values when "
  258. "logging to tensorboard: %s",
  259. str(removed),
  260. )
  261. from tensorboardX.summary import hparams
  262. try:
  263. experiment_tag, session_start_tag, session_end_tag = hparams(
  264. hparam_dict=scrubbed_params, metric_dict=result
  265. )
  266. self._trial_writer[trial].file_writer.add_summary(experiment_tag)
  267. self._trial_writer[trial].file_writer.add_summary(session_start_tag)
  268. self._trial_writer[trial].file_writer.add_summary(session_end_tag)
  269. except Exception:
  270. logger.exception(
  271. "TensorboardX failed to log hparams. "
  272. "This may be due to an unsupported type "
  273. "in the hyperparameter values."
  274. )