ConvUtils.h 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/TensorUtils.h>
  5. #include <ATen/detail/CUDAHooksInterface.h>
  6. #include <ATen/native/DispatchStub.h>
  7. #include <c10/util/env.h>
  8. #include <c10/util/irange.h>
  9. #include <utility>
  10. namespace at::native {
  11. using conv_depthwise2d_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
  12. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  13. at::IntArrayRef, at::IntArrayRef, std::array<bool, 2>);
  14. DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub)
  15. using conv_depthwise3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  16. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  17. at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
  18. DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub)
  19. using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
  20. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  21. at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
  22. DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub)
  23. using mps_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  24. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  25. at::IntArrayRef, int64_t, std::array<bool,3>);
  26. DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub)
  27. using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
  28. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  29. at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
  30. DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub)
  31. using miopen_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  32. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  33. at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
  34. DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub)
  35. using miopen_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  36. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  37. at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
  38. DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub)
  39. using miopen_depthwise_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  40. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  41. at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
  42. DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub)
  43. using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  44. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  45. at::IntArrayRef, int64_t, std::array<bool,3>);
  46. DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub)
  47. using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const std::optional<Tensor>&,
  48. IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
  49. DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub)
  50. using mkldnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  51. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  52. at::IntArrayRef, at::IntArrayRef, int64_t, std::array<bool,3>);
  53. DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub)
  54. using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  55. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  56. at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
  57. DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub)
  58. using slow_conv_dilated3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  59. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  60. at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
  61. DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub)
  62. using slow_conv_transpose2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  63. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  64. at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
  65. DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub)
  66. using slow_conv_transpose3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
  67. const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
  68. at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
  69. DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub)
  70. namespace {
  71. bool is_cudnnv8_heuristic_mode_b() {
  72. static const bool is_cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
  73. return is_cudnnv8_heuristic_mode_b;
  74. }
  75. }
  76. inline bool cudnnv8_enabled_check_debug() {
  77. static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
  78. static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
  79. static uint8_t cudnnv8_debugcount = 0;
  80. if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
  81. TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", is_cudnnv8_heuristic_mode_b());
  82. cudnnv8_debugcount++;
  83. }
  84. return cudnnv8_flag == 1;
  85. }
  86. inline bool cudnnv8_use_heur_mode_b() {
  87. return is_cudnnv8_heuristic_mode_b();
  88. }
  89. // Keep in sync with py::enum_ in Module.cpp
  90. enum class ConvBackend {
  91. CudaDepthwise2d,
  92. CudaDepthwise3d,
  93. Cudnn,
  94. CudnnTranspose,
  95. Empty,
  96. Miopen,
  97. MiopenDepthwise,
  98. MiopenTranspose,
  99. Mkldnn,
  100. MkldnnTranspose,
  101. MkldnnEmpty,
  102. NnpackSpatial,
  103. Overrideable,
  104. Slow2d,
  105. Slow3d,
  106. SlowDilated2d,
  107. SlowDilated3d,
  108. SlowTranspose2d,
  109. SlowTranspose3d,
  110. Winograd3x3Depthwise,
  111. Xnnpack2d,
  112. Mps,
  113. MpsTranspose,
  114. };
  115. // Overload for selecting the convolution backend from the full set of convolution inputs.
  116. // This overload is exposed to python for testing, etc.
  117. TORCH_API ConvBackend select_conv_backend(
  118. const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
  119. SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
  120. bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
  121. TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
  122. const Tensor& weight,
  123. const ConvBackend backend);
  124. // ---------------------------------------------------------------------
  125. //
  126. // Math
  127. //
  128. // ---------------------------------------------------------------------
  129. constexpr int input_batch_size_dim = 0; // also grad_input
  130. constexpr int input_channels_dim = 1;
  131. constexpr int output_batch_size_dim = 0; // also grad_output
  132. constexpr int output_channels_dim = 1;
  133. constexpr int weight_output_channels_dim = 0;
  134. constexpr int weight_input_channels_dim = 1;
  135. // Often written as 2 + max_dim (extra dims for batch size and channels)
  136. constexpr int max_dim = 3;
  137. // ---------------------------------------------------------------------
  138. //
  139. // Checking
  140. //
  141. // ---------------------------------------------------------------------
  142. // Used on pad, stride and dilation
  143. static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
  144. {
  145. TORCH_CHECK(args.size() <= expected_size,
  146. "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
  147. expected_size, " (while checking arguments for ", c, ")");
  148. TORCH_CHECK(args.size() >= expected_size,
  149. "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
  150. expected_size, " (while checking arguments for ", c, ")");
  151. auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
  152. if (num_negative_values > 0){
  153. std::stringstream ss;
  154. ss << arg_name << " should be greater than zero but got (";
  155. std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
  156. ss << args.back() << ")" << " (while checking arguments for " << c << ')';
  157. TORCH_CHECK(false, ss.str());
  158. }
  159. }
  160. // NOTE [ Convolution checks ]
  161. //
  162. // NB: For many call sites, it is not strictly necessary to check all of
  163. // these relationships (for example, for forward convolution, we compute
  164. // the size of output ourselves, so we don't actually need to check
  165. // output. However, writing a single function that does everything
  166. // means we get to reuse it for both forwards and all backwards
  167. // variants, even when the set of "real" inputs varies. The magic of
  168. // relational computing!
  169. //
  170. // (There is one downside, which is that it is slightly harder to write
  171. // error messages which are able to distinguish between real inputs
  172. // (which the user can change) and computed inputs (which the user can
  173. // only indirectly affect). It would be an interesting exercise to
  174. // come up with a general framework to handle such situations.)
  175. inline void convolution_shape_check(
  176. CheckedFrom c,
  177. const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
  178. IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
  179. {
  180. check_args(c, padding, input->dim() - 2, "padding");
  181. check_args(c, stride, padding.size(), "stride");
  182. check_args(c, dilation, padding.size(), "dilation");
  183. for (auto s : stride) {
  184. TORCH_CHECK(s > 0, "Stride must be greater than 0 but got ", s);
  185. }
  186. for (auto d : dilation) {
  187. TORCH_CHECK(d > 0, "Dilation must be greater than 0 but got ", d);
  188. }
  189. for (auto p : padding) {
  190. TORCH_CHECK(p >= 0, "Padding must be non-negative but got ", p);
  191. }
  192. // Input
  193. checkDimRange(c, input, 3, 6 /* exclusive */);
  194. checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
  195. // Weight
  196. checkSameDim(c, input, weight);
  197. // TODO: check that output->size() matches output_sizes
  198. // TODO: check that weight matches output->sizes()
  199. checkSameDim(c, input, output);
  200. }
  201. // NB: conv_output_size and conv_input_size are not bijections,
  202. // as conv_output_size loses information; this is why conv_input_size
  203. // takes an extra output_padding argument to resolve the ambiguity.
  204. template <typename T>
  205. inline std::vector<T> _conv_output_size(
  206. ArrayRef<T> input_size, ArrayRef<T> weight_size,
  207. ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
  208. ) {
  209. // ASSERT(input_size.size() > 2)
  210. // ASSERT(input_size.size() == weight_size.size())
  211. bool has_dilation = !dilation.empty();
  212. auto dim = input_size.size();
  213. std::vector<T> output_size(dim);
  214. output_size[0] = input_size[input_batch_size_dim];
  215. output_size[1] = weight_size[weight_output_channels_dim];
  216. for (const auto d : c10::irange(2, dim)) {
  217. auto dilation_ = has_dilation ? dilation[d - 2] : 1;
  218. auto kernel = dilation_ * (weight_size[d] - 1) + 1;
  219. output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
  220. }
  221. return output_size;
  222. }
  223. inline std::vector<int64_t> conv_output_size(
  224. IntArrayRef input_size, IntArrayRef weight_size,
  225. IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
  226. ) {
  227. return _conv_output_size(input_size, weight_size, padding, stride, dilation);
  228. }
  229. inline std::vector<c10::SymInt> conv_output_size(
  230. SymIntArrayRef input_size, SymIntArrayRef weight_size,
  231. SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
  232. ) {
  233. return _conv_output_size(input_size, weight_size, padding, stride, dilation);
  234. }
  235. template <typename T>
  236. std::vector<T> _conv_input_size(
  237. ArrayRef<T> output_size, ArrayRef<T> weight_size,
  238. ArrayRef<T> padding, ArrayRef<T> output_padding, ArrayRef<T> stride, ArrayRef<T> dilation, T groups
  239. ) {
  240. // ASSERT(output_size.size() > 2)
  241. // ASSERT(output_size.size() == weight_size.size())
  242. auto dim = output_size.size();
  243. std::vector<T> input_size(dim);
  244. input_size[0] = output_size[output_batch_size_dim];
  245. input_size[1] = weight_size[weight_input_channels_dim] * groups;
  246. for (const auto d : c10::irange(2, dim)) {
  247. auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
  248. input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
  249. kernel + output_padding[d - 2];
  250. }
  251. return input_size;
  252. }
  253. inline std::vector<c10::SymInt> conv_input_size(
  254. SymIntArrayRef output_size, SymIntArrayRef weight_size,
  255. SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
  256. ) {
  257. return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, std::move(groups));
  258. }
  259. inline std::vector<int64_t> conv_input_size(
  260. IntArrayRef output_size, IntArrayRef weight_size,
  261. IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
  262. ) {
  263. return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
  264. }
  265. template <typename T>
  266. std::vector<T> _conv_weight_size(
  267. ArrayRef<T> input_size, ArrayRef<T> output_size,
  268. ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
  269. ) {
  270. auto dim = input_size.size();
  271. std::vector<T> weight_size(dim);
  272. weight_size[0] = output_size[1];
  273. weight_size[1] = input_size[1] / groups;
  274. for (const auto d : c10::irange(2, dim)) {
  275. auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
  276. + padding[d - 2] * 2 - output_padding[d - 2];
  277. weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
  278. }
  279. return weight_size;
  280. }
  281. inline std::vector<c10::SymInt> conv_weight_size(
  282. SymIntArrayRef input_size, SymIntArrayRef output_size,
  283. SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
  284. ) {
  285. return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
  286. }
  287. inline std::vector<int64_t> conv_weight_size(
  288. IntArrayRef input_size, IntArrayRef output_size,
  289. IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
  290. ) {
  291. return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
  292. }
  293. inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
  294. std::vector<int64_t> shape(dim, 1);
  295. shape[1] = -1;
  296. return bias.reshape(shape);
  297. }
  298. inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
  299. // disable NHWC for float64 input.
  300. if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
  301. input.scalar_type() == at::kDouble ||
  302. weight.scalar_type() == at::kDouble) {
  303. return at::MemoryFormat::Contiguous;
  304. }
  305. long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
  306. auto input_memory_format = input.suggest_memory_format();
  307. auto weight_memory_format = weight.suggest_memory_format();
  308. auto weight_ndim = weight.ndimension();
  309. bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
  310. (input_memory_format == at::MemoryFormat::ChannelsLast) ||
  311. (weight_memory_format == at::MemoryFormat::ChannelsLast)
  312. );
  313. if (can_use_cudnn_channels_last_2d) {
  314. return at::MemoryFormat::ChannelsLast;
  315. }
  316. bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
  317. (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
  318. (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
  319. );
  320. if (can_use_cudnn_channels_last_3d) {
  321. return at::MemoryFormat::ChannelsLast3d;
  322. }
  323. return at::MemoryFormat::Contiguous;
  324. }
  325. // controls whether emptyCache will be called following cudnn conv benchmarking
  326. TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
  327. TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
  328. inline at::MemoryFormat miopen_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
  329. // disable NHWC for float64 input.
  330. if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
  331. input.scalar_type() == at::kDouble ||
  332. weight.scalar_type() == at::kDouble) {
  333. return at::MemoryFormat::Contiguous;
  334. }
  335. // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
  336. // See https://github.com/pytorch/pytorch/issues/64427.
  337. // non static variable is used to be able to change environment variable in runtime for testing
  338. bool suggest_nhwc = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(false);
  339. auto input_memory_format = input.suggest_memory_format();
  340. auto weight_memory_format = weight.suggest_memory_format();
  341. auto weight_ndim = weight.ndimension();
  342. bool can_use_miopen_channels_last_2d = suggest_nhwc && (weight_ndim == 4) && (
  343. (input_memory_format == at::MemoryFormat::ChannelsLast) ||
  344. (weight_memory_format == at::MemoryFormat::ChannelsLast)
  345. );
  346. if (can_use_miopen_channels_last_2d) {
  347. return at::MemoryFormat::ChannelsLast;
  348. }
  349. bool can_use_miopen_channels_last_3d = suggest_nhwc && (weight_ndim == 5) && (
  350. (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
  351. (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
  352. );
  353. if (can_use_miopen_channels_last_3d) {
  354. return at::MemoryFormat::ChannelsLast3d;
  355. }
  356. return at::MemoryFormat::Contiguous;
  357. }
  358. // deprecated, but to remove would be BC-breaking
  359. inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
  360. return miopen_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous;
  361. }
  362. inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
  363. // disable NHWC for float64 input.
  364. if (input.scalar_type() == at::kDouble ||
  365. weight.scalar_type() == at::kDouble) {
  366. return false;
  367. }
  368. // disable NHWC for MkldnnCPU tensor.
  369. if (input.is_mkldnn() || weight.is_mkldnn()) {
  370. return false;
  371. }
  372. auto input_memory_format = input.suggest_memory_format();
  373. auto weight_memory_format = weight.suggest_memory_format();
  374. bool can_use_mkldnn_channels_last_2d =
  375. (input_memory_format == at::MemoryFormat::ChannelsLast) ||
  376. (weight_memory_format == at::MemoryFormat::ChannelsLast);
  377. bool can_use_mkldnn_channels_last_3d =
  378. (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
  379. (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
  380. return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
  381. }
  382. inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
  383. auto input_memory_format = input.suggest_memory_format();
  384. auto weight_memory_format = weight.suggest_memory_format();
  385. bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
  386. (input_memory_format == at::MemoryFormat::ChannelsLast) || (
  387. weight_memory_format == at::MemoryFormat::ChannelsLast));
  388. return can_use_thnn_channels_last_2d;
  389. }
  390. inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
  391. // check layout only for xpu tensor.
  392. if (!input.is_xpu() || !weight.is_xpu()) {
  393. return false;
  394. }
  395. if (!input.defined() || input.is_sparse()) {
  396. // suggest channels_first
  397. return false;
  398. }
  399. auto is_channel_last = [](const at::Tensor& t) {
  400. auto fmt = t.suggest_memory_format();
  401. return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d;
  402. };
  403. return is_channel_last(input) || is_channel_last(weight);
  404. }
  405. inline bool mps_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
  406. // check layout only for mps tensor.
  407. if (!input.is_mps() || !weight.is_mps()) {
  408. return false;
  409. }
  410. if (!input.defined() || input.is_sparse()) {
  411. // suggest channels_first
  412. return false;
  413. }
  414. auto is_channel_last = [](const at::Tensor& t) {
  415. auto fmt = t.suggest_memory_format();
  416. return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d;
  417. };
  418. return is_channel_last(input) || is_channel_last(weight);
  419. }
  420. } // namespace at::native
  421. #else
  422. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  423. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)