checker_test_case.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
  4. from __future__ import annotations
  5. import contextlib
  6. from collections.abc import Generator, Iterator
  7. from typing import Any
  8. from astroid import nodes
  9. from pylint.testutils.global_test_linter import linter
  10. from pylint.testutils.output_line import MessageTest
  11. from pylint.testutils.unittest_linter import UnittestLinter
  12. from pylint.utils import ASTWalker
  13. class CheckerTestCase:
  14. """A base testcase class for unit testing individual checker classes."""
  15. # TODO: Figure out way to type this as type[BaseChecker] while also
  16. # setting self.checker correctly.
  17. CHECKER_CLASS: Any
  18. CONFIG: dict[str, Any] = {}
  19. def setup_method(self) -> None:
  20. self.linter = UnittestLinter()
  21. self.checker = self.CHECKER_CLASS(self.linter)
  22. for key, value in self.CONFIG.items():
  23. setattr(self.checker.linter.config, key, value)
  24. self.checker.open()
  25. @contextlib.contextmanager
  26. def assertNoMessages(self) -> Iterator[None]:
  27. """Assert that no messages are added by the given method."""
  28. with self.assertAddsMessages():
  29. yield
  30. @contextlib.contextmanager
  31. def assertAddsMessages(
  32. self, *messages: MessageTest, ignore_position: bool = False
  33. ) -> Generator[None]:
  34. """Assert that exactly the given method adds the given messages.
  35. The list of messages must exactly match *all* the messages added by the
  36. method. Additionally, we check to see whether the args in each message can
  37. actually be substituted into the message string.
  38. Using the keyword argument `ignore_position`, all checks for position
  39. arguments (line, col_offset, ...) will be skipped. This can be used to
  40. just test messages for the correct node.
  41. """
  42. yield
  43. got = self.linter.release_messages()
  44. no_msg = "No message."
  45. expected = "\n".join(repr(m) for m in messages) or no_msg
  46. got_str = "\n".join(repr(m) for m in got) or no_msg
  47. msg = (
  48. "Expected messages did not match actual.\n"
  49. f"\nExpected:\n{expected}\n\nGot:\n{got_str}\n"
  50. )
  51. assert len(messages) == len(got), msg
  52. for expected_msg, gotten_msg in zip(messages, got):
  53. assert expected_msg.msg_id == gotten_msg.msg_id, msg
  54. assert expected_msg.node == gotten_msg.node, msg
  55. assert expected_msg.args == gotten_msg.args, msg
  56. assert expected_msg.confidence == gotten_msg.confidence, msg
  57. if ignore_position:
  58. # Do not check for line, col_offset etc...
  59. continue
  60. assert expected_msg.line == gotten_msg.line, msg
  61. assert expected_msg.col_offset == gotten_msg.col_offset, msg
  62. assert expected_msg.end_line == gotten_msg.end_line, msg
  63. assert expected_msg.end_col_offset == gotten_msg.end_col_offset, msg
  64. def walk(self, node: nodes.NodeNG) -> None:
  65. """Recursive walk on the given node."""
  66. walker = ASTWalker(linter)
  67. walker.add_checker(self.checker)
  68. walker.walk(node)