wandb_summary.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import abc
  2. import typing as t
  3. from .interface.summary_record import SummaryItem, SummaryRecord
  4. def _get_dict(d):
  5. if isinstance(d, dict):
  6. return d
  7. # assume argparse Namespace
  8. return vars(d)
  9. class SummaryDict(metaclass=abc.ABCMeta):
  10. """dict-like wrapper for the nested dictionaries in a SummarySubDict.
  11. Triggers self._root._callback on property changes.
  12. """
  13. @abc.abstractmethod
  14. def _as_dict(self):
  15. raise NotImplementedError
  16. @abc.abstractmethod
  17. def _update(self, record: SummaryRecord):
  18. raise NotImplementedError
  19. def keys(self):
  20. return [k for k in self._as_dict() if k != "_wandb"]
  21. def get(self, key, default=None):
  22. return self._as_dict().get(key, default)
  23. def __getitem__(self, key):
  24. item = self._as_dict()[key]
  25. if isinstance(item, dict):
  26. # this nested dict needs to be wrapped:
  27. wrapped_item = SummarySubDict()
  28. object.__setattr__(wrapped_item, "_items", item)
  29. object.__setattr__(wrapped_item, "_parent", self)
  30. object.__setattr__(wrapped_item, "_parent_key", key)
  31. return wrapped_item
  32. # this item isn't a nested dict
  33. return item
  34. __getattr__ = __getitem__
  35. def __setitem__(self, key, val):
  36. self.update({key: val})
  37. __setattr__ = __setitem__
  38. def __delattr__(self, key):
  39. record = SummaryRecord()
  40. item = SummaryItem()
  41. item.key = (key,)
  42. record.remove = (item,)
  43. self._update(record)
  44. __delitem__ = __delattr__
  45. def update(self, d: dict):
  46. record = SummaryRecord()
  47. for key, value in d.items():
  48. item = SummaryItem()
  49. item.key = (key,)
  50. item.value = value
  51. record.update.append(item)
  52. self._update(record)
  53. class Summary(SummaryDict):
  54. """Track single values for each metric for each run.
  55. By default, a metric's summary is the last value of its History.
  56. For example, `wandb.log({'accuracy': 0.9})` will add a new step to History and
  57. update Summary to the latest value. In some cases, it's more useful to have
  58. the maximum or minimum of a metric instead of the final value. You can set
  59. history manually `(wandb.summary['accuracy'] = best_acc)`.
  60. In the UI, summary metrics appear in the table to compare across runs.
  61. Summary metrics are also used in visualizations like the scatter plot and
  62. parallel coordinates chart.
  63. After training has completed, you may want to save evaluation metrics to a
  64. run. Summary can handle numpy arrays and PyTorch/TensorFlow tensors. When
  65. you save one of these types to Summary, we persist the entire tensor in a
  66. binary file and store high level metrics in the summary object, such as min,
  67. mean, variance, and 95th percentile.
  68. Examples:
  69. ```python
  70. wandb.init(config=args)
  71. best_accuracy = 0
  72. for epoch in range(1, args.epochs + 1):
  73. test_loss, test_accuracy = test()
  74. if test_accuracy > best_accuracy:
  75. wandb.run.summary["best_accuracy"] = test_accuracy
  76. best_accuracy = test_accuracy
  77. ```
  78. """
  79. _update_callback: t.Callable
  80. _get_current_summary_callback: t.Callable
  81. def __init__(self, get_current_summary_callback: t.Callable):
  82. super().__init__()
  83. object.__setattr__(self, "_update_callback", None)
  84. object.__setattr__(
  85. self, "_get_current_summary_callback", get_current_summary_callback
  86. )
  87. def _set_update_callback(self, update_callback: t.Callable):
  88. object.__setattr__(self, "_update_callback", update_callback)
  89. def _as_dict(self):
  90. return self._get_current_summary_callback()
  91. def _update(self, record: SummaryRecord):
  92. if self._update_callback: # type: ignore
  93. self._update_callback(record)
  94. class SummarySubDict(SummaryDict):
  95. """Non-root node of the summary data structure.
  96. Contains a path to itself from the root.
  97. """
  98. _items: dict
  99. _parent: SummaryDict
  100. _parent_key: str
  101. def __init__(self):
  102. object.__setattr__(self, "_items", dict())
  103. object.__setattr__(self, "_parent", None)
  104. object.__setattr__(self, "_parent_key", None)
  105. def _as_dict(self):
  106. return self._items
  107. def _update(self, record: SummaryRecord):
  108. return self._parent._update(record._add_next_parent(self._parent_key))