diff liboctave/CmplxSVD.cc @ 10601:3ce0c530a9c9

implement svd_driver
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 03 May 2010 13:21:35 +0200
parents 12884915a8e4
children 8a5e980da6aa
line wrap: on
line diff
--- a/liboctave/CmplxSVD.cc	Sun May 02 22:05:41 2010 -0700
+++ b/liboctave/CmplxSVD.cc	Mon May 03 13:21:35 2010 +0200
@@ -28,6 +28,7 @@
 #include "CmplxSVD.h"
 #include "f77-fcn.h"
 #include "lo-error.h"
+#include "oct-locbuf.h"
 
 extern "C"
 {
@@ -40,6 +41,14 @@
                              double*, octave_idx_type&
                              F77_CHAR_ARG_LEN_DECL
                              F77_CHAR_ARG_LEN_DECL);
+
+  F77_RET_T
+  F77_FUNC (zgesdd, ZGESDD) (F77_CONST_CHAR_ARG_DECL,
+                             const octave_idx_type&, const octave_idx_type&, Complex*,
+                             const octave_idx_type&, double*, Complex*, const octave_idx_type&,
+                             Complex*, const octave_idx_type&, Complex*, const octave_idx_type&,
+                             double*, octave_idx_type *, octave_idx_type&
+                             F77_CHAR_ARG_LEN_DECL);
 }
 
 ComplexMatrix
@@ -69,7 +78,7 @@
 }
 
 octave_idx_type
-ComplexSVD::init (const ComplexMatrix& a, SVD::type svd_type)
+ComplexSVD::init (const ComplexMatrix& a, SVD::type svd_type, SVD::driver svd_driver)
 {
   octave_idx_type info;
 
@@ -144,24 +153,50 @@
   octave_idx_type one = 1;
   octave_idx_type m1 = std::max (m, one), nrow_vt1 = std::max (nrow_vt, one);
 
-  F77_XFCN (zgesvd, ZGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
-                             F77_CONST_CHAR_ARG2 (&jobv, 1),
-                             m, n, tmp_data, m1, s_vec, u, m1, vt,
-                             nrow_vt1, work.fortran_vec (), lwork,
-                             rwork.fortran_vec (), info
-                             F77_CHAR_ARG_LEN (1)
-                             F77_CHAR_ARG_LEN (1)));
+  if (svd_driver == SVD::GESVD)
+    {
+      F77_XFCN (zgesvd, ZGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
+                                 F77_CONST_CHAR_ARG2 (&jobv, 1),
+                                 m, n, tmp_data, m1, s_vec, u, m1, vt,
+                                 nrow_vt1, work.fortran_vec (), lwork,
+                                 rwork.fortran_vec (), info
+                                 F77_CHAR_ARG_LEN (1)
+                                 F77_CHAR_ARG_LEN (1)));
+
+      lwork = static_cast<octave_idx_type> (work(0).real ());
+      work.resize (lwork, 1);
 
-  lwork = static_cast<octave_idx_type> (work(0).real ());
-  work.resize (lwork, 1);
+      F77_XFCN (zgesvd, ZGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
+                                 F77_CONST_CHAR_ARG2 (&jobv, 1),
+                                 m, n, tmp_data, m1, s_vec, u, m1, vt,
+                                 nrow_vt1, work.fortran_vec (), lwork,
+                                 rwork.fortran_vec (), info
+                                 F77_CHAR_ARG_LEN (1)
+                                 F77_CHAR_ARG_LEN (1)));
+    }
+  else if (svd_driver == SVD::GESDD)
+    {
+      assert (jobu == jobv);
+      char jobz = jobu;
+      OCTAVE_LOCAL_BUFFER (octave_idx_type, iwork, 8*min_mn);
 
-  F77_XFCN (zgesvd, ZGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
-                             F77_CONST_CHAR_ARG2 (&jobv, 1),
-                             m, n, tmp_data, m1, s_vec, u, m1, vt,
-                             nrow_vt1, work.fortran_vec (), lwork,
-                             rwork.fortran_vec (), info
-                             F77_CHAR_ARG_LEN (1)
-                             F77_CHAR_ARG_LEN (1)));
+      F77_XFCN (zgesdd, ZGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1),
+                                 m, n, tmp_data, m1, s_vec, u, m1, vt,
+                                 nrow_vt1, work.fortran_vec (), lwork,
+                                 rwork.fortran_vec (), iwork, info
+                                 F77_CHAR_ARG_LEN (1)));
+
+      lwork = static_cast<octave_idx_type> (work(0).real ());
+      work.resize (lwork, 1);
+
+      F77_XFCN (zgesdd, ZGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1),
+                                 m, n, tmp_data, m1, s_vec, u, m1, vt,
+                                 nrow_vt1, work.fortran_vec (), lwork,
+                                 rwork.fortran_vec (), iwork, info
+                                 F77_CHAR_ARG_LEN (1)));
+    }
+  else
+    assert (0); // impossible
 
   if (! (jobv == 'N' || jobv == 'O'))
     right_sm = right_sm.hermitian ();