diff src/mex.cc @ 5903:11bb9bf343a0

[project @ 2006-07-26 03:36:33 by jwe]
author jwe
date Wed, 26 Jul 2006 03:36:33 +0000
parents 6af4cea82cc7
children e5c0831a48bd
line wrap: on
line diff
--- a/src/mex.cc	Tue Jul 25 19:56:00 2006 +0000
+++ b/src/mex.cc	Wed Jul 26 03:36:33 2006 +0000
@@ -372,6 +372,13 @@
       id = mxCHAR_CLASS;
     else if (cn == "double")
       id = mxDOUBLE_CLASS;
+    else if (cn == "sparse")
+      {
+	if (val.is_bool_type ())
+	  id = mxLOGICAL_CLASS;
+	else
+	  id = mxDOUBLE_CLASS;
+      }
     else if (cn == "single")
       id = mxSINGLE_CLASS;
     else if (cn == "int8")
@@ -1312,35 +1319,120 @@
 
 // Matlab-style sparse arrays.
 
-class mxArray_sparse : public mxArray_number
+class mxArray_sparse : public mxArray_matlab
 {
 public:
 
   mxArray_sparse (mxClassID id_arg, int m, int n, int nzmax_arg,
 		  mxComplexity flag = mxREAL)
-    : mxArray_number (id_arg, m, n, flag), nzmax (nzmax_arg)
+    : mxArray_matlab (id_arg, m, n), nzmax (nzmax_arg)
   {
+    pr = (calloc (nzmax, get_element_size ()));
+    pi = (flag == mxCOMPLEX ? calloc (nzmax, get_element_size ()) : 0);
     ir = static_cast<int *> (calloc (nzmax, sizeof (int)));
-    jc = static_cast<int *> (calloc (nzmax, sizeof (int)));
+    jc = static_cast<int *> (calloc (n + 1, sizeof (int)));
   }
 
   mxArray_sparse *clone (void) const { return new mxArray_sparse (*this); }
 
   ~mxArray_sparse (void)
   {
+    mxFree (pr);
+    mxFree (pi);
     mxFree (ir);
     mxFree (jc);
   }
 
   octave_value as_octave_value (void) const
   {
-    // FIXME
-    abort ();
-    return octave_value ();
+    octave_value retval;
+
+    dim_vector dv = dims_to_dim_vector ();
+
+    switch (get_class_id ())
+      {
+      case mxLOGICAL_CLASS:
+	{
+	  bool *ppr = static_cast<bool *> (pr);
+
+	  SparseBoolMatrix val (get_m (), get_n (), nzmax);
+
+	  for (int i = 0; i < nzmax; i++)
+	    {
+	      val.xdata(i) = ppr[i];
+	      val.xridx(i) = ir[i];
+	    }
+
+	  for (int i = 0; i < get_n () + 1; i++)
+	    val.xcidx(i) = jc[i];
+
+	  retval = val;
+	}
+	break;
+
+      case mxSINGLE_CLASS:
+	error ("single precision data type not supported");
+	break;
+
+      case mxDOUBLE_CLASS:
+	{
+	  if (pi)
+	    {
+	      double *ppr = static_cast<double *> (pr);
+	      double *ppi = static_cast<double *> (pi);
+
+	      SparseComplexMatrix val (get_m (), get_n (), nzmax);
+
+	      for (int i = 0; i < nzmax; i++)
+		{
+		  val.xdata(i) = Complex (ppr[i], ppi[i]);
+		  val.xridx(i) = ir[i];
+		}
+
+	      for (int i = 0; i < get_n () + 1; i++)
+		val.xcidx(i) = jc[i];
+
+	      retval = val;
+	    }
+	  else
+	    {
+	      double *ppr = static_cast<double *> (pr);
+
+	      SparseMatrix val (get_m (), get_n (), nzmax);
+
+	      for (int i = 0; i < nzmax; i++)
+		{
+		  val.xdata(i) = ppr[i];
+		  val.xridx(i) = ir[i];
+		}
+
+	      for (int i = 0; i < get_n () + 1; i++)
+		val.xcidx(i) = jc[i];
+
+	      retval = val;
+	    }
+	}
+	break;
+
+      default:
+	panic_impossible ();
+      }
+
+    return retval;
   }
 
+  int is_complex (void) const { return pi != 0; }
+
   int is_sparse (void) const { return 1; }
 
+  void *get_data (void) const { return pr; }
+
+  void *get_imag_data (void) const { return pi; }
+
+  void set_data (void *pr_arg) { pr = pr_arg; }
+
+  void set_imag_data (void *pi_arg) { pi = pi_arg; }
+
   int *get_ir (void) const { return ir; }
 
   int *get_jc (void) const { return jc; }
@@ -1357,19 +1449,23 @@
 
   int nzmax;
 
+  void *pr;
+  void *pi;
   int *ir;
   int *jc;
 
   mxArray_sparse (const mxArray_sparse& val)
-    : mxArray_number (val), nzmax (val.nzmax),
+    : mxArray_matlab (val), nzmax (val.nzmax),
       ir (static_cast<int *> (malloc (nzmax * sizeof (int)))),
       jc (static_cast<int *> (malloc (nzmax * sizeof (int))))
   {
-    for (int i = 0; i < nzmax; i++)
-      {
-	ir[i] = val.ir[i];
-	jc[i] = val.jc[i];
-      }
+    int ntot = nzmax * get_element_size ();
+
+    memcpy (pr, val.pr, ntot);
+    memcpy (ir, val.ir, nzmax * sizeof(int));
+    memcpy (jc, val.jc, (val.get_n () + 1) * sizeof (int));
+    if (pi)
+      memcpy (pi, val.pi, ntot);
   }
 };