reader_image_folder.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """ A dataset reader that extracts images from folders
  2. Folders are scanned recursively to find image files. Labels are based
  3. on the folder hierarchy, just leaf folders by default.
  4. Hacked together by / Copyright 2020 Ross Wightman
  5. """
  6. import os
  7. from typing import Dict, List, Optional, Set, Tuple, Union
  8. from timm.utils.misc import natural_key
  9. from .class_map import load_class_map
  10. from .img_extensions import get_img_extensions
  11. from .reader import Reader
  12. def find_images_and_targets(
  13. folder: str,
  14. types: Optional[Union[List, Tuple, Set]] = None,
  15. class_to_idx: Optional[Dict] = None,
  16. leaf_name_only: bool = True,
  17. sort: bool = True
  18. ):
  19. """ Walk folder recursively to discover images and map them to classes by folder names.
  20. Args:
  21. folder: root of folder to recursively search
  22. types: types (file extensions) to search for in path
  23. class_to_idx: specify mapping for class (folder name) to class index if set
  24. leaf_name_only: use only leaf-name of folder walk for class names
  25. sort: re-sort found images by name (for consistent ordering)
  26. Returns:
  27. A list of image and target tuples, class_to_idx mapping
  28. """
  29. types = get_img_extensions(as_set=True) if not types else set(types)
  30. labels = []
  31. filenames = []
  32. for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
  33. rel_path = os.path.relpath(root, folder) if (root != folder) else ''
  34. label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
  35. for f in files:
  36. base, ext = os.path.splitext(f)
  37. if ext.lower() in types:
  38. filenames.append(os.path.join(root, f))
  39. labels.append(label)
  40. if class_to_idx is None:
  41. # building class index
  42. unique_labels = set(labels)
  43. sorted_labels = list(sorted(unique_labels, key=natural_key))
  44. class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
  45. images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
  46. if sort:
  47. images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
  48. return images_and_targets, class_to_idx
  49. class ReaderImageFolder(Reader):
  50. def __init__(
  51. self,
  52. root,
  53. class_map='',
  54. input_key=None,
  55. ):
  56. super().__init__()
  57. self.root = root
  58. class_to_idx = None
  59. if class_map:
  60. class_to_idx = load_class_map(class_map, root)
  61. find_types = None
  62. if input_key:
  63. find_types = input_key.split(';')
  64. self.samples, self.class_to_idx = find_images_and_targets(
  65. root,
  66. class_to_idx=class_to_idx,
  67. types=find_types,
  68. )
  69. if len(self.samples) == 0:
  70. raise RuntimeError(
  71. f'Found 0 images in subfolders of {root}. '
  72. f'Supported image extensions are {", ".join(get_img_extensions())}')
  73. def __getitem__(self, index):
  74. path, target = self.samples[index]
  75. return open(path, 'rb'), target
  76. def __len__(self):
  77. return len(self.samples)
  78. def _filename(self, index, basename=False, absolute=False):
  79. filename = self.samples[index][0]
  80. if basename:
  81. filename = os.path.basename(filename)
  82. elif not absolute:
  83. filename = os.path.relpath(filename, self.root)
  84. return filename