__init__.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. """
  2. Common utils for testing.
  3. These functions allow testing only some frameworks, not all.
  4. """
  5. import logging
  6. import os
  7. import warnings
  8. from functools import lru_cache
  9. from typing import List, Tuple
  10. from einops import _backends
  11. __author__ = "Alex Rogozhnikov"
  12. # minimize noise in tests logging
  13. logging.getLogger("tensorflow").disabled = True
  14. logging.getLogger("matplotlib").disabled = True
  15. FLOAT_REDUCTIONS = ("min", "max", "sum", "mean", "prod") # not includes any/all
  16. def find_names_of_all_frameworks() -> List[str]:
  17. backend_subclasses = []
  18. backends = _backends.AbstractBackend.__subclasses__()
  19. while backends:
  20. backend = backends.pop()
  21. backends += backend.__subclasses__()
  22. backend_subclasses.append(backend)
  23. return [b.framework_name for b in backend_subclasses]
  24. ENVVAR_NAME = "EINOPS_TEST_BACKENDS"
  25. def unparse_backends(backend_names: List[str]) -> Tuple[str, str]:
  26. _known_backends = find_names_of_all_frameworks()
  27. for backend_name in backend_names:
  28. if backend_name not in _known_backends:
  29. raise RuntimeError(f"Unknown framework: {backend_name}")
  30. return ENVVAR_NAME, ",".join(backend_names)
  31. @lru_cache(maxsize=1)
  32. def parse_backends_to_test() -> List[str]:
  33. if ENVVAR_NAME not in os.environ:
  34. raise RuntimeError(f"Testing frameworks were not specified, env var {ENVVAR_NAME} not set")
  35. parsed_backends = os.environ[ENVVAR_NAME].split(",")
  36. _known_backends = find_names_of_all_frameworks()
  37. for backend_name in parsed_backends:
  38. if backend_name not in _known_backends:
  39. raise RuntimeError(f"Unknown framework: {backend_name}")
  40. return parsed_backends
  41. def is_backend_tested(backend: str) -> bool:
  42. """Used to skip test if corresponding backend is not tested"""
  43. if backend not in find_names_of_all_frameworks():
  44. raise RuntimeError(f"Unknown framework {backend}")
  45. return backend in parse_backends_to_test()
  46. def collect_test_backends(symbolic=False, layers=False) -> List[_backends.AbstractBackend]:
  47. """
  48. :param symbolic: symbolic or imperative frameworks?
  49. :param layers: layers or operations?
  50. :return: list of backends satisfying set conditions
  51. """
  52. if not symbolic:
  53. if not layers:
  54. backend_types = [
  55. _backends.NumpyBackend,
  56. _backends.JaxBackend,
  57. _backends.TorchBackend,
  58. _backends.TensorflowBackend,
  59. _backends.OneFlowBackend,
  60. _backends.PaddleBackend,
  61. _backends.CupyBackend,
  62. ]
  63. else:
  64. backend_types = [
  65. _backends.TorchBackend,
  66. _backends.OneFlowBackend,
  67. _backends.PaddleBackend,
  68. ]
  69. else:
  70. if not layers:
  71. backend_types = [
  72. _backends.PyTensorBackend,
  73. ]
  74. else:
  75. backend_types = [
  76. _backends.TFKerasBackend,
  77. ]
  78. backend_names_to_test = parse_backends_to_test()
  79. result = []
  80. for backend_type in backend_types:
  81. if backend_type.framework_name not in backend_names_to_test:
  82. continue
  83. try:
  84. result.append(backend_type())
  85. except ImportError:
  86. # problem with backend installation fails a specific test function,
  87. # but will be skipped in all other test cases
  88. warnings.warn(f"backend could not be initialized for tests: {backend_type}", stacklevel=1)
  89. return result