metrics.py 901 B

1234567891011121314151617181920212223242526272829303132
  1. """ Eval metrics and related
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. class AverageMeter:
  5. """Computes and stores the average and current value"""
  6. def __init__(self):
  7. self.reset()
  8. def reset(self):
  9. self.val = 0
  10. self.avg = 0
  11. self.sum = 0
  12. self.count = 0
  13. def update(self, val, n=1):
  14. self.val = val
  15. self.sum += val * n
  16. self.count += n
  17. self.avg = self.sum / self.count
  18. def accuracy(output, target, topk=(1,)):
  19. """Computes the accuracy over the k top predictions for the specified values of k"""
  20. maxk = min(max(topk), output.size()[1])
  21. batch_size = target.size(0)
  22. _, pred = output.topk(maxk, 1, True, True)
  23. pred = pred.t()
  24. correct = pred.eq(target.reshape(1, -1).expand_as(pred))
  25. return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]