diff liboctave/fCDiagMatrix.cc @ 8366:8b1a2555c4e2

implement diagonal matrix objects * * *
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 03 Dec 2008 13:32:57 +0100
parents 82be108cc558
children c3f7e2549abb
line wrap: on
line diff
--- a/liboctave/fCDiagMatrix.cc	Wed Dec 03 20:57:27 2008 -0500
+++ b/liboctave/fCDiagMatrix.cc	Wed Dec 03 13:32:57 2008 +0100
@@ -232,6 +232,15 @@
   return *this;
 }
 
+FloatDiagMatrix
+FloatComplexDiagMatrix::abs (void) const
+{
+  FloatDiagMatrix retval (rows (), cols ());
+  for (octave_idx_type i = 0; i < rows (); i++)
+    retval(i, i) = std::abs (elem (i, i));
+  return retval;
+}
+
 FloatComplexDiagMatrix
 conj (const FloatComplexDiagMatrix& a)
 {
@@ -378,6 +387,21 @@
   return retval;
 }
 
+bool
+FloatComplexDiagMatrix::all_elements_are_real (void) const
+{
+  octave_idx_type len = length ();
+  for (octave_idx_type i = 0; i < len; i++)
+    {
+      float ip = std::imag (elem (i, i));
+
+      if (ip != 0.0 || lo_ieee_signbit (ip))
+        return false;
+    }
+
+  return true;
+}
+
 // diagonal matrix by diagonal matrix -> diagonal matrix operations
 
 FloatComplexDiagMatrix&
@@ -484,6 +508,46 @@
   return c;
 }
 
+FloatComplexDiagMatrix
+operator * (const FloatComplexDiagMatrix& a, const FloatComplexDiagMatrix& b)
+{
+  octave_idx_type a_nr = a.rows ();
+  octave_idx_type a_nc = a.cols ();
+
+  octave_idx_type b_nr = b.rows ();
+  octave_idx_type b_nc = b.cols ();
+
+  if (a_nc != b_nr)
+    {
+      gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
+      return FloatComplexDiagMatrix ();
+    }
+
+  if (a_nr == 0 || a_nc == 0 || b_nc == 0)
+    return FloatComplexDiagMatrix (a_nr, a_nc, 0.0);
+
+  FloatComplexDiagMatrix c (a_nr, b_nc);
+
+  octave_idx_type len = a_nr < b_nc ? a_nr : b_nc;
+
+  for (octave_idx_type i = 0; i < len; i++)
+    {
+      FloatComplex a_element = a.elem (i, i);
+      FloatComplex b_element = b.elem (i, i);
+
+      if (a_element == static_cast<float> (0.0) || b_element == static_cast<float> (0.0))
+        c.elem (i, i) = 0;
+      else if (a_element == static_cast<float> (1.0))
+        c.elem (i, i) = b_element;
+      else if (b_element == static_cast<float> (1.0))
+        c.elem (i, i) = a_element;
+      else
+        c.elem (i, i) = a_element * b_element;
+    }
+
+  return c;
+}
+
 // other operations
 
 FloatComplexColumnVector