real_labels.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. """ Real labels evaluator for ImageNet
  2. Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
  3. Based on Numpy example at https://github.com/google-research/reassessed-imagenet
  4. Hacked together by / Copyright 2020 Ross Wightman
  5. """
  6. import os
  7. import json
  8. import numpy as np
  9. import pkgutil
  10. class RealLabelsImagenet:
  11. def __init__(self, filenames, real_json=None, topk=(1, 5)):
  12. if real_json is not None:
  13. with open(real_json) as real_labels:
  14. real_labels = json.load(real_labels)
  15. else:
  16. real_labels = json.loads(
  17. pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8'))
  18. real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
  19. self.real_labels = real_labels
  20. self.filenames = filenames
  21. assert len(self.filenames) == len(self.real_labels)
  22. self.topk = topk
  23. self.is_correct = {k: [] for k in topk}
  24. self.sample_idx = 0
  25. def add_result(self, output):
  26. maxk = max(self.topk)
  27. _, pred_batch = output.topk(maxk, 1, True, True)
  28. pred_batch = pred_batch.cpu().numpy()
  29. for pred in pred_batch:
  30. filename = self.filenames[self.sample_idx]
  31. filename = os.path.basename(filename)
  32. if self.real_labels[filename]:
  33. for k in self.topk:
  34. self.is_correct[k].append(
  35. any([p in self.real_labels[filename] for p in pred[:k]]))
  36. self.sample_idx += 1
  37. def get_accuracy(self, k=None):
  38. if k is None:
  39. return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk}
  40. else:
  41. return float(np.mean(self.is_correct[k])) * 100