pointcloud_io.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import os
  18. import torch
  19. def save_pointcloud_ply(filename: str, pointcloud: torch.Tensor) -> None:
  20. r"""Save to disk a pointcloud in PLY format.
  21. Args:
  22. filename: the path to save the pointcloud.
  23. pointcloud: tensor containing the pointcloud to save.
  24. The tensor must be in the shape of :math:`(*, 3)` where the last
  25. component is assumed to be a 3d point coordinate :math:`(X, Y, Z)`.
  26. """
  27. if not (isinstance(filename, str) and filename.lower().endswith(".ply")):
  28. raise TypeError(f"Input filename must be a string with the .ply extension. Got {filename!r}")
  29. if not torch.is_tensor(pointcloud):
  30. raise TypeError(f"Input pointcloud type is not a torch.Tensor. Got {type(pointcloud)}")
  31. if pointcloud.ndim < 2 or pointcloud.shape[-1] != 3:
  32. raise TypeError(f"Input pointcloud must have shape (..., 3). Got {tuple(pointcloud.shape)}")
  33. # Flatten points
  34. xyz = pointcloud.reshape(-1, 3)
  35. valid_mask = torch.isfinite(xyz).any(dim=1)
  36. valid_points = xyz[valid_mask]
  37. valid_count = valid_points.shape[0]
  38. with open(filename, "w", encoding="utf-8") as f:
  39. # Write PLY header
  40. f.writelines(
  41. [
  42. "ply\n",
  43. "format ascii 1.0\n",
  44. "comment arraiy generated\n",
  45. f"element vertex {valid_count}\n",
  46. "property double x\n",
  47. "property double y\n",
  48. "property double z\n",
  49. "end_header\n",
  50. ]
  51. )
  52. if valid_count > 0:
  53. # Move to CPU, convert to float64 for matching 'double' in header
  54. arr = valid_points.detach().cpu().to(torch.float64)
  55. # Write each row as space-separated floats
  56. for x, y, z in arr.tolist():
  57. f.write(f"{x:.9g} {y:.9g} {z:.9g}\n")
  58. def load_pointcloud_ply(filename: str, header_size: int = 8) -> torch.Tensor:
  59. r"""Load from disk a pointcloud in PLY format.
  60. Args:
  61. filename: the path to the pointcloud.
  62. header_size: the number of header lines to skip.
  63. Return:
  64. tensor containing the loaded points with shape :math:`(*, 3)` where
  65. :math:`*` represents the number of points.
  66. """
  67. if not (isinstance(filename, str) and filename.lower().endswith(".ply")):
  68. raise TypeError(f"Input filename must be a string with the .ply extension. Got {filename!r}")
  69. if not os.path.isfile(filename):
  70. raise ValueError("Input filename is not an existing file.")
  71. if not (isinstance(header_size, int) and header_size > 0):
  72. raise TypeError(f"Input header_size must be a positive integer. Got {header_size}.")
  73. # Read all file bytes
  74. with open(filename, "rb") as f:
  75. # Skip header lines
  76. for _ in range(header_size):
  77. f.readline()
  78. raw_data = f.read()
  79. # Decode once and split (faster than line-by-line parsing in Python)
  80. text = raw_data.decode("utf-8", errors="ignore")
  81. parts = text.split()
  82. # We only take the first 3 columns per point
  83. if len(parts) % 3 != 0:
  84. raise ValueError(f"Expected 3 columns per point, got a total of {len(parts)} values.")
  85. # Convert directly to a float32 tensor in one go
  86. tensor = torch.tensor(list(map(float, parts[: (len(parts) // 3) * 3])), dtype=torch.float32).view(-1, 3)
  87. return tensor