summary.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. import json
  2. import os
  3. import time
  4. from wandb_gql import gql
  5. import wandb
  6. from wandb import util
  7. from wandb.apis.internal import Api
  8. from wandb.sdk.data_types.utils import val_to_json
  9. from wandb.sdk.lib import filenames
  10. DEEP_SUMMARY_FNAME = "wandb.h5"
  11. H5_TYPES = ("numpy.ndarray", "tensorflow.Tensor", "torch.Tensor")
  12. h5py = util.get_module("h5py")
  13. np = util.get_module("numpy")
  14. class SummarySubDict:
  15. """Nested dict-like object that proxies read and write operations through a root object.
  16. This lets us do synchronous serialization and lazy loading of large values.
  17. """
  18. def __init__(self, root=None, path=()):
  19. self._path = tuple(path)
  20. if root is None:
  21. self._root = self
  22. self._json_dict = {}
  23. else:
  24. self._root = root
  25. json_dict = root._json_dict
  26. for k in path:
  27. json_dict = json_dict.get(k, {})
  28. self._json_dict = json_dict
  29. self._dict = {}
  30. # We use this to track which keys the user has set explicitly
  31. # so that we don't automatically overwrite them when we update
  32. # the summary from the history.
  33. self._locked_keys = set()
  34. def __setattr__(self, k, v):
  35. k = k.strip()
  36. if k.startswith("_"):
  37. object.__setattr__(self, k, v)
  38. else:
  39. self[k] = v
  40. def __getattr__(self, k):
  41. k = k.strip()
  42. if k.startswith("_"):
  43. return object.__getattribute__(self, k)
  44. else:
  45. return self[k]
  46. def _root_get(self, path, child_dict):
  47. """Load a value at a particular path from the root.
  48. This should only be implemented by the "_root" child class.
  49. We pass the child_dict so the item can be set on it or not as
  50. appropriate. Returning None for a nonexistent path wouldn't be
  51. distinguishable from that path being set to the value None.
  52. """
  53. raise NotImplementedError
  54. def _root_set(self, path, new_keys_values):
  55. """Set a value at a particular path in the root.
  56. This should only be implemented by the "_root" child class.
  57. """
  58. raise NotImplementedError
  59. def _root_del(self, path):
  60. """Delete a value at a particular path in the root.
  61. This should only be implemented by the "_root" child class.
  62. """
  63. raise NotImplementedError
  64. def _write(self, commit=False):
  65. # should only be implemented on the root summary
  66. raise NotImplementedError
  67. def keys(self):
  68. # _json_dict has the full set of keys, including those for h5 objects
  69. # that may not have been loaded yet
  70. return self._json_dict.keys()
  71. def get(self, k, default=None):
  72. if isinstance(k, str):
  73. k = k.strip()
  74. if k not in self._dict:
  75. self._root._root_get(self._path + (k,), self._dict)
  76. return self._dict.get(k, default)
  77. def items(self):
  78. # not all items may be loaded into self._dict, so we
  79. # have to build the sequence of items from scratch
  80. for k in self.keys():
  81. yield k, self[k]
  82. def __getitem__(self, k):
  83. if isinstance(k, str):
  84. k = k.strip()
  85. self.get(k) # load the value into _dict if it should be there
  86. res = self._dict[k]
  87. return res
  88. def __contains__(self, k):
  89. if isinstance(k, str):
  90. k = k.strip()
  91. return k in self._json_dict
  92. def __setitem__(self, k, v):
  93. if isinstance(k, str):
  94. k = k.strip()
  95. path = self._path
  96. if isinstance(v, dict):
  97. self._dict[k] = SummarySubDict(self._root, path + (k,))
  98. self._root._root_set(path, [(k, {})])
  99. self._dict[k].update(v)
  100. else:
  101. self._dict[k] = v
  102. self._root._root_set(path, [(k, v)])
  103. self._locked_keys.add(k)
  104. self._root._write()
  105. return v
  106. def __delitem__(self, k):
  107. k = k.strip()
  108. del self._dict[k]
  109. self._root._root_del(self._path + (k,))
  110. self._root._write()
  111. def __repr__(self):
  112. # use a copy of _dict, except add placeholders for h5 objects, etc.
  113. # that haven't been loaded yet
  114. repr_dict = dict(self._dict)
  115. for k in self._json_dict:
  116. v = self._json_dict[k]
  117. if (
  118. k not in repr_dict
  119. and isinstance(v, dict)
  120. and v.get("_type") in H5_TYPES
  121. ):
  122. # unloaded h5 objects may be very large. use a placeholder for them
  123. # if we haven't already loaded them
  124. repr_dict[k] = "..."
  125. else:
  126. repr_dict[k] = self[k]
  127. return repr(repr_dict)
  128. def update(self, key_vals=None, overwrite=True):
  129. """Locked keys will be overwritten unless overwrite=False.
  130. Otherwise, written keys will be added to the "locked" list.
  131. """
  132. if key_vals:
  133. write_items = self._update(key_vals, overwrite)
  134. self._root._root_set(self._path, write_items)
  135. self._root._write(commit=True)
  136. def _update(self, key_vals, overwrite):
  137. if not key_vals:
  138. return
  139. key_vals = {k.strip(): v for k, v in key_vals.items()}
  140. if overwrite:
  141. write_items = list(key_vals.items())
  142. self._locked_keys.update(key_vals.keys())
  143. else:
  144. write_keys = set(key_vals.keys()) - self._locked_keys
  145. write_items = [(k, key_vals[k]) for k in write_keys]
  146. for key, value in write_items:
  147. if isinstance(value, dict):
  148. self._dict[key] = SummarySubDict(self._root, self._path + (key,))
  149. self._dict[key]._update(value, overwrite)
  150. else:
  151. self._dict[key] = value
  152. return write_items
  153. class Summary(SummarySubDict):
  154. """Store summary metrics (eg. accuracy) during and after a run.
  155. You can manipulate this as if it's a Python dictionary but the keys
  156. get mangled. .strip() is called on them, so spaces at the beginning
  157. and end are removed.
  158. """
  159. def __init__(self, run, summary=None):
  160. super().__init__()
  161. self._run = run
  162. self._h5_path = os.path.join(self._run.dir, DEEP_SUMMARY_FNAME)
  163. # Lazy load the h5 file
  164. self._h5 = None
  165. # Mirrored version of self._dict with versions of values that get written
  166. # to JSON kept up to date by self._root_set() and self._root_del().
  167. self._json_dict = {}
  168. if summary is not None:
  169. self._json_dict = summary
  170. def _json_get(self, path):
  171. pass
  172. def _root_get(self, path, child_dict):
  173. json_dict = self._json_dict
  174. for key in path[:-1]:
  175. json_dict = json_dict[key]
  176. key = path[-1]
  177. if key in json_dict:
  178. child_dict[key] = self._decode(path, json_dict[key])
  179. def _root_del(self, path):
  180. json_dict = self._json_dict
  181. for key in path[:-1]:
  182. json_dict = json_dict[key]
  183. val = json_dict[path[-1]]
  184. del json_dict[path[-1]]
  185. if isinstance(val, dict) and val.get("_type") in H5_TYPES:
  186. if not h5py:
  187. wandb.termerror("Deleting tensors in summary requires h5py")
  188. else:
  189. self.open_h5()
  190. h5_key = "summary/" + ".".join(path)
  191. del self._h5[h5_key]
  192. self._h5.flush()
  193. def _root_set(self, path, new_keys_values):
  194. json_dict = self._json_dict
  195. for key in path:
  196. json_dict = json_dict[key]
  197. for new_key, new_value in new_keys_values:
  198. json_dict[new_key] = self._encode(new_value, path + (new_key,))
  199. def write_h5(self, path, val):
  200. # ensure the file is open
  201. self.open_h5()
  202. if not self._h5:
  203. wandb.termerror("Storing tensors in summary requires h5py")
  204. else:
  205. try:
  206. del self._h5["summary/" + ".".join(path)]
  207. except KeyError:
  208. pass
  209. self._h5["summary/" + ".".join(path)] = val
  210. self._h5.flush()
  211. def read_h5(self, path, val=None):
  212. # ensure the file is open
  213. self.open_h5()
  214. if not self._h5:
  215. wandb.termerror("Reading tensors from summary requires h5py")
  216. else:
  217. return self._h5.get("summary/" + ".".join(path), val)
  218. def open_h5(self):
  219. if not self._h5 and h5py:
  220. self._h5 = h5py.File(self._h5_path, "a", libver="latest")
  221. def _decode(self, path, json_value):
  222. """Decode a `dict` encoded by `Summary._encode()`, loading h5 objects.
  223. h5 objects may be very large, so we won't have loaded them automatically.
  224. """
  225. if isinstance(json_value, dict):
  226. if json_value.get("_type") in H5_TYPES:
  227. return self.read_h5(path, json_value)
  228. elif json_value.get("_type") == "data-frame":
  229. wandb.termerror(
  230. "This data frame was saved via the wandb data API. Contact support@wandb.com for help."
  231. )
  232. return None
  233. # TODO: transform wandb objects and plots
  234. else:
  235. return SummarySubDict(self, path)
  236. else:
  237. return json_value
  238. def _encode(self, value, path_from_root):
  239. """Normalize, compress, and encode sub-objects for backend storage.
  240. value: Object to encode.
  241. path_from_root: `tuple` of key strings from the top-level summary to the
  242. current `value`.
  243. Returns:
  244. A new tree of dict's with large objects replaced with dictionaries
  245. with "_type" entries that say which type the original data was.
  246. """
  247. # Constructs a new `dict` tree in `json_value` that discards and/or
  248. # encodes objects that aren't JSON serializable.
  249. if isinstance(value, dict):
  250. json_value = {}
  251. for key, value in value.items():
  252. json_value[key] = self._encode(value, path_from_root + (key,))
  253. return json_value
  254. else:
  255. path = ".".join(path_from_root)
  256. friendly_value, converted = util.json_friendly(
  257. val_to_json(self._run, path, value, namespace="summary")
  258. )
  259. json_value, compressed = util.maybe_compress_summary(
  260. friendly_value, util.get_h5_typename(value)
  261. )
  262. if compressed:
  263. self.write_h5(path_from_root, friendly_value)
  264. return json_value
  265. def download_h5(run_id, entity=None, project=None, out_dir=None):
  266. api = Api()
  267. meta = api.download_url(
  268. project or api.settings("project"),
  269. DEEP_SUMMARY_FNAME,
  270. entity=entity or api.settings("entity"),
  271. run=run_id,
  272. )
  273. if meta and "md5" in meta and meta["md5"] is not None:
  274. # TODO: make this non-blocking
  275. wandb.termlog("Downloading summary data...")
  276. path, res = api.download_write_file(meta, out_dir=out_dir)
  277. return path
  278. def upload_h5(file, run_id, entity=None, project=None):
  279. api = Api()
  280. wandb.termlog("Uploading summary data...")
  281. with open(file, "rb") as f:
  282. api.push(
  283. {os.path.basename(file): f}, run=run_id, project=project, entity=entity
  284. )
  285. class FileSummary(Summary):
  286. def __init__(self, run):
  287. super().__init__(run)
  288. self._fname = os.path.join(run.dir, filenames.SUMMARY_FNAME)
  289. self.load()
  290. def load(self):
  291. try:
  292. with open(self._fname) as f:
  293. self._json_dict = json.load(f)
  294. except (OSError, ValueError):
  295. self._json_dict = {}
  296. def _write(self, commit=False):
  297. # TODO: we just ignore commit to ensure backward capability
  298. with open(self._fname, "w") as f:
  299. f.write(util.json_dumps_safer(self._json_dict))
  300. f.write("\n")
  301. f.flush()
  302. os.fsync(f.fileno())
  303. if self._h5:
  304. self._h5.close()
  305. self._h5 = None
  306. class HTTPSummary(Summary):
  307. def __init__(self, run, client, summary=None):
  308. super().__init__(run, summary=summary)
  309. self._run = run
  310. self._client = client
  311. self._started = time.time()
  312. def __delitem__(self, key):
  313. if key not in self._json_dict:
  314. raise KeyError(key)
  315. del self._json_dict[key]
  316. def load(self):
  317. pass
  318. def open_h5(self):
  319. if not self._h5 and h5py:
  320. download_h5(
  321. self._run.id,
  322. entity=self._run.entity,
  323. project=self._run.project,
  324. out_dir=self._run.dir,
  325. )
  326. super().open_h5()
  327. def _write(self, commit=False):
  328. mutation = gql(
  329. """
  330. mutation UpsertBucket( $id: String, $summaryMetrics: JSONString) {
  331. upsertBucket(input: { id: $id, summaryMetrics: $summaryMetrics}) {
  332. bucket { id }
  333. }
  334. }
  335. """
  336. )
  337. if commit:
  338. if self._h5:
  339. self._h5.close()
  340. self._h5 = None
  341. res = self._client.execute(
  342. mutation,
  343. variable_values={
  344. "id": self._run.storage_id,
  345. "summaryMetrics": util.json_dumps_safer(self._json_dict),
  346. },
  347. )
  348. assert res["upsertBucket"]["bucket"]["id"]
  349. entity, project, run = self._run.path
  350. if (
  351. os.path.exists(self._h5_path)
  352. and os.path.getmtime(self._h5_path) >= self._started
  353. ):
  354. upload_h5(self._h5_path, run, entity=entity, project=project)
  355. else:
  356. return False