c_generator.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. # ------------------------------------------------------------------------------
  2. # pycparser: c_generator.py
  3. #
  4. # C code generator from pycparser AST nodes.
  5. #
  6. # Eli Bendersky [https://eli.thegreenplace.net/]
  7. # License: BSD
  8. # ------------------------------------------------------------------------------
  9. from typing import Callable, List, Optional
  10. from . import c_ast
  11. class CGenerator:
  12. """Uses the same visitor pattern as c_ast.NodeVisitor, but modified to
  13. return a value from each visit method, using string accumulation in
  14. generic_visit.
  15. """
  16. indent_level: int
  17. reduce_parentheses: bool
  18. def __init__(self, reduce_parentheses: bool = False) -> None:
  19. """Constructs C-code generator
  20. reduce_parentheses:
  21. if True, eliminates needless parentheses on binary operators
  22. """
  23. # Statements start with indentation of self.indent_level spaces, using
  24. # the _make_indent method.
  25. self.indent_level = 0
  26. self.reduce_parentheses = reduce_parentheses
  27. def _make_indent(self) -> str:
  28. return " " * self.indent_level
  29. def visit(self, node: c_ast.Node) -> str:
  30. method = "visit_" + node.__class__.__name__
  31. return getattr(self, method, self.generic_visit)(node)
  32. def generic_visit(self, node: Optional[c_ast.Node]) -> str:
  33. if node is None:
  34. return ""
  35. else:
  36. return "".join(self.visit(c) for c_name, c in node.children())
  37. def visit_Constant(self, n: c_ast.Constant) -> str:
  38. return n.value
  39. def visit_ID(self, n: c_ast.ID) -> str:
  40. return n.name
  41. def visit_Pragma(self, n: c_ast.Pragma) -> str:
  42. ret = "#pragma"
  43. if n.string:
  44. ret += " " + n.string
  45. return ret
  46. def visit_ArrayRef(self, n: c_ast.ArrayRef) -> str:
  47. arrref = self._parenthesize_unless_simple(n.name)
  48. return arrref + "[" + self.visit(n.subscript) + "]"
  49. def visit_StructRef(self, n: c_ast.StructRef) -> str:
  50. sref = self._parenthesize_unless_simple(n.name)
  51. return sref + n.type + self.visit(n.field)
  52. def visit_FuncCall(self, n: c_ast.FuncCall) -> str:
  53. fref = self._parenthesize_unless_simple(n.name)
  54. args = self.visit(n.args) if n.args is not None else ""
  55. return fref + "(" + args + ")"
  56. def visit_UnaryOp(self, n: c_ast.UnaryOp) -> str:
  57. match n.op:
  58. case "sizeof":
  59. # Always parenthesize the argument of sizeof since it can be
  60. # a name.
  61. return f"sizeof({self.visit(n.expr)})"
  62. case "p++":
  63. operand = self._parenthesize_unless_simple(n.expr)
  64. return f"{operand}++"
  65. case "p--":
  66. operand = self._parenthesize_unless_simple(n.expr)
  67. return f"{operand}--"
  68. case _:
  69. operand = self._parenthesize_unless_simple(n.expr)
  70. return f"{n.op}{operand}"
  71. # Precedence map of binary operators:
  72. precedence_map = {
  73. # Should be in sync with c_parser.CParser.precedence
  74. # Higher numbers are stronger binding
  75. "||": 0, # weakest binding
  76. "&&": 1,
  77. "|": 2,
  78. "^": 3,
  79. "&": 4,
  80. "==": 5,
  81. "!=": 5,
  82. ">": 6,
  83. ">=": 6,
  84. "<": 6,
  85. "<=": 6,
  86. ">>": 7,
  87. "<<": 7,
  88. "+": 8,
  89. "-": 8,
  90. "*": 9,
  91. "/": 9,
  92. "%": 9, # strongest binding
  93. }
  94. def visit_BinaryOp(self, n: c_ast.BinaryOp) -> str:
  95. # Note: all binary operators are left-to-right associative
  96. #
  97. # If `n.left.op` has a stronger or equally binding precedence in
  98. # comparison to `n.op`, no parenthesis are needed for the left:
  99. # e.g., `(a*b) + c` is equivalent to `a*b + c`, as well as
  100. # `(a+b) - c` is equivalent to `a+b - c` (same precedence).
  101. # If the left operator is weaker binding than the current, then
  102. # parentheses are necessary:
  103. # e.g., `(a+b) * c` is NOT equivalent to `a+b * c`.
  104. lval_str = self._parenthesize_if(
  105. n.left,
  106. lambda d: not (
  107. self._is_simple_node(d)
  108. or self.reduce_parentheses
  109. and isinstance(d, c_ast.BinaryOp)
  110. and self.precedence_map[d.op] >= self.precedence_map[n.op]
  111. ),
  112. )
  113. # If `n.right.op` has a stronger -but not equal- binding precedence,
  114. # parenthesis can be omitted on the right:
  115. # e.g., `a + (b*c)` is equivalent to `a + b*c`.
  116. # If the right operator is weaker or equally binding, then parentheses
  117. # are necessary:
  118. # e.g., `a * (b+c)` is NOT equivalent to `a * b+c` and
  119. # `a - (b+c)` is NOT equivalent to `a - b+c` (same precedence).
  120. rval_str = self._parenthesize_if(
  121. n.right,
  122. lambda d: not (
  123. self._is_simple_node(d)
  124. or self.reduce_parentheses
  125. and isinstance(d, c_ast.BinaryOp)
  126. and self.precedence_map[d.op] > self.precedence_map[n.op]
  127. ),
  128. )
  129. return f"{lval_str} {n.op} {rval_str}"
  130. def visit_Assignment(self, n: c_ast.Assignment) -> str:
  131. rval_str = self._parenthesize_if(
  132. n.rvalue, lambda n: isinstance(n, c_ast.Assignment)
  133. )
  134. return f"{self.visit(n.lvalue)} {n.op} {rval_str}"
  135. def visit_IdentifierType(self, n: c_ast.IdentifierType) -> str:
  136. return " ".join(n.names)
  137. def _visit_expr(self, n: c_ast.Node) -> str:
  138. match n:
  139. case c_ast.InitList():
  140. return "{" + self.visit(n) + "}"
  141. case c_ast.ExprList() | c_ast.Compound():
  142. return "(" + self.visit(n) + ")"
  143. case _:
  144. return self.visit(n)
  145. def visit_Decl(self, n: c_ast.Decl, no_type: bool = False) -> str:
  146. # no_type is used when a Decl is part of a DeclList, where the type is
  147. # explicitly only for the first declaration in a list.
  148. #
  149. s = n.name if no_type else self._generate_decl(n)
  150. if n.bitsize:
  151. s += " : " + self.visit(n.bitsize)
  152. if n.init:
  153. s += " = " + self._visit_expr(n.init)
  154. return s
  155. def visit_DeclList(self, n: c_ast.DeclList) -> str:
  156. s = self.visit(n.decls[0])
  157. if len(n.decls) > 1:
  158. s += ", " + ", ".join(
  159. self.visit_Decl(decl, no_type=True) for decl in n.decls[1:]
  160. )
  161. return s
  162. def visit_Typedef(self, n: c_ast.Typedef) -> str:
  163. s = ""
  164. if n.storage:
  165. s += " ".join(n.storage) + " "
  166. s += self._generate_type(n.type)
  167. return s
  168. def visit_Cast(self, n: c_ast.Cast) -> str:
  169. s = "(" + self._generate_type(n.to_type, emit_declname=False) + ")"
  170. return s + " " + self._parenthesize_unless_simple(n.expr)
  171. def visit_ExprList(self, n: c_ast.ExprList) -> str:
  172. visited_subexprs = []
  173. for expr in n.exprs:
  174. visited_subexprs.append(self._visit_expr(expr))
  175. return ", ".join(visited_subexprs)
  176. def visit_InitList(self, n: c_ast.InitList) -> str:
  177. visited_subexprs = []
  178. for expr in n.exprs:
  179. visited_subexprs.append(self._visit_expr(expr))
  180. return ", ".join(visited_subexprs)
  181. def visit_Enum(self, n: c_ast.Enum) -> str:
  182. return self._generate_struct_union_enum(n, name="enum")
  183. def visit_Alignas(self, n: c_ast.Alignas) -> str:
  184. return "_Alignas({})".format(self.visit(n.alignment))
  185. def visit_Enumerator(self, n: c_ast.Enumerator) -> str:
  186. if not n.value:
  187. return "{indent}{name},\n".format(
  188. indent=self._make_indent(),
  189. name=n.name,
  190. )
  191. else:
  192. return "{indent}{name} = {value},\n".format(
  193. indent=self._make_indent(),
  194. name=n.name,
  195. value=self.visit(n.value),
  196. )
  197. def visit_FuncDef(self, n: c_ast.FuncDef) -> str:
  198. decl = self.visit(n.decl)
  199. self.indent_level = 0
  200. body = self.visit(n.body)
  201. if n.param_decls:
  202. knrdecls = ";\n".join(self.visit(p) for p in n.param_decls)
  203. return decl + "\n" + knrdecls + ";\n" + body + "\n"
  204. else:
  205. return decl + "\n" + body + "\n"
  206. def visit_FileAST(self, n: c_ast.FileAST) -> str:
  207. s = ""
  208. for ext in n.ext:
  209. match ext:
  210. case c_ast.FuncDef():
  211. s += self.visit(ext)
  212. case c_ast.Pragma():
  213. s += self.visit(ext) + "\n"
  214. case _:
  215. s += self.visit(ext) + ";\n"
  216. return s
  217. def visit_Compound(self, n: c_ast.Compound) -> str:
  218. s = self._make_indent() + "{\n"
  219. self.indent_level += 2
  220. if n.block_items:
  221. s += "".join(self._generate_stmt(stmt) for stmt in n.block_items)
  222. self.indent_level -= 2
  223. s += self._make_indent() + "}\n"
  224. return s
  225. def visit_CompoundLiteral(self, n: c_ast.CompoundLiteral) -> str:
  226. return "(" + self.visit(n.type) + "){" + self.visit(n.init) + "}"
  227. def visit_EmptyStatement(self, n: c_ast.EmptyStatement) -> str:
  228. return ";"
  229. def visit_ParamList(self, n: c_ast.ParamList) -> str:
  230. return ", ".join(self.visit(param) for param in n.params)
  231. def visit_Return(self, n: c_ast.Return) -> str:
  232. s = "return"
  233. if n.expr:
  234. s += " " + self.visit(n.expr)
  235. return s + ";"
  236. def visit_Break(self, n: c_ast.Break) -> str:
  237. return "break;"
  238. def visit_Continue(self, n: c_ast.Continue) -> str:
  239. return "continue;"
  240. def visit_TernaryOp(self, n: c_ast.TernaryOp) -> str:
  241. s = "(" + self._visit_expr(n.cond) + ") ? "
  242. s += "(" + self._visit_expr(n.iftrue) + ") : "
  243. s += "(" + self._visit_expr(n.iffalse) + ")"
  244. return s
  245. def visit_If(self, n: c_ast.If) -> str:
  246. s = "if ("
  247. if n.cond:
  248. s += self.visit(n.cond)
  249. s += ")\n"
  250. s += self._generate_stmt(n.iftrue, add_indent=True)
  251. if n.iffalse:
  252. s += self._make_indent() + "else\n"
  253. s += self._generate_stmt(n.iffalse, add_indent=True)
  254. return s
  255. def visit_For(self, n: c_ast.For) -> str:
  256. s = "for ("
  257. if n.init:
  258. s += self.visit(n.init)
  259. s += ";"
  260. if n.cond:
  261. s += " " + self.visit(n.cond)
  262. s += ";"
  263. if n.next:
  264. s += " " + self.visit(n.next)
  265. s += ")\n"
  266. s += self._generate_stmt(n.stmt, add_indent=True)
  267. return s
  268. def visit_While(self, n: c_ast.While) -> str:
  269. s = "while ("
  270. if n.cond:
  271. s += self.visit(n.cond)
  272. s += ")\n"
  273. s += self._generate_stmt(n.stmt, add_indent=True)
  274. return s
  275. def visit_DoWhile(self, n: c_ast.DoWhile) -> str:
  276. s = "do\n"
  277. s += self._generate_stmt(n.stmt, add_indent=True)
  278. s += self._make_indent() + "while ("
  279. if n.cond:
  280. s += self.visit(n.cond)
  281. s += ");"
  282. return s
  283. def visit_StaticAssert(self, n: c_ast.StaticAssert) -> str:
  284. s = "_Static_assert("
  285. s += self.visit(n.cond)
  286. if n.message:
  287. s += ","
  288. s += self.visit(n.message)
  289. s += ")"
  290. return s
  291. def visit_Switch(self, n: c_ast.Switch) -> str:
  292. s = "switch (" + self.visit(n.cond) + ")\n"
  293. s += self._generate_stmt(n.stmt, add_indent=True)
  294. return s
  295. def visit_Case(self, n: c_ast.Case) -> str:
  296. s = "case " + self.visit(n.expr) + ":\n"
  297. for stmt in n.stmts:
  298. s += self._generate_stmt(stmt, add_indent=True)
  299. return s
  300. def visit_Default(self, n: c_ast.Default) -> str:
  301. s = "default:\n"
  302. for stmt in n.stmts:
  303. s += self._generate_stmt(stmt, add_indent=True)
  304. return s
  305. def visit_Label(self, n: c_ast.Label) -> str:
  306. return n.name + ":\n" + self._generate_stmt(n.stmt)
  307. def visit_Goto(self, n: c_ast.Goto) -> str:
  308. return "goto " + n.name + ";"
  309. def visit_EllipsisParam(self, n: c_ast.EllipsisParam) -> str:
  310. return "..."
  311. def visit_Struct(self, n: c_ast.Struct) -> str:
  312. return self._generate_struct_union_enum(n, "struct")
  313. def visit_Typename(self, n: c_ast.Typename) -> str:
  314. return self._generate_type(n.type)
  315. def visit_Union(self, n: c_ast.Union) -> str:
  316. return self._generate_struct_union_enum(n, "union")
  317. def visit_NamedInitializer(self, n: c_ast.NamedInitializer) -> str:
  318. s = ""
  319. for name in n.name:
  320. if isinstance(name, c_ast.ID):
  321. s += "." + name.name
  322. else:
  323. s += "[" + self.visit(name) + "]"
  324. s += " = " + self._visit_expr(n.expr)
  325. return s
  326. def visit_FuncDecl(self, n: c_ast.FuncDecl) -> str:
  327. return self._generate_type(n)
  328. def visit_ArrayDecl(self, n: c_ast.ArrayDecl) -> str:
  329. return self._generate_type(n, emit_declname=False)
  330. def visit_TypeDecl(self, n: c_ast.TypeDecl) -> str:
  331. return self._generate_type(n, emit_declname=False)
  332. def visit_PtrDecl(self, n: c_ast.PtrDecl) -> str:
  333. return self._generate_type(n, emit_declname=False)
  334. def _generate_struct_union_enum(
  335. self, n: c_ast.Struct | c_ast.Union | c_ast.Enum, name: str
  336. ) -> str:
  337. """Generates code for structs, unions, and enums. name should be
  338. 'struct', 'union', or 'enum'.
  339. """
  340. if name in ("struct", "union"):
  341. assert isinstance(n, (c_ast.Struct, c_ast.Union))
  342. members = n.decls
  343. body_function = self._generate_struct_union_body
  344. else:
  345. assert name == "enum"
  346. assert isinstance(n, c_ast.Enum)
  347. members = None if n.values is None else n.values.enumerators
  348. body_function = self._generate_enum_body
  349. s = name + " " + (n.name or "")
  350. if members is not None:
  351. # None means no members
  352. # Empty sequence means an empty list of members
  353. s += "\n"
  354. s += self._make_indent()
  355. self.indent_level += 2
  356. s += "{\n"
  357. s += body_function(members)
  358. self.indent_level -= 2
  359. s += self._make_indent() + "}"
  360. return s
  361. def _generate_struct_union_body(self, members: List[c_ast.Node]) -> str:
  362. return "".join(self._generate_stmt(decl) for decl in members)
  363. def _generate_enum_body(self, members: List[c_ast.Enumerator]) -> str:
  364. # `[:-2] + '\n'` removes the final `,` from the enumerator list
  365. return "".join(self.visit(value) for value in members)[:-2] + "\n"
  366. def _generate_stmt(self, n: c_ast.Node, add_indent: bool = False) -> str:
  367. """Generation from a statement node. This method exists as a wrapper
  368. for individual visit_* methods to handle different treatment of
  369. some statements in this context.
  370. """
  371. if add_indent:
  372. self.indent_level += 2
  373. indent = self._make_indent()
  374. if add_indent:
  375. self.indent_level -= 2
  376. match n:
  377. case (
  378. c_ast.Decl()
  379. | c_ast.Assignment()
  380. | c_ast.Cast()
  381. | c_ast.UnaryOp()
  382. | c_ast.BinaryOp()
  383. | c_ast.TernaryOp()
  384. | c_ast.FuncCall()
  385. | c_ast.ArrayRef()
  386. | c_ast.StructRef()
  387. | c_ast.Constant()
  388. | c_ast.ID()
  389. | c_ast.Typedef()
  390. | c_ast.ExprList()
  391. ):
  392. # These can also appear in an expression context so no semicolon
  393. # is added to them automatically
  394. #
  395. return indent + self.visit(n) + ";\n"
  396. case c_ast.Compound():
  397. # No extra indentation required before the opening brace of a
  398. # compound - because it consists of multiple lines it has to
  399. # compute its own indentation.
  400. #
  401. return self.visit(n)
  402. case c_ast.If():
  403. return indent + self.visit(n)
  404. case _:
  405. return indent + self.visit(n) + "\n"
  406. def _generate_decl(self, n: c_ast.Decl) -> str:
  407. """Generation from a Decl node."""
  408. s = ""
  409. if n.funcspec:
  410. s = " ".join(n.funcspec) + " "
  411. if n.storage:
  412. s += " ".join(n.storage) + " "
  413. if n.align:
  414. s += self.visit(n.align[0]) + " "
  415. s += self._generate_type(n.type)
  416. return s
  417. def _generate_type(
  418. self,
  419. n: c_ast.Node,
  420. modifiers: List[c_ast.Node] = [],
  421. emit_declname: bool = True,
  422. ) -> str:
  423. """Recursive generation from a type node. n is the type node.
  424. modifiers collects the PtrDecl, ArrayDecl and FuncDecl modifiers
  425. encountered on the way down to a TypeDecl, to allow proper
  426. generation from it.
  427. """
  428. # ~ print(n, modifiers)
  429. match n:
  430. case c_ast.TypeDecl():
  431. s = ""
  432. if n.quals:
  433. s += " ".join(n.quals) + " "
  434. s += self.visit(n.type)
  435. nstr = n.declname if n.declname and emit_declname else ""
  436. # Resolve modifiers.
  437. # Wrap in parens to distinguish pointer to array and pointer to
  438. # function syntax.
  439. #
  440. for i, modifier in enumerate(modifiers):
  441. match modifier:
  442. case c_ast.ArrayDecl():
  443. if i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl):
  444. nstr = "(" + nstr + ")"
  445. nstr += "["
  446. if modifier.dim_quals:
  447. nstr += " ".join(modifier.dim_quals) + " "
  448. if modifier.dim is not None:
  449. nstr += self.visit(modifier.dim)
  450. nstr += "]"
  451. case c_ast.FuncDecl():
  452. if i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl):
  453. nstr = "(" + nstr + ")"
  454. args = (
  455. self.visit(modifier.args)
  456. if modifier.args is not None
  457. else ""
  458. )
  459. nstr += "(" + args + ")"
  460. case c_ast.PtrDecl():
  461. if modifier.quals:
  462. quals = " ".join(modifier.quals)
  463. suffix = f" {nstr}" if nstr else ""
  464. nstr = f"* {quals}{suffix}"
  465. else:
  466. nstr = "*" + nstr
  467. if nstr:
  468. s += " " + nstr
  469. return s
  470. case c_ast.Decl():
  471. return self._generate_decl(n.type)
  472. case c_ast.Typename():
  473. return self._generate_type(n.type, emit_declname=emit_declname)
  474. case c_ast.IdentifierType():
  475. return " ".join(n.names) + " "
  476. case c_ast.ArrayDecl() | c_ast.PtrDecl() | c_ast.FuncDecl():
  477. return self._generate_type(
  478. n.type, modifiers + [n], emit_declname=emit_declname
  479. )
  480. case _:
  481. return self.visit(n)
  482. def _parenthesize_if(
  483. self, n: c_ast.Node, condition: Callable[[c_ast.Node], bool]
  484. ) -> str:
  485. """Visits 'n' and returns its string representation, parenthesized
  486. if the condition function applied to the node returns True.
  487. """
  488. s = self._visit_expr(n)
  489. if condition(n):
  490. return "(" + s + ")"
  491. else:
  492. return s
  493. def _parenthesize_unless_simple(self, n: c_ast.Node) -> str:
  494. """Common use case for _parenthesize_if"""
  495. return self._parenthesize_if(n, lambda d: not self._is_simple_node(d))
  496. def _is_simple_node(self, n: c_ast.Node) -> bool:
  497. """Returns True for nodes that are "simple" - i.e. nodes that always
  498. have higher precedence than operators.
  499. """
  500. return isinstance(
  501. n,
  502. (c_ast.Constant, c_ast.ID, c_ast.ArrayRef, c_ast.StructRef, c_ast.FuncCall),
  503. )