| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- """ Summary utilities
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import csv
- import os
- from collections import OrderedDict
- try:
- import wandb
- except ImportError:
- pass
- def get_outdir(path, *paths, inc=False):
- outdir = os.path.join(path, *paths)
- if not os.path.exists(outdir):
- os.makedirs(outdir)
- elif inc:
- count = 1
- outdir_inc = outdir + '-' + str(count)
- while os.path.exists(outdir_inc):
- count = count + 1
- outdir_inc = outdir + '-' + str(count)
- assert count < 100
- outdir = outdir_inc
- os.makedirs(outdir)
- return outdir
- def update_summary(
- epoch,
- train_metrics,
- eval_metrics,
- filename,
- lr=None,
- write_header=False,
- log_wandb=False,
- ):
- rowd = OrderedDict(epoch=epoch)
- rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
- if eval_metrics:
- rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
- if lr is not None:
- rowd['lr'] = lr
- if log_wandb:
- wandb.log(rowd)
- with open(filename, mode='a') as cf:
- dw = csv.DictWriter(cf, fieldnames=rowd.keys())
- if write_header: # first iteration (epoch == 1 can't be used)
- dw.writeheader()
- dw.writerow(rowd)
|