dataset_info.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, List, Optional, Union
  3. class DatasetInfo(ABC):
  4. def __init__(self):
  5. pass
  6. @abstractmethod
  7. def num_classes(self):
  8. pass
  9. @abstractmethod
  10. def label_names(self):
  11. pass
  12. @abstractmethod
  13. def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
  14. pass
  15. @abstractmethod
  16. def index_to_label_name(self, index) -> str:
  17. pass
  18. @abstractmethod
  19. def index_to_description(self, index: int, detailed: bool = False) -> str:
  20. pass
  21. @abstractmethod
  22. def label_name_to_description(self, label: str, detailed: bool = False) -> str:
  23. pass
  24. class CustomDatasetInfo(DatasetInfo):
  25. """ DatasetInfo that wraps passed values for custom datasets."""
  26. def __init__(
  27. self,
  28. label_names: Union[List[str], Dict[int, str]],
  29. label_descriptions: Optional[Dict[str, str]] = None
  30. ):
  31. super().__init__()
  32. assert len(label_names) > 0
  33. self._label_names = label_names # label index => label name mapping
  34. self._label_descriptions = label_descriptions # label name => label description mapping
  35. if self._label_descriptions is not None:
  36. # validate descriptions (label names required)
  37. assert isinstance(self._label_descriptions, dict)
  38. for n in self._label_names:
  39. assert n in self._label_descriptions
  40. def num_classes(self):
  41. return len(self._label_names)
  42. def label_names(self):
  43. return self._label_names
  44. def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
  45. return self._label_descriptions
  46. def label_name_to_description(self, label: str, detailed: bool = False) -> str:
  47. if self._label_descriptions:
  48. return self._label_descriptions[label]
  49. return label # return label name itself if a descriptions is not present
  50. def index_to_label_name(self, index) -> str:
  51. assert 0 <= index < len(self._label_names)
  52. return self._label_names[index]
  53. def index_to_description(self, index: int, detailed: bool = False) -> str:
  54. label = self.index_to_label_name(index)
  55. return self.label_name_to_description(label, detailed=detailed)