dataset.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import os
  2. import random
  3. import numpy as np
  4. import cv2
  5. from tqdm import tqdm
  6. from PIL import Image
  7. from torch.utils import data
  8. from torchvision import transforms
  9. from image_proc import preproc
  10. from config import Config
  11. from utils import path_to_image
  12. Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning
  13. config = Config()
  14. _class_labels_TR_sorted = (
  15. 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, '
  16. 'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, '
  17. 'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, '
  18. 'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, '
  19. 'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, '
  20. 'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, '
  21. 'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, '
  22. 'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, '
  23. 'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, '
  24. 'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, '
  25. 'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, '
  26. 'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, '
  27. 'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, '
  28. 'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht'
  29. )
  30. class_labels_TR_sorted = _class_labels_TR_sorted.split(', ')
  31. class MyData(data.Dataset):
  32. def __init__(self, datasets, data_size, is_train=True):
  33. # data_size is None when using dynamic_size or data_size is manually set to None (for inference in the original size).
  34. self.is_train = is_train
  35. self.data_size = data_size
  36. self.load_all = config.load_all
  37. self.device = config.device
  38. valid_extensions = ['.png', '.jpg', '.PNG', '.JPG', '.JPEG']
  39. if self.is_train and config.auxiliary_classification:
  40. self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)}
  41. self.transform_image = transforms.Compose([
  42. transforms.ToTensor(),
  43. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  44. ])
  45. self.transform_label = transforms.Compose([
  46. transforms.ToTensor(),
  47. ])
  48. dataset_root = os.path.join(config.data_root_dir, config.task)
  49. # datasets can be a list of different datasets for training on combined sets.
  50. self.image_paths = []
  51. for dataset in datasets.split('+'):
  52. image_root = os.path.join(dataset_root, dataset, 'im')
  53. self.image_paths += [os.path.join(image_root, p) for p in os.listdir(image_root) if any(p.endswith(ext) for ext in valid_extensions)]
  54. self.label_paths = []
  55. for p in self.image_paths:
  56. for ext in valid_extensions:
  57. ## 'im' and 'gt' may need modifying
  58. p_gt = p.replace('/im/', '/gt/')[:-(len(p.split('.')[-1])+1)] + ext
  59. file_exists = False
  60. if os.path.exists(p_gt):
  61. self.label_paths.append(p_gt)
  62. file_exists = True
  63. break
  64. if not file_exists:
  65. print('Not exists:', p_gt)
  66. if len(self.label_paths) != len(self.image_paths):
  67. set_image_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths])
  68. set_label_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths])
  69. print('Path diff:', set_image_paths - set_label_paths)
  70. raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})")
  71. if self.load_all:
  72. self.images_loaded, self.labels_loaded = [], []
  73. self.class_labels_loaded = []
  74. # for image_path, label_path in zip(self.image_paths, self.label_paths):
  75. for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)):
  76. _image = path_to_image(image_path, size=self.data_size, color_type='rgb')
  77. _label = path_to_image(label_path, size=self.data_size, color_type='gray')
  78. self.images_loaded.append(_image)
  79. self.labels_loaded.append(_label)
  80. self.class_labels_loaded.append(
  81. self.cls_name2id[label_path.split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1
  82. )
  83. def __getitem__(self, index):
  84. if self.load_all:
  85. image = self.images_loaded[index]
  86. label = self.labels_loaded[index]
  87. class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1
  88. else:
  89. image = path_to_image(self.image_paths[index], size=self.data_size, color_type='rgb')
  90. label = path_to_image(self.label_paths[index], size=self.data_size, color_type='gray')
  91. class_label = self.cls_name2id[self.label_paths[index].split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1
  92. # loading image and label
  93. if self.is_train:
  94. if config.background_color_synthesis:
  95. image.putalpha(label)
  96. array_image = np.array(image)
  97. array_foreground = array_image[:, :, :3].astype(np.float32)
  98. array_mask = (array_image[:, :, 3:] / 255).astype(np.float32)
  99. array_background = np.zeros_like(array_foreground)
  100. choice = random.random()
  101. if choice < 0.4:
  102. # Black/Gray/White backgrounds
  103. array_background[:, :, :] = random.randint(0, 255)
  104. elif choice < 0.8:
  105. # Background color that similar to the foreground object. Hard negative samples.
  106. foreground_pixel_number = np.sum(array_mask > 0)
  107. color_foreground_mean = np.mean(array_foreground * array_mask, axis=(0, 1)) * (np.prod(array_foreground.shape[:2]) / foreground_pixel_number)
  108. color_up_or_down = random.choice((-1, 1))
  109. # Up or down for 20% range from 255 or 0, respectively.
  110. color_foreground_mean += (255 - color_foreground_mean if color_up_or_down == 1 else color_foreground_mean) * (random.random() * 0.2) * color_up_or_down
  111. array_background[:, :, :] = color_foreground_mean
  112. else:
  113. # Any color
  114. for idx_channel in range(3):
  115. array_background[:, :, idx_channel] = random.randint(0, 255)
  116. array_foreground_background = array_foreground * array_mask + array_background * (1 - array_mask)
  117. image = Image.fromarray(array_foreground_background.astype(np.uint8))
  118. image, label = preproc(image, label, preproc_methods=config.preproc_methods)
  119. # else:
  120. # if _label.shape[0] > 2048 or _label.shape[1] > 2048:
  121. # _image = cv2.resize(_image, (2048, 2048), interpolation=cv2.INTER_LINEAR)
  122. # _label = cv2.resize(_label, (2048, 2048), interpolation=cv2.INTER_LINEAR)
  123. # At present, we use fixed sizes in inference, instead of consistent dynamic size with training.
  124. if self.is_train:
  125. if config.dynamic_size is None:
  126. image, label = self.transform_image(image), self.transform_label(label)
  127. else:
  128. size_div_32 = (int(image.size[0] // 32 * 32), int(image.size[1] // 32 * 32))
  129. if image.size != size_div_32:
  130. image = image.resize(size_div_32)
  131. label = label.resize(size_div_32)
  132. image, label = self.transform_image(image), self.transform_label(label)
  133. if self.is_train:
  134. return image, label, class_label
  135. else:
  136. return image, label, self.label_paths[index]
  137. def __len__(self):
  138. return len(self.image_paths)
  139. def custom_collate_fn(batch):
  140. if config.dynamic_size:
  141. dynamic_size = tuple(sorted(config.dynamic_size))
  142. dynamic_size_batch = (random.randint(dynamic_size[0][0], dynamic_size[0][1]) // 32 * 32, random.randint(dynamic_size[1][0], dynamic_size[1][1]) // 32 * 32) # select a value randomly in the range of [dynamic_size[0/1][0], dynamic_size[0/1][1]].
  143. data_size = dynamic_size_batch
  144. else:
  145. data_size = config.size
  146. new_batch = []
  147. transform_image = transforms.Compose([
  148. transforms.Resize(data_size[::-1]),
  149. transforms.ToTensor(),
  150. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  151. ])
  152. transform_label = transforms.Compose([
  153. transforms.Resize(data_size[::-1]),
  154. transforms.ToTensor(),
  155. ])
  156. for image, label, class_label in batch:
  157. new_batch.append((transform_image(image), transform_label(label), class_label))
  158. return data._utils.collate.default_collate(new_batch)