| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- import abc
- import collections
- import functools
- import inspect
- import json
- import logging
- import os
- import traceback
- from typing import Any
- from ray.dashboard.optional_deps import PathLike, RouteDef, aiohttp, hdrs
- from ray.dashboard.utils import CustomEncoder, HTTPStatusCode, to_google_style
- logger = logging.getLogger(__name__)
- class BaseRouteTable(abc.ABC):
- """A base class to bind http route to a target instance. Subclass should implement
- the _register_route method. It should define how the handler interacts with
- _BindInfo.instance.
- Subclasses must declare their own _bind_map and _routes properties to avoid
- conflicts.
- """
- class _BindInfo:
- def __init__(self, filename, lineno, instance):
- self.filename = filename
- self.lineno = lineno
- self.instance = instance
- @classmethod
- @property
- @abc.abstractmethod
- def _bind_map(cls):
- pass
- @classmethod
- @property
- @abc.abstractmethod
- def _routes(cls):
- pass
- @classmethod
- @abc.abstractmethod
- def _register_route(cls, method, path, **kwargs):
- pass
- @classmethod
- @abc.abstractmethod
- def bind(cls, instance):
- pass
- @classmethod
- def routes(cls):
- return cls._routes
- @classmethod
- def bound_routes(cls):
- bound_items = []
- for r in cls._routes._items:
- if isinstance(r, RouteDef):
- route_method = r.handler.__route_method__
- route_path = r.handler.__route_path__
- instance = cls._bind_map[route_method][route_path].instance
- if instance is not None:
- bound_items.append(r)
- else:
- bound_items.append(r)
- routes = aiohttp.web.RouteTableDef()
- routes._items = bound_items
- return routes
- @classmethod
- def head(cls, path, **kwargs):
- return cls._register_route(hdrs.METH_HEAD, path, **kwargs)
- @classmethod
- def get(cls, path, **kwargs):
- return cls._register_route(hdrs.METH_GET, path, **kwargs)
- @classmethod
- def post(cls, path, **kwargs):
- return cls._register_route(hdrs.METH_POST, path, **kwargs)
- @classmethod
- def put(cls, path, **kwargs):
- return cls._register_route(hdrs.METH_PUT, path, **kwargs)
- @classmethod
- def patch(cls, path, **kwargs):
- return cls._register_route(hdrs.METH_PATCH, path, **kwargs)
- @classmethod
- def delete(cls, path, **kwargs):
- return cls._register_route(hdrs.METH_DELETE, path, **kwargs)
- @classmethod
- def view(cls, path, **kwargs):
- return cls._register_route(hdrs.METH_ANY, path, **kwargs)
- @classmethod
- def static(cls, prefix: str, path: PathLike, **kwargs: Any) -> None:
- cls._routes.static(prefix, path, **kwargs)
- def method_route_table_factory():
- """
- Return a method-based route table class, for in-process HeadModule objects.
- """
- class MethodRouteTable(BaseRouteTable):
- """A helper class to bind http route to class method. Each _BindInfo.instance
- is a class instance, and for an inbound request, we invoke the async handler
- method."""
- _bind_map = collections.defaultdict(dict)
- _routes = aiohttp.web.RouteTableDef()
- @classmethod
- def _register_route(cls, method, path, **kwargs):
- def _wrapper(handler):
- if path in cls._bind_map[method]:
- bind_info = cls._bind_map[method][path]
- raise Exception(
- f"Duplicated route path: {path}, "
- f"previous one registered at "
- f"{bind_info.filename}:{bind_info.lineno}"
- )
- bind_info = cls._BindInfo(
- handler.__code__.co_filename, handler.__code__.co_firstlineno, None
- )
- @functools.wraps(handler)
- async def _handler_route(*args) -> aiohttp.web.Response:
- try:
- # Make the route handler as a bound method.
- # The args may be:
- # * (Request, )
- # * (self, Request)
- req = args[-1]
- return await handler(bind_info.instance, req)
- except Exception:
- logger.exception("Handle %s %s failed.", method, path)
- return rest_response(
- status_code=HTTPStatusCode.INTERNAL_ERROR,
- message=traceback.format_exc(),
- )
- cls._bind_map[method][path] = bind_info
- _handler_route.__route_method__ = method
- _handler_route.__route_path__ = path
- return cls._routes.route(method, path, **kwargs)(_handler_route)
- return _wrapper
- @classmethod
- def bind(cls, instance):
- def predicate(o):
- if inspect.ismethod(o):
- return hasattr(o, "__route_method__") and hasattr(
- o, "__route_path__"
- )
- return False
- handler_routes = inspect.getmembers(instance, predicate)
- for _, h in handler_routes:
- cls._bind_map[h.__func__.__route_method__][
- h.__func__.__route_path__
- ].instance = instance
- return MethodRouteTable
- def rest_response(
- status_code: HTTPStatusCode,
- message: str,
- convert_google_style: bool = True,
- **kwargs,
- ) -> aiohttp.web.Response:
- """
- Args:
- status_code: HTTPStatusCode
- The HTTP status code of the response.
- message: str
- The message of the response.
- convert_google_style: bool
- Whether to convert the response to google style.
- Returns:
- aiohttp.web.Response
- """
- # In the dev context we allow a dev server running on a
- # different port to consume the API, meaning we need to allow
- # cross-origin access
- if os.environ.get("RAY_DASHBOARD_DEV") == "1":
- headers = {"Access-Control-Allow-Origin": "*"}
- else:
- headers = {}
- success = status_code == HTTPStatusCode.OK
- return aiohttp.web.json_response(
- {
- "result": success,
- "msg": message,
- "data": to_google_style(kwargs) if convert_google_style else kwargs,
- },
- dumps=functools.partial(json.dumps, cls=CustomEncoder),
- headers=headers,
- status=status_code,
- )
|