| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- """ Real labels evaluator for ImageNet
- Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
- Based on Numpy example at https://github.com/google-research/reassessed-imagenet
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import os
- import json
- import numpy as np
- import pkgutil
- class RealLabelsImagenet:
- def __init__(self, filenames, real_json=None, topk=(1, 5)):
- if real_json is not None:
- with open(real_json) as real_labels:
- real_labels = json.load(real_labels)
- else:
- real_labels = json.loads(
- pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8'))
- real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
- self.real_labels = real_labels
- self.filenames = filenames
- assert len(self.filenames) == len(self.real_labels)
- self.topk = topk
- self.is_correct = {k: [] for k in topk}
- self.sample_idx = 0
- def add_result(self, output):
- maxk = max(self.topk)
- _, pred_batch = output.topk(maxk, 1, True, True)
- pred_batch = pred_batch.cpu().numpy()
- for pred in pred_batch:
- filename = self.filenames[self.sample_idx]
- filename = os.path.basename(filename)
- if self.real_labels[filename]:
- for k in self.topk:
- self.is_correct[k].append(
- any([p in self.real_labels[filename] for p in pred[:k]]))
- self.sample_idx += 1
- def get_accuracy(self, k=None):
- if k is None:
- return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk}
- else:
- return float(np.mean(self.is_correct[k])) * 100
|