_functions.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # mypy: allow-untyped-defs
  2. import torch
  3. import torch.distributed as dist
  4. from torch.autograd.function import Function
  5. class SyncBatchNorm(Function):
  6. @staticmethod
  7. # pyrefly: ignore [bad-override]
  8. def forward(
  9. self,
  10. input,
  11. weight,
  12. bias,
  13. running_mean,
  14. running_var,
  15. eps,
  16. momentum,
  17. process_group,
  18. world_size,
  19. ):
  20. if not (
  21. input.is_contiguous(memory_format=torch.channels_last)
  22. or input.is_contiguous(memory_format=torch.channels_last_3d)
  23. ):
  24. input = input.contiguous()
  25. if weight is not None:
  26. weight = weight.contiguous()
  27. size = int(input.numel() // input.size(1))
  28. if size == 1 and world_size < 2:
  29. raise ValueError(
  30. f"Expected more than 1 value per channel when training, got input size {size}"
  31. )
  32. num_channels = input.shape[1]
  33. if input.numel() > 0:
  34. # calculate mean/invstd for input.
  35. mean, invstd = torch.batch_norm_stats(input, eps)
  36. count = torch.full(
  37. (1,),
  38. input.numel() // input.size(1),
  39. dtype=mean.dtype,
  40. device=mean.device,
  41. )
  42. # C, C, 1 -> (2C + 1)
  43. combined = torch.cat([mean, invstd, count], dim=0)
  44. else:
  45. # for empty input, set stats and the count to zero. The stats with
  46. # zero count will be filtered out later when computing global mean
  47. # & invstd, but they still needs to participate the all_gather
  48. # collective communication to unblock other peer processes.
  49. combined = torch.zeros(
  50. 2 * num_channels + 1, dtype=input.dtype, device=input.device
  51. )
  52. # Use allgather instead of allreduce because count could be different across
  53. # ranks, simple all reduce op can not give correct results.
  54. # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
  55. # all gathered mean, invstd and count.
  56. # for nccl backend, use the optimized version of all gather.
  57. # The Gloo backend does not support `all_gather_into_tensor`.
  58. if process_group._get_backend_name() != "gloo":
  59. # world_size * (2C + 1)
  60. combined_size = combined.numel()
  61. combined_flat = torch.empty(
  62. 1,
  63. combined_size * world_size,
  64. dtype=combined.dtype,
  65. device=combined.device,
  66. )
  67. dist.all_gather_into_tensor(
  68. combined_flat, combined, process_group, async_op=False
  69. )
  70. combined = torch.reshape(combined_flat, (world_size, combined_size))
  71. # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
  72. mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
  73. else:
  74. # world_size * (2C + 1)
  75. combined_list = [torch.empty_like(combined) for _ in range(world_size)]
  76. dist.all_gather(combined_list, combined, process_group, async_op=False)
  77. combined = torch.stack(combined_list, dim=0)
  78. # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
  79. mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
  80. if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
  81. # The lines below force a synchronization between CUDA and CPU, because
  82. # the shape of the result count_all depends on the values in mask tensor.
  83. # Such synchronizations break CUDA Graph capturing.
  84. # See https://github.com/pytorch/pytorch/issues/78549
  85. # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
  86. # a better longer-term solution.
  87. # remove stats from empty inputs
  88. mask = count_all.squeeze(-1) >= 1
  89. count_all = count_all[mask]
  90. mean_all = mean_all[mask]
  91. invstd_all = invstd_all[mask]
  92. # calculate global mean & invstd
  93. counts = count_all.view(-1)
  94. if running_mean is not None and counts.dtype != running_mean.dtype:
  95. counts = counts.to(running_mean.dtype)
  96. mean, invstd = torch.batch_norm_gather_stats_with_counts(
  97. input,
  98. mean_all,
  99. invstd_all,
  100. running_mean,
  101. running_var,
  102. momentum,
  103. eps,
  104. counts,
  105. )
  106. self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
  107. self.process_group = process_group
  108. # apply element-wise normalization
  109. if input.numel() > 0:
  110. return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
  111. else:
  112. return torch.empty_like(input)
  113. @staticmethod
  114. def backward(self, grad_output):
  115. if not (
  116. grad_output.is_contiguous(memory_format=torch.channels_last)
  117. or grad_output.is_contiguous(memory_format=torch.channels_last_3d)
  118. ):
  119. grad_output = grad_output.contiguous()
  120. saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
  121. grad_input = grad_weight = grad_bias = None
  122. process_group = self.process_group
  123. if saved_input.numel() > 0:
  124. # calculate local stats as well as grad_weight / grad_bias
  125. (
  126. sum_dy,
  127. sum_dy_xmu,
  128. grad_weight,
  129. grad_bias,
  130. ) = torch.batch_norm_backward_reduce(
  131. grad_output,
  132. saved_input,
  133. mean,
  134. invstd,
  135. weight,
  136. self.needs_input_grad[0],
  137. self.needs_input_grad[1],
  138. self.needs_input_grad[2],
  139. )
  140. if self.needs_input_grad[0]:
  141. # synchronizing stats used to calculate input gradient.
  142. num_channels = sum_dy.shape[0]
  143. combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
  144. torch.distributed.all_reduce(
  145. combined,
  146. torch.distributed.ReduceOp.SUM,
  147. process_group,
  148. async_op=False,
  149. )
  150. sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
  151. # backward pass for gradient calculation
  152. if weight is not None and weight.dtype != mean.dtype:
  153. weight = weight.to(mean.dtype)
  154. grad_input = torch.batch_norm_backward_elemt(
  155. grad_output,
  156. saved_input,
  157. mean,
  158. invstd,
  159. weight,
  160. sum_dy,
  161. sum_dy_xmu,
  162. count_tensor,
  163. )
  164. # synchronizing of grad_weight / grad_bias is not needed as distributed
  165. # training would handle all reduce.
  166. if weight is None or not self.needs_input_grad[1]:
  167. grad_weight = None
  168. if weight is None or not self.needs_input_grad[2]:
  169. grad_bias = None
  170. else:
  171. # This process got an empty input tensor in the forward pass.
  172. # Although this process can directly set grad_input as an empty
  173. # tensor of zeros, it still needs to participate in the collective
  174. # communication to unblock its peers, as other peer processes might
  175. # have received non-empty inputs.
  176. num_channels = saved_input.shape[1]
  177. if self.needs_input_grad[0]:
  178. # launch all_reduce to unblock other peer processes
  179. combined = torch.zeros(
  180. 2 * num_channels, dtype=saved_input.dtype, device=saved_input.device
  181. )
  182. torch.distributed.all_reduce(
  183. combined,
  184. torch.distributed.ReduceOp.SUM,
  185. process_group,
  186. async_op=False,
  187. )
  188. # Leave grad_input, grad_weight and grad_bias as None, which will be
  189. # interpreted by the autograd engine as Tensors full of zeros.
  190. return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
  191. class CrossMapLRN2d(Function):
  192. @staticmethod
  193. # pyrefly: ignore [bad-override]
  194. def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
  195. ctx.size = size
  196. ctx.alpha = alpha
  197. ctx.beta = beta
  198. ctx.k = k
  199. ctx.scale = None
  200. if input.dim() != 4:
  201. raise ValueError(
  202. f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead."
  203. )
  204. ctx.scale = ctx.scale or input.new()
  205. output = input.new()
  206. channels = input.size(1)
  207. output.resize_as_(input)
  208. ctx.scale.resize_as_(input)
  209. # use output storage as temporary buffer
  210. input_square = output
  211. torch.pow(input, 2, out=input_square)
  212. pre_pad = int((ctx.size - 1) / 2 + 1)
  213. pre_pad_crop = min(pre_pad, channels)
  214. scale_first = ctx.scale.select(1, 0)
  215. scale_first.zero_()
  216. # compute first feature map normalization
  217. for c in range(pre_pad_crop):
  218. scale_first.add_(input_square.select(1, c))
  219. # reuse computations for next feature maps normalization
  220. # by adding the next feature map and removing the previous
  221. for c in range(1, channels):
  222. scale_previous = ctx.scale.select(1, c - 1)
  223. scale_current = ctx.scale.select(1, c)
  224. scale_current.copy_(scale_previous)
  225. if c < channels - pre_pad + 1:
  226. square_next = input_square.select(1, c + pre_pad - 1)
  227. scale_current.add_(square_next, alpha=1)
  228. if c > pre_pad:
  229. square_previous = input_square.select(1, c - pre_pad)
  230. scale_current.add_(square_previous, alpha=-1)
  231. ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
  232. torch.pow(ctx.scale, -ctx.beta, out=output)
  233. output.mul_(input)
  234. ctx.save_for_backward(input, output)
  235. return output
  236. @staticmethod
  237. # pyrefly: ignore [bad-override]
  238. def backward(ctx, grad_output):
  239. input, output = ctx.saved_tensors
  240. grad_input = grad_output.new()
  241. batch_size = input.size(0)
  242. channels = input.size(1)
  243. input_height = input.size(2)
  244. input_width = input.size(3)
  245. paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width)
  246. accum_ratio = input.new(input_height, input_width)
  247. cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
  248. inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
  249. grad_input.resize_as_(input)
  250. torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
  251. paddded_ratio.zero_()
  252. padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels)
  253. for n in range(batch_size):
  254. torch.mul(grad_output[n], output[n], out=padded_ratio_center)
  255. padded_ratio_center.div_(ctx.scale[n])
  256. torch.sum(
  257. paddded_ratio.narrow(0, 0, ctx.size - 1),
  258. 0,
  259. keepdim=False,
  260. out=accum_ratio,
  261. )
  262. for c in range(channels):
  263. accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
  264. grad_input[n][c].addcmul_(
  265. input[n][c], accum_ratio, value=-cache_ratio_value
  266. )
  267. accum_ratio.add_(paddded_ratio[c], alpha=-1)
  268. return grad_input, None, None, None, None
  269. class BackwardHookFunction(torch.autograd.Function):
  270. @staticmethod
  271. # pyrefly: ignore [bad-override]
  272. def forward(ctx, *args):
  273. ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
  274. return args
  275. @staticmethod
  276. def backward(ctx, *args):
  277. return args