pascal_semsegm_test_fcn.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from __future__ import print_function
  2. from abc import ABCMeta, abstractmethod
  3. import numpy as np
  4. import sys
  5. import argparse
  6. import time
  7. from imagenet_cls_test_alexnet import CaffeModel, DNNOnnxModel
  8. try:
  9. import cv2 as cv
  10. except ImportError:
  11. raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
  12. 'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
  13. def get_metrics(conf_mat):
  14. pix_accuracy = np.trace(conf_mat) / np.sum(conf_mat)
  15. t = np.sum(conf_mat, 1)
  16. num_cl = np.count_nonzero(t)
  17. assert num_cl
  18. mean_accuracy = np.sum(np.nan_to_num(np.divide(np.diagonal(conf_mat), t))) / num_cl
  19. col_sum = np.sum(conf_mat, 0)
  20. mean_iou = np.sum(
  21. np.nan_to_num(np.divide(np.diagonal(conf_mat), (t + col_sum - np.diagonal(conf_mat))))) / num_cl
  22. return pix_accuracy, mean_accuracy, mean_iou
  23. def eval_segm_result(net_out):
  24. assert type(net_out) is np.ndarray
  25. assert len(net_out.shape) == 4
  26. channels_dim = 1
  27. y_dim = channels_dim + 1
  28. x_dim = y_dim + 1
  29. res = np.zeros(net_out.shape).astype(int)
  30. for i in range(net_out.shape[y_dim]):
  31. for j in range(net_out.shape[x_dim]):
  32. max_ch = np.argmax(net_out[..., i, j])
  33. res[0, max_ch, i, j] = 1
  34. return res
  35. def get_conf_mat(gt, prob):
  36. assert type(gt) is np.ndarray
  37. assert type(prob) is np.ndarray
  38. conf_mat = np.zeros((gt.shape[0], gt.shape[0]))
  39. for ch_gt in range(conf_mat.shape[0]):
  40. gt_channel = gt[ch_gt, ...]
  41. for ch_pr in range(conf_mat.shape[1]):
  42. prob_channel = prob[ch_pr, ...]
  43. conf_mat[ch_gt][ch_pr] = np.count_nonzero(np.multiply(gt_channel, prob_channel))
  44. return conf_mat
  45. class MeanChannelsPreproc:
  46. def __init__(self):
  47. pass
  48. @staticmethod
  49. def process(img, framework):
  50. image_data = None
  51. if framework == "Caffe":
  52. image_data = cv.dnn.blobFromImage(img, scalefactor=1.0, mean=(123.0, 117.0, 104.0), swapRB=True)
  53. elif framework == "DNN (ONNX)":
  54. image_data = cv.dnn.blobFromImage(img, scalefactor=0.019, mean=(123.675, 116.28, 103.53), swapRB=True)
  55. else:
  56. raise ValueError("Unknown framework")
  57. return image_data
  58. class DatasetImageFetch(object):
  59. __metaclass__ = ABCMeta
  60. data_prepoc = object
  61. @abstractmethod
  62. def __iter__(self):
  63. pass
  64. @abstractmethod
  65. def next(self):
  66. pass
  67. @staticmethod
  68. def pix_to_c(pix):
  69. return pix[0] * 256 * 256 + pix[1] * 256 + pix[2]
  70. @staticmethod
  71. def color_to_gt(color_img, colors):
  72. num_classes = len(colors)
  73. gt = np.zeros((num_classes, color_img.shape[0], color_img.shape[1])).astype(int)
  74. for img_y in range(color_img.shape[0]):
  75. for img_x in range(color_img.shape[1]):
  76. c = DatasetImageFetch.pix_to_c(color_img[img_y][img_x])
  77. if c in colors:
  78. cls = colors.index(c)
  79. gt[cls][img_y][img_x] = 1
  80. return gt
  81. class PASCALDataFetch(DatasetImageFetch):
  82. img_dir = ''
  83. segm_dir = ''
  84. names = []
  85. colors = []
  86. i = 0
  87. def __init__(self, img_dir, segm_dir, names_file, segm_cls_colors, preproc):
  88. self.img_dir = img_dir
  89. self.segm_dir = segm_dir
  90. self.colors = self.read_colors(segm_cls_colors)
  91. self.data_prepoc = preproc
  92. self.i = 0
  93. with open(names_file) as f:
  94. for l in f.readlines():
  95. self.names.append(l.rstrip())
  96. @staticmethod
  97. def read_colors(colors):
  98. result = []
  99. for color in colors:
  100. result.append(DatasetImageFetch.pix_to_c(color))
  101. return result
  102. def __iter__(self):
  103. return self
  104. def __next__(self):
  105. if self.i < len(self.names):
  106. name = self.names[self.i]
  107. self.i += 1
  108. segm_file = self.segm_dir + name + ".png"
  109. img_file = self.img_dir + name + ".jpg"
  110. gt = self.color_to_gt(cv.imread(segm_file, cv.IMREAD_COLOR)[:, :, ::-1], self.colors)
  111. img = cv.imread(img_file, cv.IMREAD_COLOR)
  112. img_caffe = self.data_prepoc.process(img[:, :, ::-1], "Caffe")
  113. img_dnn = self.data_prepoc.process(img[:, :, ::-1], "DNN (ONNX)")
  114. img_dict = {
  115. "Caffe": img_caffe,
  116. "DNN (ONNX)": img_dnn
  117. }
  118. return img_dict, gt
  119. else:
  120. self.i = 0
  121. raise StopIteration
  122. def get_num_classes(self):
  123. return len(self.colors)
  124. class SemSegmEvaluation:
  125. log = sys.stdout
  126. def __init__(self, log_path,):
  127. self.log = open(log_path, 'w')
  128. def process(self, frameworks, data_fetcher):
  129. samples_handled = 0
  130. conf_mats = [np.zeros((data_fetcher.get_num_classes(), data_fetcher.get_num_classes())) for i in range(len(frameworks))]
  131. blobs_l1_diff = [0] * len(frameworks)
  132. blobs_l1_diff_count = [0] * len(frameworks)
  133. blobs_l_inf_diff = [sys.float_info.min] * len(frameworks)
  134. inference_time = [0.0] * len(frameworks)
  135. for in_blob_dict, gt in data_fetcher:
  136. frameworks_out = []
  137. samples_handled += 1
  138. for i in range(len(frameworks)):
  139. start = time.time()
  140. framework_name = frameworks[i].get_name()
  141. out = frameworks[i].get_output(in_blob_dict[framework_name])
  142. end = time.time()
  143. segm = eval_segm_result(out)
  144. conf_mats[i] += get_conf_mat(gt, segm[0])
  145. frameworks_out.append(out)
  146. inference_time[i] += end - start
  147. pix_acc, mean_acc, miou = get_metrics(conf_mats[i])
  148. name = frameworks[i].get_name()
  149. print(samples_handled, 'Pixel accuracy, %s:' % name, 100 * pix_acc, file=self.log)
  150. print(samples_handled, 'Mean accuracy, %s:' % name, 100 * mean_acc, file=self.log)
  151. print(samples_handled, 'Mean IOU, %s:' % name, 100 * miou, file=self.log)
  152. print("Inference time, ms ", \
  153. frameworks[i].get_name(), inference_time[i] / samples_handled * 1000, file=self.log)
  154. for i in range(1, len(frameworks)):
  155. log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
  156. diff = np.abs(frameworks_out[0] - frameworks_out[i])
  157. l1_diff = np.sum(diff) / diff.size
  158. print(samples_handled, "L1 difference", log_str, l1_diff, file=self.log)
  159. blobs_l1_diff[i] += l1_diff
  160. blobs_l1_diff_count[i] += 1
  161. if np.max(diff) > blobs_l_inf_diff[i]:
  162. blobs_l_inf_diff[i] = np.max(diff)
  163. print(samples_handled, "L_INF difference", log_str, blobs_l_inf_diff[i], file=self.log)
  164. self.log.flush()
  165. for i in range(1, len(blobs_l1_diff)):
  166. log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
  167. print('Final l1 diff', log_str, blobs_l1_diff[i] / blobs_l1_diff_count[i], file=self.log)
  168. # PASCAL VOC 2012 classes colors
  169. colors_pascal_voc_2012 = [
  170. [0, 0, 0],
  171. [128, 0, 0],
  172. [0, 128, 0],
  173. [128, 128, 0],
  174. [0, 0, 128],
  175. [128, 0, 128],
  176. [0, 128, 128],
  177. [128, 128, 128],
  178. [64, 0, 0],
  179. [192, 0, 0],
  180. [64, 128, 0],
  181. [192, 128, 0],
  182. [64, 0, 128],
  183. [192, 0, 128],
  184. [64, 128, 128],
  185. [192, 128, 128],
  186. [0, 64, 0],
  187. [128, 64, 0],
  188. [0, 192, 0],
  189. [128, 192, 0],
  190. [0, 64, 128],
  191. ]
  192. if __name__ == "__main__":
  193. parser = argparse.ArgumentParser()
  194. parser.add_argument("--imgs_dir", help="path to PASCAL VOC 2012 images dir, data/VOC2012/JPEGImages")
  195. parser.add_argument("--segm_dir", help="path to PASCAL VOC 2012 segmentation dir, data/VOC2012/SegmentationClass/")
  196. parser.add_argument("--val_names", help="path to file with validation set image names, download it here: "
  197. "https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/data/pascal/seg11valid.txt")
  198. parser.add_argument("--prototxt", help="path to caffe prototxt, download it here: "
  199. "https://github.com/opencv/opencv/blob/4.x/samples/data/dnn/fcn8s-heavy-pascal.prototxt")
  200. parser.add_argument("--caffemodel", help="path to caffemodel file, download it here: "
  201. "http://dl.caffe.berkeleyvision.org/fcn8s-heavy-pascal.caffemodel")
  202. parser.add_argument("--onnxmodel", help="path to onnx model file, download it here: "
  203. "https://github.com/onnx/models/raw/491ce05590abb7551d7fae43c067c060eeb575a6/validated/vision/object_detection_segmentation/fcn/model/fcn-resnet50-12.onnx")
  204. parser.add_argument("--log", help="path to logging file", default='log.txt')
  205. parser.add_argument("--in_blob", help="name for input blob", default='data')
  206. parser.add_argument("--out_blob", help="name for output blob", default='score')
  207. args = parser.parse_args()
  208. prep = MeanChannelsPreproc()
  209. df = PASCALDataFetch(args.imgs_dir, args.segm_dir, args.val_names, colors_pascal_voc_2012, prep)
  210. fw = [CaffeModel(args.prototxt, args.caffemodel, args.in_blob, args.out_blob, True),
  211. DNNOnnxModel(args.onnxmodel, args.in_blob, args.out_blob)]
  212. segm_eval = SemSegmEvaluation(args.log)
  213. segm_eval.process(fw, df)