trainer_jit_checkpoint.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. import signal
  3. import threading
  4. from .trainer_callback import TrainerCallback
  5. from .trainer_utils import PREFIX_CHECKPOINT_DIR
  6. from .utils import logging
  7. logger = logging.get_logger(__name__)
  8. class CheckpointManager:
  9. def __init__(self, trainer, kill_wait: int = 3):
  10. """
  11. Initialize the CheckpointManager for Just-In-Time checkpoint handling.
  12. Args:
  13. trainer: The Trainer instance that will be used to save checkpoints when SIGTERM is received.
  14. kill_wait (`int`, *optional*, defaults to 3): Grace period to distinguish between SIGTERM and SIGKILL.
  15. """
  16. self.trainer = trainer
  17. self.is_checkpoint_requested = False
  18. self._original_sigterm_handler = None
  19. self.kill_wait = kill_wait
  20. def setup_signal_handler(self):
  21. self._original_sigterm_handler = signal.signal(signal.SIGTERM, self._sigterm_handler)
  22. logger.info("JIT checkpoint signal handler registered for SIGTERM")
  23. def _sigterm_handler(self, signum, frame):
  24. if self.is_checkpoint_requested:
  25. return
  26. logger.info(f"SIGTERM received, will request JIT checkpoint after {self.kill_wait}s")
  27. threading.Timer(self.kill_wait, self._enable_checkpoint).start()
  28. def _enable_checkpoint(self):
  29. logger.info("Kill wait period elapsed, requesting checkpoint")
  30. self.is_checkpoint_requested = True
  31. def execute_jit_checkpoint(self):
  32. try:
  33. # Set checkpoint flag to False to avoid multiple checkpoints getting triggered by other callbacks
  34. self.is_checkpoint_requested = False
  35. logger.info("Starting JIT checkpointing...")
  36. current_step = self.trainer.state.global_step
  37. logger.info(f"Saving JIT checkpoint at step {current_step}")
  38. output_dir = self.trainer._get_output_dir(trial=None)
  39. checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{current_step}"
  40. checkpoint_path = os.path.join(output_dir, checkpoint_folder)
  41. # Create checkpoint directory
  42. os.makedirs(checkpoint_path, exist_ok=True)
  43. # Create a sentinel file to indicate checkpointing is in progress
  44. sentinel_file = os.path.join(output_dir, checkpoint_folder, "checkpoint-is-incomplete.txt")
  45. with open(sentinel_file, "w") as f:
  46. f.write(f"Checkpoint started at step {current_step} and in progress...")
  47. logger.info(f"Created checkpoint progress sentinel marker file: {sentinel_file}")
  48. # Invoke the trainer's checkpoint method directly
  49. self.trainer._save_checkpoint(self.trainer.model, trial=None)
  50. # Remove sentinel file upon successful checkpointing
  51. if os.path.exists(sentinel_file):
  52. os.remove(sentinel_file)
  53. logger.info("Sentinel marker file removed")
  54. logger.info("Immediate JIT checkpoint completed successfully")
  55. except Exception as e:
  56. logger.error(f"Failed to save JIT checkpoint: {e}")
  57. raise
  58. class JITCheckpointCallback(TrainerCallback):
  59. """
  60. Callback for Just-In-Time checkpointing on SIGTERM signals.
  61. When SIGTERM is received, the checkpoint manager sets `is_checkpoint_requested=True`.
  62. The callbacks detect this flag and set `control.should_training_stop=True`, which signals
  63. the Trainer's training loop to exit gracefully after saving the checkpoint.
  64. """
  65. def __init__(self):
  66. self.trainer = None
  67. self.jit_manager: CheckpointManager | None = None
  68. def set_trainer(self, trainer):
  69. self.trainer = trainer
  70. if trainer.args.enable_jit_checkpoint:
  71. self.jit_manager = CheckpointManager(trainer=trainer)
  72. self.jit_manager.setup_signal_handler()
  73. logger.info("JIT checkpointing enabled")
  74. def on_pre_optimizer_step(self, args, state, control, **kwargs):
  75. if self.jit_manager and self.jit_manager.is_checkpoint_requested:
  76. control.should_training_stop = True
  77. self.jit_manager.execute_jit_checkpoint()
  78. def on_step_begin(self, args, state, control, **kwargs):
  79. if self.jit_manager and self.jit_manager.is_checkpoint_requested:
  80. control.should_training_stop = True
  81. self.jit_manager.execute_jit_checkpoint()
  82. def on_step_end(self, args, state, control, **kwargs):
  83. if self.jit_manager and self.jit_manager.is_checkpoint_requested:
  84. control.should_save = False
  85. control.should_training_stop = True
  86. self.jit_manager.execute_jit_checkpoint()
  87. def on_epoch_end(self, args, state, control, **kwargs):
  88. if self.jit_manager and self.jit_manager.is_checkpoint_requested:
  89. control.should_save = False
  90. control.should_training_stop = True
  91. self.jit_manager.execute_jit_checkpoint()
  92. def on_train_end(self, args, state, control, **kwargs):
  93. # Restore original SIGTERM handler
  94. if self.jit_manager and self.jit_manager._original_sigterm_handler is not None:
  95. signal.signal(signal.SIGTERM, self.jit_manager._original_sigterm_handler)
  96. logger.info("Restored original SIGTERM handler after training completion")