Gelu.h 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/util/Exception.h>
  4. #include <string_view>
  5. namespace at::native {
  6. // These constants control the approximation behavior of gelu function.
  7. enum class GeluType {
  8. None, // Baseline Gelu
  9. Tanh, // Tanh Gelu Approximation
  10. END
  11. };
  12. inline GeluType get_gelutype_enum(const std::string_view approximate) {
  13. if (approximate == "none") {
  14. return GeluType::None;
  15. } else if (approximate == "tanh") {
  16. return GeluType::Tanh;
  17. } else {
  18. TORCH_CHECK(false, "approximate argument must be either none or tanh.");
  19. }
  20. }
  21. inline std::string gelutype_to_string(const GeluType type) {
  22. switch(type) {
  23. case GeluType::None: return "none";
  24. case GeluType::Tanh: return "tanh";
  25. default: TORCH_CHECK(false, "unknown GELU type: ", static_cast<int>(type));
  26. }
  27. }
  28. } // namespace at::native
  29. #else
  30. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  31. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)