DnnForwardAndRetrieve.java 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. package org.opencv.test.dnn;
  2. import java.util.ArrayList;
  3. import java.util.List;
  4. import org.opencv.core.Core;
  5. import org.opencv.core.CvType;
  6. import org.opencv.core.Mat;
  7. import org.opencv.core.MatOfByte;
  8. import org.opencv.core.Range;
  9. import org.opencv.dnn.Dnn;
  10. import org.opencv.dnn.Net;
  11. import org.opencv.test.OpenCVTestCase;
  12. public class DnnForwardAndRetrieve extends OpenCVTestCase {
  13. public void testForwardAndRetrieve()
  14. {
  15. // Create a simple Caffe prototxt with a Slice layer
  16. String prototxt =
  17. "input: \"data\"\n" +
  18. "layer {\n" +
  19. " name: \"testLayer\"\n" +
  20. " type: \"Slice\"\n" +
  21. " bottom: \"data\"\n" +
  22. " top: \"firstCopy\"\n" +
  23. " top: \"secondCopy\"\n" +
  24. " slice_param {\n" +
  25. " axis: 0\n" +
  26. " slice_point: 2\n" +
  27. " }\n" +
  28. "}";
  29. // Read network from prototxt
  30. MatOfByte bufferProto = new MatOfByte();
  31. bufferProto.fromArray(prototxt.getBytes());
  32. Net net = Dnn.readNetFromCaffe(bufferProto);
  33. net.setPreferableBackend(Dnn.DNN_BACKEND_OPENCV);
  34. // Create input data
  35. Mat inp = new Mat(4, 5, CvType.CV_32F);
  36. Core.randu(inp, -1, 1);
  37. net.setInput(inp);
  38. // Define output names
  39. List<String> outNames = new ArrayList<>();
  40. outNames.add("testLayer");
  41. // Forward and retrieve multiple outputs
  42. List<List<Mat>> outBlobs = new ArrayList<>();
  43. net.forwardAndRetrieve(outBlobs, outNames);
  44. // Verify results
  45. assertEquals(1, outBlobs.size());
  46. assertEquals(2, outBlobs.get(0).size());
  47. // Compare results
  48. Mat expectedFirst = inp.rowRange(0, 2);
  49. Mat expectedSecond = inp.rowRange(2, 4);
  50. Mat actualFirst = outBlobs.get(0).get(0);
  51. Mat actualSecond = outBlobs.get(0).get(1);
  52. assertEquals(0, Core.norm(expectedFirst, actualFirst, Core.NORM_INF), EPS);
  53. assertEquals(0, Core.norm(expectedSecond, actualSecond, Core.NORM_INF), EPS);
  54. }
  55. }