| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <metal_atomic>
- namespace c10 {
- namespace metal {
- // Atomic operations helper
- template <typename T>
- struct AtomicType {};
- template <typename T>
- using AtomicType_t = typename AtomicType<T>::type;
- template <>
- struct AtomicType<float> {
- using type = ::metal::atomic<float>;
- static inline void atomic_add(device type* data, long offset, float value) {
- ::metal::atomic_fetch_add_explicit(
- data + offset, value, ::metal::memory_order_relaxed);
- }
- };
- template <>
- struct AtomicType<int> {
- using type = ::metal::atomic<int>;
- static inline void atomic_add(device type* data, long offset, int value) {
- ::metal::atomic_fetch_add_explicit(
- data + offset, value, ::metal::memory_order_relaxed);
- }
- };
- // As of Metal3.2 atomic operations are not supported on half-precision floats,
- // so they must be simulated Using atomic compare and exchange over 32-bit
- // atomic type
- template <typename T>
- static inline void atomic_add_helper(
- device ::metal::atomic<uint>* data,
- long offset,
- T value) {
- constexpr auto elem_per_enum = sizeof(uint) / sizeof(T);
- auto ptr = data + (offset / elem_per_enum);
- auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
- union {
- uint i;
- T t[elem_per_enum];
- } val;
- do {
- val.i = old;
- val.t[offset & (elem_per_enum - 1)] += value;
- } while (!::metal::atomic_compare_exchange_weak_explicit(
- ptr,
- &old,
- val.i,
- ::metal::memory_order_relaxed,
- ::metal::memory_order_relaxed));
- }
- template <>
- struct AtomicType<half> {
- using type = ::metal::atomic<uint>;
- static inline void atomic_add(device type* data, long offset, half value) {
- atomic_add_helper(data, offset, value);
- }
- };
- template <>
- struct AtomicType<short> {
- using type = ::metal::atomic<uint>;
- static inline void atomic_add(device type* data, long offset, short value) {
- atomic_add_helper(data, offset, value);
- }
- };
- template <>
- struct AtomicType<char> {
- using type = ::metal::atomic<uint>;
- static inline void atomic_add(device type* data, long offset, char value) {
- atomic_add_helper(data, offset, value);
- }
- };
- template <>
- struct AtomicType<uchar> {
- using type = ::metal::atomic<uint>;
- static inline void atomic_add(device type* data, long offset, char value) {
- atomic_add_helper(data, offset, value);
- }
- };
- template <>
- struct AtomicType<bfloat> {
- using type = ::metal::atomic<uint>;
- static inline void atomic_add(device type* data, long offset, bfloat value) {
- atomic_add_helper<bfloat>(data, offset, value);
- }
- };
- // Metal supports atomic_store_explicit for bools, but
- // sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to
- // atomically modify unaligned memory, so fall back to compare and exchange
- // trick As accumulation over booleans are just or operation, do nothing if
- // value is false
- template <>
- struct AtomicType<bool> {
- using type = ::metal::atomic<uint>;
- static inline void atomic_add(device type* data, long offset, bool value) {
- if (!value) {
- return;
- }
- auto ptr = data + (offset >> 2);
- auto old =
- ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
- union {
- uint i;
- bool t[4];
- } val;
- do {
- val.i = old;
- val.t[offset & 3] = true;
- } while (!::metal::atomic_compare_exchange_weak_explicit(
- ptr,
- &old,
- val.i,
- ::metal::memory_order_relaxed,
- ::metal::memory_order_relaxed));
- }
- };
- // ComplexHalf atomic op
- template <>
- struct AtomicType<half2> {
- using type = ::metal::atomic<uint>;
- static inline void atomic_add(device type* data, long offset, half2 value) {
- auto ptr = data + offset;
- auto old =
- ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
- while (!::metal::atomic_compare_exchange_weak_explicit(
- ptr,
- &old,
- as_type<uint>(as_type<half2>(old) + value),
- ::metal::memory_order_relaxed,
- ::metal::memory_order_relaxed))
- ;
- }
- };
- // There are no atomic 64-bit add in Metal yet, but templates below implements a
- // consistent add I.e. if multiple threads are modify the same 64-bit value,
- // results stored at the address will eventually be equal to its original value
- // plus sum of all operands
- template <>
- struct AtomicType<long> {
- using type = ::metal::atomic<uint>;
- static inline void atomic_add(device type* data, long offset, long value) {
- const auto value_bits = as_type<ulong>(value);
- const uint low = static_cast<uint>(value_bits);
- uint high = static_cast<uint>(value_bits >> 32);
- auto ptr = data + (offset << 1);
- auto old_low =
- atomic_fetch_add_explicit(ptr, low, ::metal::memory_order_relaxed);
- high += (old_low + low < old_low) ? 1 : 0;
- atomic_fetch_add_explicit(ptr + 1, high, ::metal::memory_order_relaxed);
- }
- };
- // ComplexFloat atomic op, which again is not really atomic, but eventually
- // consistent
- template <>
- struct AtomicType<float2> {
- using type = ::metal::atomic<float>;
- static inline void atomic_add(device type* data, long offset, float2 value) {
- auto ptr = data + (offset << 1);
- atomic_fetch_add_explicit(ptr + 0, value.x, ::metal::memory_order_relaxed);
- atomic_fetch_add_explicit(ptr + 1, value.y, ::metal::memory_order_relaxed);
- }
- };
- } // namespace metal
- } // namespace c10
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|