literal.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import ast
  2. from collections.abc import Callable
  3. from pprint import PrettyPrinter
  4. from typing import Any
  5. from isort.exceptions import (
  6. AssignmentsFormatMismatch,
  7. LiteralParsingFailure,
  8. LiteralSortTypeMismatch,
  9. )
  10. from isort.settings import DEFAULT_CONFIG, Config
  11. class ISortPrettyPrinter(PrettyPrinter):
  12. """an isort customized pretty printer for sorted literals"""
  13. def __init__(self, config: Config):
  14. super().__init__(width=config.line_length, compact=True)
  15. type_mapping: dict[str, tuple[type, Callable[[Any, ISortPrettyPrinter], str]]] = {}
  16. def assignments(code: str) -> str:
  17. values = {}
  18. for line in code.splitlines(keepends=True):
  19. if not line.strip():
  20. continue
  21. if " = " not in line:
  22. raise AssignmentsFormatMismatch(code)
  23. variable_name, value = line.split(" = ", 1)
  24. values[variable_name] = value
  25. return "".join(
  26. f"{variable_name} = {values[variable_name]}" for variable_name in sorted(values.keys())
  27. )
  28. def assignment(code: str, sort_type: str, extension: str, config: Config = DEFAULT_CONFIG) -> str:
  29. """Sorts the literal present within the provided code against the provided sort type,
  30. returning the sorted representation of the source code.
  31. """
  32. if sort_type == "assignments":
  33. return assignments(code)
  34. if sort_type not in type_mapping:
  35. raise ValueError(
  36. "Trying to sort using an undefined sort_type. "
  37. f"Defined sort types are {', '.join(type_mapping.keys())}."
  38. )
  39. variable_name, literal = code.split("=")
  40. variable_name = variable_name.strip()
  41. literal = literal.lstrip()
  42. try:
  43. value = ast.literal_eval(literal)
  44. except Exception as error:
  45. raise LiteralParsingFailure(code, error)
  46. expected_type, sort_function = type_mapping[sort_type]
  47. if type(value) is not expected_type:
  48. raise LiteralSortTypeMismatch(type(value), expected_type)
  49. printer = ISortPrettyPrinter(config)
  50. sorted_value_code = f"{variable_name} = {sort_function(value, printer)}"
  51. if config.formatting_function:
  52. sorted_value_code = config.formatting_function(
  53. sorted_value_code, extension, config
  54. ).rstrip()
  55. sorted_value_code += code[len(code.rstrip()) :]
  56. return sorted_value_code
  57. def register_type(
  58. name: str, kind: type
  59. ) -> Callable[[Callable[[Any, ISortPrettyPrinter], str]], Callable[[Any, ISortPrettyPrinter], str]]:
  60. """Registers a new literal sort type."""
  61. def wrap(
  62. function: Callable[[Any, ISortPrettyPrinter], str],
  63. ) -> Callable[[Any, ISortPrettyPrinter], str]:
  64. type_mapping[name] = (kind, function)
  65. return function
  66. return wrap
  67. @register_type("dict", dict)
  68. def _dict(value: dict[Any, Any], printer: ISortPrettyPrinter) -> str:
  69. return printer.pformat(dict(sorted(value.items(), key=lambda item: item[1])))
  70. @register_type("list", list)
  71. def _list(value: list[Any], printer: ISortPrettyPrinter) -> str:
  72. return printer.pformat(sorted(value))
  73. @register_type("unique-list", list)
  74. def _unique_list(value: list[Any], printer: ISortPrettyPrinter) -> str:
  75. return printer.pformat(sorted(set(value)))
  76. @register_type("set", set)
  77. def _set(value: set[Any], printer: ISortPrettyPrinter) -> str:
  78. return "{" + printer.pformat(tuple(sorted(value)))[1:-1] + "}"
  79. @register_type("tuple", tuple)
  80. def _tuple(value: tuple[Any, ...], printer: ISortPrettyPrinter) -> str:
  81. return printer.pformat(tuple(sorted(value)))
  82. @register_type("unique-tuple", tuple)
  83. def _unique_tuple(value: tuple[Any, ...], printer: ISortPrettyPrinter) -> str:
  84. return printer.pformat(tuple(sorted(set(value))))