etree.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. """Shim module exporting the same ElementTree API for lxml and
  2. xml.etree backends.
  3. When lxml is installed, it is automatically preferred over the built-in
  4. xml.etree module.
  5. On Python 2.7, the cElementTree module is preferred over the pure-python
  6. ElementTree module.
  7. Besides exporting a unified interface, this also defines extra functions
  8. or subclasses built-in ElementTree classes to add features that are
  9. only availble in lxml, like OrderedDict for attributes, pretty_print and
  10. iterwalk.
  11. """
  12. from fontTools.misc.textTools import tostr
  13. XML_DECLARATION = """<?xml version='1.0' encoding='%s'?>"""
  14. __all__ = [
  15. # public symbols
  16. "Comment",
  17. "dump",
  18. "Element",
  19. "ElementTree",
  20. "fromstring",
  21. "fromstringlist",
  22. "iselement",
  23. "iterparse",
  24. "parse",
  25. "ParseError",
  26. "PI",
  27. "ProcessingInstruction",
  28. "QName",
  29. "SubElement",
  30. "tostring",
  31. "tostringlist",
  32. "TreeBuilder",
  33. "XML",
  34. "XMLParser",
  35. "register_namespace",
  36. ]
  37. try:
  38. from lxml.etree import *
  39. _have_lxml = True
  40. except ImportError:
  41. try:
  42. from xml.etree.cElementTree import *
  43. # the cElementTree version of XML function doesn't support
  44. # the optional 'parser' keyword argument
  45. from xml.etree.ElementTree import XML
  46. except ImportError: # pragma: no cover
  47. from xml.etree.ElementTree import *
  48. _have_lxml = False
  49. _Attrib = dict
  50. if isinstance(Element, type):
  51. _Element = Element
  52. else:
  53. # in py27, cElementTree.Element cannot be subclassed, so
  54. # we need to import the pure-python class
  55. from xml.etree.ElementTree import Element as _Element
  56. class Element(_Element):
  57. """Element subclass that keeps the order of attributes."""
  58. def __init__(self, tag, attrib=_Attrib(), **extra):
  59. super(Element, self).__init__(tag)
  60. self.attrib = _Attrib()
  61. if attrib:
  62. self.attrib.update(attrib)
  63. if extra:
  64. self.attrib.update(extra)
  65. def SubElement(parent, tag, attrib=_Attrib(), **extra):
  66. """Must override SubElement as well otherwise _elementtree.SubElement
  67. fails if 'parent' is a subclass of Element object.
  68. """
  69. element = parent.__class__(tag, attrib, **extra)
  70. parent.append(element)
  71. return element
  72. def _iterwalk(element, events, tag):
  73. include = tag is None or element.tag == tag
  74. if include and "start" in events:
  75. yield ("start", element)
  76. for e in element:
  77. for item in _iterwalk(e, events, tag):
  78. yield item
  79. if include:
  80. yield ("end", element)
  81. def iterwalk(element_or_tree, events=("end",), tag=None):
  82. """A tree walker that generates events from an existing tree as
  83. if it was parsing XML data with iterparse().
  84. Drop-in replacement for lxml.etree.iterwalk.
  85. """
  86. if iselement(element_or_tree):
  87. element = element_or_tree
  88. else:
  89. element = element_or_tree.getroot()
  90. if tag == "*":
  91. tag = None
  92. for item in _iterwalk(element, events, tag):
  93. yield item
  94. _ElementTree = ElementTree
  95. class ElementTree(_ElementTree):
  96. """ElementTree subclass that adds 'pretty_print' and 'doctype'
  97. arguments to the 'write' method.
  98. Currently these are only supported for the default XML serialization
  99. 'method', and not also for "html" or "text", for these are delegated
  100. to the base class.
  101. """
  102. def write(
  103. self,
  104. file_or_filename,
  105. encoding=None,
  106. xml_declaration=False,
  107. method=None,
  108. doctype=None,
  109. pretty_print=False,
  110. ):
  111. if method and method != "xml":
  112. # delegate to super-class
  113. super(ElementTree, self).write(
  114. file_or_filename,
  115. encoding=encoding,
  116. xml_declaration=xml_declaration,
  117. method=method,
  118. )
  119. return
  120. if encoding is not None and encoding.lower() == "unicode":
  121. if xml_declaration:
  122. raise ValueError(
  123. "Serialisation to unicode must not request an XML declaration"
  124. )
  125. write_declaration = False
  126. encoding = "unicode"
  127. elif xml_declaration is None:
  128. # by default, write an XML declaration only for non-standard encodings
  129. write_declaration = encoding is not None and encoding.upper() not in (
  130. "ASCII",
  131. "UTF-8",
  132. "UTF8",
  133. "US-ASCII",
  134. )
  135. else:
  136. write_declaration = xml_declaration
  137. if encoding is None:
  138. encoding = "ASCII"
  139. if pretty_print:
  140. # NOTE this will modify the tree in-place
  141. _indent(self._root)
  142. with _get_writer(file_or_filename, encoding) as write:
  143. if write_declaration:
  144. write(XML_DECLARATION % encoding.upper())
  145. if pretty_print:
  146. write("\n")
  147. if doctype:
  148. write(_tounicode(doctype))
  149. if pretty_print:
  150. write("\n")
  151. qnames, namespaces = _namespaces(self._root)
  152. _serialize_xml(write, self._root, qnames, namespaces)
  153. import io
  154. def tostring(
  155. element,
  156. encoding=None,
  157. xml_declaration=None,
  158. method=None,
  159. doctype=None,
  160. pretty_print=False,
  161. ):
  162. """Custom 'tostring' function that uses our ElementTree subclass, with
  163. pretty_print support.
  164. """
  165. stream = io.StringIO() if encoding == "unicode" else io.BytesIO()
  166. ElementTree(element).write(
  167. stream,
  168. encoding=encoding,
  169. xml_declaration=xml_declaration,
  170. method=method,
  171. doctype=doctype,
  172. pretty_print=pretty_print,
  173. )
  174. return stream.getvalue()
  175. # serialization support
  176. import re
  177. # Valid XML strings can include any Unicode character, excluding control
  178. # characters, the surrogate blocks, FFFE, and FFFF:
  179. # Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
  180. # Here we reversed the pattern to match only the invalid characters.
  181. _invalid_xml_string = re.compile(
  182. "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uD800-\uDFFF\uFFFE-\uFFFF]"
  183. )
  184. def _tounicode(s):
  185. """Test if a string is valid user input and decode it to unicode string
  186. using ASCII encoding if it's a bytes string.
  187. Reject all bytes/unicode input that contains non-XML characters.
  188. Reject all bytes input that contains non-ASCII characters.
  189. """
  190. try:
  191. s = tostr(s, encoding="ascii", errors="strict")
  192. except UnicodeDecodeError:
  193. raise ValueError(
  194. "Bytes strings can only contain ASCII characters. "
  195. "Use unicode strings for non-ASCII characters."
  196. )
  197. except AttributeError:
  198. _raise_serialization_error(s)
  199. if s and _invalid_xml_string.search(s):
  200. raise ValueError(
  201. "All strings must be XML compatible: Unicode or ASCII, "
  202. "no NULL bytes or control characters"
  203. )
  204. return s
  205. import contextlib
  206. @contextlib.contextmanager
  207. def _get_writer(file_or_filename, encoding):
  208. # returns text write method and release all resources after using
  209. try:
  210. write = file_or_filename.write
  211. except AttributeError:
  212. # file_or_filename is a file name
  213. f = open(
  214. file_or_filename,
  215. "w",
  216. encoding="utf-8" if encoding == "unicode" else encoding,
  217. errors="xmlcharrefreplace",
  218. )
  219. with f:
  220. yield f.write
  221. else:
  222. # file_or_filename is a file-like object
  223. # encoding determines if it is a text or binary writer
  224. if encoding == "unicode":
  225. # use a text writer as is
  226. yield write
  227. else:
  228. # wrap a binary writer with TextIOWrapper
  229. detach_buffer = False
  230. if isinstance(file_or_filename, io.BufferedIOBase):
  231. buf = file_or_filename
  232. elif isinstance(file_or_filename, io.RawIOBase):
  233. buf = io.BufferedWriter(file_or_filename)
  234. detach_buffer = True
  235. else:
  236. # This is to handle passed objects that aren't in the
  237. # IOBase hierarchy, but just have a write method
  238. buf = io.BufferedIOBase()
  239. buf.writable = lambda: True
  240. buf.write = write
  241. try:
  242. # TextIOWrapper uses this methods to determine
  243. # if BOM (for UTF-16, etc) should be added
  244. buf.seekable = file_or_filename.seekable
  245. buf.tell = file_or_filename.tell
  246. except AttributeError:
  247. pass
  248. wrapper = io.TextIOWrapper(
  249. buf,
  250. encoding=encoding,
  251. errors="xmlcharrefreplace",
  252. newline="\n",
  253. )
  254. try:
  255. yield wrapper.write
  256. finally:
  257. # Keep the original file open when the TextIOWrapper and
  258. # the BufferedWriter are destroyed
  259. wrapper.detach()
  260. if detach_buffer:
  261. buf.detach()
  262. from xml.etree.ElementTree import _namespace_map
  263. def _namespaces(elem):
  264. # identify namespaces used in this tree
  265. # maps qnames to *encoded* prefix:local names
  266. qnames = {None: None}
  267. # maps uri:s to prefixes
  268. namespaces = {}
  269. def add_qname(qname):
  270. # calculate serialized qname representation
  271. try:
  272. qname = _tounicode(qname)
  273. if qname[:1] == "{":
  274. uri, tag = qname[1:].rsplit("}", 1)
  275. prefix = namespaces.get(uri)
  276. if prefix is None:
  277. prefix = _namespace_map.get(uri)
  278. if prefix is None:
  279. prefix = "ns%d" % len(namespaces)
  280. else:
  281. prefix = _tounicode(prefix)
  282. if prefix != "xml":
  283. namespaces[uri] = prefix
  284. if prefix:
  285. qnames[qname] = "%s:%s" % (prefix, tag)
  286. else:
  287. qnames[qname] = tag # default element
  288. else:
  289. qnames[qname] = qname
  290. except TypeError:
  291. _raise_serialization_error(qname)
  292. # populate qname and namespaces table
  293. for elem in elem.iter():
  294. tag = elem.tag
  295. if isinstance(tag, QName):
  296. if tag.text not in qnames:
  297. add_qname(tag.text)
  298. elif isinstance(tag, str):
  299. if tag not in qnames:
  300. add_qname(tag)
  301. elif tag is not None and tag is not Comment and tag is not PI:
  302. _raise_serialization_error(tag)
  303. for key, value in elem.items():
  304. if isinstance(key, QName):
  305. key = key.text
  306. if key not in qnames:
  307. add_qname(key)
  308. if isinstance(value, QName) and value.text not in qnames:
  309. add_qname(value.text)
  310. text = elem.text
  311. if isinstance(text, QName) and text.text not in qnames:
  312. add_qname(text.text)
  313. return qnames, namespaces
  314. def _serialize_xml(write, elem, qnames, namespaces, **kwargs):
  315. tag = elem.tag
  316. text = elem.text
  317. if tag is Comment:
  318. write("<!--%s-->" % _tounicode(text))
  319. elif tag is ProcessingInstruction:
  320. write("<?%s?>" % _tounicode(text))
  321. else:
  322. tag = qnames[_tounicode(tag) if tag is not None else None]
  323. if tag is None:
  324. if text:
  325. write(_escape_cdata(text))
  326. for e in elem:
  327. _serialize_xml(write, e, qnames, None)
  328. else:
  329. write("<" + tag)
  330. if namespaces:
  331. for uri, prefix in sorted(
  332. namespaces.items(), key=lambda x: x[1]
  333. ): # sort on prefix
  334. if prefix:
  335. prefix = ":" + prefix
  336. write(' xmlns%s="%s"' % (prefix, _escape_attrib(uri)))
  337. attrs = elem.attrib
  338. if attrs:
  339. # try to keep existing attrib order
  340. if len(attrs) <= 1 or type(attrs) is _Attrib:
  341. items = attrs.items()
  342. else:
  343. # if plain dict, use lexical order
  344. items = sorted(attrs.items())
  345. for k, v in items:
  346. if isinstance(k, QName):
  347. k = _tounicode(k.text)
  348. else:
  349. k = _tounicode(k)
  350. if isinstance(v, QName):
  351. v = qnames[_tounicode(v.text)]
  352. else:
  353. v = _escape_attrib(v)
  354. write(' %s="%s"' % (qnames[k], v))
  355. if text is not None or len(elem):
  356. write(">")
  357. if text:
  358. write(_escape_cdata(text))
  359. for e in elem:
  360. _serialize_xml(write, e, qnames, None)
  361. write("</" + tag + ">")
  362. else:
  363. write("/>")
  364. if elem.tail:
  365. write(_escape_cdata(elem.tail))
  366. def _raise_serialization_error(text):
  367. raise TypeError("cannot serialize %r (type %s)" % (text, type(text).__name__))
  368. def _escape_cdata(text):
  369. # escape character data
  370. try:
  371. text = _tounicode(text)
  372. # it's worth avoiding do-nothing calls for short strings
  373. if "&" in text:
  374. text = text.replace("&", "&amp;")
  375. if "<" in text:
  376. text = text.replace("<", "&lt;")
  377. if ">" in text:
  378. text = text.replace(">", "&gt;")
  379. return text
  380. except (TypeError, AttributeError):
  381. _raise_serialization_error(text)
  382. def _escape_attrib(text):
  383. # escape attribute value
  384. try:
  385. text = _tounicode(text)
  386. if "&" in text:
  387. text = text.replace("&", "&amp;")
  388. if "<" in text:
  389. text = text.replace("<", "&lt;")
  390. if ">" in text:
  391. text = text.replace(">", "&gt;")
  392. if '"' in text:
  393. text = text.replace('"', "&quot;")
  394. if "\n" in text:
  395. text = text.replace("\n", "&#10;")
  396. return text
  397. except (TypeError, AttributeError):
  398. _raise_serialization_error(text)
  399. def _indent(elem, level=0):
  400. # From http://effbot.org/zone/element-lib.htm#prettyprint
  401. i = "\n" + level * " "
  402. if len(elem):
  403. if not elem.text or not elem.text.strip():
  404. elem.text = i + " "
  405. if not elem.tail or not elem.tail.strip():
  406. elem.tail = i
  407. for elem in elem:
  408. _indent(elem, level + 1)
  409. if not elem.tail or not elem.tail.strip():
  410. elem.tail = i
  411. else:
  412. if level and (not elem.tail or not elem.tail.strip()):
  413. elem.tail = i