_deprecated.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from collections.abc import Collection
  2. from typing import Any
  3. from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality
  4. from torchmetrics.utilities.prints import _deprecated_root_import_class
  5. class _ModifiedPanopticQuality(ModifiedPanopticQuality):
  6. """Wrapper for deprecated import.
  7. >>> from torch import tensor
  8. >>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]])
  9. >>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]])
  10. >>> pq_modified = _ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
  11. >>> pq_modified(preds, target)
  12. tensor(0.7667, dtype=torch.float64)
  13. """
  14. def __init__(
  15. self,
  16. things: Collection[int],
  17. stuffs: Collection[int],
  18. allow_unknown_preds_category: bool = False,
  19. **kwargs: Any,
  20. ) -> None:
  21. _deprecated_root_import_class("ModifiedPanopticQuality", "detection")
  22. super().__init__(
  23. things=things, stuffs=stuffs, allow_unknown_preds_category=allow_unknown_preds_category, **kwargs
  24. )
  25. class _PanopticQuality(PanopticQuality):
  26. """Wrapper for deprecated import.
  27. >>> from torch import tensor
  28. >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
  29. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  30. ... [[0, 0], [0, 0], [6, 0], [0, 1]],
  31. ... [[0, 0], [7, 0], [6, 0], [1, 0]],
  32. ... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
  33. >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
  34. ... [[0, 1], [0, 1], [6, 0], [0, 1]],
  35. ... [[0, 1], [0, 1], [6, 0], [1, 0]],
  36. ... [[0, 1], [7, 0], [1, 0], [1, 0]],
  37. ... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
  38. >>> panoptic_quality = _PanopticQuality(things = {0, 1}, stuffs = {6, 7})
  39. >>> panoptic_quality(preds, target)
  40. tensor(0.5463, dtype=torch.float64)
  41. """
  42. def __init__(
  43. self,
  44. things: Collection[int],
  45. stuffs: Collection[int],
  46. allow_unknown_preds_category: bool = False,
  47. **kwargs: Any,
  48. ) -> None:
  49. _deprecated_root_import_class("PanopticQuality", "detection")
  50. super().__init__(
  51. things=things, stuffs=stuffs, allow_unknown_preds_category=allow_unknown_preds_category, **kwargs
  52. )