_ast.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
  4. from __future__ import annotations
  5. import ast
  6. from typing import NamedTuple
  7. from astroid.const import Context
  8. class FunctionType(NamedTuple):
  9. argtypes: list[ast.expr]
  10. returns: ast.expr
  11. class ParserModule(NamedTuple):
  12. unary_op_classes: dict[type[ast.unaryop], str]
  13. cmp_op_classes: dict[type[ast.cmpop], str]
  14. bool_op_classes: dict[type[ast.boolop], str]
  15. bin_op_classes: dict[type[ast.operator], str]
  16. context_classes: dict[type[ast.expr_context], Context]
  17. def parse(
  18. self, string: str, type_comments: bool = True, filename: str | None = None
  19. ) -> ast.Module:
  20. if filename:
  21. return ast.parse(string, filename=filename, type_comments=type_comments)
  22. return ast.parse(string, type_comments=type_comments)
  23. def parse_function_type_comment(type_comment: str) -> FunctionType | None:
  24. """Given a correct type comment, obtain a FunctionType object."""
  25. func_type = ast.parse(type_comment, "<type_comment>", "func_type")
  26. return FunctionType(argtypes=func_type.argtypes, returns=func_type.returns)
  27. def get_parser_module(type_comments: bool = True) -> ParserModule:
  28. unary_op_classes = _unary_operators_from_module()
  29. cmp_op_classes = _compare_operators_from_module()
  30. bool_op_classes = _bool_operators_from_module()
  31. bin_op_classes = _binary_operators_from_module()
  32. context_classes = _contexts_from_module()
  33. return ParserModule(
  34. unary_op_classes,
  35. cmp_op_classes,
  36. bool_op_classes,
  37. bin_op_classes,
  38. context_classes,
  39. )
  40. def _unary_operators_from_module() -> dict[type[ast.unaryop], str]:
  41. return {ast.UAdd: "+", ast.USub: "-", ast.Not: "not", ast.Invert: "~"}
  42. def _binary_operators_from_module() -> dict[type[ast.operator], str]:
  43. return {
  44. ast.Add: "+",
  45. ast.BitAnd: "&",
  46. ast.BitOr: "|",
  47. ast.BitXor: "^",
  48. ast.Div: "/",
  49. ast.FloorDiv: "//",
  50. ast.MatMult: "@",
  51. ast.Mod: "%",
  52. ast.Mult: "*",
  53. ast.Pow: "**",
  54. ast.Sub: "-",
  55. ast.LShift: "<<",
  56. ast.RShift: ">>",
  57. }
  58. def _bool_operators_from_module() -> dict[type[ast.boolop], str]:
  59. return {ast.And: "and", ast.Or: "or"}
  60. def _compare_operators_from_module() -> dict[type[ast.cmpop], str]:
  61. return {
  62. ast.Eq: "==",
  63. ast.Gt: ">",
  64. ast.GtE: ">=",
  65. ast.In: "in",
  66. ast.Is: "is",
  67. ast.IsNot: "is not",
  68. ast.Lt: "<",
  69. ast.LtE: "<=",
  70. ast.NotEq: "!=",
  71. ast.NotIn: "not in",
  72. }
  73. def _contexts_from_module() -> dict[type[ast.expr_context], Context]:
  74. return {
  75. ast.Load: Context.Load,
  76. ast.Store: Context.Store,
  77. ast.Del: Context.Del,
  78. ast.Param: Context.Store,
  79. }