| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438 |
- import json
- import os
- import time
- from wandb_gql import gql
- import wandb
- from wandb import util
- from wandb.apis.internal import Api
- from wandb.sdk.data_types.utils import val_to_json
- from wandb.sdk.lib import filenames
- DEEP_SUMMARY_FNAME = "wandb.h5"
- H5_TYPES = ("numpy.ndarray", "tensorflow.Tensor", "torch.Tensor")
- h5py = util.get_module("h5py")
- np = util.get_module("numpy")
- class SummarySubDict:
- """Nested dict-like object that proxies read and write operations through a root object.
- This lets us do synchronous serialization and lazy loading of large values.
- """
- def __init__(self, root=None, path=()):
- self._path = tuple(path)
- if root is None:
- self._root = self
- self._json_dict = {}
- else:
- self._root = root
- json_dict = root._json_dict
- for k in path:
- json_dict = json_dict.get(k, {})
- self._json_dict = json_dict
- self._dict = {}
- # We use this to track which keys the user has set explicitly
- # so that we don't automatically overwrite them when we update
- # the summary from the history.
- self._locked_keys = set()
- def __setattr__(self, k, v):
- k = k.strip()
- if k.startswith("_"):
- object.__setattr__(self, k, v)
- else:
- self[k] = v
- def __getattr__(self, k):
- k = k.strip()
- if k.startswith("_"):
- return object.__getattribute__(self, k)
- else:
- return self[k]
- def _root_get(self, path, child_dict):
- """Load a value at a particular path from the root.
- This should only be implemented by the "_root" child class.
- We pass the child_dict so the item can be set on it or not as
- appropriate. Returning None for a nonexistent path wouldn't be
- distinguishable from that path being set to the value None.
- """
- raise NotImplementedError
- def _root_set(self, path, new_keys_values):
- """Set a value at a particular path in the root.
- This should only be implemented by the "_root" child class.
- """
- raise NotImplementedError
- def _root_del(self, path):
- """Delete a value at a particular path in the root.
- This should only be implemented by the "_root" child class.
- """
- raise NotImplementedError
- def _write(self, commit=False):
- # should only be implemented on the root summary
- raise NotImplementedError
- def keys(self):
- # _json_dict has the full set of keys, including those for h5 objects
- # that may not have been loaded yet
- return self._json_dict.keys()
- def get(self, k, default=None):
- if isinstance(k, str):
- k = k.strip()
- if k not in self._dict:
- self._root._root_get(self._path + (k,), self._dict)
- return self._dict.get(k, default)
- def items(self):
- # not all items may be loaded into self._dict, so we
- # have to build the sequence of items from scratch
- for k in self.keys():
- yield k, self[k]
- def __getitem__(self, k):
- if isinstance(k, str):
- k = k.strip()
- self.get(k) # load the value into _dict if it should be there
- res = self._dict[k]
- return res
- def __contains__(self, k):
- if isinstance(k, str):
- k = k.strip()
- return k in self._json_dict
- def __setitem__(self, k, v):
- if isinstance(k, str):
- k = k.strip()
- path = self._path
- if isinstance(v, dict):
- self._dict[k] = SummarySubDict(self._root, path + (k,))
- self._root._root_set(path, [(k, {})])
- self._dict[k].update(v)
- else:
- self._dict[k] = v
- self._root._root_set(path, [(k, v)])
- self._locked_keys.add(k)
- self._root._write()
- return v
- def __delitem__(self, k):
- k = k.strip()
- del self._dict[k]
- self._root._root_del(self._path + (k,))
- self._root._write()
- def __repr__(self):
- # use a copy of _dict, except add placeholders for h5 objects, etc.
- # that haven't been loaded yet
- repr_dict = dict(self._dict)
- for k in self._json_dict:
- v = self._json_dict[k]
- if (
- k not in repr_dict
- and isinstance(v, dict)
- and v.get("_type") in H5_TYPES
- ):
- # unloaded h5 objects may be very large. use a placeholder for them
- # if we haven't already loaded them
- repr_dict[k] = "..."
- else:
- repr_dict[k] = self[k]
- return repr(repr_dict)
- def update(self, key_vals=None, overwrite=True):
- """Locked keys will be overwritten unless overwrite=False.
- Otherwise, written keys will be added to the "locked" list.
- """
- if key_vals:
- write_items = self._update(key_vals, overwrite)
- self._root._root_set(self._path, write_items)
- self._root._write(commit=True)
- def _update(self, key_vals, overwrite):
- if not key_vals:
- return
- key_vals = {k.strip(): v for k, v in key_vals.items()}
- if overwrite:
- write_items = list(key_vals.items())
- self._locked_keys.update(key_vals.keys())
- else:
- write_keys = set(key_vals.keys()) - self._locked_keys
- write_items = [(k, key_vals[k]) for k in write_keys]
- for key, value in write_items:
- if isinstance(value, dict):
- self._dict[key] = SummarySubDict(self._root, self._path + (key,))
- self._dict[key]._update(value, overwrite)
- else:
- self._dict[key] = value
- return write_items
- class Summary(SummarySubDict):
- """Store summary metrics (eg. accuracy) during and after a run.
- You can manipulate this as if it's a Python dictionary but the keys
- get mangled. .strip() is called on them, so spaces at the beginning
- and end are removed.
- """
- def __init__(self, run, summary=None):
- super().__init__()
- self._run = run
- self._h5_path = os.path.join(self._run.dir, DEEP_SUMMARY_FNAME)
- # Lazy load the h5 file
- self._h5 = None
- # Mirrored version of self._dict with versions of values that get written
- # to JSON kept up to date by self._root_set() and self._root_del().
- self._json_dict = {}
- if summary is not None:
- self._json_dict = summary
- def _json_get(self, path):
- pass
- def _root_get(self, path, child_dict):
- json_dict = self._json_dict
- for key in path[:-1]:
- json_dict = json_dict[key]
- key = path[-1]
- if key in json_dict:
- child_dict[key] = self._decode(path, json_dict[key])
- def _root_del(self, path):
- json_dict = self._json_dict
- for key in path[:-1]:
- json_dict = json_dict[key]
- val = json_dict[path[-1]]
- del json_dict[path[-1]]
- if isinstance(val, dict) and val.get("_type") in H5_TYPES:
- if not h5py:
- wandb.termerror("Deleting tensors in summary requires h5py")
- else:
- self.open_h5()
- h5_key = "summary/" + ".".join(path)
- del self._h5[h5_key]
- self._h5.flush()
- def _root_set(self, path, new_keys_values):
- json_dict = self._json_dict
- for key in path:
- json_dict = json_dict[key]
- for new_key, new_value in new_keys_values:
- json_dict[new_key] = self._encode(new_value, path + (new_key,))
- def write_h5(self, path, val):
- # ensure the file is open
- self.open_h5()
- if not self._h5:
- wandb.termerror("Storing tensors in summary requires h5py")
- else:
- try:
- del self._h5["summary/" + ".".join(path)]
- except KeyError:
- pass
- self._h5["summary/" + ".".join(path)] = val
- self._h5.flush()
- def read_h5(self, path, val=None):
- # ensure the file is open
- self.open_h5()
- if not self._h5:
- wandb.termerror("Reading tensors from summary requires h5py")
- else:
- return self._h5.get("summary/" + ".".join(path), val)
- def open_h5(self):
- if not self._h5 and h5py:
- self._h5 = h5py.File(self._h5_path, "a", libver="latest")
- def _decode(self, path, json_value):
- """Decode a `dict` encoded by `Summary._encode()`, loading h5 objects.
- h5 objects may be very large, so we won't have loaded them automatically.
- """
- if isinstance(json_value, dict):
- if json_value.get("_type") in H5_TYPES:
- return self.read_h5(path, json_value)
- elif json_value.get("_type") == "data-frame":
- wandb.termerror(
- "This data frame was saved via the wandb data API. Contact support@wandb.com for help."
- )
- return None
- # TODO: transform wandb objects and plots
- else:
- return SummarySubDict(self, path)
- else:
- return json_value
- def _encode(self, value, path_from_root):
- """Normalize, compress, and encode sub-objects for backend storage.
- value: Object to encode.
- path_from_root: `tuple` of key strings from the top-level summary to the
- current `value`.
- Returns:
- A new tree of dict's with large objects replaced with dictionaries
- with "_type" entries that say which type the original data was.
- """
- # Constructs a new `dict` tree in `json_value` that discards and/or
- # encodes objects that aren't JSON serializable.
- if isinstance(value, dict):
- json_value = {}
- for key, value in value.items():
- json_value[key] = self._encode(value, path_from_root + (key,))
- return json_value
- else:
- path = ".".join(path_from_root)
- friendly_value, converted = util.json_friendly(
- val_to_json(self._run, path, value, namespace="summary")
- )
- json_value, compressed = util.maybe_compress_summary(
- friendly_value, util.get_h5_typename(value)
- )
- if compressed:
- self.write_h5(path_from_root, friendly_value)
- return json_value
- def download_h5(run_id, entity=None, project=None, out_dir=None):
- api = Api()
- meta = api.download_url(
- project or api.settings("project"),
- DEEP_SUMMARY_FNAME,
- entity=entity or api.settings("entity"),
- run=run_id,
- )
- if meta and "md5" in meta and meta["md5"] is not None:
- # TODO: make this non-blocking
- wandb.termlog("Downloading summary data...")
- path, res = api.download_write_file(meta, out_dir=out_dir)
- return path
- def upload_h5(file, run_id, entity=None, project=None):
- api = Api()
- wandb.termlog("Uploading summary data...")
- with open(file, "rb") as f:
- api.push(
- {os.path.basename(file): f}, run=run_id, project=project, entity=entity
- )
- class FileSummary(Summary):
- def __init__(self, run):
- super().__init__(run)
- self._fname = os.path.join(run.dir, filenames.SUMMARY_FNAME)
- self.load()
- def load(self):
- try:
- with open(self._fname) as f:
- self._json_dict = json.load(f)
- except (OSError, ValueError):
- self._json_dict = {}
- def _write(self, commit=False):
- # TODO: we just ignore commit to ensure backward capability
- with open(self._fname, "w") as f:
- f.write(util.json_dumps_safer(self._json_dict))
- f.write("\n")
- f.flush()
- os.fsync(f.fileno())
- if self._h5:
- self._h5.close()
- self._h5 = None
- class HTTPSummary(Summary):
- def __init__(self, run, client, summary=None):
- super().__init__(run, summary=summary)
- self._run = run
- self._client = client
- self._started = time.time()
- def __delitem__(self, key):
- if key not in self._json_dict:
- raise KeyError(key)
- del self._json_dict[key]
- def load(self):
- pass
- def open_h5(self):
- if not self._h5 and h5py:
- download_h5(
- self._run.id,
- entity=self._run.entity,
- project=self._run.project,
- out_dir=self._run.dir,
- )
- super().open_h5()
- def _write(self, commit=False):
- mutation = gql(
- """
- mutation UpsertBucket( $id: String, $summaryMetrics: JSONString) {
- upsertBucket(input: { id: $id, summaryMetrics: $summaryMetrics}) {
- bucket { id }
- }
- }
- """
- )
- if commit:
- if self._h5:
- self._h5.close()
- self._h5 = None
- res = self._client.execute(
- mutation,
- variable_values={
- "id": self._run.storage_id,
- "summaryMetrics": util.json_dumps_safer(self._json_dict),
- },
- )
- assert res["upsertBucket"]["bucket"]["id"]
- entity, project, run = self._run.path
- if (
- os.path.exists(self._h5_path)
- and os.path.getmtime(self._h5_path) >= self._started
- ):
- upload_h5(self._h5_path, run, entity=entity, project=project)
- else:
- return False
|