| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- // Copyright © 2022 Apple Inc.
- #pragma once
- #include <ATen/Generator.h>
- #include <ATen/detail/MPSHooksInterface.h>
- #include <ATen/mps/MPSEvent.h>
- #include <optional>
- namespace at::mps {
- // The real implementation of MPSHooksInterface
- struct MPSHooks : public at::MPSHooksInterface {
- MPSHooks(at::MPSHooksArgs) {}
- void init() const override;
- // MPSDevice interface
- bool hasMPS() const override;
- bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
- Device getDeviceFromPtr(void* data) const override;
- // MPSGeneratorImpl interface
- const Generator& getDefaultGenerator(
- DeviceIndex device_index = -1) const override;
- Generator getNewGenerator(DeviceIndex device_index = -1) const override;
- // MPSStream interface
- void deviceSynchronize() const override;
- void commitStream() const override;
- void* getCommandBuffer() const override;
- void* getDispatchQueue() const override;
- // MPSAllocator interface
- Allocator* getMPSDeviceAllocator() const override;
- void emptyCache() const override;
- size_t getCurrentAllocatedMemory() const override;
- size_t getDriverAllocatedMemory() const override;
- size_t getRecommendedMaxMemory() const override;
- void setMemoryFraction(double ratio) const override;
- bool isPinnedPtr(const void* data) const override;
- Allocator* getPinnedMemoryAllocator() const override;
- // MPSProfiler interface
- void profilerStartTrace(const std::string& mode, bool waitUntilCompleted)
- const override;
- void profilerStopTrace() const override;
- // MPSEvent interface
- uint32_t acquireEvent(bool enable_timing) const override;
- void releaseEvent(uint32_t event_id) const override;
- void recordEvent(uint32_t event_id) const override;
- void waitForEvent(uint32_t event_id) const override;
- void synchronizeEvent(uint32_t event_id) const override;
- bool queryEvent(uint32_t event_id) const override;
- double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
- const override;
- bool isBuilt() const override {
- return true;
- }
- bool isAvailable() const override {
- return hasMPS();
- }
- bool hasPrimaryContext(DeviceIndex device_index) const override {
- // When MPS is available, it is always in use for the one device.
- return true;
- }
- };
- } // namespace at::mps
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|