SpectralOpsUtils.h 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <string>
  4. #include <stdexcept>
  5. #include <sstream>
  6. #include <c10/core/ScalarType.h>
  7. #include <c10/util/ArrayRef.h>
  8. #include <c10/util/Exception.h>
  9. #include <ATen/native/DispatchStub.h>
  10. #include <ATen/core/TensorBase.h>
  11. namespace at::native {
  12. // Normalization types used in _fft_with_size
  13. enum class fft_norm_mode {
  14. none, // No normalization
  15. by_root_n, // Divide by sqrt(signal_size)
  16. by_n, // Divide by signal_size
  17. };
  18. // NOTE [ Fourier Transform Conjugate Symmetry ]
  19. //
  20. // Real-to-complex Fourier transform satisfies the conjugate symmetry. That is,
  21. // assuming X is the transformed K-dimensional signal, we have
  22. //
  23. // X[i_1, ..., i_K] = X[j_i, ..., j_K]*,
  24. //
  25. // where j_k = (N_k - i_k) mod N_k, N_k being the signal size at dim k,
  26. // * is the conjugate operator.
  27. //
  28. // Therefore, in such cases, FFT libraries return only roughly half of the
  29. // values to avoid redundancy:
  30. //
  31. // X[:, :, ..., :floor(N / 2) + 1]
  32. //
  33. // This is also the assumption in cuFFT and MKL. In ATen SpectralOps, such
  34. // halved signal will also be returned by default (flag onesided=True).
  35. // The following infer_ft_real_to_complex_onesided_size function calculates the
  36. // onesided size from the twosided size.
  37. //
  38. // Note that this loses some information about the size of signal at last
  39. // dimension. E.g., both 11 and 10 maps to 6. Hence, the following
  40. // infer_ft_complex_to_real_onesided_size function takes in optional parameter
  41. // to infer the twosided size from given onesided size.
  42. //
  43. // cuFFT doc: http://docs.nvidia.com/cuda/cufft/index.html#multi-dimensional
  44. // MKL doc: https://software.intel.com/en-us/mkl-developer-reference-c-dfti-complex-storage-dfti-real-storage-dfti-conjugate-even-storage#CONJUGATE_EVEN_STORAGE
  45. inline int64_t infer_ft_real_to_complex_onesided_size(int64_t real_size) {
  46. return (real_size / 2) + 1;
  47. }
  48. inline int64_t infer_ft_complex_to_real_onesided_size(int64_t complex_size,
  49. int64_t expected_size=-1) {
  50. int64_t base = (complex_size - 1) * 2;
  51. if (expected_size < 0) {
  52. return base + 1;
  53. } else if (base == expected_size) {
  54. return base;
  55. } else if (base + 1 == expected_size) {
  56. return base + 1;
  57. } else {
  58. std::ostringstream ss;
  59. ss << "expected real signal size " << expected_size << " is incompatible "
  60. << "with onesided complex frequency size " << complex_size;
  61. TORCH_CHECK(false, ss.str());
  62. }
  63. }
  64. using fft_fill_with_conjugate_symmetry_fn =
  65. void (*)(ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef half_sizes,
  66. IntArrayRef in_strides, const void* in_data,
  67. IntArrayRef out_strides, void* out_data);
  68. DECLARE_DISPATCH(fft_fill_with_conjugate_symmetry_fn, fft_fill_with_conjugate_symmetry_stub)
  69. // In real-to-complex transform, cuFFT and MKL only fill half of the values
  70. // due to conjugate symmetry. This function fills in the other half of the full
  71. // fft by using the Hermitian symmetry in the signal.
  72. // self should be the shape of the full signal and dims.back() should be the
  73. // one-sided dimension.
  74. // See NOTE [ Fourier Transform Conjugate Symmetry ]
  75. TORCH_API void _fft_fill_with_conjugate_symmetry_(const Tensor& self, IntArrayRef dims);
  76. } // namespace at::native
  77. #else
  78. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  79. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)