face_detection.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # based on: https://github.com/ShiqiYu/libfacedetection.train/blob/74f3aa77c63234dd954d21286e9a60703b8d0868/tasks/task1/yufacedetectnet.py # noqa
  18. import math
  19. from enum import Enum
  20. from typing import Dict, List, Optional, Tuple
  21. import torch
  22. import torch.nn.functional as F
  23. from torch import nn
  24. from kornia.geometry.bbox import nms as nms_kornia
  25. __all__ = ["FaceDetector", "FaceDetectorResult", "FaceKeypoint"]
  26. url: str = "https://github.com/kornia/data/raw/main/yunet_final.pth"
  27. class FaceKeypoint(Enum):
  28. r"""Define the keypoints detected in a face.
  29. The left/right convention is based on the screen viewer.
  30. """
  31. EYE_LEFT = 0
  32. EYE_RIGHT = 1
  33. NOSE = 2
  34. MOUTH_LEFT = 3
  35. MOUTH_RIGHT = 4
  36. class FaceDetectorResult:
  37. r"""Encapsulate the results obtained by the :py:class:`kornia.contrib.FaceDetector`.
  38. Args:
  39. data: the encoded results coming from the feature detector with shape :math:`(14,)`.
  40. """
  41. def __init__(self, data: torch.Tensor) -> None:
  42. if len(data) < 15:
  43. raise ValueError(f"Result must comes as vector of size(14). Got: {data.shape}.")
  44. self._data = data
  45. def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "FaceDetectorResult":
  46. """Like :func:`torch.nn.Module.to()` method."""
  47. self._data = self._data.to(device=device, dtype=dtype)
  48. return self
  49. @property
  50. def xmin(self) -> torch.Tensor:
  51. """The bounding box top-left x-coordinate."""
  52. return self._data[..., 0]
  53. @property
  54. def ymin(self) -> torch.Tensor:
  55. """The bounding box top-left y-coordinate."""
  56. return self._data[..., 1]
  57. @property
  58. def xmax(self) -> torch.Tensor:
  59. """The bounding box bottom-right x-coordinate."""
  60. return self._data[..., 2]
  61. @property
  62. def ymax(self) -> torch.Tensor:
  63. """The bounding box bottom-right y-coordinate."""
  64. return self._data[..., 3]
  65. def get_keypoint(self, keypoint: FaceKeypoint) -> torch.Tensor:
  66. """Get the [x y] position of a given facial keypoint.
  67. Args:
  68. keypoint: the keypoint type to return the position.
  69. """
  70. if keypoint == FaceKeypoint.EYE_LEFT:
  71. out = self._data[..., (4, 5)]
  72. elif keypoint == FaceKeypoint.EYE_RIGHT:
  73. out = self._data[..., (6, 7)]
  74. elif keypoint == FaceKeypoint.NOSE:
  75. out = self._data[..., (8, 9)]
  76. elif keypoint == FaceKeypoint.MOUTH_LEFT:
  77. out = self._data[..., (10, 11)]
  78. elif keypoint == FaceKeypoint.MOUTH_RIGHT:
  79. out = self._data[..., (12, 13)]
  80. else:
  81. raise ValueError(f"Not valid keypoint type. Got: {keypoint}.")
  82. return out
  83. @property
  84. def score(self) -> torch.Tensor:
  85. """The detection score."""
  86. return self._data[..., 14]
  87. @property
  88. def width(self) -> torch.Tensor:
  89. """The bounding box width."""
  90. return self.xmax - self.xmin
  91. @property
  92. def height(self) -> torch.Tensor:
  93. """The bounding box height."""
  94. return self.ymax - self.ymin
  95. @property
  96. def top_left(self) -> torch.Tensor:
  97. """The [x y] position of the top-left coordinate of the bounding box."""
  98. return self._data[..., (0, 1)]
  99. @property
  100. def top_right(self) -> torch.Tensor:
  101. """The [x y] position of the top-left coordinate of the bounding box."""
  102. out = self.top_left
  103. out[..., 0] += self.width
  104. return out
  105. @property
  106. def bottom_right(self) -> torch.Tensor:
  107. """The [x y] position of the bottom-right coordinate of the bounding box."""
  108. return self._data[..., (2, 3)]
  109. @property
  110. def bottom_left(self) -> torch.Tensor:
  111. """The [x y] position of the top-left coordinate of the bounding box."""
  112. out = self.top_left
  113. out[..., 1] += self.height
  114. return out
  115. class FaceDetector(nn.Module):
  116. r"""Detect faces in a given image using a CNN.
  117. By default, it uses the method described in :cite:`facedetect-yu`.
  118. Args:
  119. top_k: the maximum number of detections to return before the nms.
  120. confidence_threshold: the threshold used to discard detections.
  121. nms_threshold: the threshold used by the nms for iou.
  122. keep_top_k: the maximum number of detections to return after the nms.
  123. Return:
  124. A list of B tensors with shape :math:`(N,15)` to be used with :py:class:`kornia.contrib.FaceDetectorResult`.
  125. Example:
  126. >>> img = torch.rand(1, 3, 320, 320)
  127. >>> detect = FaceDetector()
  128. >>> res = detect(img)
  129. """
  130. def __init__(
  131. self, top_k: int = 5000, confidence_threshold: float = 0.3, nms_threshold: float = 0.3, keep_top_k: int = 750
  132. ) -> None:
  133. super().__init__()
  134. self.top_k = top_k
  135. self.confidence_threshold = confidence_threshold
  136. self.nms_threshold = nms_threshold
  137. self.keep_top_k = keep_top_k
  138. self.config = {
  139. "name": "YuFaceDetectNet",
  140. "min_sizes": [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]],
  141. "steps": [8, 16, 32, 64],
  142. "variance": [0.1, 0.2],
  143. "clip": False,
  144. }
  145. self.min_sizes = [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]]
  146. self.steps = [8, 16, 32, 64]
  147. self.variance = [0.1, 0.2]
  148. self.clip = False
  149. self.model = YuFaceDetectNet("test", pretrained=True)
  150. self.nms = nms_kornia
  151. def preprocess(self, image: torch.Tensor) -> torch.Tensor:
  152. return image
  153. def postprocess(self, data: Dict[str, torch.Tensor], height: int, width: int) -> List[torch.Tensor]:
  154. loc, conf, iou = data["loc"], data["conf"], data["iou"]
  155. scale = torch.tensor(
  156. [width, height, width, height, width, height, width, height, width, height, width, height, width, height],
  157. device=loc.device,
  158. dtype=loc.dtype,
  159. ) # 14
  160. priors = _PriorBox(self.min_sizes, self.steps, self.clip, image_size=(height, width))
  161. priors = priors.to(loc.device, loc.dtype)
  162. batched_dets: List[torch.Tensor] = []
  163. for batch_elem in range(loc.shape[0]):
  164. boxes = _decode(loc[batch_elem], priors(), self.variance) # Nx14
  165. boxes = boxes * scale
  166. # clamp here for the compatibility for ONNX
  167. cls_scores, iou_scores = conf[batch_elem, :, 1], iou[batch_elem, :, 0]
  168. scores = (cls_scores * iou_scores.clamp(0.0, 1.0)).sqrt()
  169. # ignore low scores
  170. inds = scores > self.confidence_threshold
  171. boxes, scores = boxes[inds], scores[inds]
  172. # keep top-K before NMS
  173. order = scores.sort(descending=True)[1][: self.top_k]
  174. boxes, scores = boxes[order], scores[order]
  175. # performd NMS
  176. # NOTE: nms need to be revise since does not export well to onnx
  177. dets = torch.cat((boxes, scores[:, None]), dim=-1) # Nx15
  178. keep = self.nms(boxes[:, :4], scores, self.nms_threshold)
  179. if len(keep) > 0:
  180. dets = dets[keep, :]
  181. # keep top-K faster NMS
  182. batched_dets.append(dets[: self.keep_top_k])
  183. return batched_dets
  184. def forward(self, image: torch.Tensor) -> List[torch.Tensor]:
  185. r"""Detect faces in a given batch of images.
  186. Args:
  187. image: batch of images :math:`(B,3,H,W)`
  188. Return:
  189. List[torch.Tensor]: list with the boxes found on each image. :math:`Bx(N,15)`.
  190. """
  191. img = self.preprocess(image)
  192. out = self.model(img)
  193. return self.postprocess(out, img.shape[-2], img.shape[-1])
  194. # utils for the network
  195. class ConvDPUnit(nn.Sequential):
  196. def __init__(self, in_channels: int, out_channels: int, withBNRelu: bool = True) -> None:
  197. super().__init__()
  198. self.add_module("conv1", nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=True, groups=1))
  199. self.add_module("conv2", nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True, groups=out_channels))
  200. if withBNRelu:
  201. self.add_module("bn", nn.BatchNorm2d(out_channels))
  202. self.add_module("relu", nn.ReLU(inplace=True))
  203. class Conv_head(nn.Sequential):
  204. def __init__(self, in_channels: int, mid_channels: int, out_channels: int) -> None:
  205. super().__init__()
  206. self.add_module("conv1", nn.Conv2d(in_channels, mid_channels, 3, 2, 1, bias=True, groups=1))
  207. self.add_module("bn1", nn.BatchNorm2d(mid_channels))
  208. self.add_module("relu", nn.ReLU(inplace=True))
  209. self.add_module("conv2", ConvDPUnit(mid_channels, out_channels))
  210. class Conv4layerBlock(nn.Sequential):
  211. def __init__(self, in_channels: int, out_channels: int, withBNRelu: bool = True) -> None:
  212. super().__init__()
  213. self.add_module("conv1", ConvDPUnit(in_channels, in_channels, True))
  214. self.add_module("conv2", ConvDPUnit(in_channels, out_channels, withBNRelu))
  215. class YuFaceDetectNet(nn.Module):
  216. def __init__(self, phase: str, pretrained: bool) -> None:
  217. super().__init__()
  218. self.phase = phase
  219. self.num_classes = 2
  220. self.model0 = Conv_head(3, 16, 16)
  221. self.model1 = Conv4layerBlock(16, 64)
  222. self.model2 = Conv4layerBlock(64, 64)
  223. self.model3 = Conv4layerBlock(64, 64)
  224. self.model4 = Conv4layerBlock(64, 64)
  225. self.model5 = Conv4layerBlock(64, 64)
  226. self.model6 = Conv4layerBlock(64, 64)
  227. self.head = nn.Sequential(
  228. Conv4layerBlock(64, 3 * (14 + 2 + 1), False),
  229. Conv4layerBlock(64, 2 * (14 + 2 + 1), False),
  230. Conv4layerBlock(64, 2 * (14 + 2 + 1), False),
  231. Conv4layerBlock(64, 3 * (14 + 2 + 1), False),
  232. )
  233. if self.phase == "train":
  234. for m in self.modules():
  235. if isinstance(m, nn.Conv2d):
  236. if m.bias is not None:
  237. nn.init.xavier_normal_(m.weight.data)
  238. m.bias.data.fill_(0.02)
  239. else:
  240. m.weight.data.normal_(0, 0.01)
  241. elif isinstance(m, nn.BatchNorm2d):
  242. m.weight.data.fill_(1)
  243. m.bias.data.zero_()
  244. # use torch.hub to load pretrained model
  245. if pretrained:
  246. pretrained_dict = torch.hub.load_state_dict_from_url(url, map_location=torch.device("cpu"))
  247. self.load_state_dict(pretrained_dict, strict=True)
  248. self.eval()
  249. def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
  250. detection_sources, head_list = [], []
  251. x = self.model0(x)
  252. x = F.max_pool2d(x, 2)
  253. x = self.model1(x)
  254. x = self.model2(x)
  255. x = F.max_pool2d(x, 2)
  256. x = self.model3(x)
  257. detection_sources.append(x)
  258. x = F.max_pool2d(x, 2)
  259. x = self.model4(x)
  260. detection_sources.append(x)
  261. x = F.max_pool2d(x, 2)
  262. x = self.model5(x)
  263. detection_sources.append(x)
  264. x = F.max_pool2d(x, 2)
  265. x = self.model6(x)
  266. detection_sources.append(x)
  267. for i, h in enumerate(self.head):
  268. x_tmp = h(detection_sources[i])
  269. head_list.append(x_tmp.permute(0, 2, 3, 1).contiguous())
  270. head_data = torch.cat([o.view(o.size(0), -1) for o in head_list], 1)
  271. head_data = head_data.view(head_data.size(0), -1, 17)
  272. loc_data, conf_data, iou_data = head_data.split((14, 2, 1), dim=-1)
  273. if self.phase == "test":
  274. conf_data = torch.softmax(conf_data, dim=-1)
  275. else:
  276. loc_data = loc_data.view(loc_data.size(0), -1, 14)
  277. conf_data = conf_data.view(conf_data.size(0), -1, self.num_classes)
  278. iou_data = iou_data.view(iou_data.size(0), -1, 1)
  279. return {"loc": loc_data, "conf": conf_data, "iou": iou_data}
  280. # utils for post-processing
  281. # Adapted from https://github.com/Hakuyume/chainer-ssd
  282. def _decode(loc: torch.Tensor, priors: torch.Tensor, variances: List[float]) -> torch.Tensor:
  283. """Decode locations from predictions using priors to undo the encoding for offset regression at train time.
  284. Args:
  285. loc:location predictions for loc layers. Shape: [num_priors,4].
  286. priors: Prior boxes in center-offset form. Shape: [num_priors,4].
  287. variances: (list[float]) Variances of priorboxes.
  288. Return:
  289. Tensor containing decoded bounding box predictions.
  290. """
  291. boxes = torch.cat(
  292. (
  293. priors[:, 0:2] + loc[:, 0:2] * variances[0] * priors[:, 2:4],
  294. priors[:, 2:4] * torch.exp(loc[:, 2:4] * variances[1]),
  295. priors[:, 0:2] + loc[:, 4:6] * variances[0] * priors[:, 2:4],
  296. priors[:, 0:2] + loc[:, 6:8] * variances[0] * priors[:, 2:4],
  297. priors[:, 0:2] + loc[:, 8:10] * variances[0] * priors[:, 2:4],
  298. priors[:, 0:2] + loc[:, 10:12] * variances[0] * priors[:, 2:4],
  299. priors[:, 0:2] + loc[:, 12:14] * variances[0] * priors[:, 2:4],
  300. ),
  301. 1,
  302. )
  303. # prepare final output
  304. tmp = boxes[:, 0:2] - boxes[:, 2:4] / 2
  305. return torch.cat((tmp, boxes[:, 2:4] + tmp, boxes[:, 4:]), dim=-1)
  306. class _PriorBox:
  307. def __init__(self, min_sizes: List[List[int]], steps: List[int], clip: bool, image_size: Tuple[int, int]) -> None:
  308. self.min_sizes = min_sizes
  309. self.steps = steps
  310. self.clip = clip
  311. self.image_size = image_size
  312. self.device: torch.device = torch.device("cpu")
  313. self.dtype: torch.dtype = torch.float32
  314. for i in range(4):
  315. if self.steps[i] != math.pow(2, (i + 3)):
  316. raise ValueError("steps must be [8,16,32,64]")
  317. self.feature_map_2th = [int(int((self.image_size[0] + 1) / 2) / 2), int(int((self.image_size[1] + 1) / 2) / 2)]
  318. self.feature_map_3th = [int(self.feature_map_2th[0] / 2), int(self.feature_map_2th[1] / 2)]
  319. self.feature_map_4th = [int(self.feature_map_3th[0] / 2), int(self.feature_map_3th[1] / 2)]
  320. self.feature_map_5th = [int(self.feature_map_4th[0] / 2), int(self.feature_map_4th[1] / 2)]
  321. self.feature_map_6th = [int(self.feature_map_5th[0] / 2), int(self.feature_map_5th[1] / 2)]
  322. self.feature_maps = [self.feature_map_3th, self.feature_map_4th, self.feature_map_5th, self.feature_map_6th]
  323. def to(self, device: torch.device, dtype: torch.dtype) -> "_PriorBox":
  324. self.device = device
  325. self.dtype = dtype
  326. return self
  327. def __call__(self) -> torch.Tensor:
  328. anchors: List[float] = []
  329. for k, f in enumerate(self.feature_maps):
  330. min_sizes: List[int] = self.min_sizes[k]
  331. # NOTE: the nested loop it's to make torchscript happy
  332. for i in range(f[0]):
  333. for j in range(f[1]):
  334. for min_size in min_sizes:
  335. s_kx = min_size / self.image_size[1]
  336. s_ky = min_size / self.image_size[0]
  337. cx = (j + 0.5) * self.steps[k] / self.image_size[1]
  338. cy = (i + 0.5) * self.steps[k] / self.image_size[0]
  339. anchors += [cx, cy, s_kx, s_ky]
  340. # back to torch land
  341. output = torch.tensor(anchors, device=self.device, dtype=self.dtype).view(-1, 4)
  342. if self.clip:
  343. output = output.clamp(max=1, min=0)
  344. return output