decay_batch.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. """ Batch size decay and retry helpers.
  2. Copyright 2022 Ross Wightman
  3. """
  4. import math
  5. def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False):
  6. """ power of two batch-size decay with intra steps
  7. Decay by stepping between powers of 2:
  8. * determine power-of-2 floor of current batch size (base batch size)
  9. * divide above value by num_intra_steps to determine step size
  10. * floor batch_size to nearest multiple of step_size (from base batch size)
  11. Examples:
  12. num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1
  13. num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2
  14. num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1
  15. num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1
  16. """
  17. if batch_size <= 1:
  18. # return 0 for stopping value so easy to use in loop
  19. return 0
  20. base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2)))
  21. step_size = max(base_batch_size // num_intra_steps, 1)
  22. batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size
  23. if no_odd and batch_size % 2:
  24. batch_size -= 1
  25. return batch_size
  26. def check_batch_size_retry(error_str):
  27. """ check failure error string for conditions where batch decay retry should not be attempted
  28. """
  29. error_str = error_str.lower()
  30. if 'required rank' in error_str:
  31. # Errors involving phrase 'required rank' typically happen when a conv is used that's
  32. # not compatible with channels_last memory format.
  33. return False
  34. if 'illegal' in error_str:
  35. # 'Illegal memory access' errors in CUDA typically leave process in unusable state
  36. return False
  37. return True