tf_run_builder.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import logging
  2. import os
  3. import time
  4. from ray.rllib.utils.annotations import OldAPIStack
  5. from ray.rllib.utils.framework import try_import_tf
  6. from ray.util.debug import log_once
  7. tf1, tf, tfv = try_import_tf()
  8. logger = logging.getLogger(__name__)
  9. @OldAPIStack
  10. class _TFRunBuilder:
  11. """Used to incrementally build up a TensorFlow run.
  12. This is particularly useful for batching ops from multiple different
  13. policies in the multi-agent setting.
  14. """
  15. def __init__(self, session, debug_name):
  16. self.session = session
  17. self.debug_name = debug_name
  18. self.feed_dict = {}
  19. self.fetches = []
  20. self._executed = None
  21. def add_feed_dict(self, feed_dict):
  22. assert not self._executed
  23. for k in feed_dict:
  24. if k in self.feed_dict:
  25. raise ValueError("Key added twice: {}".format(k))
  26. self.feed_dict.update(feed_dict)
  27. def add_fetches(self, fetches):
  28. assert not self._executed
  29. base_index = len(self.fetches)
  30. self.fetches.extend(fetches)
  31. return list(range(base_index, len(self.fetches)))
  32. def get(self, to_fetch):
  33. if self._executed is None:
  34. try:
  35. self._executed = _run_timeline(
  36. self.session,
  37. self.fetches,
  38. self.debug_name,
  39. self.feed_dict,
  40. os.environ.get("TF_TIMELINE_DIR"),
  41. )
  42. except Exception as e:
  43. logger.exception(
  44. "Error fetching: {}, feed_dict={}".format(
  45. self.fetches, self.feed_dict
  46. )
  47. )
  48. raise e
  49. if isinstance(to_fetch, int):
  50. return self._executed[to_fetch]
  51. elif isinstance(to_fetch, list):
  52. return [self.get(x) for x in to_fetch]
  53. elif isinstance(to_fetch, tuple):
  54. return tuple(self.get(x) for x in to_fetch)
  55. else:
  56. raise ValueError("Unsupported fetch type: {}".format(to_fetch))
  57. _count = 0
  58. def _run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None):
  59. if feed_dict is None:
  60. feed_dict = {}
  61. if timeline_dir:
  62. from tensorflow.python.client import timeline
  63. try:
  64. run_options = tf1.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
  65. except AttributeError:
  66. run_options = None
  67. # In local mode, tf1.RunOptions is not available, see #26511
  68. if log_once("tf1.RunOptions_not_available"):
  69. logger.exception(
  70. "Can not access tf.RunOptions.FULL_TRACE. This may be because "
  71. "you have used `ray.init(local_mode=True)`. RLlib will use "
  72. "timeline without `options=tf.RunOptions.FULL_TRACE`."
  73. )
  74. run_metadata = tf1.RunMetadata()
  75. start = time.time()
  76. fetches = sess.run(
  77. ops, options=run_options, run_metadata=run_metadata, feed_dict=feed_dict
  78. )
  79. trace = timeline.Timeline(step_stats=run_metadata.step_stats)
  80. global _count
  81. outf = os.path.join(
  82. timeline_dir,
  83. "timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count % 10),
  84. )
  85. _count += 1
  86. trace_file = open(outf, "w")
  87. logger.info(
  88. "Wrote tf timeline ({} s) to {}".format(
  89. time.time() - start, os.path.abspath(outf)
  90. )
  91. )
  92. trace_file.write(trace.generate_chrome_trace_format())
  93. else:
  94. if log_once("tf_timeline"):
  95. logger.info(
  96. "Executing TF run without tracing. To dump TF timeline traces "
  97. "to disk, set the TF_TIMELINE_DIR environment variable."
  98. )
  99. fetches = sess.run(ops, feed_dict=feed_dict)
  100. return fetches