train.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from tqdm import tqdm
  2. from romatch.utils.utils import to_cuda
  3. import romatch
  4. import torch
  5. import wandb
  6. def log_param_statistics(named_parameters, norm_type = 2):
  7. named_parameters = list(named_parameters)
  8. grads = [p.grad for n, p in named_parameters if p.grad is not None]
  9. weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
  10. names = [n for n,p in named_parameters if p.grad is not None]
  11. param_norm = torch.stack(weight_norms).norm(p=norm_type)
  12. device = grads[0].device
  13. grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
  14. nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
  15. nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
  16. total_grad_norm = torch.norm(grad_norms, norm_type)
  17. if torch.any(nans_or_infs):
  18. print(f"These params have nan or inf grads: {nan_inf_names}")
  19. wandb.log({"grad_norm": total_grad_norm.item()}, step = romatch.GLOBAL_STEP)
  20. wandb.log({"param_norm": param_norm.item()}, step = romatch.GLOBAL_STEP)
  21. def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
  22. optimizer.zero_grad()
  23. out = model(train_batch)
  24. l = objective(out, train_batch)
  25. grad_scaler.scale(l).backward()
  26. grad_scaler.unscale_(optimizer)
  27. log_param_statistics(model.named_parameters())
  28. torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
  29. grad_scaler.step(optimizer)
  30. grad_scaler.update()
  31. wandb.log({"grad_scale": grad_scaler._scale.item()}, step = romatch.GLOBAL_STEP)
  32. if grad_scaler._scale < 1.:
  33. grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
  34. romatch.GLOBAL_STEP = romatch.GLOBAL_STEP + romatch.STEP_SIZE # increment global step
  35. return {"train_out": out, "train_loss": l.item()}
  36. def train_k_steps(
  37. n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None, pbar_n_seconds = 1,
  38. ):
  39. for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or romatch.RANK > 0, mininterval=pbar_n_seconds):
  40. batch = next(dataloader)
  41. model.train(True)
  42. batch = to_cuda(batch)
  43. train_step(
  44. train_batch=batch,
  45. model=model,
  46. objective=objective,
  47. optimizer=optimizer,
  48. lr_scheduler=lr_scheduler,
  49. grad_scaler=grad_scaler,
  50. n=n,
  51. grad_clip_norm = grad_clip_norm,
  52. )
  53. if ema_model is not None:
  54. ema_model.update()
  55. if warmup is not None:
  56. with warmup.dampening():
  57. lr_scheduler.step()
  58. else:
  59. lr_scheduler.step()
  60. [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
  61. def train_epoch(
  62. dataloader=None,
  63. model=None,
  64. objective=None,
  65. optimizer=None,
  66. lr_scheduler=None,
  67. epoch=None,
  68. ):
  69. model.train(True)
  70. print(f"At epoch {epoch}")
  71. for batch in tqdm(dataloader, mininterval=5.0):
  72. batch = to_cuda(batch)
  73. train_step(
  74. train_batch=batch, model=model, objective=objective, optimizer=optimizer
  75. )
  76. lr_scheduler.step()
  77. return {
  78. "model": model,
  79. "optimizer": optimizer,
  80. "lr_scheduler": lr_scheduler,
  81. "epoch": epoch,
  82. }
  83. def train_k_epochs(
  84. start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
  85. ):
  86. for epoch in range(start_epoch, end_epoch + 1):
  87. train_epoch(
  88. dataloader=dataloader,
  89. model=model,
  90. objective=objective,
  91. optimizer=optimizer,
  92. lr_scheduler=lr_scheduler,
  93. epoch=epoch,
  94. )