| 1234567891011121314151617181920212223242526272829303132333435363738 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <c10/util/Exception.h>
- #include <string_view>
- namespace at::native {
- // These constants control the approximation behavior of gelu function.
- enum class GeluType {
- None, // Baseline Gelu
- Tanh, // Tanh Gelu Approximation
- END
- };
- inline GeluType get_gelutype_enum(const std::string_view approximate) {
- if (approximate == "none") {
- return GeluType::None;
- } else if (approximate == "tanh") {
- return GeluType::Tanh;
- } else {
- TORCH_CHECK(false, "approximate argument must be either none or tanh.");
- }
- }
- inline std::string gelutype_to_string(const GeluType type) {
- switch(type) {
- case GeluType::None: return "none";
- case GeluType::Tanh: return "tanh";
- default: TORCH_CHECK(false, "unknown GELU type: ", static_cast<int>(type));
- }
- }
- } // namespace at::native
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|