utils.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from typing import Dict
  2. from ray.dag import (
  3. ClassMethodNode,
  4. ClassNode,
  5. DAGNode,
  6. FunctionNode,
  7. InputAttributeNode,
  8. InputNode,
  9. MultiOutputNode,
  10. )
  11. class _DAGNodeNameGenerator(object):
  12. """
  13. Generate unique suffix for each given Node in the DAG.
  14. Apply monotonic increasing id suffix for duplicated names.
  15. """
  16. def __init__(self):
  17. self.name_to_suffix: Dict[str, int] = dict()
  18. def get_node_name(self, node: DAGNode):
  19. # InputNode should be unique.
  20. if isinstance(node, InputNode):
  21. return "INPUT_NODE"
  22. if isinstance(node, MultiOutputNode):
  23. return "MultiOutputNode"
  24. # InputAttributeNode suffixes should match the user-defined key.
  25. elif isinstance(node, InputAttributeNode):
  26. return f"INPUT_ATTRIBUTE_NODE_{node._key}"
  27. # As class, method, and function nodes may have duplicated names,
  28. # generate unique suffixes for such nodes.
  29. if isinstance(node, ClassMethodNode):
  30. node_name = node.get_options().get("name", None) or node._method_name
  31. elif isinstance(node, (ClassNode, FunctionNode)):
  32. node_name = node.get_options().get("name", None) or node._body.__name__
  33. # we use instance class name check here to avoid importing ServeNodes as
  34. # serve components are not included in Ray Core.
  35. elif type(node).__name__ in ("DeploymentNode", "DeploymentFunctionNode"):
  36. node_name = node.get_deployment_name()
  37. elif type(node).__name__ == "DeploymentFunctionExecutorNode":
  38. node_name = node._deployment_function_handle.deployment_name
  39. else:
  40. raise ValueError(
  41. "get_node_name() should only be called on DAGNode instances."
  42. )
  43. if node_name not in self.name_to_suffix:
  44. self.name_to_suffix[node_name] = 0
  45. return node_name
  46. else:
  47. self.name_to_suffix[node_name] += 1
  48. suffix_num = self.name_to_suffix[node_name]
  49. return f"{node_name}_{suffix_num}"
  50. def reset(self):
  51. self.name_to_suffix = dict()
  52. def __enter__(self):
  53. return self
  54. def __exit__(self, *args):
  55. self.reset()