class_map.py 895 B

1234567891011121314151617181920212223
  1. import os
  2. import pickle
  3. def load_class_map(map_or_filename, root=''):
  4. if isinstance(map_or_filename, dict):
  5. assert dict, 'class_map dict must be non-empty'
  6. return map_or_filename
  7. class_map_path = map_or_filename
  8. if not os.path.exists(class_map_path):
  9. class_map_path = os.path.join(root, class_map_path)
  10. assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
  11. class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
  12. if class_map_ext == '.txt':
  13. with open(class_map_path) as f:
  14. class_to_idx = {v.strip(): k for k, v in enumerate(f)}
  15. elif class_map_ext == '.pkl':
  16. with open(class_map_path, 'rb') as f:
  17. class_to_idx = pickle.load(f)
  18. else:
  19. assert False, f'Unsupported class map file extension ({class_map_ext}).'
  20. return class_to_idx