signature.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import inspect
  2. import logging
  3. from inspect import Parameter
  4. from typing import Any, Dict, List, Tuple
  5. from ray._private.inspect_util import is_cython
  6. # Logger for this module. It should be configured at the entry point
  7. # into the program using Ray. Ray provides a default configuration at
  8. # entry/init points.
  9. logger = logging.getLogger(__name__)
  10. # This dummy type is also defined in ArgumentsBuilder.java. Please keep it
  11. # synced.
  12. DUMMY_TYPE = b"__RAY_DUMMY__"
  13. def get_signature(func: Any) -> inspect.Signature:
  14. """Get signature parameters.
  15. Support Cython functions by grabbing relevant attributes from the Cython
  16. function and attaching to a no-op function. This is somewhat brittle, since
  17. inspect may change, but given that inspect is written to a PEP, we hope
  18. it is relatively stable. Future versions of Python may allow overloading
  19. the inspect 'isfunction' and 'ismethod' functions / create ABC for Python
  20. functions. Until then, it appears that Cython won't do anything about
  21. compatability with the inspect module.
  22. Args:
  23. func: The function whose signature should be checked.
  24. Returns:
  25. A function signature object, which includes the names of the keyword
  26. arguments as well as their default values.
  27. Raises:
  28. TypeError: A type error if the signature is not supported
  29. """
  30. # The first condition for Cython functions, the latter for Cython instance
  31. # methods
  32. if is_cython(func):
  33. attrs = ["__code__", "__annotations__", "__defaults__", "__kwdefaults__"]
  34. if all(hasattr(func, attr) for attr in attrs):
  35. original_func = func
  36. def func():
  37. return
  38. for attr in attrs:
  39. setattr(func, attr, getattr(original_func, attr))
  40. else:
  41. raise TypeError(f"{func!r} is not a Python function we can process")
  42. return inspect.signature(func)
  43. def extract_signature(func: Any, ignore_first: bool = False) -> List[Parameter]:
  44. """Extract the function signature from the function.
  45. Args:
  46. func: The function whose signature should be extracted.
  47. ignore_first: True if the first argument should be ignored. This should
  48. be used when func is a method of a class.
  49. Returns:
  50. List of Parameter objects representing the function signature.
  51. """
  52. signature_parameters = list(get_signature(func).parameters.values())
  53. if ignore_first:
  54. if len(signature_parameters) == 0:
  55. raise ValueError(
  56. "Methods must take a 'self' argument, but the "
  57. f"method '{func.__name__}' does not have one."
  58. )
  59. signature_parameters = signature_parameters[1:]
  60. return signature_parameters
  61. def validate_args(
  62. signature_parameters: List[Parameter], args: Tuple[Any, ...], kwargs: Dict[str, Any]
  63. ) -> None:
  64. """Validates the arguments against the signature.
  65. Args:
  66. signature_parameters: The list of Parameter objects
  67. representing the function signature, obtained from
  68. `extract_signature`.
  69. args: The positional arguments passed into the function.
  70. kwargs: The keyword arguments passed into the function.
  71. Raises:
  72. TypeError: Raised if arguments do not fit in the function signature.
  73. """
  74. reconstructed_signature = inspect.Signature(parameters=signature_parameters)
  75. try:
  76. reconstructed_signature.bind(*args, **kwargs)
  77. except TypeError as exc: # capture a friendlier stacktrace
  78. raise TypeError(str(exc)) from None
  79. def flatten_args(
  80. signature_parameters: List[Parameter], args: Tuple[Any, ...], kwargs: Dict[str, Any]
  81. ) -> List[Any]:
  82. """Validates the arguments against the signature and flattens them.
  83. The flat list representation is a serializable format for arguments.
  84. Since the flatbuffer representation of function arguments is a list, we
  85. combine both keyword arguments and positional arguments. We represent
  86. this with two entries per argument value - [DUMMY_TYPE, x] for positional
  87. arguments and [KEY, VALUE] for keyword arguments. See the below example.
  88. See `recover_args` for logic restoring the flat list back to args/kwargs.
  89. Args:
  90. signature_parameters: The list of Parameter objects
  91. representing the function signature, obtained from
  92. `extract_signature`.
  93. args: The positional arguments passed into the function.
  94. kwargs: The keyword arguments passed into the function.
  95. Returns:
  96. List of args and kwargs. Non-keyword arguments are prefixed
  97. by internal enum DUMMY_TYPE.
  98. Raises:
  99. TypeError: Raised if arguments do not fit in the function signature.
  100. """
  101. validate_args(signature_parameters, args, kwargs)
  102. list_args = []
  103. for arg in args:
  104. list_args += [DUMMY_TYPE, arg]
  105. for keyword, arg in kwargs.items():
  106. list_args += [keyword, arg]
  107. return list_args
  108. def recover_args(flattened_args: List[Any]) -> Tuple[List[Any], Dict[str, Any]]:
  109. """Recreates `args` and `kwargs` from the flattened arg list.
  110. Args:
  111. flattened_args: List of args and kwargs. This should be the output of
  112. `flatten_args`.
  113. Returns:
  114. args: The non-keyword arguments passed into the function.
  115. kwargs: The keyword arguments passed into the function.
  116. """
  117. assert (
  118. len(flattened_args) % 2 == 0
  119. ), "Flattened arguments need to be even-numbered. See `flatten_args`."
  120. args = []
  121. kwargs = {}
  122. for name_index in range(0, len(flattened_args), 2):
  123. name, arg = flattened_args[name_index], flattened_args[name_index + 1]
  124. if name == DUMMY_TYPE:
  125. args.append(arg)
  126. else:
  127. kwargs[name] = arg
  128. return args, kwargs