diff liboctave/numeric/CmplxCHOL.cc @ 20497:5ce959c55cc0

Propagate 'lower' in chol(a, 'lower') to underlying library function. * chol.cc (chol): Send 'L' parameter correctly when chol is called with 'lower'. * floatCHOL.cc (init): Propagate 'lower' to underlying library function. * floatCHOL.h: Modify the prototype of methods. * fMatrix.cc (inverse): Invoke chol with additional parameter. * dbleCHOL.cc (init): Propagate 'lower' to underlying library function. * dbleCHOL.h: Modify the prototype of methods. * dMatrix.cc (inverse): Invoke chol with additional parameter. * CmplxCHOL.cc (init): Propagate 'lower' to underlying library function. * CmplxCHOL.h: Modify the prototype of methods. * CMatrix.cc (inverse): Invoke chol with additional parameter.
author PrasannaKumar Muralidharan <prasannatsmkumar@gmail.com>
date Sun, 24 Aug 2014 19:35:06 +0530
parents a9574e3c6e9e
children dcfbf4c1c3c8
line wrap: on
line diff
--- a/liboctave/numeric/CmplxCHOL.cc	Thu Aug 20 14:37:57 2015 -0400
+++ b/liboctave/numeric/CmplxCHOL.cc	Sun Aug 24 19:35:06 2014 +0530
@@ -86,7 +86,7 @@
 }
 
 octave_idx_type
-ComplexCHOL::init (const ComplexMatrix& a, bool calc_cond)
+ComplexCHOL::init (const ComplexMatrix& a, bool upper, bool calc_cond)
 {
   octave_idx_type a_nr = a.rows ();
   octave_idx_type a_nc = a.cols ();
@@ -101,13 +101,28 @@
   octave_idx_type n = a_nc;
   octave_idx_type info;
 
+  is_upper = upper;
+
   chol_mat.clear (n, n);
-  for (octave_idx_type j = 0; j < n; j++)
+  if (is_upper)
     {
-      for (octave_idx_type i = 0; i <= j; i++)
-        chol_mat.xelem (i, j) = a(i, j);
-      for (octave_idx_type i = j+1; i < n; i++)
-        chol_mat.xelem (i, j) = 0.0;
+      for (octave_idx_type j = 0; j < n; j++)
+        {
+          for (octave_idx_type i = 0; i <= j; i++)
+            chol_mat.xelem (i, j) = a (i, j);
+          for (octave_idx_type i = j + 1; i < n; i++)
+            chol_mat.xelem (i, j) = 0.0;
+        }
+     }
+  else
+    {
+      for (octave_idx_type j = 0; j < n; j++)
+        {
+          for (octave_idx_type i = 0; i < j; i++)
+            chol_mat.xelem (i, j) = 0.0;
+       	  for (octave_idx_type i = j; i < n; i++)
+            chol_mat.xelem (i, j) = a (i, j);
+        }
     }
   Complex *h = chol_mat.fortran_vec ();
 
@@ -116,8 +131,18 @@
   if (calc_cond)
     anorm = xnorm (a, 1);
 
-  F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
-                             F77_CHAR_ARG_LEN (1)));
+  if (is_upper)
+    {
+      F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
+                                 n, h, n, info
+                                 F77_CHAR_ARG_LEN (1)));
+    }
+  else
+    {
+      F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1),
+                                 n, h, n, info
+                                 F77_CHAR_ARG_LEN (1)));
+    }
 
   xrcond = 0.0;
   if (info > 0)
@@ -143,7 +168,7 @@
 }
 
 static ComplexMatrix
-chol2inv_internal (const ComplexMatrix& r)
+chol2inv_internal (const ComplexMatrix& r, bool is_upper = true)
 {
   ComplexMatrix retval;
 
@@ -157,17 +182,37 @@
 
       ComplexMatrix tmp = r;
 
-      F77_XFCN (zpotri, ZPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
-                                 tmp.fortran_vec (), n, info
-                                 F77_CHAR_ARG_LEN (1)));
+      if (is_upper)
+        {
+          F77_XFCN (zpotri, ZPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
+                                     tmp.fortran_vec (), n, info
+                                     F77_CHAR_ARG_LEN (1)));
+        }
+      else
+        {
+          F77_XFCN (zpotri, ZPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
+                                     tmp.fortran_vec (), n, info
+                                     F77_CHAR_ARG_LEN (1)));
+        }
 
       // If someone thinks of a more graceful way of doing this (or
       // faster for that matter :-)), please let me know!
 
       if (n > 1)
-        for (octave_idx_type j = 0; j < r_nc; j++)
-          for (octave_idx_type i = j+1; i < r_nr; i++)
-            tmp.xelem (i, j) = std::conj (tmp.xelem (j, i));
+        {
+          if (is_upper)
+            {
+              for (octave_idx_type j = 0; j < r_nc; j++)
+                for (octave_idx_type i = j+1; i < r_nr; i++)
+                  tmp.xelem (i, j) = tmp.xelem (j, i);
+            }
+          else
+            {
+              for (octave_idx_type j = 0; j < r_nc; j++)
+                for (octave_idx_type i = j+1; i < r_nr; i++)
+                  tmp.xelem (j, i) = tmp.xelem (i, j);
+            }
+        }
 
       retval = tmp;
     }
@@ -181,7 +226,7 @@
 ComplexMatrix
 ComplexCHOL::inverse (void) const
 {
-  return chol2inv_internal (chol_mat);
+  return chol2inv_internal (chol_mat, is_upper);
 }
 
 void