_utils.py 955 B

123456789101112131415161718192021222324252627282930313233
  1. import enum
  2. from collections.abc import Sequence
  3. from typing import TypeVar
  4. T = TypeVar("T", bound=enum.Enum)
  5. class StrEnumMeta(enum.EnumMeta):
  6. auto = enum.auto
  7. def from_str(self: type[T], member: str) -> T: # type: ignore[misc]
  8. try:
  9. return self[member]
  10. except KeyError:
  11. # TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
  12. # soon as it is migrated.
  13. raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
  14. class StrEnum(enum.Enum, metaclass=StrEnumMeta):
  15. pass
  16. def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
  17. if not seq:
  18. return ""
  19. if len(seq) == 1:
  20. return f"'{seq[0]}'"
  21. head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
  22. tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
  23. return head + tail