PhiloxXpuState.h 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. namespace at {
  4. struct PhiloxXpuState {
  5. PhiloxXpuState() = default;
  6. PhiloxXpuState(uint64_t seed, uint64_t offset) {
  7. seed_.val = seed;
  8. offset_.val = offset;
  9. }
  10. // for graph capture
  11. PhiloxXpuState(
  12. int64_t* seed,
  13. int64_t* offset_extragraph,
  14. uint32_t offset_intragraph) {
  15. seed_.ptr = seed;
  16. offset_.ptr = offset_extragraph;
  17. offset_intragraph_ = offset_intragraph;
  18. captured_ = true;
  19. }
  20. union Payload {
  21. uint64_t val;
  22. int64_t* ptr;
  23. };
  24. Payload seed_{};
  25. Payload offset_{};
  26. uint32_t offset_intragraph_ = 0;
  27. bool captured_ = false;
  28. };
  29. namespace xpu::philox {
  30. inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxXpuState arg) {
  31. if (arg.captured_) {
  32. return std::make_tuple(
  33. static_cast<uint64_t>(*arg.seed_.ptr),
  34. static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
  35. } else {
  36. return std::make_tuple(arg.seed_.val, arg.offset_.val);
  37. }
  38. }
  39. } // namespace xpu::philox
  40. } // namespace at
  41. #else
  42. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  43. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)