summary.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import pprint
  2. from typing import Any
  3. import numpy as np
  4. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
  5. from ray.rllib.utils.annotations import DeveloperAPI
  6. _printer = pprint.PrettyPrinter(indent=2, width=60)
  7. @DeveloperAPI
  8. def summarize(obj: Any) -> Any:
  9. """Return a pretty-formatted string for an object.
  10. This has special handling for pretty-formatting of commonly used data types
  11. in RLlib, such as SampleBatch, numpy arrays, etc.
  12. Args:
  13. obj: The object to format.
  14. Returns:
  15. The summarized object.
  16. """
  17. return _printer.pformat(_summarize(obj))
  18. def _summarize(obj):
  19. if isinstance(obj, dict):
  20. return {k: _summarize(v) for k, v in obj.items()}
  21. elif hasattr(obj, "_asdict"):
  22. return {
  23. "type": obj.__class__.__name__,
  24. "data": _summarize(obj._asdict()),
  25. }
  26. elif isinstance(obj, list):
  27. return [_summarize(x) for x in obj]
  28. elif isinstance(obj, tuple):
  29. return tuple(_summarize(x) for x in obj)
  30. elif isinstance(obj, np.ndarray):
  31. if obj.size == 0:
  32. return _StringValue("np.ndarray({}, dtype={})".format(obj.shape, obj.dtype))
  33. elif obj.dtype == object or obj.dtype.type is np.str_:
  34. return _StringValue(
  35. "np.ndarray({}, dtype={}, head={})".format(
  36. obj.shape, obj.dtype, _summarize(obj[0])
  37. )
  38. )
  39. else:
  40. return _StringValue(
  41. "np.ndarray({}, dtype={}, min={}, max={}, mean={})".format(
  42. obj.shape,
  43. obj.dtype,
  44. round(float(np.min(obj)), 3),
  45. round(float(np.max(obj)), 3),
  46. round(float(np.mean(obj)), 3),
  47. )
  48. )
  49. elif isinstance(obj, MultiAgentBatch):
  50. return {
  51. "type": "MultiAgentBatch",
  52. "policy_batches": _summarize(obj.policy_batches),
  53. "count": obj.count,
  54. }
  55. elif isinstance(obj, SampleBatch):
  56. return {
  57. "type": "SampleBatch",
  58. "data": {k: _summarize(v) for k, v in obj.items()},
  59. }
  60. else:
  61. return obj
  62. class _StringValue:
  63. def __init__(self, value):
  64. self.value = value
  65. def __repr__(self):
  66. return self.value