| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- from typing import Dict
- from ray.dag import (
- ClassMethodNode,
- ClassNode,
- DAGNode,
- FunctionNode,
- InputAttributeNode,
- InputNode,
- MultiOutputNode,
- )
- class _DAGNodeNameGenerator(object):
- """
- Generate unique suffix for each given Node in the DAG.
- Apply monotonic increasing id suffix for duplicated names.
- """
- def __init__(self):
- self.name_to_suffix: Dict[str, int] = dict()
- def get_node_name(self, node: DAGNode):
- # InputNode should be unique.
- if isinstance(node, InputNode):
- return "INPUT_NODE"
- if isinstance(node, MultiOutputNode):
- return "MultiOutputNode"
- # InputAttributeNode suffixes should match the user-defined key.
- elif isinstance(node, InputAttributeNode):
- return f"INPUT_ATTRIBUTE_NODE_{node._key}"
- # As class, method, and function nodes may have duplicated names,
- # generate unique suffixes for such nodes.
- if isinstance(node, ClassMethodNode):
- node_name = node.get_options().get("name", None) or node._method_name
- elif isinstance(node, (ClassNode, FunctionNode)):
- node_name = node.get_options().get("name", None) or node._body.__name__
- # we use instance class name check here to avoid importing ServeNodes as
- # serve components are not included in Ray Core.
- elif type(node).__name__ in ("DeploymentNode", "DeploymentFunctionNode"):
- node_name = node.get_deployment_name()
- elif type(node).__name__ == "DeploymentFunctionExecutorNode":
- node_name = node._deployment_function_handle.deployment_name
- else:
- raise ValueError(
- "get_node_name() should only be called on DAGNode instances."
- )
- if node_name not in self.name_to_suffix:
- self.name_to_suffix[node_name] = 0
- return node_name
- else:
- self.name_to_suffix[node_name] += 1
- suffix_num = self.name_to_suffix[node_name]
- return f"{node_name}_{suffix_num}"
- def reset(self):
- self.name_to_suffix = dict()
- def __enter__(self):
- return self
- def __exit__(self, *args):
- self.reset()
|