changeset 8382:9b20a4847056

implement scalar powers of diag matrices
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 08 Dec 2008 14:12:08 +0100
parents ad896677a2e2
children a762d9daa700
files src/ChangeLog src/OPERATORS/op-dms-template.cc src/xpow.cc src/xpow.h
diffstat 4 files changed, 180 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog	Mon Dec 08 12:31:57 2008 +0100
+++ b/src/ChangeLog	Mon Dec 08 14:12:08 2008 +0100
@@ -1,3 +1,17 @@
+2008-12-08  Jaroslav Hajek  <highegg@gmail.com>
+	
+	* xpow.cc ( xpow (const DiagMatrix& a, double b), 
+	xpow (const DiagMatrix& a, const Complex& b), 
+	xpow (const ComplexDiagMatrix& a, double b), 
+	xpow (const ComplexDiagMatrix& a, const Complex& b), 
+	xpow (const FloatDiagMatrix& a, float b), 
+	xpow (const FloatDiagMatrix& a, const FloatComplex& b), 
+	xpow (const FloatComplexDiagMatrix& a, float b), 
+	xpow (const FloatComplexDiagMatrix& a, const FloatComplex& b)):
+	New methods.
+	* xpow.h: Declare them.
+	* OPERATORS/op-dms-template.cc: Support diagonal matrix ^ scalar.
+
 2008-12-08  Jaroslav Hajek  <highegg@gmail.com>
 
 	* ov-re-diag.cc (octave_diag_matrix::save_binary,
--- a/src/OPERATORS/op-dms-template.cc	Mon Dec 08 12:31:57 2008 +0100
+++ b/src/OPERATORS/op-dms-template.cc	Mon Dec 08 14:12:08 2008 +0100
@@ -25,6 +25,7 @@
 #endif
 
 #include "ops.h"
+#include "xpow.h"
 #include SINCLUDE
 #include MINCLUDE
 
@@ -58,6 +59,13 @@
   return v2.MATRIX_VALUE () / v1.SCALAR_VALUE ();
 }
 
+DEFBINOP (dmspow, MATRIX, SCALAR)
+{
+  CAST_BINOP_ARGS (const OCTAVE_MATRIX&, const OCTAVE_SCALAR&);
+
+  return xpow (v1.MATRIX_VALUE (), v2.SCALAR_VALUE ());
+}
+
 #define SHORT_NAME CONCAT3(MSHORT, _, SSHORT)
 #define INST_NAME CONCAT3(install_, SHORT_NAME, _ops)
 
@@ -72,4 +80,5 @@
   INSTALL_BINOP (op_sub, OCTAVE_SCALAR, OCTAVE_MATRIX, sdmsub);
   INSTALL_BINOP (op_mul, OCTAVE_SCALAR, OCTAVE_MATRIX, sdmmul);
   INSTALL_BINOP (op_ldiv, OCTAVE_SCALAR, OCTAVE_MATRIX, sdmldiv);
+  INSTALL_BINOP (op_pow, OCTAVE_MATRIX, OCTAVE_SCALAR, dmspow);
 }
--- a/src/xpow.cc	Mon Dec 08 12:31:57 2008 +0100
+++ b/src/xpow.cc	Mon Dec 08 14:12:08 2008 +0100
@@ -31,10 +31,12 @@
 #include "Array-util.h"
 #include "CColVector.h"
 #include "CDiagMatrix.h"
+#include "fCDiagMatrix.h"
 #include "CMatrix.h"
 #include "EIG.h"
 #include "fEIG.h"
 #include "dDiagMatrix.h"
+#include "fDiagMatrix.h"
 #include "dMatrix.h"
 #include "mx-cm-cdm.h"
 #include "oct-cmplx.h"
@@ -262,6 +264,38 @@
   return retval;
 }
 
+// -*- 5d -*-
+octave_value
+xpow (const DiagMatrix& a, double b)
+{
+  octave_value retval;
+
+  octave_idx_type nr = a.rows ();
+  octave_idx_type nc = a.cols ();
+
+  if (nr == 0 || nc == 0 || nr != nc)
+    error ("for A^b, A must be square");
+  else
+    {
+      if (static_cast<int> (b) == b)
+	{
+          DiagMatrix r (nr, nc);
+          for (octave_idx_type i = 0; i < nc; i++)
+            r(i, i) = std::pow (a(i, i), b);
+          retval = r;
+        }
+      else
+	{
+          ComplexDiagMatrix r (nr, nc);
+          for (octave_idx_type i = 0; i < nc; i++)
+            r(i, i) = std::pow (static_cast<Complex> (a(i, i)), b);
+          retval = r;
+	}
+    }
+
+  return retval;
+}
+
 // -*- 6 -*-
 octave_value
 xpow (const Matrix& a, const Complex& b)
@@ -517,6 +551,42 @@
   return retval;
 }
 
+// -*- 12d -*-
+octave_value
+xpow (const ComplexDiagMatrix& a, const Complex& b)
+{
+  octave_value retval;
+
+  octave_idx_type nr = a.rows ();
+  octave_idx_type nc = a.cols ();
+
+  if (nr == 0 || nc == 0 || nr != nc)
+    error ("for A^b, A must be square");
+  else
+    {
+      ComplexDiagMatrix r (nr, nc);
+      for (octave_idx_type i = 0; i < nc; i++)
+        r(i, i) = std::pow (a(i, i), b);
+      retval = r;
+    }
+
+  return retval;
+}
+
+// mixed
+octave_value
+xpow (const ComplexDiagMatrix& a, double b)
+{
+  return xpow (a, static_cast<Complex> (b));
+}
+
+octave_value
+xpow (const DiagMatrix& a, const Complex& b)
+{
+  return xpow (ComplexDiagMatrix (a), b);
+}
+
+
 // Safer pow functions that work elementwise for matrices.
 //
 //       op2 \ op1:   s   m   cs   cm
