BatchLinearAlgebra.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <optional>
  4. #include <string_view>
  5. #include <ATen/Config.h>
  6. #include <ATen/native/DispatchStub.h>
  7. // Forward declare TI
  8. namespace at {
  9. class Tensor;
  10. struct TensorIterator;
  11. namespace native {
  12. enum class TransposeType;
  13. }
  14. }
  15. namespace at::native {
  16. enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
  17. #if AT_BUILD_WITH_LAPACK()
  18. // Define per-batch functions to be used in the implementation of batched
  19. // linear algebra operations
  20. template <class scalar_t>
  21. void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
  22. template <class scalar_t>
  23. void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
  24. template <class scalar_t, class value_t=scalar_t>
  25. void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
  26. template <class scalar_t>
  27. void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
  28. template <class scalar_t>
  29. void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
  30. template <class scalar_t>
  31. void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
  32. template <class scalar_t, class value_t = scalar_t>
  33. void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
  34. template <class scalar_t>
  35. void lapackGels(char trans, int m, int n, int nrhs,
  36. scalar_t *a, int lda, scalar_t *b, int ldb,
  37. scalar_t *work, int lwork, int *info);
  38. template <class scalar_t, class value_t = scalar_t>
  39. void lapackGelsd(int m, int n, int nrhs,
  40. scalar_t *a, int lda, scalar_t *b, int ldb,
  41. value_t *s, value_t rcond, int *rank,
  42. scalar_t* work, int lwork,
  43. value_t *rwork, int* iwork, int *info);
  44. template <class scalar_t, class value_t = scalar_t>
  45. void lapackGelsy(int m, int n, int nrhs,
  46. scalar_t *a, int lda, scalar_t *b, int ldb,
  47. int *jpvt, value_t rcond, int *rank,
  48. scalar_t *work, int lwork, value_t* rwork, int *info);
  49. template <class scalar_t, class value_t = scalar_t>
  50. void lapackGelss(int m, int n, int nrhs,
  51. scalar_t *a, int lda, scalar_t *b, int ldb,
  52. value_t *s, value_t rcond, int *rank,
  53. scalar_t *work, int lwork,
  54. value_t *rwork, int *info);
  55. template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t>
  56. struct lapackLstsq_impl;
  57. template <class scalar_t, class value_t>
  58. struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> {
  59. static void call(
  60. char trans, int m, int n, int nrhs,
  61. scalar_t *a, int lda, scalar_t *b, int ldb,
  62. scalar_t *work, int lwork, int *info, // Gels flavor
  63. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  64. value_t *s, // Gelss flavor
  65. int *iwork // Gelsd flavor
  66. ) {
  67. lapackGels<scalar_t>(
  68. trans, m, n, nrhs,
  69. a, lda, b, ldb,
  70. work, lwork, info);
  71. }
  72. };
  73. template <class scalar_t, class value_t>
  74. struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> {
  75. static void call(
  76. char trans, int m, int n, int nrhs,
  77. scalar_t *a, int lda, scalar_t *b, int ldb,
  78. scalar_t *work, int lwork, int *info, // Gels flavor
  79. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  80. value_t *s, // Gelss flavor
  81. int *iwork // Gelsd flavor
  82. ) {
  83. lapackGelsy<scalar_t, value_t>(
  84. m, n, nrhs,
  85. a, lda, b, ldb,
  86. jpvt, rcond, rank,
  87. work, lwork, rwork, info);
  88. }
  89. };
  90. template <class scalar_t, class value_t>
  91. struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> {
  92. static void call(
  93. char trans, int m, int n, int nrhs,
  94. scalar_t *a, int lda, scalar_t *b, int ldb,
  95. scalar_t *work, int lwork, int *info, // Gels flavor
  96. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  97. value_t *s, // Gelss flavor
  98. int *iwork // Gelsd flavor
  99. ) {
  100. lapackGelsd<scalar_t, value_t>(
  101. m, n, nrhs,
  102. a, lda, b, ldb,
  103. s, rcond, rank,
  104. work, lwork,
  105. rwork, iwork, info);
  106. }
  107. };
  108. template <class scalar_t, class value_t>
  109. struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> {
  110. static void call(
  111. char trans, int m, int n, int nrhs,
  112. scalar_t *a, int lda, scalar_t *b, int ldb,
  113. scalar_t *work, int lwork, int *info, // Gels flavor
  114. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  115. value_t *s, // Gelss flavor
  116. int *iwork // Gelsd flavor
  117. ) {
  118. lapackGelss<scalar_t, value_t>(
  119. m, n, nrhs,
  120. a, lda, b, ldb,
  121. s, rcond, rank,
  122. work, lwork,
  123. rwork, info);
  124. }
  125. };
  126. template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t>
  127. void lapackLstsq(
  128. char trans, int m, int n, int nrhs,
  129. scalar_t *a, int lda, scalar_t *b, int ldb,
  130. scalar_t *work, int lwork, int *info, // Gels flavor
  131. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  132. value_t *s, // Gelss flavor
  133. int *iwork // Gelsd flavor
  134. ) {
  135. lapackLstsq_impl<driver_type, scalar_t, value_t>::call(
  136. trans, m, n, nrhs,
  137. a, lda, b, ldb,
  138. work, lwork, info,
  139. jpvt, rcond, rank, rwork,
  140. s,
  141. iwork);
  142. }
  143. template <class scalar_t>
  144. void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
  145. template <class scalar_t>
  146. void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
  147. template <class scalar_t>
  148. void lapackLdlHermitian(
  149. char uplo,
  150. int n,
  151. scalar_t* a,
  152. int lda,
  153. int* ipiv,
  154. scalar_t* work,
  155. int lwork,
  156. int* info);
  157. template <class scalar_t>
  158. void lapackLdlSymmetric(
  159. char uplo,
  160. int n,
  161. scalar_t* a,
  162. int lda,
  163. int* ipiv,
  164. scalar_t* work,
  165. int lwork,
  166. int* info);
  167. template <class scalar_t>
  168. void lapackLdlSolveHermitian(
  169. char uplo,
  170. int n,
  171. int nrhs,
  172. scalar_t* a,
  173. int lda,
  174. int* ipiv,
  175. scalar_t* b,
  176. int ldb,
  177. int* info);
  178. template <class scalar_t>
  179. void lapackLdlSolveSymmetric(
  180. char uplo,
  181. int n,
  182. int nrhs,
  183. scalar_t* a,
  184. int lda,
  185. int* ipiv,
  186. scalar_t* b,
  187. int ldb,
  188. int* info);
  189. template<class scalar_t, class value_t=scalar_t>
  190. void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
  191. #endif
  192. #if AT_BUILD_WITH_BLAS()
  193. template <class scalar_t>
  194. void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
  195. #endif
  196. using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
  197. DECLARE_DISPATCH(cholesky_fn, cholesky_stub)
  198. using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
  199. DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub)
  200. using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
  201. DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub)
  202. // Converts LAPACK's real-valued eigenvector encoding to complex eigenvectors
  203. TORCH_API void linalg_eig_make_complex_eigenvectors(
  204. const Tensor& complex_vectors,
  205. const Tensor& complex_values,
  206. const Tensor& real_vectors);
  207. DECLARE_DISPATCH(
  208. void(*)(const Tensor&, const Tensor&, const Tensor&),
  209. linalg_eig_make_complex_eigenvectors_stub)
  210. using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
  211. DECLARE_DISPATCH(geqrf_fn, geqrf_stub)
  212. using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
  213. DECLARE_DISPATCH(orgqr_fn, orgqr_stub)
  214. using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
  215. DECLARE_DISPATCH(ormqr_fn, ormqr_stub)
  216. using linalg_eigh_fn = void (*)(
  217. const Tensor& /*eigenvalues*/,
  218. const Tensor& /*eigenvectors*/,
  219. const Tensor& /*infos*/,
  220. bool /*upper*/,
  221. bool /*compute_eigenvectors*/);
  222. DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub)
  223. using lstsq_fn = void (*)(
  224. const Tensor& /*a*/,
  225. Tensor& /*b*/,
  226. Tensor& /*rank*/,
  227. Tensor& /*singular_values*/,
  228. Tensor& /*infos*/,
  229. double /*rcond*/,
  230. std::string /*driver_name*/);
  231. DECLARE_DISPATCH(lstsq_fn, lstsq_stub)
  232. using triangular_solve_fn = void (*)(
  233. const Tensor& /*A*/,
  234. const Tensor& /*B*/,
  235. bool /*left*/,
  236. bool /*upper*/,
  237. TransposeType /*transpose*/,
  238. bool /*unitriangular*/);
  239. DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub)
  240. using lu_factor_fn = void (*)(
  241. const Tensor& /*input*/,
  242. const Tensor& /*pivots*/,
  243. const Tensor& /*infos*/,
  244. bool /*compute_pivots*/);
  245. DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub)
  246. using unpack_pivots_fn = void(*)(
  247. TensorIterator& iter,
  248. const int64_t dim_size,
  249. const int64_t max_pivot);
  250. DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub)
  251. using lu_solve_fn = void (*)(
  252. const Tensor& /*LU*/,
  253. const Tensor& /*pivots*/,
  254. const Tensor& /*B*/,
  255. TransposeType /*trans*/);
  256. DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub)
  257. using ldl_factor_fn = void (*)(
  258. const Tensor& /*LD*/,
  259. const Tensor& /*pivots*/,
  260. const Tensor& /*info*/,
  261. bool /*upper*/,
  262. bool /*hermitian*/);
  263. DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub)
  264. using svd_fn = void (*)(
  265. const Tensor& /*A*/,
  266. const bool /*full_matrices*/,
  267. const bool /*compute_uv*/,
  268. const std::optional<std::string_view>& /*driver*/,
  269. const Tensor& /*U*/,
  270. const Tensor& /*S*/,
  271. const Tensor& /*Vh*/,
  272. const Tensor& /*info*/);
  273. DECLARE_DISPATCH(svd_fn, svd_stub)
  274. using ldl_solve_fn = void (*)(
  275. const Tensor& /*LD*/,
  276. const Tensor& /*pivots*/,
  277. const Tensor& /*result*/,
  278. bool /*upper*/,
  279. bool /*hermitian*/);
  280. DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub)
  281. } // namespace at::native
  282. #else
  283. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  284. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)