utils.py 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. import math
  16. from typing import Optional, Union
  17. import torch
  18. from torch import Tensor
  19. from torch.nn.functional import conv2d, conv3d, pad, unfold
  20. from typing_extensions import Literal
  21. from torchmetrics.utilities.checks import _check_same_shape
  22. from torchmetrics.utilities.imports import _SCIPY_AVAILABLE
  23. def _ignore_background(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
  24. """Ignore the background class in the computation assuming it is the first, index 0."""
  25. preds = preds[:, 1:] if preds.shape[1] > 1 else preds
  26. target = target[:, 1:] if target.shape[1] > 1 else target
  27. return preds, target
  28. def _check_mixed_shape(preds: Tensor, target: Tensor) -> None:
  29. """Check that predictions and target have the same shape, else raise error."""
  30. if preds.dim() == (target.dim() + 1):
  31. if preds.shape[0] != target.shape[0] or preds.shape[2:] != target.shape[1:]:
  32. raise RuntimeError(
  33. f"Predictions and targets are expected to have the same shape, got {preds.shape} and {target.shape}."
  34. )
  35. elif (preds.dim() + 1) == target.dim():
  36. if preds.shape[0] != target.shape[0] or preds.shape[1:] != target.shape[2:]:
  37. raise RuntimeError(
  38. f"Predictions and targets are expected to have the same shape, got {preds.shape} and {target.shape}."
  39. )
  40. else:
  41. raise RuntimeError(
  42. f"Predictions and targets are expected to have the same shape, got {preds.shape} and {target.shape}."
  43. )
  44. def _segmentation_inputs_format(
  45. preds: Tensor,
  46. target: Tensor,
  47. include_background: bool,
  48. num_classes: Optional[int] = None,
  49. input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
  50. ) -> tuple[Tensor, Tensor]:
  51. """Check and format inputs to the one-hot encodings."""
  52. if input_format == "mixed":
  53. _check_mixed_shape(preds, target)
  54. else:
  55. _check_same_shape(preds, target)
  56. if input_format == "index":
  57. if num_classes is None:
  58. raise ValueError("Argument `num_classes` must be provided when `input_format='index'`.")
  59. preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
  60. target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
  61. elif input_format == "one-hot":
  62. if num_classes is None:
  63. num_classes = _get_num_classes(preds)
  64. preds = _format_logits(preds, num_classes)
  65. target = _format_logits(target, num_classes)
  66. elif input_format == "mixed":
  67. if preds.dim() == (target.dim() + 1):
  68. if num_classes is None:
  69. num_classes = _get_num_classes(preds)
  70. preds = _format_logits(preds, num_classes)
  71. target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
  72. elif (preds.dim() + 1) == target.dim():
  73. if num_classes is None:
  74. num_classes = _get_num_classes(target)
  75. target = _format_logits(target, num_classes)
  76. preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
  77. if preds.ndim < 3:
  78. raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
  79. if not include_background:
  80. preds, target = _ignore_background(preds, target)
  81. return preds, target
  82. def _format_logits(tensor: Tensor, num_classes: int) -> Tensor:
  83. """Transform logits or probabilities into integer one-hot encodings."""
  84. if torch.is_floating_point(tensor):
  85. tensor = tensor.argmax(dim=1)
  86. tensor = torch.nn.functional.one_hot(tensor, num_classes=num_classes).movedim(-1, 1)
  87. return tensor
  88. def _get_num_classes(tensor: Tensor) -> int:
  89. """Get num classes from a tensor if it is not set."""
  90. try:
  91. num_classes = tensor.shape[1]
  92. except IndexError as err:
  93. raise IndexError(f"Cannot determine `num_classes` from tensor: {tensor}.") from err
  94. if num_classes == 0:
  95. raise ValueError(f"Expected argument `num_classes` to be a positive integer, but got {num_classes}.")
  96. return num_classes
  97. def check_if_binarized(x: Tensor) -> None:
  98. """Check if tensor is binarized.
  99. Example:
  100. >>> from torchmetrics.functional.segmentation.utils import check_if_binarized
  101. >>> import torch
  102. >>> check_if_binarized(torch.tensor([0, 1, 1, 0]))
  103. """
  104. if not torch.all(x.bool() == x):
  105. raise ValueError("Input x should be binarized")
  106. def _unfold(x: Tensor, kernel_size: tuple[int, ...]) -> Tensor:
  107. """Unfold the input tensor to a matrix. Function supports 3d images e.g. (B, C, D, H, W).
  108. Inspired by:
  109. https://github.com/f-dangel/unfoldNd/blob/main/unfoldNd/unfold.py
  110. Args:
  111. x: Input tensor to be unfolded.
  112. kernel_size: The size of the sliding blocks in each dimension.
  113. """
  114. batch_size, channels = x.shape[:2]
  115. n = x.ndim - 2
  116. if n == 2:
  117. return unfold(x, kernel_size)
  118. kernel_size_numel = kernel_size[0] * kernel_size[1] * kernel_size[2]
  119. repeat = [channels, 1] + [1 for _ in kernel_size]
  120. weight = torch.eye(kernel_size_numel, device=x.device, dtype=x.dtype)
  121. weight = weight.reshape(kernel_size_numel, 1, *kernel_size).repeat(*repeat)
  122. unfold_x = conv3d(x, weight=weight, bias=None)
  123. return unfold_x.reshape(batch_size, channels * kernel_size_numel, -1)
  124. def generate_binary_structure(rank: int, connectivity: int) -> Tensor:
  125. """Translated version of the function from scipy.ndimage.morphology.
  126. Args:
  127. rank: The rank of the structuring element.
  128. connectivity: The number of neighbors connected to a given pixel.
  129. Returns:
  130. The structuring element.
  131. Examples::
  132. >>> from torchmetrics.functional.segmentation.utils import generate_binary_structure
  133. >>> import torch
  134. >>> generate_binary_structure(2, 1)
  135. tensor([[False, True, False],
  136. [ True, True, True],
  137. [False, True, False]])
  138. >>> generate_binary_structure(2, 2)
  139. tensor([[True, True, True],
  140. [True, True, True],
  141. [True, True, True]])
  142. >>> generate_binary_structure(3, 2) # doctest: +NORMALIZE_WHITESPACE
  143. tensor([[[False, True, False],
  144. [ True, True, True],
  145. [False, True, False]],
  146. [[ True, True, True],
  147. [ True, True, True],
  148. [ True, True, True]],
  149. [[False, True, False],
  150. [ True, True, True],
  151. [False, True, False]]])
  152. """
  153. if connectivity < 1:
  154. connectivity = 1
  155. if rank < 1:
  156. return torch.tensor([1], dtype=torch.uint8)
  157. grids = torch.meshgrid([torch.arange(3) for _ in range(rank)], indexing="ij")
  158. output = torch.abs(torch.stack(grids, dim=0) - 1)
  159. output = torch.sum(output, dim=0)
  160. return output <= connectivity
  161. def binary_erosion(
  162. image: Tensor, structure: Optional[Tensor] = None, origin: Optional[tuple[int, ...]] = None, border_value: int = 0
  163. ) -> Tensor:
  164. """Binary erosion of a tensor image.
  165. Implementation inspired by answer to this question: https://stackoverflow.com/questions/56235733/
  166. Args:
  167. image: The image to be eroded, must be a binary tensor with shape ``(batch_size, channels, height, width)``.
  168. structure: The structuring element used for the erosion. If no structuring element is provided, an element
  169. is generated with a square connectivity equal to one.
  170. origin: The origin of the structuring element.
  171. border_value: The value to be used for the border.
  172. Examples::
  173. >>> from torchmetrics.functional.segmentation.utils import binary_erosion
  174. >>> import torch
  175. >>> image = torch.tensor([[[[0, 0, 0, 0, 0],
  176. ... [0, 1, 1, 1, 0],
  177. ... [0, 1, 1, 1, 0],
  178. ... [0, 1, 1, 1, 0],
  179. ... [0, 0, 0, 0, 0]]]])
  180. >>> binary_erosion(image)
  181. tensor([[[[0, 0, 0, 0, 0],
  182. [0, 0, 0, 0, 0],
  183. [0, 0, 1, 0, 0],
  184. [0, 0, 0, 0, 0],
  185. [0, 0, 0, 0, 0]]]], dtype=torch.uint8)
  186. >>> binary_erosion(image, structure=torch.ones(4, 4))
  187. tensor([[[[0, 0, 0, 0, 0],
  188. [0, 0, 0, 0, 0],
  189. [0, 0, 0, 0, 0],
  190. [0, 0, 0, 0, 0],
  191. [0, 0, 0, 0, 0]]]], dtype=torch.uint8)
  192. """
  193. if not isinstance(image, Tensor):
  194. raise TypeError(f"Expected argument `image` to be of type Tensor but found {type(image)}")
  195. if image.ndim not in [4, 5]:
  196. raise ValueError(f"Expected argument `image` to be of rank 4 or 5 but found rank {image.ndim}")
  197. check_if_binarized(image)
  198. # construct the structuring element if not provided
  199. if structure is None:
  200. structure = generate_binary_structure(image.ndim - 2, 1).int().to(image.device)
  201. check_if_binarized(structure)
  202. if origin is None:
  203. origin = structure.ndim * (1,)
  204. # first pad the image to have correct unfolding; here is where the origins is used
  205. image_pad = pad(
  206. image,
  207. [x for i in range(len(origin)) for x in [origin[i], structure.shape[i] - origin[i] - 1]],
  208. mode="constant",
  209. value=border_value,
  210. )
  211. # Unfold the image to be able to perform operation on neighborhoods
  212. image_unfold = _unfold(image_pad.float(), kernel_size=structure.shape)
  213. strel_flatten = torch.flatten(structure).unsqueeze(0).unsqueeze(-1)
  214. sums = image_unfold - strel_flatten.int()
  215. # Take minimum over the neighborhood
  216. result, _ = sums.min(dim=1)
  217. # Reshape the image to recover initial shape
  218. return (torch.reshape(result, image.shape) + 1).byte()
  219. def distance_transform(
  220. x: Tensor,
  221. sampling: Optional[Union[Tensor, list[float]]] = None,
  222. metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
  223. engine: Literal["pytorch", "scipy"] = "pytorch",
  224. ) -> Tensor:
  225. """Calculate distance transform of a binary tensor.
  226. This function calculates the distance transform of a binary tensor, replacing each foreground pixel with the
  227. distance to the closest background pixel. The distance is calculated using the euclidean, chessboard or taxicab
  228. distance.
  229. The memory consumption of this function is in the worst cast N/2**2 where N is the number of pixel. Since we need
  230. to compare all foreground pixels to all background pixels, the memory consumption is quadratic in the number of
  231. pixels. The memory consumption can be reduced by using the ``scipy`` engine, which is more memory efficient but
  232. should also be slower for larger images.
  233. Args:
  234. x: The binary tensor to calculate the distance transform of.
  235. sampling: The sampling refers to the pixel spacing in the image, i.e. the distance between two adjacent pixels.
  236. If not provided, the pixel spacing is assumed to be 1.
  237. metric: The distance to use for the distance transform. Can be one of ``"euclidean"``, ``"chessboard"``
  238. or ``"taxicab"``.
  239. engine: The engine to use for the distance transform. Can be one of ``["pytorch", "scipy"]``. In general,
  240. the ``pytorch`` engine is faster, but the ``scipy`` engine is more memory efficient.
  241. Returns:
  242. The distance transform of the input tensor.
  243. Examples::
  244. >>> from torchmetrics.functional.segmentation.utils import distance_transform
  245. >>> import torch
  246. >>> x = torch.tensor([[0, 0, 0, 0, 0],
  247. ... [0, 1, 1, 1, 0],
  248. ... [0, 1, 1, 1, 0],
  249. ... [0, 1, 1, 1, 0],
  250. ... [0, 0, 0, 0, 0]])
  251. >>> distance_transform(x)
  252. tensor([[0., 0., 0., 0., 0.],
  253. [0., 1., 1., 1., 0.],
  254. [0., 1., 2., 1., 0.],
  255. [0., 1., 1., 1., 0.],
  256. [0., 0., 0., 0., 0.]])
  257. """
  258. if not isinstance(x, Tensor):
  259. raise ValueError(f"Expected argument `x` to be of type `torch.Tensor` but got `{type(x)}`.")
  260. if x.ndim != 2:
  261. raise ValueError(f"Expected argument `x` to be of rank 2 but got rank `{x.ndim}`.")
  262. if sampling is not None and not isinstance(sampling, list):
  263. raise ValueError(
  264. f"Expected argument `sampling` to either be `None` or of type `list` but got `{type(sampling)}`."
  265. )
  266. if metric not in ["euclidean", "chessboard", "taxicab"]:
  267. raise ValueError(
  268. f"Expected argument `metric` to be one of `['euclidean', 'chessboard', 'taxicab']` but got `{metric}`."
  269. )
  270. if engine not in ["pytorch", "scipy"]:
  271. raise ValueError(f"Expected argument `engine` to be one of `['pytorch', 'scipy']` but got `{engine}`.")
  272. if sampling is None:
  273. sampling = [1, 1]
  274. else:
  275. if len(sampling) != 2:
  276. raise ValueError(f"Expected argument `sampling` to have length 2 but got length `{len(sampling)}`.")
  277. if engine == "pytorch":
  278. x = x.float()
  279. # calculate distance from every foreground pixel to every background pixel
  280. i0, j0 = torch.where(x == 0)
  281. i1, j1 = torch.where(x == 1)
  282. dis_row = (i1.view(-1, 1) - i0.view(1, -1)).abs()
  283. dis_col = (j1.view(-1, 1) - j0.view(1, -1)).abs()
  284. # # calculate distance
  285. h, _ = x.shape
  286. if metric == "euclidean":
  287. dis = ((sampling[0] * dis_row) ** 2 + (sampling[1] * dis_col) ** 2).sqrt()
  288. if metric == "chessboard":
  289. dis = torch.max(sampling[0] * dis_row, sampling[1] * dis_col).float()
  290. if metric == "taxicab":
  291. dis = (sampling[0] * dis_row + sampling[1] * dis_col).float()
  292. # select only the closest distance
  293. mindis, _ = torch.min(dis, dim=1)
  294. z = torch.zeros_like(x).view(-1)
  295. z[i1 * h + j1] = mindis
  296. return z.view(x.shape)
  297. if not _SCIPY_AVAILABLE:
  298. raise ValueError(
  299. "The `scipy` engine requires `scipy` to be installed. Either install `scipy` or use the `pytorch` engine."
  300. )
  301. from scipy import ndimage
  302. if metric == "euclidean":
  303. return ndimage.distance_transform_edt(x.cpu().numpy(), sampling)
  304. return ndimage.distance_transform_cdt(x.cpu().numpy(), sampling, metric=metric)
  305. def mask_edges(
  306. preds: Tensor,
  307. target: Tensor,
  308. crop: bool = True,
  309. spacing: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
  310. ) -> Union[tuple[Tensor, Tensor], tuple[Tensor, Tensor, Tensor, Tensor]]:
  311. """Get the edges of binary segmentation masks.
  312. Args:
  313. preds: The predicted binary segmentation mask
  314. target: The ground truth binary segmentation mask
  315. crop: Whether to crop the edges to the region of interest. If ``True``, the edges are cropped to the bounding
  316. spacing: The pixel spacing of the input images. If provided, the edges are calculated using the euclidean
  317. Returns:
  318. If spacing is not provided, a 2-tuple containing the edges of the predicted and target mask respectively is
  319. returned. If spacing is provided, a 4-tuple containing the edges and areas of the predicted and target mask
  320. respectively is returned.
  321. """
  322. _check_same_shape(preds, target)
  323. if preds.ndim not in [2, 3]:
  324. raise ValueError(f"Expected argument `preds` to be of rank 2 or 3 but got rank `{preds.ndim}`.")
  325. check_if_binarized(preds)
  326. check_if_binarized(target)
  327. if crop:
  328. or_val = preds | target
  329. if not or_val.any():
  330. p, t = torch.zeros_like(preds), torch.zeros_like(target)
  331. return p, t, p, t
  332. # this seems to be working but does not seem to be right
  333. preds, target = pad(preds, preds.ndim * [1, 1]), pad(target, target.ndim * [1, 1])
  334. if spacing is None:
  335. # no spacing, use binary erosion
  336. be_pred = binary_erosion(preds.unsqueeze(0).unsqueeze(0)).squeeze() ^ preds
  337. be_target = binary_erosion(target.unsqueeze(0).unsqueeze(0)).squeeze() ^ target
  338. return be_pred, be_target
  339. # use neighborhood to get edges
  340. table, kernel = get_neighbour_tables(spacing, device=preds.device)
  341. spatial_dims = len(spacing)
  342. conv_operator = conv2d if spatial_dims == 2 else conv3d
  343. volume = torch.stack([preds.unsqueeze(0), target.unsqueeze(0)], dim=0).float()
  344. code_preds, code_target = conv_operator(volume, kernel.to(volume))
  345. # edges
  346. all_ones = len(table) - 1
  347. edges_preds = (code_preds != 0) & (code_preds != all_ones)
  348. edges_target = (code_target != 0) & (code_target != all_ones)
  349. # # areas of edges
  350. areas_preds = torch.index_select(table, 0, code_preds.view(-1).int()).view_as(code_preds)
  351. areas_target = torch.index_select(table, 0, code_target.view(-1).int()).view_as(code_target)
  352. return edges_preds[0], edges_target[0], areas_preds[0], areas_target[0]
  353. def surface_distance(
  354. preds: Tensor,
  355. target: Tensor,
  356. distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
  357. spacing: Optional[Union[Tensor, list[float]]] = None,
  358. ) -> Tensor:
  359. """Calculate the surface distance between two binary edge masks.
  360. May return infinity if the predicted mask is empty and the target mask is not, or vice versa.
  361. Args:
  362. preds: The predicted binary edge mask.
  363. target: The target binary edge mask.
  364. distance_metric: The distance metric to use. One of `["euclidean", "chessboard", "taxicab"]`.
  365. spacing: The spacing between pixels along each spatial dimension.
  366. Returns:
  367. A tensor with length equal to the number of edges in predictions e.g. `preds.sum()`. Each element is the
  368. distance from the corresponding edge in `preds` to the closest edge in `target`.
  369. Example::
  370. >>> import torch
  371. >>> from torchmetrics.functional.segmentation.utils import surface_distance
  372. >>> preds = torch.tensor([[1, 1, 1, 1, 1],
  373. ... [1, 0, 0, 0, 1],
  374. ... [1, 0, 0, 0, 1],
  375. ... [1, 0, 0, 0, 1],
  376. ... [1, 1, 1, 1, 1]], dtype=torch.bool)
  377. >>> target = torch.tensor([[1, 1, 1, 1, 0],
  378. ... [1, 0, 0, 1, 0],
  379. ... [1, 0, 0, 1, 0],
  380. ... [1, 0, 0, 1, 0],
  381. ... [1, 1, 1, 1, 0]], dtype=torch.bool)
  382. >>> surface_distance(preds, target, distance_metric="euclidean", spacing=[1, 1])
  383. tensor([0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1.])
  384. """
  385. if not (preds.dtype == torch.bool and target.dtype == torch.bool):
  386. raise ValueError(f"Expected both inputs to be of type `torch.bool`, but got {preds.dtype} and {target.dtype}.")
  387. if not torch.any(target):
  388. dis = torch.inf * torch.ones_like(target)
  389. else:
  390. if not torch.any(preds):
  391. dis = torch.inf * torch.ones_like(preds)
  392. return dis[target]
  393. dis = distance_transform(~target, sampling=spacing, metric=distance_metric)
  394. return dis[preds]
  395. def edge_surface_distance(
  396. preds: Tensor,
  397. target: Tensor,
  398. distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
  399. spacing: Optional[Union[Tensor, list[float]]] = None,
  400. symmetric: bool = False,
  401. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  402. """Extracts the edges from the input masks and calculates the surface distance between them.
  403. Args:
  404. preds: The predicted binary edge mask.
  405. target: The target binary edge mask.
  406. distance_metric: The distance metric to use. One of `["euclidean", "chessboard", "taxicab"]`.
  407. spacing: The spacing between pixels along each spatial dimension.
  408. symmetric: Whether to calculate the symmetric distance between the edges.
  409. Returns:
  410. A tensor with length equal to the number of edges in predictions e.g. `preds.sum()`. Each element is the
  411. distance from the corresponding edge in `preds` to the closest edge in `target`. If `symmetric` is `True`, the
  412. function returns a tuple containing the distances from the predicted edges to the target edges and vice versa.
  413. """
  414. output = mask_edges(preds, target)
  415. edges_preds, edges_target = output[0].bool(), output[1].bool()
  416. if symmetric:
  417. return (
  418. surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing),
  419. surface_distance(edges_target, edges_preds, distance_metric=distance_metric, spacing=spacing),
  420. )
  421. return surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing)
  422. @functools.lru_cache
  423. def get_neighbour_tables(
  424. spacing: Union[tuple[int, int], tuple[int, int, int]], device: Optional[torch.device] = None
  425. ) -> tuple[Tensor, Tensor]:
  426. """Create a table that maps neighbour codes to the contour length or surface area of the corresponding contour.
  427. Args:
  428. spacing: The spacing between pixels along each spatial dimension.
  429. device: The device on which the table should be created.
  430. Returns:
  431. A tuple containing as its first element the table that maps neighbour codes to the contour length or surface
  432. area of the corresponding contour and as its second element the kernel used to compute the neighbour codes.
  433. """
  434. if isinstance(spacing, tuple) and len(spacing) == 2:
  435. return table_contour_length(spacing, device)
  436. if isinstance(spacing, tuple) and len(spacing) == 3:
  437. return table_surface_area(spacing, device)
  438. raise ValueError("The spacing must be a tuple of length 2 or 3.")
  439. def table_contour_length(spacing: tuple[int, int], device: Optional[torch.device] = None) -> tuple[Tensor, Tensor]:
  440. """Create a table that maps neighbour codes to the contour length of the corresponding contour.
  441. Adopted from:
  442. https://github.com/deepmind/surface-distance/blob/master/surface_distance/lookup_tables.py
  443. Args:
  444. spacing: The spacing between pixels along each spatial dimension. Should be a tuple of length 2.
  445. device: The device on which the table should be created.
  446. Returns:
  447. A tuple containing as its first element the table that maps neighbour codes to the contour length of the
  448. corresponding contour and as its second element the kernel used to compute the neighbour codes.
  449. Example::
  450. >>> from torchmetrics.functional.segmentation.utils import table_contour_length
  451. >>> table, kernel = table_contour_length((2,2))
  452. >>> table
  453. tensor([0.0000, 1.4142, 1.4142, 2.0000, 1.4142, 2.0000, 2.8284, 1.4142, 1.4142,
  454. 2.8284, 2.0000, 1.4142, 2.0000, 1.4142, 1.4142, 0.0000])
  455. >>> kernel
  456. tensor([[[[8, 4],
  457. [2, 1]]]])
  458. """
  459. if not isinstance(spacing, tuple) and len(spacing) != 2:
  460. raise ValueError("The spacing must be a tuple of length 2.")
  461. first, second = spacing # spacing along the first and second spatial dimension respectively
  462. diag = 0.5 * math.sqrt(first**2 + second**2)
  463. table = torch.zeros(16, dtype=torch.float32, device=device)
  464. for i in [1, 2, 4, 7, 8, 11, 13, 14]:
  465. table[i] = diag
  466. for i in [3, 12]:
  467. table[i] = second
  468. for i in [5, 10]:
  469. table[i] = first
  470. for i in [6, 9]:
  471. table[i] = 2 * diag
  472. kernel = torch.as_tensor([[[[8, 4], [2, 1]]]], device=device)
  473. return table, kernel
  474. @functools.lru_cache
  475. def table_surface_area(spacing: tuple[int, int, int], device: Optional[torch.device] = None) -> tuple[Tensor, Tensor]:
  476. """Create a table that maps neighbour codes to the surface area of the corresponding surface.
  477. Adopted from:
  478. https://github.com/deepmind/surface-distance/blob/master/surface_distance/lookup_tables.py
  479. Args:
  480. spacing: The spacing between pixels along each spatial dimension. Should be a tuple of length 3.
  481. device: The device on which the table should be created.
  482. Returns:
  483. A tuple containing as its first element the table that maps neighbour codes to the surface area of the
  484. corresponding surface and as its second element the kernel used to compute the neighbour codes.
  485. Example::
  486. >>> from torchmetrics.functional.segmentation.utils import table_surface_area
  487. >>> table, kernel = table_surface_area((2,2,2))
  488. >>> table
  489. tensor([0.0000, 0.8660, 0.8660, 2.8284, 0.8660, 2.8284, 1.7321, 4.5981, 0.8660,
  490. 1.7321, 2.8284, 4.5981, 2.8284, 4.5981, 4.5981, 4.0000, 0.8660, 2.8284,
  491. 1.7321, 4.5981, 1.7321, 4.5981, 2.5981, 5.1962, 1.7321, 3.6945, 3.6945,
  492. 6.2925, 3.6945, 6.2925, 5.4641, 4.5981, 0.8660, 1.7321, 2.8284, 4.5981,
  493. 1.7321, 3.6945, 3.6945, 6.2925, 1.7321, 2.5981, 4.5981, 5.1962, 3.6945,
  494. 5.4641, 6.2925, 4.5981, 2.8284, 4.5981, 4.5981, 4.0000, 3.6945, 6.2925,
  495. 5.4641, 4.5981, 3.6945, 5.4641, 6.2925, 4.5981, 5.6569, 3.6945, 3.6945,
  496. 2.8284, 0.8660, 1.7321, 1.7321, 3.6945, 2.8284, 4.5981, 3.6945, 6.2925,
  497. 1.7321, 2.5981, 3.6945, 5.4641, 4.5981, 5.1962, 6.2925, 4.5981, 2.8284,
  498. 4.5981, 3.6945, 6.2925, 4.5981, 4.0000, 5.4641, 4.5981, 3.6945, 5.4641,
  499. 5.6569, 3.6945, 6.2925, 4.5981, 3.6945, 2.8284, 1.7321, 2.5981, 3.6945,
  500. 5.4641, 3.6945, 5.4641, 5.6569, 3.6945, 2.5981, 3.4641, 5.4641, 2.5981,
  501. 5.4641, 2.5981, 3.6945, 1.7321, 4.5981, 5.1962, 6.2925, 4.5981, 6.2925,
  502. 4.5981, 3.6945, 2.8284, 5.4641, 2.5981, 3.6945, 1.7321, 3.6945, 1.7321,
  503. 1.7321, 0.8660, 0.8660, 1.7321, 1.7321, 3.6945, 1.7321, 3.6945, 2.5981,
  504. 5.4641, 2.8284, 3.6945, 4.5981, 6.2925, 4.5981, 6.2925, 5.1962, 4.5981,
  505. 1.7321, 3.6945, 2.5981, 5.4641, 2.5981, 5.4641, 3.4641, 2.5981, 3.6945,
  506. 5.6569, 5.4641, 3.6945, 5.4641, 3.6945, 2.5981, 1.7321, 2.8284, 3.6945,
  507. 4.5981, 6.2925, 3.6945, 5.6569, 5.4641, 3.6945, 4.5981, 5.4641, 4.0000,
  508. 4.5981, 6.2925, 3.6945, 4.5981, 2.8284, 4.5981, 6.2925, 5.1962, 4.5981,
  509. 5.4641, 3.6945, 2.5981, 1.7321, 6.2925, 3.6945, 4.5981, 2.8284, 3.6945,
  510. 1.7321, 1.7321, 0.8660, 2.8284, 3.6945, 3.6945, 5.6569, 4.5981, 6.2925,
  511. 5.4641, 3.6945, 4.5981, 5.4641, 6.2925, 3.6945, 4.0000, 4.5981, 4.5981,
  512. 2.8284, 4.5981, 6.2925, 5.4641, 3.6945, 5.1962, 4.5981, 2.5981, 1.7321,
  513. 6.2925, 3.6945, 3.6945, 1.7321, 4.5981, 2.8284, 1.7321, 0.8660, 4.5981,
  514. 5.4641, 6.2925, 3.6945, 6.2925, 3.6945, 3.6945, 1.7321, 5.1962, 2.5981,
  515. 4.5981, 1.7321, 4.5981, 1.7321, 2.8284, 0.8660, 4.0000, 4.5981, 4.5981,
  516. 2.8284, 4.5981, 2.8284, 1.7321, 0.8660, 4.5981, 1.7321, 2.8284, 0.8660,
  517. 2.8284, 0.8660, 0.8660, 0.0000])
  518. >>> kernel
  519. tensor([[[[[128, 64],
  520. [ 32, 16]],
  521. [[ 8, 4],
  522. [ 2, 1]]]]])
  523. """
  524. if not isinstance(spacing, tuple) and len(spacing) != 3:
  525. raise ValueError("The spacing must be a tuple of length 3.")
  526. zeros = [0.0, 0.0, 0.0]
  527. table = torch.tensor(
  528. [
  529. [zeros, zeros, zeros, zeros],
  530. [[0.125, 0.125, 0.125], zeros, zeros, zeros],
  531. [[-0.125, -0.125, 0.125], zeros, zeros, zeros],
  532. [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros],
  533. [[0.125, -0.125, 0.125], zeros, zeros, zeros],
  534. [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros],
  535. [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],
  536. [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],
  537. [[-0.125, 0.125, 0.125], zeros, zeros, zeros],
  538. [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros],
  539. [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros],
  540. [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],
  541. [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros],
  542. [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros],
  543. [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],
  544. [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros],
  545. [[0.125, -0.125, -0.125], zeros, zeros, zeros],
  546. [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros],
  547. [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],
  548. [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],
  549. [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],
  550. [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros],
  551. [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],
  552. [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
  553. [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],
  554. [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],
  555. [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros],
  556. [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],
  557. [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros],
  558. [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
  559. [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],
  560. [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros],
  561. [[0.125, -0.125, 0.125], zeros, zeros, zeros],
  562. [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  563. [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros],
  564. [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros],
  565. [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  566. [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros],
  567. [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],
  568. [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],
  569. [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  570. [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros],
  571. [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],
  572. [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],
  573. [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros],
  574. [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],
  575. [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],
  576. [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],
  577. [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros],
  578. [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros],
  579. [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros],
  580. [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros],
  581. [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros],
  582. [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
  583. [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
  584. [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros],
  585. [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros],
  586. [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
  587. [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
  588. [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros],
  589. [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
  590. [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros],
  591. [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros],
  592. [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros],
  593. [[-0.125, -0.125, 0.125], zeros, zeros, zeros],
  594. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],
  595. [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],
  596. [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros],
  597. [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros],
  598. [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],
  599. [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],
  600. [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],
  601. [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros],
  602. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros],
  603. [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros],
  604. [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
  605. [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],
  606. [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],
  607. [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],
  608. [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],
  609. [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros],
  610. [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],
  611. [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros],
  612. [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],
  613. [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros],
  614. [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros],
  615. [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
  616. [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros],
  617. [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros],
  618. [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
  619. [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
  620. [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros],
  621. [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],
  622. [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],
  623. [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros],
  624. [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros],
  625. [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  626. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros],
  627. [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros],
  628. [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
  629. [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],
  630. [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
  631. [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
  632. [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros],
  633. [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros],
  634. [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
  635. [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
  636. [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],
  637. [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
  638. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],
  639. [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],
  640. [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],
  641. [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros],
  642. [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
  643. [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
  644. [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],
  645. [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],
  646. [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros],
  647. [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros],
  648. [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros],
  649. [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
  650. [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros],
  651. [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros],
  652. [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  653. [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros],
  654. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],
  655. [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros],
  656. [[0.125, 0.125, 0.125], zeros, zeros, zeros],
  657. [[0.125, 0.125, 0.125], zeros, zeros, zeros],
  658. [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros],
  659. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],
  660. [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros],
  661. [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  662. [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros],
  663. [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros],
  664. [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
  665. [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros],
  666. [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros],
  667. [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros],
  668. [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],
  669. [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],
  670. [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
  671. [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
  672. [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros],
  673. [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],
  674. [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],
  675. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],
  676. [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
  677. [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],
  678. [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
  679. [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
  680. [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros],
  681. [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros],
  682. [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
  683. [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
  684. [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],
  685. [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
  686. [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros],
  687. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros],
  688. [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  689. [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros],
  690. [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros],
  691. [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],
  692. [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],
  693. [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros],
  694. [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
  695. [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
  696. [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros],
  697. [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros],
  698. [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
  699. [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros],
  700. [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros],
  701. [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],
  702. [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros],
  703. [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],
  704. [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros],
  705. [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],
  706. [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],
  707. [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],
  708. [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],
  709. [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
  710. [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros],
  711. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros],
  712. [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros],
  713. [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],
  714. [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],
  715. [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],
  716. [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros],
  717. [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros],
  718. [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],
  719. [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],
  720. [[-0.125, -0.125, 0.125], zeros, zeros, zeros],
  721. [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros],
  722. [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros],
  723. [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros],
  724. [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
  725. [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros],
  726. [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
  727. [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
  728. [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros],
  729. [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros],
  730. [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
  731. [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
  732. [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros],
  733. [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros],
  734. [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros],
  735. [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros],
  736. [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros],
  737. [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],
  738. [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],
  739. [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],
  740. [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros],
  741. [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],
  742. [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],
  743. [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros],
  744. [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  745. [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],
  746. [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],
  747. [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros],
  748. [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  749. [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros],
  750. [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros],
  751. [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],
  752. [[0.125, -0.125, 0.125], zeros, zeros, zeros],
  753. [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros],
  754. [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],
  755. [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
  756. [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros],
  757. [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],
  758. [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros],
  759. [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],
  760. [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],
  761. [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
  762. [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],
  763. [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros],
  764. [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],
  765. [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],
  766. [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],
  767. [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros],
  768. [[0.125, -0.125, -0.125], zeros, zeros, zeros],
  769. [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros],
  770. [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],
  771. [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros],
  772. [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros],
  773. [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],
  774. [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros],
  775. [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros],
  776. [[-0.125, 0.125, 0.125], zeros, zeros, zeros],
  777. [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],
  778. [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],
  779. [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros],
  780. [[0.125, 0.125, 0.125], zeros, zeros, zeros],
  781. [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros],
  782. [[0.125, 0.125, 0.125], zeros, zeros, zeros],
  783. [[0.125, 0.125, 0.125], zeros, zeros, zeros],
  784. [zeros, zeros, zeros, zeros],
  785. ],
  786. dtype=torch.float32,
  787. device=device,
  788. )
  789. space = torch.as_tensor(
  790. [[[spacing[1] * spacing[2], spacing[0] * spacing[2], spacing[0] * spacing[1]]]],
  791. device=device,
  792. dtype=table.dtype,
  793. )
  794. norm = torch.linalg.norm(table * space, dim=-1)
  795. table = norm.sum(-1)
  796. kernel = torch.as_tensor([[[[[128, 64], [32, 16]], [[8, 4], [2, 1]]]]], device=device)
  797. return table, kernel