| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- from typing import Any, Dict, List, Union
- from ray.rllib.utils.metrics.stats.base import StatsBase
- from ray.rllib.utils.metrics.stats.utils import single_value_to_cpu
- from ray.util.annotations import DeveloperAPI
- @DeveloperAPI
- class ItemStats(StatsBase):
- """A Stats object that tracks a single item.
- Note the follwing limitation: That, when calling `ItemStats.merge()`, we replace the current item.
- This is because there can only be a single item tracked by definition.
- This class will check if the logged item is a GPU tensor.
- If it is, it will be converted to CPU memory.
- Use this if you want to track a single item that should not be reduced.
- An example would be to log the total loss.
- """
- stats_cls_identifier = "item"
- def __init__(self, *args, **kwargs):
- """Initializes a ItemStats instance."""
- super().__init__(*args, **kwargs)
- self._item = None
- def get_state(self) -> Dict[str, Any]:
- state = super().get_state()
- state["item"] = self._item
- return state
- def set_state(self, state: Dict[str, Any]) -> None:
- super().set_state(state)
- self._item = state["item"]
- def __len__(self) -> int:
- return 1
- def reduce(self, compile: bool = True) -> Union[Any, "ItemStats"]:
- item = self._item
- self._item = None
- item = single_value_to_cpu(item)
- if compile:
- return item
- return_stats = self.clone()
- return_stats._item = item
- return return_stats
- def push(self, item: Any) -> None:
- """Pushes a value into this Stats object.
- Args:
- item: The value to push. Can be of any type.
- GPU tensors are moved to CPU memory.
- Returns:
- None
- """
- # Put directly onto CPU memory. peek(), reduce() and merge() don't handle GPU tensors.
- self._item = single_value_to_cpu(item)
- def merge(self, incoming_stats: List["ItemStats"]) -> None:
- """Merges ItemStats objects.
- Args:
- incoming_stats: The list of ItemStats objects to merge.
- Returns:
- None. The merge operation modifies self in place.
- """
- assert (
- len(incoming_stats) == 1
- ), "ItemStats should only be merged with one other ItemStats object which replaces the current item"
- self._item = incoming_stats[0]._item
- def peek(
- self, compile: bool = True, latest_merged_only: bool = False
- ) -> Union[Any, List[Any]]:
- """Returns the internal item.
- This does not alter the internal item.
- Args:
- compile: If True, return the internal item directly.
- If False, return the internal item as a single-element list.
- latest_merged_only: This parameter is ignored for ItemStats.
- ItemStats tracks a single item, not a series of merged values.
- The current item is always returned regardless of this parameter.
- Returns:
- The internal item.
- """
- # ItemStats doesn't support latest_merged_only since it tracks a single item
- # Just return the current item regardless
- item = single_value_to_cpu(self._item)
- if compile:
- return item
- return [item]
- def __repr__(self) -> str:
- return f"ItemStats({self.peek()})"
|