IndexKernels.h 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. namespace at::mps {
  4. static const char* SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
  5. template<typename Y, typename X>
  6. Y cast(const X x);
  7. template<>
  8. {1} cast<{1}, {0}>(const {0} x) {{
  9. return {2};
  10. }}
  11. kernel void scatter_kernel_n(uint linear_index [[thread_position_in_grid]],
  12. constant void * src_ [[buffer(0)]],
  13. device void * dst_ [[buffer(1)]],
  14. constant uint32_t * size [[buffer(2)]],
  15. constant uint32_t * stride [[buffer(3)]],
  16. constant uint32_t & numel [[buffer(4)]],
  17. constant int32_t & ndim [[buffer(5)]]) {{
  18. if (linear_index >= numel) return;
  19. constant {0} * src = (constant {0} *)src_;
  20. device {1} * dst = (device {1} *)dst_;
  21. uint64_t dst_offs = 0;
  22. auto dst_idx = linear_index;
  23. for(int dim = ndim - 1; dim >= 0; --dim) {{
  24. dst_offs += stride[dim] * (dst_idx % size[dim]);
  25. dst_idx /= size[dim];
  26. }}
  27. dst[dst_offs] = cast<{1}>(src[linear_index]);
  28. }}
  29. kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
  30. constant void * src_ [[buffer(0)]],
  31. device void * dst_ [[buffer(1)]],
  32. constant packed_uint4 & size [[buffer(2)]],
  33. constant packed_uint4 & stride [[buffer(3)]],
  34. constant uint32_t & numel [[buffer(4)]]) {{
  35. if (linear_index >= numel) return;
  36. constant {0} * src = (constant {0} *)src_;
  37. device {1} * dst = (device {1} *)dst_;
  38. packed_uint4 local_index;
  39. local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
  40. local_index.y = linear_index / (size[3] * size[2]) % size[1];
  41. local_index.z = linear_index / size[3] % size[2];
  42. local_index.w = linear_index % size[3];
  43. const packed_uint4 strided_index = local_index * stride;
  44. dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
  45. }}
  46. kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
  47. constant void * src_ [[buffer(0)]],
  48. device void * dst_ [[buffer(1)]],
  49. constant packed_uint3 & size [[buffer(2)]],
  50. constant packed_uint3 & stride [[buffer(3)]],
  51. constant uint32_t & numel [[buffer(4)]]) {{
  52. if (linear_index >= numel) return;
  53. constant {0} * src = (constant {0} *)src_;
  54. device {1} * dst = (device {1} *)dst_;
  55. packed_uint3 local_index;
  56. local_index.x = linear_index / (size[2] * size[1]) % size[0];
  57. local_index.y = linear_index / size[2] % size[1];
  58. local_index.z = linear_index % size[2];
  59. const packed_uint3 strided_index = local_index * stride;
  60. dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
  61. }}
  62. kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
  63. constant void * src_ [[buffer(0)]],
  64. device void * dst_ [[buffer(1)]],
  65. constant packed_uint2 & size [[buffer(2)]],
  66. constant packed_uint2 & stride [[buffer(3)]],
  67. constant uint32_t & numel [[buffer(4)]]) {{
  68. if (linear_index >= numel) return;
  69. constant {0} * src = (constant {0} *)src_;
  70. device {1} * dst = (device {1} *)dst_;
  71. packed_uint2 local_index;
  72. local_index.x = linear_index / size[1] % size[0];
  73. local_index.y = linear_index % size[1];
  74. const packed_uint2 strided_index = local_index * stride;
  75. dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
  76. }}
  77. kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
  78. constant void * src_ [[buffer(0)]],
  79. device void * dst_ [[buffer(1)]],
  80. constant int & size [[buffer(2)]],
  81. constant int & stride [[buffer(3)]],
  82. constant uint32_t & numel [[buffer(4)]]) {{
  83. if (linear_index >= numel) return;
  84. constant {0} * src = (constant {0} *)src_;
  85. device {1} * dst = (device {1} *)dst_;
  86. const int local_index = linear_index % size;
  87. const int strided_index = local_index * stride;
  88. dst[strided_index] = cast<{1}>(src[linear_index]);
  89. }}
  90. )METAL_SCATTER";
  91. static const char* GATHER_OPS_TEMPLATE = R"METAL_GATHER(
  92. template<typename Y, typename X>
  93. Y cast(const X x);
  94. template<>
  95. {1} cast<{1}, {0}>(const {0} x) {{
  96. return {2};
  97. }}
  98. kernel void gather_kernel_n(uint linear_index [[thread_position_in_grid]],
  99. constant void * src_ [[buffer(0)]],
  100. device void * dst_ [[buffer(1)]],
  101. constant uint32_t * size [[buffer(2)]],
  102. constant uint32_t * stride [[buffer(3)]],
  103. constant uint32_t & numel [[buffer(4)]],
  104. constant int32_t & ndim [[buffer(5)]]) {{
  105. if (linear_index >= numel) return;
  106. constant {0} * src = (constant {0} *)src_;
  107. device {1} * dst = (device {1} *)dst_;
  108. uint64_t src_offs = 0;
  109. auto src_idx = linear_index;
  110. for(int dim = ndim - 1; dim >= 0; --dim) {{
  111. src_offs += stride[dim] * (src_idx % size[dim]);
  112. src_idx /= size[dim];
  113. }}
  114. dst[linear_index] = cast<{1}>(src[src_offs]);
  115. }}
  116. kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
  117. constant void * src_ [[buffer(0)]],
  118. device void * dst_ [[buffer(1)]],
  119. constant packed_uint4 & size [[buffer(2)]],
  120. constant packed_uint4 & stride [[buffer(3)]],
  121. constant uint32_t & numel [[buffer(4)]]) {{
  122. if (linear_index >= numel) return;
  123. constant {0} * src = (constant {0} *)src_;
  124. device {1} * dst = (device {1} *)dst_;
  125. packed_uint4 local_index;
  126. local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
  127. local_index.y = linear_index / (size[3] * size[2]) % size[1];
  128. local_index.z = linear_index / size[3] % size[2];
  129. local_index.w = linear_index % size[3];
  130. const packed_uint4 strided_index = local_index * stride;
  131. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
  132. }}
  133. kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
  134. constant void * src_ [[buffer(0)]],
  135. device void * dst_ [[buffer(1)]],
  136. constant packed_uint3 & size [[buffer(2)]],
  137. constant packed_uint3 & stride [[buffer(3)]],
  138. constant uint32_t & numel [[buffer(4)]]) {{
  139. if (linear_index >= numel) return;
  140. constant {0} * src = (constant {0} *)src_;
  141. device {1} * dst = (device {1} *)dst_;
  142. packed_uint3 local_index;
  143. local_index.x = linear_index / (size[2] * size[1]) % size[0];
  144. local_index.y = linear_index / size[2] % size[1];
  145. local_index.z = linear_index % size[2];
  146. const packed_uint3 strided_index = local_index * stride;
  147. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
  148. }}
  149. kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
  150. constant void * src_ [[buffer(0)]],
  151. device void * dst_ [[buffer(1)]],
  152. constant packed_uint2 & size [[buffer(2)]],
  153. constant packed_uint2 & stride [[buffer(3)]],
  154. constant uint32_t & numel [[buffer(4)]]) {{
  155. if (linear_index >= numel) return;
  156. constant {0} * src = (constant {0} *)src_;
  157. device {1} * dst = (device {1} *)dst_;
  158. packed_uint2 local_index;
  159. local_index.x = linear_index / size[1] % size[0];
  160. local_index.y = linear_index % size[1];
  161. const packed_uint2 strided_index = local_index * stride;
  162. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
  163. }}
  164. kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
  165. constant void * src_ [[buffer(0)]],
  166. device void * dst_ [[buffer(1)]],
  167. constant int & size [[buffer(2)]],
  168. constant int & stride [[buffer(3)]],
  169. constant uint32_t & numel [[buffer(4)]]) {{
  170. if (linear_index >= numel) return;
  171. constant {0} * src = (constant {0} *)src_;
  172. device {1} * dst = (device {1} *)dst_;
  173. const int local_index = linear_index % size;
  174. const int strided_index = local_index * stride;
  175. dst[linear_index] = cast<{1}>(src[strided_index]);
  176. }}
  177. )METAL_GATHER";
  178. } // namespace at::mps
  179. #else
  180. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  181. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)