routes.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import abc
  2. import collections
  3. import functools
  4. import inspect
  5. import json
  6. import logging
  7. import os
  8. import traceback
  9. from typing import Any
  10. from ray.dashboard.optional_deps import PathLike, RouteDef, aiohttp, hdrs
  11. from ray.dashboard.utils import CustomEncoder, HTTPStatusCode, to_google_style
  12. logger = logging.getLogger(__name__)
  13. class BaseRouteTable(abc.ABC):
  14. """A base class to bind http route to a target instance. Subclass should implement
  15. the _register_route method. It should define how the handler interacts with
  16. _BindInfo.instance.
  17. Subclasses must declare their own _bind_map and _routes properties to avoid
  18. conflicts.
  19. """
  20. class _BindInfo:
  21. def __init__(self, filename, lineno, instance):
  22. self.filename = filename
  23. self.lineno = lineno
  24. self.instance = instance
  25. @classmethod
  26. @property
  27. @abc.abstractmethod
  28. def _bind_map(cls):
  29. pass
  30. @classmethod
  31. @property
  32. @abc.abstractmethod
  33. def _routes(cls):
  34. pass
  35. @classmethod
  36. @abc.abstractmethod
  37. def _register_route(cls, method, path, **kwargs):
  38. pass
  39. @classmethod
  40. @abc.abstractmethod
  41. def bind(cls, instance):
  42. pass
  43. @classmethod
  44. def routes(cls):
  45. return cls._routes
  46. @classmethod
  47. def bound_routes(cls):
  48. bound_items = []
  49. for r in cls._routes._items:
  50. if isinstance(r, RouteDef):
  51. route_method = r.handler.__route_method__
  52. route_path = r.handler.__route_path__
  53. instance = cls._bind_map[route_method][route_path].instance
  54. if instance is not None:
  55. bound_items.append(r)
  56. else:
  57. bound_items.append(r)
  58. routes = aiohttp.web.RouteTableDef()
  59. routes._items = bound_items
  60. return routes
  61. @classmethod
  62. def head(cls, path, **kwargs):
  63. return cls._register_route(hdrs.METH_HEAD, path, **kwargs)
  64. @classmethod
  65. def get(cls, path, **kwargs):
  66. return cls._register_route(hdrs.METH_GET, path, **kwargs)
  67. @classmethod
  68. def post(cls, path, **kwargs):
  69. return cls._register_route(hdrs.METH_POST, path, **kwargs)
  70. @classmethod
  71. def put(cls, path, **kwargs):
  72. return cls._register_route(hdrs.METH_PUT, path, **kwargs)
  73. @classmethod
  74. def patch(cls, path, **kwargs):
  75. return cls._register_route(hdrs.METH_PATCH, path, **kwargs)
  76. @classmethod
  77. def delete(cls, path, **kwargs):
  78. return cls._register_route(hdrs.METH_DELETE, path, **kwargs)
  79. @classmethod
  80. def view(cls, path, **kwargs):
  81. return cls._register_route(hdrs.METH_ANY, path, **kwargs)
  82. @classmethod
  83. def static(cls, prefix: str, path: PathLike, **kwargs: Any) -> None:
  84. cls._routes.static(prefix, path, **kwargs)
  85. def method_route_table_factory():
  86. """
  87. Return a method-based route table class, for in-process HeadModule objects.
  88. """
  89. class MethodRouteTable(BaseRouteTable):
  90. """A helper class to bind http route to class method. Each _BindInfo.instance
  91. is a class instance, and for an inbound request, we invoke the async handler
  92. method."""
  93. _bind_map = collections.defaultdict(dict)
  94. _routes = aiohttp.web.RouteTableDef()
  95. @classmethod
  96. def _register_route(cls, method, path, **kwargs):
  97. def _wrapper(handler):
  98. if path in cls._bind_map[method]:
  99. bind_info = cls._bind_map[method][path]
  100. raise Exception(
  101. f"Duplicated route path: {path}, "
  102. f"previous one registered at "
  103. f"{bind_info.filename}:{bind_info.lineno}"
  104. )
  105. bind_info = cls._BindInfo(
  106. handler.__code__.co_filename, handler.__code__.co_firstlineno, None
  107. )
  108. @functools.wraps(handler)
  109. async def _handler_route(*args) -> aiohttp.web.Response:
  110. try:
  111. # Make the route handler as a bound method.
  112. # The args may be:
  113. # * (Request, )
  114. # * (self, Request)
  115. req = args[-1]
  116. return await handler(bind_info.instance, req)
  117. except Exception:
  118. logger.exception("Handle %s %s failed.", method, path)
  119. return rest_response(
  120. status_code=HTTPStatusCode.INTERNAL_ERROR,
  121. message=traceback.format_exc(),
  122. )
  123. cls._bind_map[method][path] = bind_info
  124. _handler_route.__route_method__ = method
  125. _handler_route.__route_path__ = path
  126. return cls._routes.route(method, path, **kwargs)(_handler_route)
  127. return _wrapper
  128. @classmethod
  129. def bind(cls, instance):
  130. def predicate(o):
  131. if inspect.ismethod(o):
  132. return hasattr(o, "__route_method__") and hasattr(
  133. o, "__route_path__"
  134. )
  135. return False
  136. handler_routes = inspect.getmembers(instance, predicate)
  137. for _, h in handler_routes:
  138. cls._bind_map[h.__func__.__route_method__][
  139. h.__func__.__route_path__
  140. ].instance = instance
  141. return MethodRouteTable
  142. def rest_response(
  143. status_code: HTTPStatusCode,
  144. message: str,
  145. convert_google_style: bool = True,
  146. **kwargs,
  147. ) -> aiohttp.web.Response:
  148. """
  149. Args:
  150. status_code: HTTPStatusCode
  151. The HTTP status code of the response.
  152. message: str
  153. The message of the response.
  154. convert_google_style: bool
  155. Whether to convert the response to google style.
  156. Returns:
  157. aiohttp.web.Response
  158. """
  159. # In the dev context we allow a dev server running on a
  160. # different port to consume the API, meaning we need to allow
  161. # cross-origin access
  162. if os.environ.get("RAY_DASHBOARD_DEV") == "1":
  163. headers = {"Access-Control-Allow-Origin": "*"}
  164. else:
  165. headers = {}
  166. success = status_code == HTTPStatusCode.OK
  167. return aiohttp.web.json_response(
  168. {
  169. "result": success,
  170. "msg": message,
  171. "data": to_google_style(kwargs) if convert_google_style else kwargs,
  172. },
  173. dumps=functools.partial(json.dumps, cls=CustomEncoder),
  174. headers=headers,
  175. status=status_code,
  176. )