ZEDCustomObjDetection.cs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. #if ZED_OPENCV_FOR_UNITY
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using UnityEngine;
  5. using OpenCVForUnity.CoreModule;
  6. using OpenCVForUnity.DnnModule;
  7. using OpenCVForUnity.ImgprocModule;
  8. using sl;
  9. using System.Linq;
  10. /// <summary>
  11. /// Example that shows how to use the custom object detection module from ZED SDK.
  12. /// Uses Yolov4 from Opencv. Therefore requires the OpenCVForUnity package.
  13. /// </summary>
  14. public class ZEDCustomObjDetection : MonoBehaviour
  15. {
  16. [TooltipAttribute("Path to a binary file of model contains trained weights. It could be a file with extensions .caffemodel (Caffe), .pb (TensorFlow), .t7 or .net (Torch), .weights (Darknet).")]
  17. public string model;
  18. [TooltipAttribute("Path to a text file of model contains network configuration. It could be a file with extensions .prototxt (Caffe), .pbtxt (TensorFlow), .cfg (Darknet).")]
  19. public string config;
  20. [TooltipAttribute("Path to a text file with names of classes to label detected objects.")]
  21. public string classes;
  22. [TooltipAttribute("Optional list of classes filters. Add classes you want to keep displayed.")]
  23. public List<string> classesFilter;
  24. [TooltipAttribute("Confidence threshold.")]
  25. public float confThreshold = 0.5f;
  26. [TooltipAttribute("Non-maximum suppression threshold.")]
  27. public float nmsThreshold = 0.4f;
  28. private List<string> classNames;
  29. private List<string> outBlobNames;
  30. private List<string> outBlobTypes;
  31. private Net net;
  32. public int inferenceWidth = 416;
  33. public int inferenceHeight = 416;
  34. public float scale = 1.0f;
  35. public Scalar mean = new Scalar(0, 0, 0, 0);
  36. private Mat bgrMat;
  37. public ZEDManager zedManager;
  38. /// <summary>
  39. /// Scene's ZEDToOpenCVRetriever, which creates OpenCV mats and deploys events each time the ZED grabs an image.
  40. /// It's how we get the image and required matrices that we use to look for markers.
  41. /// </summary>
  42. public ZEDToOpenCVRetriever imageRetriever;
  43. public delegate void onNewIngestCustomODDelegate();
  44. public event onNewIngestCustomODDelegate OnIngestCustomOD;
  45. public void Start()
  46. {
  47. if (!zedManager) zedManager = FindObjectOfType<ZEDManager>();
  48. if (zedManager.objectDetectionModel != DETECTION_MODEL.CUSTOM_BOX_OBJECTS)
  49. {
  50. Debug.LogWarning("sl.DETECTION_MODEL.CUSTOM_BOX_OBJECTS is mandatory for this sample");
  51. }
  52. else
  53. {
  54. //We'll listen for updates from a ZEDToOpenCVRetriever, which will call an event whenever it has a new image from the ZED.
  55. if (!imageRetriever) imageRetriever = ZEDToOpenCVRetriever.GetInstance();
  56. imageRetriever.OnImageUpdated_LeftRGBA += Run;
  57. }
  58. Init();
  59. }
  60. public void OnDestroy()
  61. {
  62. imageRetriever.OnImageUpdated_LeftRGBA -= Run;
  63. if (net != null)
  64. net.Dispose();
  65. if (bgrMat != null)
  66. bgrMat.Dispose();
  67. }
  68. public void OnValidate()
  69. {
  70. if (classesFilter.Count > 0)
  71. {
  72. classNames = classesFilter;
  73. }
  74. else
  75. classNames = readClassNames(classes);
  76. }
  77. public void Init()
  78. {
  79. if (!string.IsNullOrEmpty(classes))
  80. {
  81. classNames = readClassNames(classes);
  82. if (classNames == null)
  83. {
  84. Debug.LogError("Classes file is not loaded. Please see \"StreamingAssets/dnn/setup_dnn_module.pdf\". ");
  85. }
  86. }
  87. else if (classesFilter.Count > 0)
  88. {
  89. classNames = classesFilter;
  90. }
  91. if (string.IsNullOrEmpty(model))
  92. {
  93. Debug.LogError("Model file is not loaded. Please see \"StreamingAssets/dnn/setup_dnn_module.pdf\". ");
  94. }
  95. else if (string.IsNullOrEmpty(config))
  96. {
  97. Debug.LogError("Config file is not loaded. Please see \"StreamingAssets/dnn/setup_dnn_module.pdf\". ");
  98. }
  99. else
  100. {
  101. net = Dnn.readNet(model, config);
  102. if (net == null) Debug.LogWarning("network is null");
  103. outBlobNames = getOutputsNames(net);
  104. outBlobTypes = getOutputsTypes(net);
  105. }
  106. }
  107. public void Run(Camera cam, Mat camera_matrix, Mat rgbaMat)
  108. {
  109. if (!zedManager.IsObjectDetectionRunning) return;
  110. Mat bgrMat = new Mat(rgbaMat.rows(), rgbaMat.cols(), CvType.CV_8UC3);
  111. Imgproc.cvtColor(rgbaMat, bgrMat, Imgproc.COLOR_RGBA2BGR);
  112. // Create a 4D blob from a frame.
  113. Size infSize = new Size(inferenceWidth > 0 ? inferenceWidth : bgrMat.cols(),
  114. inferenceHeight > 0 ? inferenceHeight : bgrMat.rows());
  115. Mat blob = Dnn.blobFromImage(bgrMat, scale, infSize, mean, true, false);
  116. // Run a model.
  117. net.setInput(blob);
  118. if (net.getLayer(new DictValue(0)).outputNameToIndex("im_info") != -1)
  119. { // Faster-RCNN or R-FCN
  120. Imgproc.resize(bgrMat, bgrMat, infSize);
  121. Mat imInfo = new Mat(1, 3, CvType.CV_32FC1);
  122. imInfo.put(0, 0, new float[] {
  123. (float)infSize.height,
  124. (float)infSize.width,
  125. 1.6f
  126. });
  127. net.setInput(imInfo, "im_info");
  128. }
  129. List<Mat> outs = new List<Mat>();
  130. net.forward(outs, outBlobNames);
  131. postprocess(rgbaMat, outs, net, Dnn.DNN_BACKEND_OPENCV);
  132. for (int i = 0; i < outs.Count; i++)
  133. {
  134. outs[i].Dispose();
  135. }
  136. blob.Dispose();
  137. }
  138. /// <summary>
  139. /// Postprocess the specified frame, outs and net.
  140. /// </summary>
  141. /// <param name="frame">Frame.</param>
  142. /// <param name="outs">Outs.</param>
  143. /// <param name="net">Net.</param>
  144. /// <param name="backend">Backend.</param>
  145. protected virtual void postprocess(Mat frame, List<Mat> outs, Net net, int backend = Dnn.DNN_BACKEND_OPENCV)
  146. {
  147. MatOfInt outLayers = net.getUnconnectedOutLayers();
  148. string outLayerType = outBlobTypes[0];
  149. List<int> classIdsList = new List<int>();
  150. List<float> confidencesList = new List<float>();
  151. List<Rect2d> boxesList = new List<Rect2d>();
  152. for (int i = 0; i < outs.Count; ++i)
  153. {
  154. // Network produces output blob with a shape NxC where N is a number of
  155. // detected objects and C is a number of classes + 4 where the first 4
  156. // numbers are [center_x, center_y, width, height]
  157. //Debug.Log ("outs[i].ToString() "+outs[i].ToString());
  158. float[] positionData = new float[5];
  159. float[] confidenceData = new float[outs[i].cols() - 5];
  160. for (int p = 0; p < outs[i].rows(); p++)
  161. {
  162. outs[i].get(p, 0, positionData);
  163. outs[i].get(p, 5, confidenceData);
  164. int maxIdx = confidenceData.Select((val, idx) => new { V = val, I = idx }).Aggregate((max, working) => (max.V > working.V) ? max : working).I;
  165. float confidence = confidenceData[maxIdx];
  166. if (confidence > confThreshold)
  167. {
  168. float centerX = positionData[0] * frame.cols();
  169. float centerY = positionData[1] * frame.rows();
  170. float width = positionData[2] * frame.cols();
  171. float height = positionData[3] * frame.rows();
  172. float left = centerX - width / 2;
  173. float top = centerY - height / 2;
  174. classIdsList.Add(maxIdx);
  175. confidencesList.Add((float)confidence);
  176. boxesList.Add(new Rect2d(left, top, width, height));
  177. }
  178. }
  179. }
  180. Dictionary<int, List<int>> class2indices = new Dictionary<int, List<int>>();
  181. for (int i = 0; i < classIdsList.Count; i++)
  182. {
  183. if (confidencesList[i] >= confThreshold)
  184. {
  185. if (!class2indices.ContainsKey(classIdsList[i]))
  186. class2indices.Add(classIdsList[i], new List<int>());
  187. class2indices[classIdsList[i]].Add(i);
  188. }
  189. }
  190. List<Rect2d> nmsBoxesList = new List<Rect2d>();
  191. List<float> nmsConfidencesList = new List<float>();
  192. List<int> nmsClassIdsList = new List<int>();
  193. foreach (int key in class2indices.Keys)
  194. {
  195. List<Rect2d> localBoxesList = new List<Rect2d>();
  196. List<float> localConfidencesList = new List<float>();
  197. List<int> classIndicesList = class2indices[key];
  198. for (int i = 0; i < classIndicesList.Count; i++)
  199. {
  200. localBoxesList.Add(boxesList[classIndicesList[i]]);
  201. localConfidencesList.Add(confidencesList[classIndicesList[i]]);
  202. }
  203. using (MatOfRect2d localBoxes = new MatOfRect2d(localBoxesList.ToArray()))
  204. using (MatOfFloat localConfidences = new MatOfFloat(localConfidencesList.ToArray()))
  205. using (MatOfInt nmsIndices = new MatOfInt())
  206. {
  207. Dnn.NMSBoxes(localBoxes, localConfidences, confThreshold, nmsThreshold, nmsIndices);
  208. for (int i = 0; i < nmsIndices.total(); i++)
  209. {
  210. int idx = (int)nmsIndices.get(i, 0)[0];
  211. nmsBoxesList.Add(localBoxesList[idx]);
  212. nmsConfidencesList.Add(localConfidencesList[idx]);
  213. nmsClassIdsList.Add(key);
  214. }
  215. }
  216. }
  217. boxesList = nmsBoxesList;
  218. classIdsList = nmsClassIdsList;
  219. confidencesList = nmsConfidencesList;
  220. ingestCustomData(boxesList, confidencesList, classIdsList);
  221. }
  222. private void ingestCustomData(List<Rect2d> boxesList, List<float> confidencesList, List<int> classIdsList)
  223. {
  224. List<CustomBoxObjectData> objects_in = new List<CustomBoxObjectData>();
  225. for (int idx = 0; idx < boxesList.Count; ++idx)
  226. {
  227. if (classNames != null && classNames.Count != 0)
  228. {
  229. if (classesFilter.Count == 0 || (classIdsList[idx] < (int)classNames.Count && (classesFilter.Contains(classNames[classIdsList[idx]]))))
  230. {
  231. CustomBoxObjectData tmp = new CustomBoxObjectData();
  232. tmp.uniqueObjectID = sl.ZEDCamera.GenerateUniqueID();
  233. tmp.label = classIdsList[idx];
  234. tmp.probability = confidencesList[idx];
  235. Vector2[] bbox = new Vector2[4];
  236. bbox[0] = new Vector2((float)boxesList[idx].x, (float)boxesList[idx].y);
  237. bbox[1] = new Vector2((float)boxesList[idx].x + (float)boxesList[idx].width, (float)boxesList[idx].y);
  238. bbox[2] = new Vector2((float)boxesList[idx].x + (float)boxesList[idx].width, (float)boxesList[idx].y + (float)boxesList[idx].height);
  239. bbox[3] = new Vector2((float)boxesList[idx].x, (float)boxesList[idx].y + (float)boxesList[idx].height);
  240. tmp.boundingBox2D = bbox;
  241. objects_in.Add(tmp);
  242. }
  243. }
  244. }
  245. zedManager.zedCamera.IngestCustomBoxObjects(objects_in);
  246. if (OnIngestCustomOD != null)
  247. OnIngestCustomOD();
  248. }
  249. /// <summary>
  250. /// Reads the class names.
  251. /// </summary>
  252. /// <returns>The class names.</returns>
  253. /// <param name="filename">Filename.</param>
  254. private List<string> readClassNames(string filename)
  255. {
  256. List<string> classNames = new List<string>();
  257. System.IO.StreamReader cReader = null;
  258. try
  259. {
  260. cReader = new System.IO.StreamReader(filename, System.Text.Encoding.Default);
  261. while (cReader.Peek() >= 0)
  262. {
  263. string name = cReader.ReadLine();
  264. classNames.Add(name);
  265. }
  266. }
  267. catch (System.Exception ex)
  268. {
  269. Debug.LogError(ex.Message);
  270. return null;
  271. }
  272. finally
  273. {
  274. if (cReader != null)
  275. cReader.Close();
  276. }
  277. return classNames;
  278. }
  279. /// <summary>
  280. /// Gets the outputs names.
  281. /// </summary>
  282. /// <returns>The outputs names.</returns>
  283. /// <param name="net">Net.</param>
  284. protected List<string> getOutputsNames(Net net)
  285. {
  286. List<string> names = new List<string>();
  287. MatOfInt outLayers = net.getUnconnectedOutLayers();
  288. for (int i = 0; i < outLayers.total(); ++i)
  289. {
  290. names.Add(net.getLayer(new DictValue((int)outLayers.get(i, 0)[0])).get_name());
  291. }
  292. outLayers.Dispose();
  293. return names;
  294. }
  295. /// <summary>
  296. /// Gets the outputs types.
  297. /// </summary>
  298. /// <returns>The outputs types.</returns>
  299. /// <param name="net">Net.</param>
  300. protected virtual List<string> getOutputsTypes(Net net)
  301. {
  302. List<string> types = new List<string>();
  303. MatOfInt outLayers = net.getUnconnectedOutLayers();
  304. for (int i = 0; i < outLayers.total(); ++i)
  305. {
  306. types.Add(net.getLayer(new DictValue((int)outLayers.get(i, 0)[0])).get_type());
  307. }
  308. outLayers.Dispose();
  309. return types;
  310. }
  311. }
  312. #endif