input_reader.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import logging
  2. import threading
  3. from abc import ABCMeta, abstractmethod
  4. from typing import Dict, List
  5. import numpy as np
  6. from ray.rllib.policy.sample_batch import MultiAgentBatch
  7. from ray.rllib.utils.annotations import PublicAPI
  8. from ray.rllib.utils.framework import try_import_tf
  9. from ray.rllib.utils.typing import SampleBatchType, TensorType
  10. tf1, tf, tfv = try_import_tf()
  11. logger = logging.getLogger(__name__)
  12. @PublicAPI
  13. class InputReader(metaclass=ABCMeta):
  14. """API for collecting and returning experiences during policy evaluation."""
  15. @abstractmethod
  16. @PublicAPI
  17. def next(self) -> SampleBatchType:
  18. """Returns the next batch of read experiences.
  19. Returns:
  20. The experience read (SampleBatch or MultiAgentBatch).
  21. """
  22. raise NotImplementedError
  23. @PublicAPI
  24. def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]:
  25. """Returns TensorFlow queue ops for reading inputs from this reader.
  26. The main use of these ops is for integration into custom model losses.
  27. For example, you can use tf_input_ops() to read from files of external
  28. experiences to add an imitation learning loss to your model.
  29. This method creates a queue runner thread that will call next() on this
  30. reader repeatedly to feed the TensorFlow queue.
  31. Args:
  32. queue_size: Max elements to allow in the TF queue.
  33. .. testcode::
  34. :skipif: True
  35. from ray.rllib.models.modelv2 import ModelV2
  36. from ray.rllib.offline.json_reader import JsonReader
  37. imitation_loss = ...
  38. class MyModel(ModelV2):
  39. def custom_loss(self, policy_loss, loss_inputs):
  40. reader = JsonReader(...)
  41. input_ops = reader.tf_input_ops()
  42. logits, _ = self._build_layers_v2(
  43. {"obs": input_ops["obs"]},
  44. self.num_outputs, self.options)
  45. il_loss = imitation_loss(logits, input_ops["action"])
  46. return policy_loss + il_loss
  47. You can find a runnable version of this in examples/custom_loss.py.
  48. Returns:
  49. Dict of Tensors, one for each column of the read SampleBatch.
  50. """
  51. if hasattr(self, "_queue_runner"):
  52. raise ValueError(
  53. "A queue runner already exists for this input reader. "
  54. "You can only call tf_input_ops() once per reader."
  55. )
  56. logger.info("Reading initial batch of data from input reader.")
  57. batch = self.next()
  58. if isinstance(batch, MultiAgentBatch):
  59. raise NotImplementedError(
  60. "tf_input_ops() is not implemented for multi agent batches"
  61. )
  62. # Note on casting to `np.array(batch[k])`: In order to get all keys that
  63. # are numbers, we need to convert to numpy everything that is not a numpy array.
  64. # This is because SampleBatches used to only hold numpy arrays, but since our
  65. # RNN efforts under RLModules, we also allow lists.
  66. keys = [
  67. k
  68. for k in sorted(batch.keys())
  69. if np.issubdtype(np.array(batch[k]).dtype, np.number)
  70. ]
  71. dtypes = [batch[k].dtype for k in keys]
  72. shapes = {k: (-1,) + s[1:] for (k, s) in [(k, batch[k].shape) for k in keys]}
  73. queue = tf1.FIFOQueue(capacity=queue_size, dtypes=dtypes, names=keys)
  74. tensors = queue.dequeue()
  75. logger.info("Creating TF queue runner for {}".format(self))
  76. self._queue_runner = _QueueRunner(self, queue, keys, dtypes)
  77. self._queue_runner.enqueue(batch)
  78. self._queue_runner.start()
  79. out = {k: tf.reshape(t, shapes[k]) for k, t in tensors.items()}
  80. return out
  81. class _QueueRunner(threading.Thread):
  82. """Thread that feeds a TF queue from a InputReader."""
  83. def __init__(
  84. self,
  85. input_reader: InputReader,
  86. queue: "tf1.FIFOQueue",
  87. keys: List[str],
  88. dtypes: "tf.dtypes.DType",
  89. ):
  90. threading.Thread.__init__(self)
  91. self.sess = tf1.get_default_session()
  92. self.daemon = True
  93. self.input_reader = input_reader
  94. self.keys = keys
  95. self.queue = queue
  96. self.placeholders = [tf1.placeholder(dtype) for dtype in dtypes]
  97. self.enqueue_op = queue.enqueue(dict(zip(keys, self.placeholders)))
  98. def enqueue(self, batch: SampleBatchType):
  99. data = {self.placeholders[i]: batch[key] for i, key in enumerate(self.keys)}
  100. self.sess.run(self.enqueue_op, feed_dict=data)
  101. def run(self):
  102. while True:
  103. try:
  104. batch = self.input_reader.next()
  105. self.enqueue(batch)
  106. except Exception:
  107. logger.exception("Error reading from input")