state_api_utils.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import dataclasses
  2. from dataclasses import asdict, fields
  3. from typing import Awaitable, Callable, List, Tuple
  4. import aiohttp.web
  5. from ray.dashboard.optional_utils import rest_response
  6. from ray.dashboard.utils import HTTPStatusCode
  7. from ray.util.state.common import (
  8. DEFAULT_LIMIT,
  9. DEFAULT_RPC_TIMEOUT,
  10. RAY_MAX_LIMIT_FROM_API_SERVER,
  11. ListApiOptions,
  12. ListApiResponse,
  13. PredicateType,
  14. StateSchema,
  15. SummaryApiOptions,
  16. SummaryApiResponse,
  17. SupportedFilterType,
  18. filter_fields,
  19. )
  20. from ray.util.state.exception import DataSourceUnavailable
  21. from ray.util.state.util import convert_string_to_type
  22. def do_reply(
  23. status_code: HTTPStatusCode, error_message: str, result: ListApiResponse, **kwargs
  24. ):
  25. return rest_response(
  26. status_code=status_code,
  27. message=error_message,
  28. result=result,
  29. convert_google_style=False,
  30. **kwargs,
  31. )
  32. async def handle_list_api(
  33. list_api_fn: Callable[[ListApiOptions], Awaitable[ListApiResponse]],
  34. req: aiohttp.web.Request,
  35. ):
  36. try:
  37. result = await list_api_fn(option=options_from_req(req))
  38. return do_reply(
  39. status_code=HTTPStatusCode.OK,
  40. error_message="",
  41. result=asdict(result),
  42. )
  43. except ValueError as e:
  44. return do_reply(
  45. status_code=HTTPStatusCode.BAD_REQUEST,
  46. error_message=str(e),
  47. result=None,
  48. )
  49. except DataSourceUnavailable as e:
  50. return do_reply(
  51. status_code=HTTPStatusCode.INTERNAL_ERROR,
  52. error_message=str(e),
  53. result=None,
  54. )
  55. def _get_filters_from_req(
  56. req: aiohttp.web.Request,
  57. ) -> List[Tuple[str, PredicateType, SupportedFilterType]]:
  58. filter_keys = req.query.getall("filter_keys", [])
  59. filter_predicates = req.query.getall("filter_predicates", [])
  60. filter_values = req.query.getall("filter_values", [])
  61. assert len(filter_keys) == len(filter_values)
  62. filters = []
  63. for key, predicate, val in zip(filter_keys, filter_predicates, filter_values):
  64. filters.append((key, predicate, val))
  65. return filters
  66. def options_from_req(req: aiohttp.web.Request) -> ListApiOptions:
  67. """Obtain `ListApiOptions` from the aiohttp request."""
  68. limit = int(
  69. req.query.get("limit") if req.query.get("limit") is not None else DEFAULT_LIMIT
  70. )
  71. if limit > RAY_MAX_LIMIT_FROM_API_SERVER:
  72. raise ValueError(
  73. f"Given limit {limit} exceeds the supported "
  74. f"Given limit {limit} exceeds the supported "
  75. f"limit {RAY_MAX_LIMIT_FROM_API_SERVER}. Use a lower limit, or set the "
  76. f"`RAY_MAX_LIMIT_FROM_API_SERVER` environment variable to a larger value."
  77. )
  78. timeout = int(req.query.get("timeout", 30))
  79. filters = _get_filters_from_req(req)
  80. detail = convert_string_to_type(req.query.get("detail", False), bool)
  81. exclude_driver = convert_string_to_type(req.query.get("exclude_driver", True), bool)
  82. return ListApiOptions(
  83. limit=limit,
  84. timeout=timeout,
  85. filters=filters,
  86. detail=detail,
  87. exclude_driver=exclude_driver,
  88. )
  89. def summary_options_from_req(req: aiohttp.web.Request) -> SummaryApiOptions:
  90. timeout = int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT))
  91. filters = _get_filters_from_req(req)
  92. summary_by = req.query.get("summary_by", None)
  93. return SummaryApiOptions(timeout=timeout, filters=filters, summary_by=summary_by)
  94. async def handle_summary_api(
  95. summary_fn: Callable[[SummaryApiOptions], SummaryApiResponse],
  96. req: aiohttp.web.Request,
  97. ):
  98. result = await summary_fn(option=summary_options_from_req(req))
  99. return do_reply(
  100. status_code=HTTPStatusCode.OK,
  101. error_message="",
  102. result=asdict(result),
  103. )
  104. def convert_filters_type(
  105. filter: List[Tuple[str, PredicateType, SupportedFilterType]],
  106. schema: StateSchema,
  107. ) -> List[Tuple[str, PredicateType, SupportedFilterType]]:
  108. """Convert the given filter's type to SupportedFilterType.
  109. This method is necessary because click can only accept a single type
  110. for its tuple (which is string in this case).
  111. Args:
  112. filter: A list of filter which is a tuple of (key, val).
  113. schema: The state schema. It is used to infer the type of the column for filter.
  114. Returns:
  115. A new list of filters with correct types that match the schema.
  116. """
  117. new_filter = []
  118. if dataclasses.is_dataclass(schema):
  119. schema = {field.name: field.type for field in fields(schema)}
  120. else:
  121. schema = schema.schema_dict()
  122. for col, predicate, val in filter:
  123. if col in schema:
  124. column_type = schema[col]
  125. try:
  126. isinstance(val, column_type)
  127. except TypeError:
  128. # Calling `isinstance` to the Literal type raises a TypeError.
  129. # Ignore this case.
  130. pass
  131. else:
  132. if isinstance(val, column_type):
  133. # Do nothing.
  134. pass
  135. elif column_type is int or column_type == "integer":
  136. try:
  137. val = convert_string_to_type(val, int)
  138. except ValueError:
  139. raise ValueError(
  140. f"Invalid filter `--filter {col} {val}` for a int type "
  141. "column. Please provide an integer filter "
  142. f"`--filter {col} [int]`"
  143. )
  144. elif column_type is float or column_type == "number":
  145. try:
  146. val = convert_string_to_type(
  147. val,
  148. float,
  149. )
  150. except ValueError:
  151. raise ValueError(
  152. f"Invalid filter `--filter {col} {val}` for a float "
  153. "type column. Please provide an integer filter "
  154. f"`--filter {col} [float]`"
  155. )
  156. elif column_type is bool or column_type == "boolean":
  157. try:
  158. val = convert_string_to_type(val, bool)
  159. except ValueError:
  160. raise ValueError(
  161. f"Invalid filter `--filter {col} {val}` for a boolean "
  162. "type column. Please provide "
  163. f"`--filter {col} [True|true|1]` for True or "
  164. f"`--filter {col} [False|false|0]` for False."
  165. )
  166. new_filter.append((col, predicate, val))
  167. return new_filter
  168. def do_filter(
  169. data: List[dict],
  170. filters: List[Tuple[str, PredicateType, SupportedFilterType]],
  171. state_dataclass: StateSchema,
  172. detail: bool,
  173. ) -> List[dict]:
  174. """Return the filtered data given filters.
  175. Args:
  176. data: A list of state data.
  177. filters: A list of KV tuple to filter data (key, val). The data is filtered
  178. if data[key] != val.
  179. state_dataclass: The state schema.
  180. Returns:
  181. A list of filtered state data in dictionary. Each state data's
  182. unnecessary columns are filtered by the given state_dataclass schema.
  183. """
  184. filters = convert_filters_type(filters, state_dataclass)
  185. result = []
  186. for datum in data:
  187. match = True
  188. for filter_column, filter_predicate, filter_value in filters:
  189. filterable_columns = state_dataclass.filterable_columns()
  190. filter_column = filter_column.lower()
  191. if filter_column not in filterable_columns:
  192. raise ValueError(
  193. f"The given filter column {filter_column} is not supported. "
  194. "Enter filters with –-filter key=value "
  195. "or –-filter key!=value "
  196. f"Supported filter columns: {filterable_columns}"
  197. )
  198. if filter_column not in datum:
  199. match = False
  200. elif filter_predicate == "=":
  201. if isinstance(filter_value, str) and isinstance(
  202. datum[filter_column], str
  203. ):
  204. # Case insensitive match for string filter values.
  205. match = datum[filter_column].lower() == filter_value.lower()
  206. elif isinstance(filter_value, str) and isinstance(
  207. datum[filter_column], bool
  208. ):
  209. match = datum[filter_column] == convert_string_to_type(
  210. filter_value, bool
  211. )
  212. elif isinstance(filter_value, str) and isinstance(
  213. datum[filter_column], int
  214. ):
  215. match = datum[filter_column] == convert_string_to_type(
  216. filter_value, int
  217. )
  218. else:
  219. match = datum[filter_column] == filter_value
  220. elif filter_predicate == "!=":
  221. if isinstance(filter_value, str) and isinstance(
  222. datum[filter_column], str
  223. ):
  224. match = datum[filter_column].lower() != filter_value.lower()
  225. else:
  226. match = datum[filter_column] != filter_value
  227. else:
  228. raise ValueError(
  229. f"Unsupported filter predicate {filter_predicate} is given. "
  230. "Available predicates: =, !=."
  231. )
  232. if not match:
  233. break
  234. if match:
  235. result.append(filter_fields(datum, state_dataclass, detail))
  236. return result