torch_version.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from collections.abc import Iterable
  2. from typing import Any
  3. from torch._vendor.packaging.version import InvalidVersion, Version
  4. from torch.version import __version__ as internal_version
  5. __all__ = ["TorchVersion"]
  6. class TorchVersion(str):
  7. """A string with magic powers to compare to both Version and iterables!
  8. Prior to 1.10.0 torch.__version__ was stored as a str and so many did
  9. comparisons against torch.__version__ as if it were a str. In order to not
  10. break them we have TorchVersion which masquerades as a str while also
  11. having the ability to compare against both packaging.version.Version as
  12. well as tuples of values, eg. (1, 2, 1)
  13. Examples:
  14. Comparing a TorchVersion object to a Version object
  15. TorchVersion('1.10.0a') > Version('1.10.0a')
  16. Comparing a TorchVersion object to a Tuple object
  17. TorchVersion('1.10.0a') > (1, 2) # 1.2
  18. TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
  19. Comparing a TorchVersion object against a string
  20. TorchVersion('1.10.0a') > '1.2'
  21. TorchVersion('1.10.0a') > '1.2.1'
  22. """
  23. __slots__ = ()
  24. # fully qualified type names here to appease mypy
  25. def _convert_to_version(self, inp: Any) -> Any:
  26. if isinstance(inp, Version):
  27. return inp
  28. elif isinstance(inp, str):
  29. return Version(inp)
  30. elif isinstance(inp, Iterable):
  31. # Ideally this should work for most cases by attempting to group
  32. # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH)
  33. # Examples:
  34. # * (1) -> Version("1")
  35. # * (1, 20) -> Version("1.20")
  36. # * (1, 20, 1) -> Version("1.20.1")
  37. return Version(".".join(str(item) for item in inp))
  38. else:
  39. raise InvalidVersion(inp)
  40. def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
  41. try:
  42. return getattr(Version(self), method)(self._convert_to_version(cmp))
  43. except BaseException as e:
  44. if not isinstance(e, InvalidVersion):
  45. raise
  46. # Fall back to regular string comparison if dealing with an invalid
  47. # version like 'parrot'
  48. return getattr(super(), method)(cmp)
  49. for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
  50. setattr(
  51. TorchVersion,
  52. cmp_method,
  53. lambda x, y, method=cmp_method: x._cmp_wrapper(y, method),
  54. )
  55. __version__ = TorchVersion(internal_version)