changeset 13264:11c8b60f1b68

Eliminate duplicate code for op+= and op-= for MSparse
author Jordi Gutiérrez Hermoso <jordigh@octave.org>
date Mon, 03 Oct 2011 00:15:00 -0500
parents a156263b5509
children 89789bc755a1
files liboctave/MSparse.cc
diffstat 1 files changed, 19 insertions(+), 75 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/MSparse.cc	Sun Oct 02 20:13:33 2011 +0200
+++ b/liboctave/MSparse.cc	Mon Oct 03 00:15:00 2011 -0500
@@ -25,6 +25,8 @@
 #include <config.h>
 #endif
 
+#include <functional>
+
 #include "quit.h"
 #include "lo-error.h"
 #include "MArray.h"
@@ -37,9 +39,9 @@
 
 // Element by element MSparse by MSparse ops.
 
-template <class T>
+template <class T, class OP>
 MSparse<T>&
-operator += (MSparse<T>& a, const MSparse<T>& b)
+plus_or_minus (MSparse<T>& a, const MSparse<T>& b, OP op, const char* op_name)
 {
     MSparse<T> r;
 
@@ -50,7 +52,7 @@
     octave_idx_type b_nc = b.cols ();
 
     if (a_nr != b_nr || a_nc != b_nc)
-      gripe_nonconformant ("operator +=" , a_nr, a_nc, b_nr, b_nc);
+      gripe_nonconformant (op_name , a_nr, a_nc, b_nr, b_nc);
     else
       {
         r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
@@ -73,7 +75,7 @@
                       (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
                   {
                     r.ridx(jx) = a.ridx(ja);
-                    r.data(jx) = a.data(ja) + 0.;
+                    r.data(jx) = op (a.data(ja), 0.);
                     jx++;
                     ja++;
                     ja_lt_max= ja < ja_max;
@@ -82,16 +84,16 @@
                      (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
                   {
                     r.ridx(jx) = b.ridx(jb);
-                    r.data(jx) = 0. + b.data(jb);
+                    r.data(jx) = op (0., b.data(jb));
                     jx++;
                     jb++;
                     jb_lt_max= jb < jb_max;
                   }
                 else
                   {
-                     if ((a.data(ja) + b.data(jb)) != 0.)
+                     if (op (a.data(ja), b.data(jb)) != 0.)
                        {
-                          r.data(jx) = a.data(ja) + b.data(jb);
+                          r.data(jx) = op (a.data(ja), b.data(jb));
                           r.ridx(jx) = a.ridx(ja);
                           jx++;
                        }
@@ -110,78 +112,20 @@
     return a;
 }
 
-template <class T>
+template <typename T>
+MSparse<T>&
+operator += (MSparse<T>& a, const MSparse<T>& b)
+{
+  return plus_or_minus (a, b, std::plus<T> (), "operator +=");
+}
+
+template <typename T>
 MSparse<T>&
 operator -= (MSparse<T>& a, const MSparse<T>& b)
 {
-    MSparse<T> r;
-
-    octave_idx_type a_nr = a.rows ();
-    octave_idx_type a_nc = a.cols ();
-
-    octave_idx_type b_nr = b.rows ();
-    octave_idx_type b_nc = b.cols ();
-
-    if (a_nr != b_nr || a_nc != b_nc)
-      gripe_nonconformant ("operator -=" , a_nr, a_nc, b_nr, b_nc);
-    else
-      {
-        r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
-
-        octave_idx_type jx = 0;
-        for (octave_idx_type i = 0 ; i < a_nc ; i++)
-          {
-            octave_idx_type  ja = a.cidx(i);
-            octave_idx_type  ja_max = a.cidx(i+1);
-            bool ja_lt_max= ja < ja_max;
-
-            octave_idx_type  jb = b.cidx(i);
-            octave_idx_type  jb_max = b.cidx(i+1);
-            bool jb_lt_max = jb < jb_max;
+  return plus_or_minus (a, b, std::minus<T> (), "operator -=");
+}
 
-            while (ja_lt_max || jb_lt_max )
-              {
-                octave_quit ();
-                if ((! jb_lt_max) ||
-                      (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
-                  {
-                    r.ridx(jx) = a.ridx(ja);
-                    r.data(jx) = a.data(ja) - 0.;
-                    jx++;
-                    ja++;
-                    ja_lt_max= ja < ja_max;
-                  }
-                else if (( !ja_lt_max ) ||
-                     (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
-                  {
-                    r.ridx(jx) = b.ridx(jb);
-                    r.data(jx) = 0. - b.data(jb);
-                    jx++;
-                    jb++;
-                    jb_lt_max= jb < jb_max;
-                  }
-                else
-                  {
-                     if ((a.data(ja) - b.data(jb)) != 0.)
-                       {
-                          r.data(jx) = a.data(ja) - b.data(jb);
-                          r.ridx(jx) = a.ridx(ja);
-                          jx++;
-                       }
-                     ja++;
-                     ja_lt_max= ja < ja_max;
-                     jb++;
-                     jb_lt_max= jb < jb_max;
-                  }
-              }
-            r.cidx(i+1) = jx;
-          }
-
-        a = r.maybe_compress ();
-      }
-
-    return a;
-}
 
 // Element by element MSparse by scalar ops.