item.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from typing import Any, Dict, List, Union
  2. from ray.rllib.utils.metrics.stats.base import StatsBase
  3. from ray.rllib.utils.metrics.stats.utils import single_value_to_cpu
  4. from ray.util.annotations import DeveloperAPI
  5. @DeveloperAPI
  6. class ItemStats(StatsBase):
  7. """A Stats object that tracks a single item.
  8. Note the follwing limitation: That, when calling `ItemStats.merge()`, we replace the current item.
  9. This is because there can only be a single item tracked by definition.
  10. This class will check if the logged item is a GPU tensor.
  11. If it is, it will be converted to CPU memory.
  12. Use this if you want to track a single item that should not be reduced.
  13. An example would be to log the total loss.
  14. """
  15. stats_cls_identifier = "item"
  16. def __init__(self, *args, **kwargs):
  17. """Initializes a ItemStats instance."""
  18. super().__init__(*args, **kwargs)
  19. self._item = None
  20. def get_state(self) -> Dict[str, Any]:
  21. state = super().get_state()
  22. state["item"] = self._item
  23. return state
  24. def set_state(self, state: Dict[str, Any]) -> None:
  25. super().set_state(state)
  26. self._item = state["item"]
  27. def __len__(self) -> int:
  28. return 1
  29. def reduce(self, compile: bool = True) -> Union[Any, "ItemStats"]:
  30. item = self._item
  31. self._item = None
  32. item = single_value_to_cpu(item)
  33. if compile:
  34. return item
  35. return_stats = self.clone()
  36. return_stats._item = item
  37. return return_stats
  38. def push(self, item: Any) -> None:
  39. """Pushes a value into this Stats object.
  40. Args:
  41. item: The value to push. Can be of any type.
  42. GPU tensors are moved to CPU memory.
  43. Returns:
  44. None
  45. """
  46. # Put directly onto CPU memory. peek(), reduce() and merge() don't handle GPU tensors.
  47. self._item = single_value_to_cpu(item)
  48. def merge(self, incoming_stats: List["ItemStats"]) -> None:
  49. """Merges ItemStats objects.
  50. Args:
  51. incoming_stats: The list of ItemStats objects to merge.
  52. Returns:
  53. None. The merge operation modifies self in place.
  54. """
  55. assert (
  56. len(incoming_stats) == 1
  57. ), "ItemStats should only be merged with one other ItemStats object which replaces the current item"
  58. self._item = incoming_stats[0]._item
  59. def peek(
  60. self, compile: bool = True, latest_merged_only: bool = False
  61. ) -> Union[Any, List[Any]]:
  62. """Returns the internal item.
  63. This does not alter the internal item.
  64. Args:
  65. compile: If True, return the internal item directly.
  66. If False, return the internal item as a single-element list.
  67. latest_merged_only: This parameter is ignored for ItemStats.
  68. ItemStats tracks a single item, not a series of merged values.
  69. The current item is always returned regardless of this parameter.
  70. Returns:
  71. The internal item.
  72. """
  73. # ItemStats doesn't support latest_merged_only since it tracks a single item
  74. # Just return the current item regardless
  75. item = single_value_to_cpu(self._item)
  76. if compile:
  77. return item
  78. return [item]
  79. def __repr__(self) -> str:
  80. return f"ItemStats({self.peek()})"