summary.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """ Summary utilities
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import csv
  5. import os
  6. from collections import OrderedDict
  7. try:
  8. import wandb
  9. except ImportError:
  10. pass
  11. def get_outdir(path, *paths, inc=False):
  12. outdir = os.path.join(path, *paths)
  13. if not os.path.exists(outdir):
  14. os.makedirs(outdir)
  15. elif inc:
  16. count = 1
  17. outdir_inc = outdir + '-' + str(count)
  18. while os.path.exists(outdir_inc):
  19. count = count + 1
  20. outdir_inc = outdir + '-' + str(count)
  21. assert count < 100
  22. outdir = outdir_inc
  23. os.makedirs(outdir)
  24. return outdir
  25. def update_summary(
  26. epoch,
  27. train_metrics,
  28. eval_metrics,
  29. filename,
  30. lr=None,
  31. write_header=False,
  32. log_wandb=False,
  33. ):
  34. rowd = OrderedDict(epoch=epoch)
  35. rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
  36. if eval_metrics:
  37. rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
  38. if lr is not None:
  39. rowd['lr'] = lr
  40. if log_wandb:
  41. wandb.log(rowd)
  42. with open(filename, mode='a') as cf:
  43. dw = csv.DictWriter(cf, fieldnames=rowd.keys())
  44. if write_header: # first iteration (epoch == 1 can't be used)
  45. dw.writeheader()
  46. dw.writerow(rowd)