run_moment.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import Union
  4. from urllib import parse
  5. @dataclass
  6. class RunMoment:
  7. """A moment in a run.
  8. Defines a branching point in a finished run to fork or resume from.
  9. A run moment is identified by a run ID and a metric value.
  10. Currently, only the metric '_step' is supported.
  11. """
  12. run: str
  13. """run ID"""
  14. value: Union[int, float]
  15. """Value of the metric."""
  16. metric: str = "_step"
  17. """Metric to use to determine the moment in the run.
  18. Currently, only the metric '_step' is supported.
  19. In future, this will be relaxed to be any metric.
  20. """
  21. def __post_init__(self):
  22. if self.metric != "_step":
  23. raise ValueError(
  24. f"Only the metric '_step' is supported, got '{self.metric}'."
  25. )
  26. if not isinstance(self.value, (int, float)):
  27. raise TypeError(
  28. f"Only int or float values are supported, got '{self.value}'."
  29. )
  30. if not isinstance(self.run, str):
  31. raise TypeError(f"Only string run names are supported, got '{self.run}'.")
  32. @classmethod
  33. def from_uri(cls, uri: str) -> RunMoment:
  34. parsable = "runmoment://" + uri
  35. parse_err = ValueError(
  36. f"Could not parse passed run moment string '{uri}', "
  37. f"expected format '<run>?<metric>=<numeric_value>'. "
  38. f"Currently, only the metric '_step' is supported. "
  39. f"Example: 'ans3bsax?_step=123'."
  40. )
  41. try:
  42. parsed = parse.urlparse(parsable)
  43. except ValueError as e:
  44. raise parse_err from e
  45. if parsed.scheme != "runmoment":
  46. raise parse_err
  47. # extract run, metric, value from parsed
  48. if not parsed.netloc:
  49. raise parse_err
  50. run = parsed.netloc
  51. if parsed.path or parsed.params or parsed.fragment:
  52. raise parse_err
  53. query = parse.parse_qs(parsed.query)
  54. if len(query) != 1:
  55. raise parse_err
  56. metric = list(query.keys())[0]
  57. if metric != "_step":
  58. raise parse_err
  59. value: str = query[metric][0]
  60. try:
  61. num_value = int(value) if value.isdigit() else float(value)
  62. except ValueError as e:
  63. raise parse_err from e
  64. return cls(run=run, metric=metric, value=num_value)