glob_group.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # mypy: allow-untyped-defs
  2. import re
  3. from collections.abc import Iterable
  4. from typing import Union
  5. GlobPattern = Union[str, Iterable[str]]
  6. class GlobGroup:
  7. """A set of patterns that candidate strings will be matched against.
  8. A candidate is composed of a list of segments separated by ``separator``, e.g. "foo.bar.baz".
  9. A pattern contains one or more segments. Segments can be:
  10. - A literal string (e.g. "foo"), which matches exactly.
  11. - A string containing a wildcard (e.g. "torch*", or "foo*baz*"). The wildcard matches
  12. any string, including the empty string.
  13. - A double wildcard ("**"). This matches against zero or more complete segments.
  14. Examples:
  15. ``torch.**``: matches ``torch`` and all its submodules, e.g. ``torch.nn`` and ``torch.nn.functional``.
  16. ``torch.*``: matches ``torch.nn`` or ``torch.functional``, but not ``torch.nn.functional``.
  17. ``torch*.**``: matches ``torch``, ``torchvision``, and all their submodules.
  18. A candidates will match the ``GlobGroup`` if it matches any of the ``include`` patterns and
  19. none of the ``exclude`` patterns.
  20. Args:
  21. include (Union[str, Iterable[str]]): A string or list of strings,
  22. each representing a pattern to be matched against. A candidate
  23. will match if it matches *any* include pattern
  24. exclude (Union[str, Iterable[str]]): A string or list of strings,
  25. each representing a pattern to be matched against. A candidate
  26. will be excluded from matching if it matches *any* exclude pattern.
  27. separator (str): A string that delimits segments in candidates and
  28. patterns. By default this is "." which corresponds to how modules are
  29. named in Python. Another common value for this is "/", which is
  30. the Unix path separator.
  31. """
  32. def __init__(
  33. self, include: GlobPattern, *, exclude: GlobPattern = (), separator: str = "."
  34. ):
  35. self._dbg = f"GlobGroup(include={include}, exclude={exclude})"
  36. self.include = GlobGroup._glob_list(include, separator)
  37. self.exclude = GlobGroup._glob_list(exclude, separator)
  38. self.separator = separator
  39. def __str__(self):
  40. return self._dbg
  41. def __repr__(self):
  42. return self._dbg
  43. def matches(self, candidate: str) -> bool:
  44. candidate = self.separator + candidate
  45. return any(p.fullmatch(candidate) for p in self.include) and all(
  46. not p.fullmatch(candidate) for p in self.exclude
  47. )
  48. @staticmethod
  49. def _glob_list(elems: GlobPattern, separator: str = "."):
  50. if isinstance(elems, str):
  51. return [GlobGroup._glob_to_re(elems, separator)]
  52. else:
  53. return [GlobGroup._glob_to_re(e, separator) for e in elems]
  54. @staticmethod
  55. def _glob_to_re(pattern: str, separator: str = "."):
  56. # to avoid corner cases for the first component, we prefix the candidate string
  57. # with '.' so `import torch` will regex against `.torch`, assuming '.' is the separator
  58. def component_to_re(component):
  59. if "**" in component:
  60. if component == "**":
  61. return "(" + re.escape(separator) + "[^" + separator + "]+)*"
  62. else:
  63. raise ValueError("** can only appear as an entire path segment")
  64. else:
  65. return re.escape(separator) + ("[^" + separator + "]*").join(
  66. re.escape(x) for x in component.split("*")
  67. )
  68. result = "".join(component_to_re(c) for c in pattern.split(separator))
  69. return re.compile(result)