validation_config.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from dataclasses import dataclass
  2. from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol
  3. from ray.util.annotations import PublicAPI
  4. if TYPE_CHECKING:
  5. from ray.train import Checkpoint
  6. @PublicAPI(stability="alpha")
  7. class ValidationFn(Protocol):
  8. """Protocol for a function that validates a checkpoint."""
  9. def __call__(self, checkpoint: "Checkpoint", **kwargs: Any) -> Dict:
  10. ...
  11. @dataclass
  12. @PublicAPI(stability="alpha")
  13. class ValidationTaskConfig:
  14. """Configuration for a specific validation task, passed to report().
  15. Args:
  16. fn_kwargs: json-serializable keyword arguments to pass to the validation function.
  17. Note that we always pass `checkpoint` as the first argument to the
  18. validation function.
  19. """
  20. fn_kwargs: Optional[Dict[str, Any]] = None
  21. def __post_init__(self):
  22. if self.fn_kwargs is None:
  23. self.fn_kwargs = {}
  24. @PublicAPI(stability="alpha")
  25. class ValidationConfig:
  26. """Configuration for validation, passed to the trainer.
  27. Args:
  28. fn: The validation function to run on checkpoints.
  29. This function should accept a checkpoint as the first argument
  30. and return a dictionary of metrics.
  31. task_config: Default configuration for validation tasks.
  32. The fn_kwargs in this config can be overridden by
  33. ValidationTaskConfig passed to report().
  34. ray_remote_kwargs: Keyword arguments to pass to `ray.remote()` for the validation task.
  35. This can be used to specify resource requirements, number of retries, etc.
  36. """
  37. def __init__(
  38. self,
  39. fn: ValidationFn,
  40. task_config: Optional[ValidationTaskConfig] = None,
  41. ray_remote_kwargs: Optional[Dict[str, Any]] = None,
  42. ):
  43. self.fn = fn
  44. if task_config is None:
  45. self.task_config = ValidationTaskConfig()
  46. else:
  47. self.task_config = task_config
  48. # TODO: ray_remote_kwargs is not json-serializable because retry_exceptions
  49. # can be a list of exception types. If ray core makes ray_remote_kwargs json-serializable
  50. # we can move this to ValidationTaskConfig.
  51. if ray_remote_kwargs is None:
  52. self.ray_remote_kwargs = {}
  53. else:
  54. self.ray_remote_kwargs = ray_remote_kwargs