test_parsing.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import pytest
  2. from einops import EinopsError
  3. from einops.parsing import AnonymousAxis, ParsedExpression, _ellipsis
  4. __author__ = "Alex Rogozhnikov"
  5. class AnonymousAxisPlaceholder:
  6. def __init__(self, value: int):
  7. self.value = value
  8. assert isinstance(self.value, int)
  9. def __eq__(self, other):
  10. return isinstance(other, AnonymousAxis) and self.value == other.value
  11. def test_anonymous_axes():
  12. a, b = AnonymousAxis("2"), AnonymousAxis("2")
  13. assert a != b
  14. c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3)
  15. assert a == c and b == c
  16. assert a != d and b != d
  17. assert [a, 2, b] == [c, 2, c]
  18. def test_elementary_axis_name():
  19. for name in [
  20. "a",
  21. "b",
  22. "h",
  23. "dx",
  24. "h1",
  25. "zz",
  26. "i9123",
  27. "somelongname",
  28. "Alex",
  29. "camelCase",
  30. "u_n_d_e_r_score",
  31. "unreasonablyLongAxisName",
  32. ]:
  33. assert ParsedExpression.check_axis_name(name)
  34. for name in ["", "2b", "12", "_startWithUnderscore", "endWithUnderscore_", "_", "...", _ellipsis]:
  35. assert not ParsedExpression.check_axis_name(name)
  36. def test_invalid_expressions():
  37. # double ellipsis should raise an error
  38. ParsedExpression("... a b c d")
  39. with pytest.raises(EinopsError):
  40. ParsedExpression("... a b c d ...")
  41. with pytest.raises(EinopsError):
  42. ParsedExpression("... a b c (d ...)")
  43. with pytest.raises(EinopsError):
  44. ParsedExpression("(... a) b c (d ...)")
  45. # double/missing/enclosed parenthesis
  46. ParsedExpression("(a) b c (d ...)")
  47. with pytest.raises(EinopsError):
  48. ParsedExpression("(a)) b c (d ...)")
  49. with pytest.raises(EinopsError):
  50. ParsedExpression("(a b c (d ...)")
  51. with pytest.raises(EinopsError):
  52. ParsedExpression("(a) (()) b c (d ...)")
  53. with pytest.raises(EinopsError):
  54. ParsedExpression("(a) ((b c) (d ...))")
  55. # invalid identifiers
  56. ParsedExpression("camelCase under_scored cApiTaLs ß ...")
  57. with pytest.raises(EinopsError):
  58. ParsedExpression("1a")
  59. with pytest.raises(EinopsError):
  60. ParsedExpression("_pre")
  61. with pytest.raises(EinopsError):
  62. ParsedExpression("...pre")
  63. with pytest.raises(EinopsError):
  64. ParsedExpression("pre...")
  65. def test_parse_expression():
  66. parsed = ParsedExpression("a1 b1 c1 d1")
  67. assert parsed.identifiers == {"a1", "b1", "c1", "d1"}
  68. assert parsed.composition == [["a1"], ["b1"], ["c1"], ["d1"]]
  69. assert not parsed.has_non_unitary_anonymous_axes
  70. assert not parsed.has_ellipsis
  71. parsed = ParsedExpression("() () () ()")
  72. assert parsed.identifiers == set()
  73. assert parsed.composition == [[], [], [], []]
  74. assert not parsed.has_non_unitary_anonymous_axes
  75. assert not parsed.has_ellipsis
  76. parsed = ParsedExpression("1 1 1 ()")
  77. assert parsed.identifiers == set()
  78. assert parsed.composition == [[], [], [], []]
  79. assert not parsed.has_non_unitary_anonymous_axes
  80. assert not parsed.has_ellipsis
  81. aap = AnonymousAxisPlaceholder
  82. parsed = ParsedExpression("5 (3 4)")
  83. assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5}
  84. assert parsed.composition == [[aap(5)], [aap(3), aap(4)]]
  85. assert parsed.has_non_unitary_anonymous_axes
  86. assert not parsed.has_ellipsis
  87. parsed = ParsedExpression("5 1 (1 4) 1")
  88. assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5}
  89. assert parsed.composition == [[aap(5)], [], [aap(4)], []]
  90. parsed = ParsedExpression("name1 ... a1 12 (name2 14)")
  91. assert len(parsed.identifiers) == 6
  92. assert parsed.identifiers.difference({"name1", _ellipsis, "a1", "name2"}).__len__() == 2
  93. assert parsed.composition == [["name1"], _ellipsis, ["a1"], [aap(12)], ["name2", aap(14)]]
  94. assert parsed.has_non_unitary_anonymous_axes
  95. assert parsed.has_ellipsis
  96. assert not parsed.has_ellipsis_parenthesized
  97. parsed = ParsedExpression("(name1 ... a1 12) name2 14")
  98. assert len(parsed.identifiers) == 6
  99. assert parsed.identifiers.difference({"name1", _ellipsis, "a1", "name2"}).__len__() == 2
  100. assert parsed.composition == [["name1", _ellipsis, "a1", aap(12)], ["name2"], [aap(14)]]
  101. assert parsed.has_non_unitary_anonymous_axes
  102. assert parsed.has_ellipsis
  103. assert parsed.has_ellipsis_parenthesized