diagrams.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
  4. """Diagram objects."""
  5. from __future__ import annotations
  6. from collections.abc import Iterable
  7. from typing import Any
  8. import astroid
  9. from astroid import nodes, objects, util
  10. from pylint.checkers.utils import decorated_with_property, in_type_checking_block
  11. from pylint.pyreverse.utils import FilterMixIn, get_annotation_label
  12. class Figure:
  13. """Base class for counter handling."""
  14. def __init__(self) -> None:
  15. self.fig_id: str = ""
  16. class Relationship(Figure):
  17. """A relationship from an object in the diagram to another."""
  18. def __init__(
  19. self,
  20. from_object: DiagramEntity,
  21. to_object: DiagramEntity,
  22. relation_type: str,
  23. name: str | None = None,
  24. ):
  25. super().__init__()
  26. self.from_object = from_object
  27. self.to_object = to_object
  28. self.type = relation_type
  29. self.name = name
  30. class DiagramEntity(Figure):
  31. """A diagram object, i.e. a label associated to an astroid node."""
  32. default_shape = ""
  33. def __init__(
  34. self, title: str = "No name", node: nodes.NodeNG | None = None
  35. ) -> None:
  36. super().__init__()
  37. self.title = title
  38. self.node: nodes.NodeNG = node or nodes.NodeNG(
  39. lineno=None,
  40. col_offset=None,
  41. end_lineno=None,
  42. end_col_offset=None,
  43. parent=None,
  44. )
  45. self.shape = self.default_shape
  46. class PackageEntity(DiagramEntity):
  47. """A diagram object representing a package."""
  48. default_shape = "package"
  49. class ClassEntity(DiagramEntity):
  50. """A diagram object representing a class."""
  51. default_shape = "class"
  52. def __init__(self, title: str, node: nodes.ClassDef) -> None:
  53. super().__init__(title=title, node=node)
  54. self.attrs: list[str] = []
  55. self.methods: list[nodes.FunctionDef] = []
  56. class ClassDiagram(Figure, FilterMixIn):
  57. """Main class diagram handling."""
  58. TYPE = "class"
  59. def __init__(self, title: str, mode: str) -> None:
  60. FilterMixIn.__init__(self, mode)
  61. Figure.__init__(self)
  62. self.title = title
  63. # TODO: Specify 'Any' after refactor of `DiagramEntity`
  64. self.objects: list[Any] = []
  65. self.relationships: dict[str, list[Relationship]] = {}
  66. self._nodes: dict[nodes.NodeNG, DiagramEntity] = {}
  67. def get_relationships(self, role: str) -> Iterable[Relationship]:
  68. # sorted to get predictable (hence testable) results
  69. return sorted(
  70. self.relationships.get(role, ()),
  71. key=lambda x: (x.from_object.fig_id, x.to_object.fig_id),
  72. )
  73. def add_relationship(
  74. self,
  75. from_object: DiagramEntity,
  76. to_object: DiagramEntity,
  77. relation_type: str,
  78. name: str | None = None,
  79. ) -> None:
  80. """Create a relationship."""
  81. rel = Relationship(from_object, to_object, relation_type, name)
  82. self.relationships.setdefault(relation_type, []).append(rel)
  83. def get_relationship(
  84. self, from_object: DiagramEntity, relation_type: str
  85. ) -> Relationship:
  86. """Return a relationship or None."""
  87. for rel in self.relationships.get(relation_type, ()):
  88. if rel.from_object is from_object:
  89. return rel
  90. raise KeyError(relation_type)
  91. def get_attrs(self, node: nodes.ClassDef) -> list[str]:
  92. """Return visible attributes, possibly with class name."""
  93. attrs = []
  94. # Collect functions decorated with @property
  95. properties = {
  96. local_name: local_node
  97. for local_name, local_node in node.items()
  98. if isinstance(local_node, nodes.FunctionDef)
  99. and decorated_with_property(local_node)
  100. }
  101. # Add instance attributes to properties
  102. for attr_name, attr_type in list(node.locals_type.items()) + list(
  103. node.instance_attrs_type.items()
  104. ):
  105. if attr_name not in properties:
  106. properties[attr_name] = attr_type
  107. for node_name, associated_nodes in properties.items():
  108. if not self.show_attr(node_name):
  109. continue
  110. # Handle property methods differently to correctly extract return type
  111. if isinstance(
  112. associated_nodes, nodes.FunctionDef
  113. ) and decorated_with_property(associated_nodes):
  114. if associated_nodes.returns:
  115. type_annotation = get_annotation_label(associated_nodes.returns)
  116. node_name = f"{node_name} : {type_annotation}"
  117. # Handle regular attributes
  118. else:
  119. names = self.class_names(associated_nodes)
  120. if names:
  121. node_name = f"{node_name} : {', '.join(names)}"
  122. attrs.append(node_name)
  123. return sorted(attrs)
  124. def get_methods(self, node: nodes.ClassDef) -> list[nodes.FunctionDef]:
  125. """Return visible methods."""
  126. methods = [
  127. m
  128. for m in node.values()
  129. if isinstance(m, nodes.FunctionDef)
  130. and not isinstance(m, objects.Property)
  131. and not decorated_with_property(m)
  132. and self.show_attr(m.name)
  133. ]
  134. return sorted(methods, key=lambda n: n.name)
  135. def add_object(self, title: str, node: nodes.ClassDef) -> None:
  136. """Create a diagram object."""
  137. assert node not in self._nodes
  138. ent = ClassEntity(title, node)
  139. self._nodes[node] = ent
  140. self.objects.append(ent)
  141. def class_names(self, nodes_lst: Iterable[nodes.NodeNG]) -> list[str]:
  142. """Return class names if needed in diagram."""
  143. names = []
  144. for node in nodes_lst:
  145. if isinstance(node, astroid.Instance):
  146. node = node._proxied
  147. if (
  148. isinstance(
  149. node, (nodes.ClassDef, nodes.Name, nodes.Subscript, nodes.BinOp)
  150. )
  151. and hasattr(node, "name")
  152. and not self.has_node(node)
  153. ):
  154. if node.name not in names:
  155. node_name = node.name
  156. names.append(node_name)
  157. # sorted to get predictable (hence testable) results
  158. return sorted(
  159. name
  160. for name in names
  161. if all(name not in other or name == other for other in names)
  162. )
  163. def has_node(self, node: nodes.NodeNG) -> bool:
  164. """Return true if the given node is included in the diagram."""
  165. return node in self._nodes
  166. def object_from_node(self, node: nodes.NodeNG) -> DiagramEntity:
  167. """Return the diagram object mapped to node."""
  168. return self._nodes[node]
  169. def classes(self) -> list[ClassEntity]:
  170. """Return all class nodes in the diagram."""
  171. return [o for o in self.objects if isinstance(o, ClassEntity)]
  172. def classe(self, name: str) -> ClassEntity:
  173. """Return a class by its name, raise KeyError if not found."""
  174. for klass in self.classes():
  175. if klass.node.name == name:
  176. return klass
  177. raise KeyError(name)
  178. def extract_relationships(self) -> None:
  179. """Extract relationships between nodes in the diagram."""
  180. for obj in self.classes():
  181. node = obj.node
  182. obj.attrs = self.get_attrs(node)
  183. obj.methods = self.get_methods(node)
  184. obj.shape = "class"
  185. # inheritance link
  186. for par_node in node.ancestors(recurs=False):
  187. try:
  188. par_obj = self.object_from_node(par_node)
  189. self.add_relationship(obj, par_obj, "specialization")
  190. except KeyError:
  191. continue
  192. # Track processed attributes to avoid duplicates
  193. processed_attrs = set()
  194. # Process in priority order: Composition > Aggregation > Association
  195. # 1. Composition links (highest priority)
  196. for name, values in list(node.compositions_type.items()):
  197. if not self.show_attr(name):
  198. continue
  199. for value in values:
  200. self.assign_association_relationship(
  201. value, obj, name, "composition"
  202. )
  203. processed_attrs.add(name)
  204. # 2. Aggregation links (medium priority)
  205. for name, values in list(node.aggregations_type.items()):
  206. if not self.show_attr(name) or name in processed_attrs:
  207. continue
  208. for value in values:
  209. self.assign_association_relationship(
  210. value, obj, name, "aggregation"
  211. )
  212. processed_attrs.add(name)
  213. # 3. Association links (lowest priority)
  214. associations = node.associations_type.copy()
  215. for name, values in node.locals_type.items():
  216. if name not in associations:
  217. associations[name] = values
  218. for name, values in associations.items():
  219. if not self.show_attr(name) or name in processed_attrs:
  220. continue
  221. for value in values:
  222. self.assign_association_relationship(
  223. value, obj, name, "association"
  224. )
  225. def assign_association_relationship(
  226. self, value: nodes.NodeNG, obj: ClassEntity, name: str, type_relationship: str
  227. ) -> None:
  228. if isinstance(value, util.UninferableBase):
  229. return
  230. if isinstance(value, astroid.Instance):
  231. value = value._proxied
  232. try:
  233. associated_obj = self.object_from_node(value)
  234. self.add_relationship(associated_obj, obj, type_relationship, name)
  235. except KeyError:
  236. return
  237. class PackageDiagram(ClassDiagram):
  238. """Package diagram handling."""
  239. TYPE = "package"
  240. def modules(self) -> list[PackageEntity]:
  241. """Return all module nodes in the diagram."""
  242. return [o for o in self.objects if isinstance(o, PackageEntity)]
  243. def module(self, name: str) -> PackageEntity:
  244. """Return a module by its name, raise KeyError if not found."""
  245. for mod in self.modules():
  246. if mod.node.name == name:
  247. return mod
  248. raise KeyError(name)
  249. def add_object(self, title: str, node: nodes.Module) -> None:
  250. """Create a diagram object."""
  251. assert node not in self._nodes
  252. ent = PackageEntity(title, node)
  253. self._nodes[node] = ent
  254. self.objects.append(ent)
  255. def get_module(self, name: str, node: nodes.Module) -> PackageEntity:
  256. """Return a module by its name, looking also for relative imports;
  257. raise KeyError if not found.
  258. """
  259. for mod in self.modules():
  260. mod_name = mod.node.name
  261. if mod_name == name:
  262. return mod
  263. # search for fullname of relative import modules
  264. package = node.root().name
  265. if mod_name == f"{package}.{name}":
  266. return mod
  267. if mod_name == f"{package.rsplit('.', 1)[0]}.{name}":
  268. return mod
  269. raise KeyError(name)
  270. def add_from_depend(self, node: nodes.ImportFrom, from_module: str) -> None:
  271. """Add dependencies created by from-imports."""
  272. mod_name = node.root().name
  273. package = self.module(mod_name).node
  274. if from_module in package.depends:
  275. return
  276. if not in_type_checking_block(node):
  277. package.depends.append(from_module)
  278. elif from_module not in package.type_depends:
  279. package.type_depends.append(from_module)
  280. def extract_relationships(self) -> None:
  281. """Extract relationships between nodes in the diagram."""
  282. super().extract_relationships()
  283. for class_obj in self.classes():
  284. # ownership
  285. try:
  286. mod = self.object_from_node(class_obj.node.root())
  287. self.add_relationship(class_obj, mod, "ownership")
  288. except KeyError:
  289. continue
  290. for package_obj in self.modules():
  291. package_obj.shape = "package"
  292. # dependencies
  293. for dep_name in package_obj.node.depends:
  294. try:
  295. dep = self.get_module(dep_name, package_obj.node)
  296. except KeyError:
  297. continue
  298. self.add_relationship(package_obj, dep, "depends")
  299. for dep_name in package_obj.node.type_depends:
  300. try:
  301. dep = self.get_module(dep_name, package_obj.node)
  302. except KeyError: # pragma: no cover
  303. continue
  304. self.add_relationship(package_obj, dep, "type_depends")