Mercurial > octave
changeset 29412:9b6bf68ea663
Add additional svd driver GEJSV for accurate SVD (bug #55727).
* liboctave/numeric/svd.h (svd::gesjv): Add new member function.
(svd::Driver): Add enum GEJSV.
* liboctave/numeric/svd.cc (svd::gejsv): Add new member function.
(svd<T>::svd): Call gejsv if it is selected svd driver. Correct a bug about
right singular vector size when input empty matrix.
(gejsv_lwork): Add new class to compute workspace size of gejsv.
* liboctave/numeric/lo-lapack-proto.h: Declare LAPACK functions: GEJSV for
singular value decomposition, and GELQF, ORMLQ, ORMQR for workspace size
query.
* libinterp/corefcn/svd.cc (svd_driver): Add driver GEJSV.
(Fsvd): Document GEJSV. Add a test for GEJSV.
(Fsvd_driver): Document GEJSV. Add driver name "gejsv". Add a test for GEJSV.
author | Eddy Xiao <bewantbe@gmail.com> |
---|---|
date | Wed, 20 Mar 2019 00:58:11 +0800 |
parents | f26f9dbfb2c5 |
children | 0006d00aa097 |
files | libinterp/corefcn/svd.cc liboctave/numeric/lo-lapack-proto.h liboctave/numeric/svd.cc liboctave/numeric/svd.h |
diffstat | 4 files changed, 723 insertions(+), 21 deletions(-) [+] |
line wrap: on
line diff
--- a/libinterp/corefcn/svd.cc Sun Mar 07 10:55:54 2021 +0100 +++ b/libinterp/corefcn/svd.cc Wed Mar 20 00:58:11 2019 +0800 @@ -62,9 +62,12 @@ static typename octave::math::svd<T>::Driver svd_driver (void) { - return (Vsvd_driver == "gesvd" - ? octave::math::svd<T>::Driver::GESVD - : octave::math::svd<T>::Driver::GESDD); + if (Vsvd_driver == "gejsv") + return octave::math::svd<T>::Driver::GEJSV; + else if (Vsvd_driver == "gesdd") + return octave::math::svd<T>::Driver::GESDD; + else + return octave::math::svd<T>::Driver::GESVD; // default } DEFUN (svd, args, nargout, @@ -164,8 +167,9 @@ singular matrices in addition to singular values) there is a choice of two routines in @sc{lapack}. The default routine used by Octave is @code{gesvd}. The alternative is @code{gesdd} which is 5X faster, but may use more memory -and may be inaccurate for some input matrices. See the documentation for -@code{svd_driver} for more information on choosing a driver. +and may be inaccurate for some input matrices. There is a third routine +@code{gejsv}, suitable for better accuracy at extreme scale. See the +documentation for @code{svd_driver} for more information on choosing a driver. @seealso{svd_driver, svds, eig, lu, chol, hess, qr, qz} @end deftypefn */) { @@ -377,6 +381,15 @@ %!assert <*55710> (1 / svd (-0), Inf) +%!test +%! old_driver = svd_driver ("gejsv"); +%! s0 = [1e-20; 1e-10; 1]; # only gejsv can pass +%! q = sqrt (0.5); +%! a = s0 .* [q, 0, -q; -0.5, q, -0.5; 0.5, q, 0.5]; +%! s1 = svd (a); +%! svd_driver (old_driver); +%! assert (sort (s1), s0, -10 * eps); + %!error svd () %!error svd ([1, 2; 4, 5], 2, 3) */ @@ -388,17 +401,17 @@ @deftypefnx {} {} svd_driver (@var{new_val}, "local") Query or set the underlying @sc{lapack} driver used by @code{svd}. -Currently recognized values are @qcode{"gesdd"} and @qcode{"gesvd"}. -The default is @qcode{"gesvd"}. +Currently recognized values are @qcode{"gesdd"}, @qcode{"gesvd"}, and +@qcode{"gejsv"}. The default is @qcode{"gesvd"}. When called from inside a function with the @qcode{"local"} option, the variable is changed locally for the function and any subroutines it calls. The original variable value is restored when exiting the function. -Algorithm Notes: The @sc{lapack} library provides two routines for calculating -the full singular value decomposition (left and right singular matrices as -well as singular values). When calculating just the singular values the -following discussion is not relevant. +Algorithm Notes: The @sc{lapack} library routines @code{gesvd} and @code{gesdd} +are different only when calculating the full singular value decomposition (left +and right singular matrices as well as singular values). When calculating just +the singular values the following discussion is not relevant. The newer @code{gesdd} routine is based on a Divide-and-Conquer algorithm that is 5X faster than the alternative @code{gesvd}, which is based on QR @@ -406,6 +419,12 @@ For an @nospell{MxN} input matrix the memory usage is of order O(min(M,N) ^ 2), whereas the alternative is of order O(max(M,N)). +The routine @code{gejsv} uses a preconditioned Jacobi SVD algorithm. Unlike +@code{gesvd} and @code{gesdd}, in @code{gejsv}, there is no bidiagonalization +step that could contaminate accuracy in some extreme case. Also, @code{gejsv} +is shown to be optimally accurate in some sense. However, the speed is slower +(single threaded at its core) and uses more memory (O(min(M,N) ^ 2 + M + N)). + Beyond speed and memory issues, there have been instances where some input matrices were not accurately decomposed by @code{gesdd}. See currently active bug @url{https://savannah.gnu.org/bugs/?55564}. Until these accuracy issues @@ -415,7 +434,7 @@ @seealso{svd} @end deftypefn */) { - static const char *driver_names[] = { "gesvd", "gesdd", nullptr }; + static const char *driver_names[] = { "gesvd", "gesdd", "gejsv", nullptr }; return SET_INTERNAL_VARIABLE_CHOICES (svd_driver, driver_names); } @@ -427,8 +446,15 @@ %! [U1, S1, V1] = svd (A); %! svd_driver ("gesdd"); %! [U2, S2, V2] = svd (A); +%! svd_driver ("gejsv"); +%! [U3, S3, V3] = svd (A); +%! assert (svd_driver (), "gejsv"); %! svd_driver (old_driver); -%! assert (U1, U2, 5*eps); -%! assert (S1, S2, 5*eps); -%! assert (V1, V2, 5*eps); +%! assert (U1, U2, 6*eps); +%! assert (S1, S2, 6*eps); +%! assert (V1, V2, 6*eps); +%! z = U1(1,:) ./ U3(1,:); +%! assert (U1, U3 .* z, 100*eps); +%! assert (S1, S3, 6*eps); +%! assert (V1, V3 .* z, 100*eps); */
--- a/liboctave/numeric/lo-lapack-proto.h Sun Mar 07 10:55:54 2021 +0100 +++ b/liboctave/numeric/lo-lapack-proto.h Wed Mar 20 00:58:11 2019 +0800 @@ -266,6 +266,124 @@ F77_DBLE_CMPLX*, F77_DBLE_CMPLX*, const F77_INT&, F77_INT&); + // GELQF + + F77_RET_T + F77_FUNC (cgelqf, CGELQF) (const F77_INT&, const F77_INT&, + F77_CMPLX*, const F77_INT&, + F77_CMPLX*, F77_CMPLX*, + const F77_INT&, F77_INT&); + + F77_RET_T + F77_FUNC (dgelqf, DGELQF) (const F77_INT&, const F77_INT&, + F77_DBLE*, const F77_INT&, + F77_DBLE*, F77_DBLE*, + const F77_INT&, F77_INT&); + + F77_RET_T + F77_FUNC (sgelqf, SGELQF) (const F77_INT&, const F77_INT&, + F77_REAL*, const F77_INT&, + F77_REAL*, F77_REAL*, + const F77_INT&, F77_INT&); + + F77_RET_T + F77_FUNC (zgelqf, ZGELQF) (const F77_INT&, const F77_INT&, + F77_DBLE_CMPLX*, const F77_INT&, + F77_DBLE_CMPLX*, F77_DBLE_CMPLX*, + const F77_INT&, F77_INT&); + + // ORMLQ + + F77_RET_T + F77_FUNC (cormlq, CORMLQ) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, const F77_INT&, + F77_CMPLX*, const F77_INT&, + F77_CMPLX*, F77_CMPLX*, + const F77_INT&, F77_CMPLX*, + const F77_INT&, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (dormlq, DORMLQ) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, const F77_INT&, + F77_DBLE*, const F77_INT&, + F77_DBLE*, F77_DBLE*, + const F77_INT&, F77_DBLE*, + const F77_INT&, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (sormlq, SORMLQ) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, const F77_INT&, + F77_REAL*, const F77_INT&, + F77_REAL*, F77_REAL*, + const F77_INT&, F77_REAL*, + const F77_INT&, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (zormlq, ZORMLQ) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, const F77_INT&, + F77_DBLE_CMPLX*, const F77_INT&, + F77_DBLE_CMPLX*, F77_DBLE_CMPLX*, + const F77_INT&, F77_DBLE_CMPLX*, + const F77_INT&, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + // ORMQR + + F77_RET_T + F77_FUNC (cormqr, CORMQR) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, const F77_INT&, + F77_CMPLX*, const F77_INT&, + F77_CMPLX*, F77_CMPLX*, + const F77_INT&, F77_CMPLX*, + const F77_INT&, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (dormqr, DORMQR) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, const F77_INT&, + F77_DBLE*, const F77_INT&, + F77_DBLE*, F77_DBLE*, + const F77_INT&, F77_DBLE*, + const F77_INT&, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (sormqr, SORMQR) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, const F77_INT&, + F77_REAL*, const F77_INT&, + F77_REAL*, F77_REAL*, + const F77_INT&, F77_REAL*, + const F77_INT&, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (zormqr, ZORMQR) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, const F77_INT&, + F77_DBLE_CMPLX*, const F77_INT&, + F77_DBLE_CMPLX*, F77_DBLE_CMPLX*, + const F77_INT&, F77_DBLE_CMPLX*, + const F77_INT&, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + // GESDD F77_RET_T @@ -354,6 +472,90 @@ F77_CHAR_ARG_LEN_DECL F77_CHAR_ARG_LEN_DECL); + // GEJSV + + F77_RET_T + F77_FUNC (cgejsv, CGEJSV) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, + F77_CMPLX*, const F77_INT&, F77_REAL*, + F77_CMPLX*, const F77_INT&, + F77_CMPLX*, const F77_INT&, + F77_CMPLX*, const F77_INT&, + F77_REAL*, const F77_INT&, + F77_INT *, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (dgejsv, DGEJSV) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, + F77_DBLE*, const F77_INT&, F77_DBLE*, + F77_DBLE*, const F77_INT&, + F77_DBLE*, const F77_INT&, + F77_DBLE*, const F77_INT&, + F77_INT *, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (sgejsv, SGEJSV) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, + F77_REAL*, const F77_INT&, F77_REAL*, + F77_REAL*, const F77_INT&, + F77_REAL*, const F77_INT&, + F77_REAL*, const F77_INT&, + F77_INT *, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (zgejsv, ZGEJSV) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const F77_INT&, const F77_INT&, + F77_DBLE_CMPLX*, const F77_INT&, F77_DBLE*, + F77_DBLE_CMPLX*, const F77_INT&, + F77_DBLE_CMPLX*, const F77_INT&, + F77_DBLE_CMPLX*, const F77_INT&, + F77_DBLE*, const F77_INT&, + F77_INT *, F77_INT& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + // GEESX typedef F77_INT (*double_selector) (const F77_DBLE&, const F77_DBLE&);
--- a/liboctave/numeric/svd.cc Sun Mar 07 10:55:54 2021 +0100 +++ b/liboctave/numeric/svd.cc Wed Mar 20 00:58:11 2019 +0800 @@ -41,6 +41,267 @@ #include "lo-lapack-proto.h" #include "svd.h" +// class to compute optimal work space size (lwork) for DGEJSV and SGEJSV +template<typename T> +class +gejsv_lwork +{ +public: + gejsv_lwork () = delete; + + // Unfortunately, dgejsv and sgejsv do not provide estimation of 'lwork'. + // Thus, we have to estimate it according to corresponding LAPACK + // documentation and related source codes (e.g. cgejsv). + // In LAPACKE (C interface to LAPACK), the memory handling code in + // LAPACKE_dgejsv() (lapacke_dgejsv.c, last visit 2019-02-17) uses + // the minimum required working space. In contrast, here the optimal + // working space size is computed, at the cost of much longer code. + + static F77_INT optimal (char& joba, char& jobu, char& jobv, + F77_INT m, F77_INT n); + +private: + typedef typename T::element_type P; + + // functions could be called from GEJSV + static F77_INT geqp3_lwork (F77_INT m, F77_INT n, + P *a, F77_INT lda, + F77_INT* jpvt, P *tau, P *work, + F77_INT lwork, F77_INT& info); + + static F77_INT geqrf_lwork (F77_INT m, F77_INT n, + P *a, F77_INT lda, + P *tau, P *work, + F77_INT lwork, F77_INT& info); + + static F77_INT gelqf_lwork (F77_INT m, F77_INT n, + P *a, F77_INT lda, + P *tau, P *work, + F77_INT lwork, F77_INT& info); + + static F77_INT ormlq_lwork (char& side, char& trans, + F77_INT m, F77_INT n, F77_INT k, + P *a, F77_INT lda, + P *tau, P *c, F77_INT ldc, + P *work, F77_INT lwork, F77_INT& info); + + static F77_INT ormqr_lwork (char& side, char& trans, + F77_INT m, F77_INT n, F77_INT k, + P *a, F77_INT lda, + P *tau, P *c, F77_INT ldc, + P *work, F77_INT lwork, F77_INT& info); +}; + +#define GEJSV_REAL_QP3_LWORK(f, F) \ + F77_XFCN (f, F, (m, n, a, lda, jpvt, \ + tau, work, lwork, info)) + +#define GEJSV_REAL_QR_LWORK(f, F) \ + F77_XFCN (f, F, (m, n, a, lda, \ + tau, work, lwork, info)) + +#define GEJSV_REAL_ORM_LWORK(f, F) \ + F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&side, 1), \ + F77_CONST_CHAR_ARG2 (&trans, 1), \ + m, n, k, a, lda, tau, \ + c, ldc, work, lwork, info \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1))) + +// For Matrix +template<> +F77_INT +gejsv_lwork<Matrix>::geqp3_lwork (F77_INT m, F77_INT n, + P *a, F77_INT lda, + F77_INT* jpvt, P *tau, P *work, + F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_QP3_LWORK (dgeqp3, DGEQP3); + return static_cast<F77_INT> (work[0]); +} + +template<> +F77_INT +gejsv_lwork<Matrix>::geqrf_lwork (F77_INT m, F77_INT n, + P *a, F77_INT lda, + P *tau, P *work, + F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_QR_LWORK (dgeqrf, DGEQRF); + return static_cast<F77_INT> (work[0]); +} + +template<> +F77_INT +gejsv_lwork<Matrix>::gelqf_lwork (F77_INT m, F77_INT n, + P *a, F77_INT lda, + P *tau, P *work, + F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_QR_LWORK (dgelqf, DGELQF); + return static_cast<F77_INT> (work[0]); +} + +template<> +F77_INT +gejsv_lwork<Matrix>::ormlq_lwork (char& side, char& trans, + F77_INT m, F77_INT n, F77_INT k, + P *a, F77_INT lda, + P *tau, P *c, F77_INT ldc, + P *work, F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_ORM_LWORK (dormlq, DORMLQ); + return static_cast<F77_INT> (work[0]); +} + +template<> +F77_INT +gejsv_lwork<Matrix>::ormqr_lwork (char& side, char& trans, + F77_INT m, F77_INT n, F77_INT k, + P *a, F77_INT lda, + P *tau, P *c, F77_INT ldc, + P *work, F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_ORM_LWORK (dormqr, DORMQR); + return static_cast<F77_INT> (work[0]); +} + +// For FloatMatrix +template<> +F77_INT +gejsv_lwork<FloatMatrix>::geqp3_lwork (F77_INT m, F77_INT n, + P *a, F77_INT lda, + F77_INT* jpvt, P *tau, P *work, + F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_QP3_LWORK (sgeqp3, SGEQP3); + return static_cast<F77_INT> (work[0]); +} + +template<> +F77_INT +gejsv_lwork<FloatMatrix>::geqrf_lwork (F77_INT m, F77_INT n, + P *a, F77_INT lda, + P *tau, P *work, + F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_QR_LWORK (sgeqrf, SGEQRF); + return static_cast<F77_INT> (work[0]); +} + +template<> +F77_INT +gejsv_lwork<FloatMatrix>::gelqf_lwork (F77_INT m, F77_INT n, + P *a, F77_INT lda, + P *tau, P *work, + F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_QR_LWORK (sgelqf, SGELQF); + return static_cast<F77_INT> (work[0]); +} + +template<> +F77_INT +gejsv_lwork<FloatMatrix>::ormlq_lwork (char& side, char& trans, + F77_INT m, F77_INT n, F77_INT k, + P *a, F77_INT lda, + P *tau, P *c, F77_INT ldc, + P *work, F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_ORM_LWORK (sormlq, SORMLQ); + return static_cast<F77_INT> (work[0]); +} + +template<> +F77_INT +gejsv_lwork<FloatMatrix>::ormqr_lwork (char& side, char& trans, + F77_INT m, F77_INT n, F77_INT k, + P *a, F77_INT lda, + P *tau, P *c, F77_INT ldc, + P *work, F77_INT lwork, F77_INT& info) +{ + GEJSV_REAL_ORM_LWORK (sormqr, SORMQR); + return static_cast<F77_INT> (work[0]); +} + +#undef GEJSV_REAL_QP3_LWORK +#undef GEJSV_REAL_QR_LWORK +#undef GEJSV_REAL_ORM_LWORK + +template<typename T> +F77_INT +gejsv_lwork<T>::optimal (char& joba, char& jobu, char& jobv, + F77_INT m, F77_INT n) +{ + F77_INT lwork = -1; + std::vector<P> work (2); // dummy work space + + // variables that mimic running environment of gejsv + F77_INT lda = std::max (m, 1); + F77_INT ierr = 0; + char side = 'L'; + char trans = 'N'; + std::vector<P> mat_a (1); + P *a = mat_a.data (); // dummy input matrix + std::vector<F77_INT> vec_jpvt = {0}; + P *tau = work.data (); + P *u = work.data (); + P *v = work.data (); + + bool need_lsvec = jobu == 'U' || jobu == 'F'; + bool need_rsvec = jobv == 'V' || jobv == 'J'; + + F77_INT lw_pocon = 3 * n; // for [s,d]pocon + F77_INT lw_geqp3 = geqp3_lwork (m, n, a, lda, vec_jpvt.data (), + tau, work.data (), -1, ierr); + F77_INT lw_geqrf = geqrf_lwork (m, n, a, lda, + tau, work.data (), -1, ierr); + + if (! (need_lsvec || need_rsvec) ) + { + // only SIGMA is needed + if (! (joba == 'E' || joba == 'G') ) + lwork = std::max<int> ({2*m + n, n + lw_geqp3, n + lw_geqrf, 7}); + else + lwork = std::max<int> ({2*m + n, n + lw_geqp3, n + lw_geqrf, + n + n*n + lw_pocon, 7}); + } + else if (need_rsvec && ! need_lsvec) + { + // SIGMA and the right singular vectors are needed + F77_INT lw_gelqf = gelqf_lwork (n, n, a, lda, + tau, work.data (), -1, ierr); + trans = 'T'; + F77_INT lw_ormlq = ormlq_lwork (side, trans, n, n, n, a, lda, + tau, v, n, work.data (), -1, ierr); + lwork = std::max<int> ({2*m + n, n + lw_geqp3, n + lw_pocon, + n + lw_gelqf, 2*n + lw_geqrf, n + lw_ormlq}); + } + else if (need_lsvec && ! need_rsvec) + { + // SIGMA and the left singular vectors are needed + F77_INT n1 = (jobu == 'U') ? n : m; // size of U is m x n1 + F77_INT lw_ormqr = ormqr_lwork (side, trans, m, n1, n, a, lda, + tau, u, m, work.data (), -1, ierr); + lwork = std::max<int> ({2*m + n, n + lw_geqp3, n + lw_pocon, + 2*n + lw_geqrf, n + lw_ormqr}); + } + else // full SVD is needed + { + if (jobv == 'V') + lwork = std::max (2*m + n, 6*n + 2*n*n); + else if (jobv == 'J') + lwork = std::max<int> ({2*m + n, 4*n + n*n, 2*n + n*n + 6}); + + F77_INT n1 = (jobu == 'U') ? n : m; // size of U is m x n1 + F77_INT lw_ormqr = ormqr_lwork (side, trans, m, n1, n, a, lda, + tau, u, m, work.data (), -1, ierr); + lwork = std::max (lwork, n + lw_ormqr); + } + + return lwork; +} + namespace octave { namespace math @@ -274,6 +535,138 @@ #undef GESDD_REAL_STEP #undef GESDD_COMPLEX_STEP + // GEJSV specializations + +#define GEJSV_REAL_STEP(f, F) \ + F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&joba, 1), \ + F77_CONST_CHAR_ARG2 (&jobu, 1), \ + F77_CONST_CHAR_ARG2 (&jobv, 1), \ + F77_CONST_CHAR_ARG2 (&jobr, 1), \ + F77_CONST_CHAR_ARG2 (&jobt, 1), \ + F77_CONST_CHAR_ARG2 (&jobp, 1), \ + m, n, tmp_data, m1, s_vec, u, m1, v, nrow_v1, \ + work.data (), lwork, iwork.data (), info \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1))) + +#define GEJSV_COMPLEX_STEP(f, F, CMPLX_ARG) \ + F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&joba, 1), \ + F77_CONST_CHAR_ARG2 (&jobu, 1), \ + F77_CONST_CHAR_ARG2 (&jobv, 1), \ + F77_CONST_CHAR_ARG2 (&jobr, 1), \ + F77_CONST_CHAR_ARG2 (&jobt, 1), \ + F77_CONST_CHAR_ARG2 (&jobp, 1), \ + m, n, CMPLX_ARG (tmp_data), m1, \ + s_vec, CMPLX_ARG (u), m1, \ + CMPLX_ARG (v), nrow_v1, \ + CMPLX_ARG (work.data ()), lwork, \ + rwork.data (), lrwork, iwork.data (), info \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1) \ + F77_CHAR_ARG_LEN (1))) + + // DGEJSV + template<> + void + svd<Matrix>::gejsv (char& joba, char& jobu, char& jobv, + char& jobr, char& jobt, char& jobp, + F77_INT m, F77_INT n, + P *tmp_data, F77_INT m1, DM_P *s_vec, P *u, + P *v, F77_INT nrow_v1, std::vector<P>& work, + F77_INT& lwork, std::vector<F77_INT>& iwork, + F77_INT& info) + { + lwork = gejsv_lwork<Matrix>::optimal(joba, jobu, jobv, m, n); + work.reserve (lwork); + + GEJSV_REAL_STEP (dgejsv, DGEJSV); + } + + // SGEJSV + template<> + void + svd<FloatMatrix>::gejsv (char& joba, char& jobu, char& jobv, + char& jobr, char& jobt, char& jobp, + F77_INT m, F77_INT n, + P *tmp_data, F77_INT m1, DM_P *s_vec, P *u, + P *v, F77_INT nrow_v1, std::vector<P>& work, + F77_INT& lwork, std::vector<F77_INT>& iwork, + F77_INT& info) + { + lwork = gejsv_lwork<FloatMatrix>::optimal(joba, jobu, jobv, m, n); + work.reserve (lwork); + + GEJSV_REAL_STEP (sgejsv, SGEJSV); + } + + // ZGEJSV + template<> + void + svd<ComplexMatrix>::gejsv (char& joba, char& jobu, char& jobv, + char& jobr, char& jobt, char& jobp, + F77_INT m, F77_INT n, + P *tmp_data, F77_INT m1, DM_P *s_vec, P *u, + P *v, F77_INT nrow_v1, std::vector<P>& work, + F77_INT& lwork, std::vector<F77_INT>& iwork, + F77_INT& info) + { + F77_INT lrwork = -1; // work space size query + std::vector<double> rwork (1); + work.reserve (2); + + GEJSV_COMPLEX_STEP (zgejsv, ZGEJSV, F77_DBLE_CMPLX_ARG); + + lwork = static_cast<F77_INT> (work[0].real ()); + work.reserve (lwork); + + lrwork = static_cast<F77_INT> (rwork[0]); + rwork.reserve (lrwork); + + F77_INT liwork = static_cast<F77_INT> (iwork[0]); + iwork.reserve (liwork); + + GEJSV_COMPLEX_STEP (zgejsv, ZGEJSV, F77_DBLE_CMPLX_ARG); + } + + // CGEJSV + template<> + void + svd<FloatComplexMatrix>::gejsv (char& joba, char& jobu, char& jobv, + char& jobr, char& jobt, char& jobp, + F77_INT m, F77_INT n, P *tmp_data, + F77_INT m1, DM_P *s_vec, P *u, P *v, + F77_INT nrow_v1, std::vector<P>& work, + F77_INT& lwork, + std::vector<F77_INT>& iwork, F77_INT& info) + { + F77_INT lrwork = -1; // work space size query + std::vector<float> rwork (1); + work.reserve (2); + + GEJSV_COMPLEX_STEP (cgejsv, CGEJSV, F77_CMPLX_ARG); + + lwork = static_cast<F77_INT> (work[0].real ()); + work.reserve (lwork); + + lrwork = static_cast<F77_INT> (rwork[0]); + rwork.reserve (lrwork); + + F77_INT liwork = static_cast<F77_INT> (iwork[0]); + iwork.reserve (liwork); + + GEJSV_COMPLEX_STEP (cgejsv, CGEJSV, F77_CMPLX_ARG); + } + +#undef GEJSV_REAL_STEP +#undef GEJSV_COMPLEX_STEP + template<typename T> svd<T>::svd (const T& a, svd::Type type, svd::Driver driver) @@ -301,7 +694,7 @@ case svd::Type::economy: left_sm = T (m, 0, 0); sigma = DM_T (0, 0); - right_sm = T (0, n, 0); + right_sm = T (n, 0, 0); break; case svd::Type::sigma_only: @@ -358,7 +751,12 @@ DM_P *s_vec = sigma.fortran_vec (); if (! (jobv == 'N' || jobv == 'O')) - right_sm.resize (nrow_vt, n); + { + if (m_driver == svd::Driver::GEJSV) + right_sm.resize (n, nrow_vt); + else + right_sm.resize (nrow_vt, n); + } P *vt = right_sm.fortran_vec (); @@ -368,8 +766,9 @@ std::vector<P> work (1); - F77_INT m1 = std::max (m, static_cast<F77_INT> (1)); - F77_INT nrow_vt1 = std::max (nrow_vt, static_cast<F77_INT> (1)); + const F77_INT f77_int_one = static_cast<F77_INT> (1); + F77_INT m1 = std::max (m, f77_int_one); + F77_INT nrow_vt1 = std::max (nrow_vt, f77_int_one); if (m_driver == svd::Driver::GESVD) gesvd (jobu, jobv, m, n, tmp_data, m1, s_vec, u, vt, nrow_vt1, @@ -384,6 +783,71 @@ gesdd (jobz, m, n, tmp_data, m1, s_vec, u, vt, nrow_vt1, work, lwork, iwork.data (), info); } + else if (m_driver == svd::Driver::GEJSV) + { + bool transposed = false; + if (n > m) + { + // GEJSV only accepts m >= n, thus we need to transpose here + transposed = true; + + std::swap (m, n); + m1 = std::max (m, f77_int_one); + nrow_vt1 = std::max (n, f77_int_one); // we have m > n + if (m_type == svd::Type::sigma_only) + nrow_vt1 = 1; + std::swap (jobu, jobv); + + atmp = atmp.hermitian (); + tmp_data = atmp.fortran_vec (); + + // Swap pointers of U and V. + u = right_sm.fortran_vec (); + vt = left_sm.fortran_vec (); + } + + // translate jobu and jobv from gesvd to gejsv. + assert ('A' <= jobu && jobu <= 'S' && 'A' <= jobv && jobv <= 'S'); + char job_svd2jsv[1 + 'S' - 'A'][2] = {0}; + job_svd2jsv['A' - 'A'][0] = 'F'; + job_svd2jsv['A' - 'A'][1] = 'J'; + job_svd2jsv['S' - 'A'][0] = 'U'; + job_svd2jsv['S' - 'A'][1] = 'V'; + job_svd2jsv['O' - 'A'][0] = 'W'; + job_svd2jsv['O' - 'A'][1] = 'W'; + job_svd2jsv['N' - 'A'][0] = 'N'; + job_svd2jsv['N' - 'A'][1] = 'N'; + jobu = job_svd2jsv[jobu - 'A'][0]; + jobv = job_svd2jsv[jobv - 'A'][1]; + + char joba = 'F'; // 'F': most conservative + char jobr = 'R'; // 'R' is recommended. + char jobt = 'N'; // or 'T', but that requires U and V appear together + char jobp = 'N'; // use 'P' if denormal is poorly implemented. + + std::vector<F77_INT> iwork (std::max (m + 3*n, 1)); + + gejsv (joba, jobu, jobv, jobr, jobt, jobp, m, n, tmp_data, m1, + s_vec, u, vt, nrow_vt1, work, lwork, iwork, info); + + if (iwork[2] == 1) + (*current_liboctave_warning_with_id_handler) + ("Octave:convergence", "svd: (driver: GEJSV) " + "Denormal occured, possible loss of accuracy."); + + if (info < 0) + (*current_liboctave_error_handler) + ("svd: (driver: GEJSV) Illegal argument at #%d", + static_cast<int> (-info)); + else if (info > 0) + (*current_liboctave_warning_with_id_handler) + ("Octave:convergence", "svd: (driver: GEJSV) " + "Fail to converge within max sweeps, " + "possible inaccurate result."); + + if (transposed) // put things that need to transpose back here + std::swap (m, n); + } else (*current_liboctave_error_handler) ("svd: unknown driver"); @@ -394,7 +858,8 @@ sigma.dgxelem (i) = DM_P (0); } - if (! (jobv == 'N' || jobv == 'O')) + // GESVD and GESDD return VT instead of V, GEJSV return V. + if (! (jobv == 'N' || jobv == 'O') && (m_driver != svd::Driver::GEJSV)) right_sm = right_sm.hermitian (); }
--- a/liboctave/numeric/svd.h Sun Mar 07 10:55:54 2021 +0100 +++ b/liboctave/numeric/svd.h Wed Mar 20 00:58:11 2019 +0800 @@ -53,7 +53,8 @@ enum class Driver { GESVD, - GESDD + GESDD, + GEJSV }; svd (void) @@ -113,6 +114,14 @@ P *vt, octave_f77_int_type nrow_vt1, std::vector<P>& work, octave_f77_int_type& lwork, octave_f77_int_type *iwork, octave_f77_int_type& info); + + void gejsv (char& joba, char& jobu, char& jobv, char& jobr, char& jobt, + char& jobp, octave_f77_int_type m, octave_f77_int_type n, + P *tmp_data, octave_f77_int_type m1, DM_P *s_vec, P *u, + P *v, octave_f77_int_type nrow_v1, std::vector<P>& work, + octave_f77_int_type& lwork, + std::vector<octave_f77_int_type>& iwork, + octave_f77_int_type& info); }; } }