diadefslib.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
  4. """Handle diagram generation options for class diagram or default diagrams."""
  5. from __future__ import annotations
  6. import argparse
  7. import warnings
  8. from collections.abc import Generator, Sequence
  9. from typing import Any
  10. import astroid
  11. from astroid import nodes
  12. from astroid.modutils import is_stdlib_module
  13. from pylint.pyreverse.diagrams import ClassDiagram, PackageDiagram
  14. from pylint.pyreverse.inspector import Linker, Project
  15. from pylint.pyreverse.utils import LocalsVisitor
  16. # diagram generators ##########################################################
  17. class DiaDefGenerator:
  18. """Handle diagram generation options."""
  19. def __init__(self, linker: Linker, handler: DiadefsHandler) -> None:
  20. """Common Diagram Handler initialization."""
  21. self.config = handler.config
  22. self.args = handler.args
  23. self.module_names: bool = False
  24. self._set_default_options()
  25. self.linker = linker
  26. self.classdiagram: ClassDiagram # defined by subclasses
  27. # Only pre-calculate depths if user has requested a max_depth
  28. if handler.config.max_depth is not None:
  29. # Detect which of the args are leaf nodes
  30. leaf_nodes = self.get_leaf_nodes()
  31. # Emit a warning if any of the args are not leaf nodes
  32. diff = set(self.args).difference(set(leaf_nodes))
  33. if len(diff) > 0:
  34. warnings.warn(
  35. "Detected nested names within the specified packages. "
  36. f"The following packages: {sorted(diff)} will be ignored for "
  37. f"depth calculations, using only: {sorted(leaf_nodes)} as the base for limiting "
  38. "package depth.",
  39. stacklevel=2,
  40. )
  41. self.args_depths = {module: module.count(".") for module in leaf_nodes}
  42. def get_title(self, node: nodes.ClassDef) -> str:
  43. """Get title for objects."""
  44. title = node.name
  45. if self.module_names:
  46. title = f"{node.root().name}.{title}"
  47. return title # type: ignore[no-any-return]
  48. def get_leaf_nodes(self) -> list[str]:
  49. """
  50. Get the leaf nodes from the list of args in the generator.
  51. A leaf node is one that is not a prefix (with an extra dot) of any other node.
  52. """
  53. leaf_nodes = [
  54. module
  55. for module in self.args
  56. if not any(
  57. other != module and other.startswith(module + ".")
  58. for other in self.args
  59. )
  60. ]
  61. return leaf_nodes
  62. def _set_option(self, option: bool | None) -> bool:
  63. """Activate some options if not explicitly deactivated."""
  64. # if we have a class diagram, we want more information by default;
  65. # so if the option is None, we return True
  66. if option is None:
  67. return bool(self.config.classes)
  68. return option
  69. def _set_default_options(self) -> None:
  70. """Set different default options with _default dictionary."""
  71. self.module_names = self._set_option(self.config.module_names)
  72. all_ancestors = self._set_option(self.config.all_ancestors)
  73. all_associated = self._set_option(self.config.all_associated)
  74. anc_level, association_level = (0, 0)
  75. if all_ancestors:
  76. anc_level = -1
  77. if all_associated:
  78. association_level = -1
  79. if self.config.show_ancestors is not None:
  80. anc_level = self.config.show_ancestors
  81. if self.config.show_associated is not None:
  82. association_level = self.config.show_associated
  83. self.anc_level, self.association_level = anc_level, association_level
  84. def _get_levels(self) -> tuple[int, int]:
  85. """Help function for search levels."""
  86. return self.anc_level, self.association_level
  87. def _should_include_by_depth(self, node: nodes.NodeNG) -> bool:
  88. """Check if a node should be included based on depth.
  89. A node will be included if it is at or below the max_depth relative to the
  90. specified base packages. A node is considered to be a base package if it is the
  91. deepest package in the list of specified packages. In other words the base nodes
  92. are the leaf nodes of the specified package tree.
  93. """
  94. # If max_depth is not set, include all nodes
  95. if self.config.max_depth is None:
  96. return True
  97. # Calculate the absolute depth of the node
  98. name = node.root().name
  99. absolute_depth = name.count(".")
  100. # Retrieve the base depth to compare against
  101. relative_depth = next(
  102. (v for k, v in self.args_depths.items() if name.startswith(k)), None
  103. )
  104. return relative_depth is not None and bool(
  105. (absolute_depth - relative_depth) <= self.config.max_depth
  106. )
  107. def show_node(self, node: nodes.ClassDef) -> bool:
  108. """Determine if node should be shown based on config."""
  109. if node.root().name == "builtins":
  110. return self.config.show_builtin # type: ignore[no-any-return]
  111. if is_stdlib_module(node.root().name):
  112. return self.config.show_stdlib # type: ignore[no-any-return]
  113. # Filter node by depth
  114. return self._should_include_by_depth(node)
  115. def add_class(self, node: nodes.ClassDef) -> None:
  116. """Visit one class and add it to diagram."""
  117. self.linker.visit(node)
  118. self.classdiagram.add_object(self.get_title(node), node)
  119. def get_ancestors(
  120. self, node: nodes.ClassDef, level: int
  121. ) -> Generator[nodes.ClassDef]:
  122. """Return ancestor nodes of a class node."""
  123. if level == 0:
  124. return
  125. for ancestor in node.ancestors(recurs=False):
  126. if not self.show_node(ancestor):
  127. continue
  128. yield ancestor
  129. def get_associated(
  130. self, klass_node: nodes.ClassDef, level: int
  131. ) -> Generator[nodes.ClassDef]:
  132. """Return associated nodes of a class node."""
  133. if level == 0:
  134. return
  135. for association_nodes in list(klass_node.instance_attrs_type.values()) + list(
  136. klass_node.locals_type.values()
  137. ):
  138. for node in association_nodes:
  139. if isinstance(node, astroid.Instance):
  140. node = node._proxied
  141. if not (isinstance(node, nodes.ClassDef) and self.show_node(node)):
  142. continue
  143. yield node
  144. def extract_classes(
  145. self, klass_node: nodes.ClassDef, anc_level: int, association_level: int
  146. ) -> None:
  147. """Extract recursively classes related to klass_node."""
  148. if self.classdiagram.has_node(klass_node) or not self.show_node(klass_node):
  149. return
  150. self.add_class(klass_node)
  151. for ancestor in self.get_ancestors(klass_node, anc_level):
  152. self.extract_classes(ancestor, anc_level - 1, association_level)
  153. for node in self.get_associated(klass_node, association_level):
  154. self.extract_classes(node, anc_level, association_level - 1)
  155. class DefaultDiadefGenerator(LocalsVisitor, DiaDefGenerator):
  156. """Generate minimum diagram definition for the project :
  157. * a package diagram including project's modules
  158. * a class diagram including project's classes
  159. """
  160. def __init__(self, linker: Linker, handler: DiadefsHandler) -> None:
  161. DiaDefGenerator.__init__(self, linker, handler)
  162. LocalsVisitor.__init__(self)
  163. def visit_project(self, node: Project) -> None:
  164. """Visit a pyreverse.utils.Project node.
  165. create a diagram definition for packages
  166. """
  167. mode = self.config.mode
  168. if len(node.modules) > 1:
  169. self.pkgdiagram: PackageDiagram | None = PackageDiagram(
  170. f"packages {node.name}", mode
  171. )
  172. else:
  173. self.pkgdiagram = None
  174. self.classdiagram = ClassDiagram(f"classes {node.name}", mode)
  175. def leave_project(self, _: Project) -> Any:
  176. """Leave the pyreverse.utils.Project node.
  177. return the generated diagram definition
  178. """
  179. if self.pkgdiagram:
  180. return self.pkgdiagram, self.classdiagram
  181. return (self.classdiagram,)
  182. def visit_module(self, node: nodes.Module) -> None:
  183. """Visit an nodes.Module node.
  184. add this class to the package diagram definition
  185. """
  186. if self.pkgdiagram and self._should_include_by_depth(node):
  187. self.linker.visit(node)
  188. self.pkgdiagram.add_object(node.name, node)
  189. def visit_classdef(self, node: nodes.ClassDef) -> None:
  190. """Visit an nodes.Class node.
  191. add this class to the class diagram definition
  192. """
  193. anc_level, association_level = self._get_levels()
  194. self.extract_classes(node, anc_level, association_level)
  195. def visit_importfrom(self, node: nodes.ImportFrom) -> None:
  196. """Visit nodes.ImportFrom and catch modules for package diagram."""
  197. if self.pkgdiagram and self._should_include_by_depth(node):
  198. self.pkgdiagram.add_from_depend(node, node.modname)
  199. class ClassDiadefGenerator(DiaDefGenerator):
  200. """Generate a class diagram definition including all classes related to a
  201. given class.
  202. """
  203. def class_diagram(self, project: Project, klass: nodes.ClassDef) -> ClassDiagram:
  204. """Return a class diagram definition for the class and related classes."""
  205. self.classdiagram = ClassDiagram(klass, self.config.mode)
  206. if len(project.modules) > 1:
  207. module, klass = klass.rsplit(".", 1)
  208. module = project.get_module(module)
  209. else:
  210. module = project.modules[0]
  211. klass = klass.split(".")[-1]
  212. klass = next(module.ilookup(klass))
  213. anc_level, association_level = self._get_levels()
  214. self.extract_classes(klass, anc_level, association_level)
  215. return self.classdiagram
  216. # diagram handler #############################################################
  217. class DiadefsHandler:
  218. """Get diagram definitions from user (i.e. xml files) or generate them."""
  219. def __init__(self, config: argparse.Namespace, args: Sequence[str]) -> None:
  220. self.config = config
  221. self.args = args
  222. def get_diadefs(self, project: Project, linker: Linker) -> list[ClassDiagram]:
  223. """Get the diagram's configuration data.
  224. :param project:The pyreverse project
  225. :type project: pyreverse.utils.Project
  226. :param linker: The linker
  227. :type linker: pyreverse.inspector.Linker(IdGeneratorMixIn, LocalsVisitor)
  228. :returns: The list of diagram definitions
  229. :rtype: list(:class:`pylint.pyreverse.diagrams.ClassDiagram`)
  230. """
  231. # read and interpret diagram definitions (Diadefs)
  232. diagrams = []
  233. generator = ClassDiadefGenerator(linker, self)
  234. for klass in self.config.classes:
  235. diagrams.append(generator.class_diagram(project, klass))
  236. if not diagrams:
  237. diagrams = DefaultDiadefGenerator(linker, self).visit(project)
  238. for diagram in diagrams:
  239. diagram.extract_relationships()
  240. return diagrams