| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- namespace at::mps {
- static const char* SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
- template<typename Y, typename X>
- Y cast(const X x);
- template<>
- {1} cast<{1}, {0}>(const {0} x) {{
- return {2};
- }}
- kernel void scatter_kernel_n(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant uint32_t * size [[buffer(2)]],
- constant uint32_t * stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]],
- constant int32_t & ndim [[buffer(5)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- uint64_t dst_offs = 0;
- auto dst_idx = linear_index;
- for(int dim = ndim - 1; dim >= 0; --dim) {{
- dst_offs += stride[dim] * (dst_idx % size[dim]);
- dst_idx /= size[dim];
- }}
- dst[dst_offs] = cast<{1}>(src[linear_index]);
- }}
- kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint4 & size [[buffer(2)]],
- constant packed_uint4 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint4 local_index;
- local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
- local_index.y = linear_index / (size[3] * size[2]) % size[1];
- local_index.z = linear_index / size[3] % size[2];
- local_index.w = linear_index % size[3];
- const packed_uint4 strided_index = local_index * stride;
- dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
- }}
- kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint3 & size [[buffer(2)]],
- constant packed_uint3 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint3 local_index;
- local_index.x = linear_index / (size[2] * size[1]) % size[0];
- local_index.y = linear_index / size[2] % size[1];
- local_index.z = linear_index % size[2];
- const packed_uint3 strided_index = local_index * stride;
- dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
- }}
- kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint2 & size [[buffer(2)]],
- constant packed_uint2 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint2 local_index;
- local_index.x = linear_index / size[1] % size[0];
- local_index.y = linear_index % size[1];
- const packed_uint2 strided_index = local_index * stride;
- dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
- }}
- kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant int & size [[buffer(2)]],
- constant int & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- const int local_index = linear_index % size;
- const int strided_index = local_index * stride;
- dst[strided_index] = cast<{1}>(src[linear_index]);
- }}
- )METAL_SCATTER";
- static const char* GATHER_OPS_TEMPLATE = R"METAL_GATHER(
- template<typename Y, typename X>
- Y cast(const X x);
- template<>
- {1} cast<{1}, {0}>(const {0} x) {{
- return {2};
- }}
- kernel void gather_kernel_n(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant uint32_t * size [[buffer(2)]],
- constant uint32_t * stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]],
- constant int32_t & ndim [[buffer(5)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- uint64_t src_offs = 0;
- auto src_idx = linear_index;
- for(int dim = ndim - 1; dim >= 0; --dim) {{
- src_offs += stride[dim] * (src_idx % size[dim]);
- src_idx /= size[dim];
- }}
- dst[linear_index] = cast<{1}>(src[src_offs]);
- }}
- kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint4 & size [[buffer(2)]],
- constant packed_uint4 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint4 local_index;
- local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
- local_index.y = linear_index / (size[3] * size[2]) % size[1];
- local_index.z = linear_index / size[3] % size[2];
- local_index.w = linear_index % size[3];
- const packed_uint4 strided_index = local_index * stride;
- dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
- }}
- kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint3 & size [[buffer(2)]],
- constant packed_uint3 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint3 local_index;
- local_index.x = linear_index / (size[2] * size[1]) % size[0];
- local_index.y = linear_index / size[2] % size[1];
- local_index.z = linear_index % size[2];
- const packed_uint3 strided_index = local_index * stride;
- dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
- }}
- kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint2 & size [[buffer(2)]],
- constant packed_uint2 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint2 local_index;
- local_index.x = linear_index / size[1] % size[0];
- local_index.y = linear_index % size[1];
- const packed_uint2 strided_index = local_index * stride;
- dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
- }}
- kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant int & size [[buffer(2)]],
- constant int & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- const int local_index = linear_index % size;
- const int strided_index = local_index * stride;
- dst[linear_index] = cast<{1}>(src[strided_index]);
- }}
- )METAL_GATHER";
- } // namespace at::mps
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|