mkldnn.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # mypy: allow-untyped-defs
  2. import torch
  3. class MkldnnLinear(torch.jit.ScriptModule):
  4. def __init__(self, dense_module, dtype) -> None:
  5. super().__init__()
  6. self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
  7. if dense_module.bias is not None:
  8. # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
  9. # we use fp32 dtype.
  10. self.register_buffer('bias', dense_module.bias.to_mkldnn())
  11. else:
  12. # TODO: Remove this once ScriptModule supports registering None buffer
  13. self.register_buffer(
  14. 'bias',
  15. torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
  16. @torch.jit.script_method
  17. def __getstate__(self):
  18. return (self.weight.to_dense(), self.bias.to_dense(), self.training)
  19. @torch.jit.script_method
  20. def __setstate__(self, state):
  21. self.weight = state[0].to_mkldnn()
  22. self.bias = state[1].to_mkldnn()
  23. self.training = state[2]
  24. @torch.jit.script_method
  25. def forward(self, x):
  26. x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
  27. y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias)
  28. y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
  29. return y
  30. class _MkldnnConvNd(torch.jit.ScriptModule):
  31. """Common base of MkldnnConv1d and MkldnnConv2d."""
  32. __constants__ = ['stride', 'padding', 'dilation', 'groups']
  33. def __init__(self, dense_module) -> None:
  34. super().__init__()
  35. self.stride = dense_module.stride
  36. self.padding = dense_module.padding
  37. self.dilation = dense_module.dilation
  38. self.groups = dense_module.groups
  39. if dense_module.bias is not None:
  40. self.register_buffer('bias', dense_module.bias.to_mkldnn())
  41. else:
  42. # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
  43. # we use fp32 dtype.
  44. # TODO: Remove this once ScriptModule supports registering None buffer
  45. self.register_buffer(
  46. 'bias',
  47. torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
  48. @torch.jit.script_method
  49. def __getstate__(self):
  50. return (self.weight.to_dense(), self.bias.to_dense(), self.training)
  51. @torch.jit.script_method
  52. def forward(self, x):
  53. return torch.mkldnn_convolution(
  54. x,
  55. self.weight,
  56. self.bias,
  57. self.padding,
  58. self.stride,
  59. self.dilation,
  60. self.groups)
  61. class MkldnnConv1d(_MkldnnConvNd):
  62. def __init__(self, dense_module, dtype) -> None:
  63. super().__init__(dense_module)
  64. self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
  65. @torch.jit.script_method
  66. def __setstate__(self, state):
  67. self.weight = state[0].to_mkldnn()
  68. self.bias = state[1].to_mkldnn()
  69. self.training = state[2]
  70. class MkldnnConv2d(_MkldnnConvNd):
  71. def __init__(self, dense_module, dtype) -> None:
  72. super().__init__(dense_module)
  73. self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight(
  74. dense_module.weight.to_mkldnn(dtype),
  75. self.padding,
  76. self.stride,
  77. self.dilation,
  78. self.groups))
  79. @torch.jit.script_method
  80. def __setstate__(self, state):
  81. self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
  82. state[0].to_mkldnn(),
  83. self.padding,
  84. self.stride,
  85. self.dilation,
  86. self.groups)
  87. self.bias = state[1].to_mkldnn()
  88. self.training = state[2]
  89. class MkldnnConv3d(_MkldnnConvNd):
  90. def __init__(self, dense_module, dtype) -> None:
  91. super().__init__(dense_module)
  92. self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight(
  93. dense_module.weight.to_mkldnn(dtype),
  94. self.padding,
  95. self.stride,
  96. self.dilation,
  97. self.groups))
  98. @torch.jit.script_method
  99. def __setstate__(self, state):
  100. self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight(
  101. state[0].to_mkldnn(),
  102. self.padding,
  103. self.stride,
  104. self.dilation,
  105. self.groups)
  106. self.bias = state[1].to_mkldnn()
  107. self.training = state[2]
  108. class MkldnnBatchNorm(torch.jit.ScriptModule):
  109. __constants__ = ['exponential_average_factor', 'eps']
  110. def __init__(self, dense_module) -> None:
  111. super().__init__()
  112. if dense_module.training:
  113. raise AssertionError("Only support eval mode batchnorm for mkldnn path now")
  114. if not dense_module.track_running_stats:
  115. raise AssertionError("Only support track_running_stats=True for mkldnn path now")
  116. if not dense_module.affine:
  117. raise AssertionError("Only support affine=True for mkldnn path now")
  118. if dense_module.momentum is None:
  119. self.exponential_average_factor = 0.0
  120. else:
  121. self.exponential_average_factor = dense_module.momentum
  122. self.eps = dense_module.eps
  123. self.register_buffer('weight', dense_module.weight.to_mkldnn())
  124. self.register_buffer('bias', dense_module.bias.to_mkldnn())
  125. self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
  126. self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
  127. @torch.jit.script_method
  128. def __getstate__(self):
  129. weight = self.weight.to_dense()
  130. bias = self.bias.to_dense()
  131. running_mean = self.running_mean.to_dense()
  132. running_var = self.running_var.to_dense()
  133. return (weight, bias, running_mean, running_var, self.training)
  134. @torch.jit.script_method
  135. def __setstate__(self, state):
  136. self.weight = state[0].to_mkldnn()
  137. self.bias = state[1].to_mkldnn()
  138. self.running_mean = state[2].to_mkldnn()
  139. self.running_var = state[3].to_mkldnn()
  140. self.training = state[4]
  141. @torch.jit.script_method
  142. def forward(self, x):
  143. return torch.batch_norm(
  144. x,
  145. self.weight,
  146. self.bias,
  147. self.running_mean,
  148. self.running_var,
  149. False, # training
  150. self.exponential_average_factor,
  151. self.eps,
  152. False, # cuda_enabled
  153. )
  154. class MkldnnPrelu(torch.jit.ScriptModule):
  155. def __init__(self, dense_module, dtype) -> None:
  156. super().__init__()
  157. self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
  158. @torch.jit.script_method
  159. def __getstate__(self):
  160. return (self.weight.to_dense(), self.training)
  161. @torch.jit.script_method
  162. def __setstate__(self, state):
  163. self.weight = state[0].to_mkldnn()
  164. self.training = state[1]
  165. @torch.jit.script_method
  166. def forward(self, x):
  167. x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
  168. y_mkldnn = torch.prelu(x_mkldnn, self.weight)
  169. y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
  170. return y
  171. def to_mkldnn(module, dtype=torch.float):
  172. if dtype not in (torch.float, torch.bfloat16, torch.half):
  173. raise AssertionError("MKLDNN only support float, bfloat16, and half path now")
  174. def m_fn(m, d):
  175. if isinstance(m, torch.nn.Linear):
  176. return MkldnnLinear(m, d)
  177. elif isinstance(m, torch.nn.Conv1d):
  178. return MkldnnConv1d(m, d)
  179. elif isinstance(m, torch.nn.Conv2d):
  180. return MkldnnConv2d(m, d)
  181. elif isinstance(m, torch.nn.Conv3d):
  182. return MkldnnConv3d(m, d)
  183. elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
  184. # For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype.
  185. # so it doesn't need dtype argument.
  186. return MkldnnBatchNorm(m)
  187. elif isinstance(m, torch.nn.PReLU):
  188. return MkldnnPrelu(m, d)
  189. else:
  190. return m
  191. def m_fn_rec(m, d):
  192. new_m = m_fn(m, d)
  193. for name, sub_m in m.named_children():
  194. setattr(new_m, name, m_fn_rec(sub_m, d))
  195. return new_m
  196. return m_fn_rec(module, dtype)