@@ -1474,6 +1544,38 @@
   return retval;
 }
 
+// -*- 5d -*-
+octave_value
+xpow (const FloatDiagMatrix& a, float b)
+{
+  octave_value retval;
+
+  octave_idx_type nr = a.rows ();
+  octave_idx_type nc = a.cols ();
+
+  if (nr == 0 || nc == 0 || nr != nc)
+    error ("for A^b, A must be square");
+  else
+    {
+      if (static_cast<int> (b) == b)
+	{
+          FloatDiagMatrix r (nr, nc);
+          for (octave_idx_type i = 0; i < nc; i++)
+            r(i, i) = std::pow (a(i, i), b);
+          retval = r;
+        }
+      else
+	{
+          FloatComplexDiagMatrix r (nr, nc);
+          for (octave_idx_type i = 0; i < nc; i++)
+            r(i, i) = std::pow (static_cast<FloatComplex> (a(i, i)), b);
+          retval = r;
+	}
+    }
+
+  return retval;
+}
+
 // -*- 6 -*-
 octave_value
 xpow (const FloatMatrix& a, const FloatComplex& b)
@@ -1729,6 +1831,41 @@
   return retval;
 }
 
+// -*- 12d -*-
+octave_value
+xpow (const FloatComplexDiagMatrix& a, const FloatComplex& b)
+{
+  octave_value retval;
+
+  octave_idx_type nr = a.rows ();
+  octave_idx_type nc = a.cols ();
+
+  if (nr == 0 || nc == 0 || nr != nc)
+    error ("for A^b, A must be square");
+  else
+    {
+      FloatComplexDiagMatrix r (nr, nc);
+      for (octave_idx_type i = 0; i < nc; i++)
+        r(i, i) = std::pow (a(i, i), b);
+      retval = r;
+    }
+
+  return retval;
+}
+
+// mixed
+octave_value
+xpow (const FloatComplexDiagMatrix& a, float b)
+{
+  return xpow (a, static_cast<FloatComplex> (b));
+}
+
+octave_value
+xpow (const FloatDiagMatrix& a, const FloatComplex& b)
+{
+  return xpow (FloatComplexDiagMatrix (a), b);
+}
+
 // Safer pow functions that work elementwise for matrices.
 //
 //       op2 \ op1:   s   m   cs   cm
--- a/src/xpow.h	Mon Dec 08 12:31:57 2008 +0100
+++ b/src/xpow.h	Mon Dec 08 14:12:08 2008 +0100
@@ -30,6 +30,14 @@
 class ComplexMatrix;
 class FloatMatrix;
 class FloatComplexMatrix;
+class DiagMatrix;
+class ComplexDiagMatrix;
+class FloatDiagMatrix;
+class FloatComplexDiagMatrix;
+class NDArray;
+class FloatNDArray;
+class ComplexNDArray;
+class FloatComplexNDArray;
 class octave_value;
 
 extern octave_value xpow (double a, double b);
@@ -40,6 +48,9 @@
 extern octave_value xpow (const Matrix& a, double b);
 extern octave_value xpow (const Matrix& a, const Complex& b);
 
+extern octave_value xpow (const DiagMatrix& a, double b);
+extern octave_value xpow (const DiagMatrix& a, const Complex& b);
+
 extern octave_value xpow (const Complex& a, double b);
 extern octave_value xpow (const Complex& a, const Matrix& b);
 extern octave_value xpow (const Complex& a, const Complex& b);
@@ -48,6 +59,9 @@
 extern octave_value xpow (const ComplexMatrix& a, double b);
 extern octave_value xpow (const ComplexMatrix& a, const Complex& b);
 
+extern octave_value xpow (const ComplexDiagMatrix& a, double b);
+extern octave_value xpow (const ComplexDiagMatrix& a, const Complex& b);
+
 extern octave_value elem_xpow (double a, const Matrix& b);
 extern octave_value elem_xpow (double a, const ComplexMatrix& b);
 
@@ -89,6 +103,9 @@
 extern octave_value xpow (const FloatMatrix& a, float b);
 extern octave_value xpow (const FloatMatrix& a, const FloatComplex& b);
 
+extern octave_value xpow (const FloatDiagMatrix& a, float b);
+extern octave_value xpow (const FloatDiagMatrix& a, const FloatComplex& b);
+
 extern octave_value xpow (const FloatComplex& a, float b);
 extern octave_value xpow (const FloatComplex& a, const FloatMatrix& b);
 extern octave_value xpow (const FloatComplex& a, const FloatComplex& b);
@@ -97,6 +114,9 @@
 extern octave_value xpow (const FloatComplexMatrix& a, float b);
 extern octave_value xpow (const FloatComplexMatrix& a, const FloatComplex& b);
 
+extern octave_value xpow (const FloatComplexDiagMatrix& a, float b);
+extern octave_value xpow (const FloatComplexDiagMatrix& a, const FloatComplex& b);
+
 extern octave_value elem_xpow (float a, const FloatMatrix& b);
 extern octave_value elem_xpow (float a, const FloatComplexMatrix& b);