semantic_segmentation.cpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. #include <opencv2/imgproc.hpp>
  2. #include <opencv2/gapi/infer/ie.hpp>
  3. #include <opencv2/gapi/cpu/gcpukernel.hpp>
  4. #include <opencv2/gapi/streaming/cap.hpp>
  5. #include <opencv2/gapi/operators.hpp>
  6. #include <opencv2/highgui.hpp>
  7. #include <opencv2/gapi/streaming/desync.hpp>
  8. #include <opencv2/gapi/streaming/format.hpp>
  9. #include <iomanip>
  10. const std::string keys =
  11. "{ h help | | Print this help message }"
  12. "{ desync | false | Desynchronize inference }"
  13. "{ input | | Path to the input video file }"
  14. "{ output | | Path to the output video file }"
  15. "{ ssm | semantic-segmentation-adas-0001.xml | Path to OpenVINO IE semantic segmentation model (.xml) }";
  16. // 20 colors for 20 classes of semantic-segmentation-adas-0001
  17. static std::vector<cv::Vec3b> colors = {
  18. { 0, 0, 0 },
  19. { 0, 0, 128 },
  20. { 0, 128, 0 },
  21. { 0, 128, 128 },
  22. { 128, 0, 0 },
  23. { 128, 0, 128 },
  24. { 128, 128, 0 },
  25. { 128, 128, 128 },
  26. { 0, 0, 64 },
  27. { 0, 0, 192 },
  28. { 0, 128, 64 },
  29. { 0, 128, 192 },
  30. { 128, 0, 64 },
  31. { 128, 0, 192 },
  32. { 128, 128, 64 },
  33. { 128, 128, 192 },
  34. { 0, 64, 0 },
  35. { 0, 64, 128 },
  36. { 0, 192, 0 },
  37. { 0, 192, 128 },
  38. { 128, 64, 0 }
  39. };
  40. namespace {
  41. std::string get_weights_path(const std::string &model_path) {
  42. const auto EXT_LEN = 4u;
  43. const auto sz = model_path.size();
  44. CV_Assert(sz > EXT_LEN);
  45. auto ext = model_path.substr(sz - EXT_LEN);
  46. std::transform(ext.begin(), ext.end(), ext.begin(), [](unsigned char c){
  47. return static_cast<unsigned char>(std::tolower(c));
  48. });
  49. CV_Assert(ext == ".xml");
  50. return model_path.substr(0u, sz - EXT_LEN) + ".bin";
  51. }
  52. bool isNumber(const std::string &str) {
  53. return !str.empty() && std::all_of(str.begin(), str.end(),
  54. [](unsigned char ch) { return std::isdigit(ch); });
  55. }
  56. std::string toStr(double value) {
  57. std::stringstream ss;
  58. ss << std::fixed << std::setprecision(1) << value;
  59. return ss.str();
  60. }
  61. void classesToColors(const cv::Mat &out_blob,
  62. cv::Mat &mask_img) {
  63. const int H = out_blob.size[0];
  64. const int W = out_blob.size[1];
  65. mask_img.create(H, W, CV_8UC3);
  66. GAPI_Assert(out_blob.type() == CV_8UC1);
  67. const uint8_t* const classes = out_blob.ptr<uint8_t>();
  68. for (int rowId = 0; rowId < H; ++rowId) {
  69. for (int colId = 0; colId < W; ++colId) {
  70. uint8_t class_id = classes[rowId * W + colId];
  71. mask_img.at<cv::Vec3b>(rowId, colId) =
  72. class_id < colors.size()
  73. ? colors[class_id]
  74. : cv::Vec3b{0, 0, 0}; // NB: sample supports 20 classes
  75. }
  76. }
  77. }
  78. void probsToClasses(const cv::Mat& probs, cv::Mat& classes) {
  79. const int C = probs.size[1];
  80. const int H = probs.size[2];
  81. const int W = probs.size[3];
  82. classes.create(H, W, CV_8UC1);
  83. GAPI_Assert(probs.depth() == CV_32F);
  84. float* out_p = reinterpret_cast<float*>(probs.data);
  85. uint8_t* classes_p = reinterpret_cast<uint8_t*>(classes.data);
  86. for (int h = 0; h < H; ++h) {
  87. for (int w = 0; w < W; ++w) {
  88. double max = 0;
  89. int class_id = 0;
  90. for (int c = 0; c < C; ++c) {
  91. int idx = c * H * W + h * W + w;
  92. if (out_p[idx] > max) {
  93. max = out_p[idx];
  94. class_id = c;
  95. }
  96. }
  97. classes_p[h * W + w] = static_cast<uint8_t>(class_id);
  98. }
  99. }
  100. }
  101. } // anonymous namespace
  102. namespace vis {
  103. static void putText(cv::Mat& mat, const cv::Point &position, const std::string &message) {
  104. auto fontFace = cv::FONT_HERSHEY_COMPLEX;
  105. int thickness = 2;
  106. cv::Scalar color = {200, 10, 10};
  107. double fontScale = 0.65;
  108. cv::putText(mat, message, position, fontFace,
  109. fontScale, cv::Scalar(255, 255, 255), thickness + 1);
  110. cv::putText(mat, message, position, fontFace, fontScale, color, thickness);
  111. }
  112. static void drawResults(cv::Mat &img, const cv::Mat &color_mask) {
  113. img = img / 2 + color_mask / 2;
  114. }
  115. } // namespace vis
  116. namespace custom {
  117. G_API_OP(PostProcessing, <cv::GMat(cv::GMat, cv::GMat)>, "sample.custom.post_processing") {
  118. static cv::GMatDesc outMeta(const cv::GMatDesc &in, const cv::GMatDesc &) {
  119. return in;
  120. }
  121. };
  122. GAPI_OCV_KERNEL(OCVPostProcessing, PostProcessing) {
  123. static void run(const cv::Mat &in, const cv::Mat &out_blob, cv::Mat &out) {
  124. int C = -1, H = -1, W = -1;
  125. if (out_blob.size.dims() == 4u) {
  126. C = 1; H = 2, W = 3;
  127. } else if (out_blob.size.dims() == 3u) {
  128. C = 0; H = 1, W = 2;
  129. } else {
  130. throw std::logic_error(
  131. "Number of dimmensions for model output must be 3 or 4!");
  132. }
  133. cv::Mat classes;
  134. // NB: If output has more than single plane, it contains probabilities
  135. // otherwise class id.
  136. if (out_blob.size[C] > 1) {
  137. probsToClasses(out_blob, classes);
  138. } else {
  139. if (out_blob.depth() != CV_32S) {
  140. throw std::logic_error(
  141. "Single channel output must have integer precision!");
  142. }
  143. cv::Mat view(out_blob.size[H], // cols
  144. out_blob.size[W], // rows
  145. CV_32SC1,
  146. out_blob.data);
  147. view.convertTo(classes, CV_8UC1);
  148. }
  149. cv::Mat mask_img;
  150. classesToColors(classes, mask_img);
  151. cv::resize(mask_img, out, in.size(), 0, 0, cv::INTER_NEAREST);
  152. }
  153. };
  154. } // namespace custom
  155. int main(int argc, char *argv[]) {
  156. cv::CommandLineParser cmd(argc, argv, keys);
  157. if (cmd.has("help")) {
  158. cmd.printMessage();
  159. return 0;
  160. }
  161. // Prepare parameters first
  162. const std::string input = cmd.get<std::string>("input");
  163. const std::string output = cmd.get<std::string>("output");
  164. const auto model_path = cmd.get<std::string>("ssm");
  165. const bool desync = cmd.get<bool>("desync");
  166. const auto weights_path = get_weights_path(model_path);
  167. const auto device = "CPU";
  168. G_API_NET(SemSegmNet, <cv::GMat(cv::GMat)>, "semantic-segmentation");
  169. const auto net = cv::gapi::ie::Params<SemSegmNet> {
  170. model_path, weights_path, device
  171. };
  172. const auto kernels = cv::gapi::kernels<custom::OCVPostProcessing>();
  173. const auto networks = cv::gapi::networks(net);
  174. // Now build the graph
  175. cv::GMat in;
  176. cv::GMat bgr = cv::gapi::copy(in);
  177. cv::GMat frame = desync ? cv::gapi::streaming::desync(bgr) : bgr;
  178. cv::GMat out_blob = cv::gapi::infer<SemSegmNet>(frame);
  179. cv::GMat out = custom::PostProcessing::on(frame, out_blob);
  180. cv::GStreamingCompiled pipeline = cv::GComputation(cv::GIn(in), cv::GOut(bgr, out))
  181. .compileStreaming(cv::compile_args(kernels, networks,
  182. cv::gapi::streaming::queue_capacity{1}));
  183. std::shared_ptr<cv::gapi::wip::GCaptureSource> source;
  184. if (isNumber(input)) {
  185. source = std::make_shared<cv::gapi::wip::GCaptureSource>(
  186. std::stoi(input),
  187. std::map<int, double> {
  188. {cv::CAP_PROP_FRAME_WIDTH, 1280},
  189. {cv::CAP_PROP_FRAME_HEIGHT, 720},
  190. {cv::CAP_PROP_BUFFERSIZE, 1},
  191. {cv::CAP_PROP_AUTOFOCUS, true}
  192. }
  193. );
  194. } else {
  195. source = std::make_shared<cv::gapi::wip::GCaptureSource>(input);
  196. }
  197. auto inputs = cv::gin(
  198. static_cast<cv::gapi::wip::IStreamSource::Ptr>(source));
  199. // The execution part
  200. pipeline.setSource(std::move(inputs));
  201. cv::TickMeter tm;
  202. cv::VideoWriter writer;
  203. cv::util::optional<cv::Mat> color_mask;
  204. cv::util::optional<cv::Mat> image;
  205. cv::Mat last_image;
  206. cv::Mat last_color_mask;
  207. pipeline.start();
  208. tm.start();
  209. std::size_t frames = 0u;
  210. std::size_t masks = 0u;
  211. while (pipeline.pull(cv::gout(image, color_mask))) {
  212. if (image.has_value()) {
  213. ++frames;
  214. last_image = std::move(*image);
  215. }
  216. if (color_mask.has_value()) {
  217. ++masks;
  218. last_color_mask = std::move(*color_mask);
  219. }
  220. if (!last_image.empty() && !last_color_mask.empty()) {
  221. tm.stop();
  222. std::string stream_fps = "Stream FPS: " + toStr(frames / tm.getTimeSec());
  223. std::string inference_fps = "Inference FPS: " + toStr(masks / tm.getTimeSec());
  224. cv::Mat tmp = last_image.clone();
  225. vis::drawResults(tmp, last_color_mask);
  226. vis::putText(tmp, {10, 22}, stream_fps);
  227. vis::putText(tmp, {10, 22 + 30}, inference_fps);
  228. cv::imshow("Out", tmp);
  229. cv::waitKey(1);
  230. if (!output.empty()) {
  231. if (!writer.isOpened()) {
  232. const auto sz = cv::Size{tmp.cols, tmp.rows};
  233. writer.open(output, cv::VideoWriter::fourcc('M','J','P','G'), 25.0, sz);
  234. CV_Assert(writer.isOpened());
  235. }
  236. writer << tmp;
  237. }
  238. tm.start();
  239. }
  240. }
  241. tm.stop();
  242. std::cout << "Processed " << frames << " frames" << " ("
  243. << frames / tm.getTimeSec()<< " FPS)" << std::endl;
  244. return 0;
  245. }