_deprecated.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from collections.abc import Collection
  2. from torch import Tensor
  3. from torchmetrics.functional.detection.panoptic_qualities import modified_panoptic_quality, panoptic_quality
  4. from torchmetrics.utilities.prints import _deprecated_root_import_func
  5. def _modified_panoptic_quality(
  6. preds: Tensor,
  7. target: Tensor,
  8. things: Collection[int],
  9. stuffs: Collection[int],
  10. allow_unknown_preds_category: bool = False,
  11. ) -> Tensor:
  12. """Wrapper for deprecated import.
  13. >>> from torch import tensor
  14. >>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]])
  15. >>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]])
  16. >>> _modified_panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7})
  17. tensor(0.7667, dtype=torch.float64)
  18. """
  19. _deprecated_root_import_func("modified_panoptic_quality", "detection")
  20. return modified_panoptic_quality(
  21. preds=preds,
  22. target=target,
  23. things=things,
  24. stuffs=stuffs,
  25. allow_unknown_preds_category=allow_unknown_preds_category,
  26. )
  27. def _panoptic_quality(
  28. preds: Tensor,
  29. target: Tensor,
  30. things: Collection[int],
  31. stuffs: Collection[int],
  32. allow_unknown_preds_category: bool = False,
  33. ) -> Tensor:
  34. """Wrapper for deprecated import.
  35. >>> from torch import tensor
  36. >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
  37. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  38. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  39. ... [[0, 0], [7, 0], [6, 0], [1, 0]],
  40. ... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
  41. >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
  42. ... [[0, 1], [0, 1], [6, 0], [0, 1]],
  43. ... [[0, 1], [0, 1], [6, 0], [1, 0]],
  44. ... [[0, 1], [7, 0], [1, 0], [1, 0]],
  45. ... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
  46. >>> _panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7})
  47. tensor(0.5463, dtype=torch.float64)
  48. """
  49. _deprecated_root_import_func("panoptic_quality", "detection")
  50. return panoptic_quality(
  51. preds=preds,
  52. target=target,
  53. things=things,
  54. stuffs=stuffs,
  55. allow_unknown_preds_category=allow_unknown_preds_category,
  56. )