diff src/data.cc @ 10533:f094ac9bc93e

reuse Array<T>::cat and Sparse<T>::cat in cat/horzcat/vertcat
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 19 Apr 2010 15:31:49 +0200
parents 8615b55b5caf
children 26673015caec
line wrap: on
line diff
--- a/src/data.cc	Mon Apr 19 07:22:30 2010 -0400
+++ b/src/data.cc	Mon Apr 19 15:31:49 2010 +0200
@@ -1345,50 +1345,58 @@
 
  */
 
-template<class TYPE>
+template <class TYPE, class T>
 static void 
-single_type_concat (TYPE& result,
+single_type_concat (Array<T>& result,
+                    const octave_value_list& args,
+                    int dim)
+{
+  int n_args = args.length ();
+  OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args - 1);
+
+  for (int j = 1; j < n_args && ! error_state; j++)
+    {
+      octave_quit ();
+
+      array_list[j-1] = octave_value_extract<TYPE> (args(j));
+    }
+
+  if (! error_state)
+    result = Array<T>::cat (dim, n_args-1, array_list);
+}
+
+template <class TYPE, class T>
+static void 
+single_type_concat (Sparse<T>& result,
                     const octave_value_list& args,
                     int dim)
 {
-  int dv_len = result.ndims (), n_args = args.length ();
-  Array<octave_idx_type> ra_idx (dv_len, 1, 0);
-
-  for (int j = 1; j < n_args; j++)
+  int n_args = args.length ();
+  OCTAVE_LOCAL_BUFFER (Sparse<T>, sparse_list, n_args-1);
+
+  for (int j = 1; j < n_args && ! error_state; j++)
     {
       octave_quit ();
 
-      TYPE ra = octave_value_extract<TYPE> (args(j));
-      dim_vector dvra = ra.dims ();
-      if (error_state)
-        break;
-
-      if (dvra.zero_by_zero ())
-        continue;
-
-      result.insert (ra, ra_idx);
-
-      if (error_state)
-        break;
-
-      ra_idx (dim) += (dim < dvra.length () ? dvra(dim) : 1);
+      sparse_list[j-1] = octave_value_extract<TYPE> (args(j));
     }
+
+  if (! error_state)
+    result = Sparse<T>::cat (dim, n_args-1, sparse_list);
 }
 
+// Dispatcher.
 template<class TYPE>
-static octave_value 
-do_single_type_concat (const dim_vector& dv,
-                       const octave_value_list& args,
-                       int dim)
+static TYPE 
+do_single_type_concat (const octave_value_list& args, int dim)
 {
-  TYPE result (dv);
-
-  single_type_concat (result, args, dim);
+  TYPE result;
+
+  single_type_concat<TYPE, typename TYPE::element_type> (result, args, dim);
 
   return result;
 }
 
-
 static octave_value
 do_cat (const octave_value_list& args, std::string fname)
 {
@@ -1413,7 +1421,6 @@
       if (dim >= 0)
         {
           
-          dim_vector  dv = args(1).dims ();
           std::string result_type = args(1).class_name ();
           
           bool all_sq_strings_p = args(1).is_sq_string ();
@@ -1423,16 +1430,6 @@
 
           for (int i = 2; i < args.length (); i++)
             {
-              // add_dims constructs a dimension vector which holds the
-              // dimensions of the final array after concatenation.
-
-              if (! dv.concat (args(i).dims (), dim))
-                {
-                  // Dimensions do not match. 
-                  error ("cat: dimension mismatch");
-                  return retval;
-                }
-              
               result_type = 
                 get_concat_class (result_type, args(i).class_name ());
 
@@ -1451,24 +1448,24 @@
               if (any_sparse_p)
                 {           
                   if (all_real_p)
-                    retval = do_single_type_concat<SparseMatrix> (dv, args, dim);
+                    retval = do_single_type_concat<SparseMatrix> (args, dim);
                   else
-                    retval = do_single_type_concat<SparseComplexMatrix> (dv, args, dim);
+                    retval = do_single_type_concat<SparseComplexMatrix> (args, dim);
                 }
               else
                 {
                   if (all_real_p)
-                    retval = do_single_type_concat<NDArray> (dv, args, dim);
+                    retval = do_single_type_concat<NDArray> (args, dim);
                   else
-                    retval = do_single_type_concat<ComplexNDArray> (dv, args, dim);
+                    retval = do_single_type_concat<ComplexNDArray> (args, dim);
                 }
             }
           else if (result_type == "single")
             {
               if (all_real_p)
-                retval = do_single_type_concat<FloatNDArray> (dv, args, dim);
+                retval = do_single_type_concat<FloatNDArray> (args, dim);
               else
-                retval = do_single_type_concat<FloatComplexNDArray> (dv, args, dim);
+                retval = do_single_type_concat<FloatComplexNDArray> (args, dim);
             }
           else if (result_type == "char")
             {
@@ -1476,37 +1473,47 @@
 
               maybe_warn_string_concat (all_dq_strings_p, all_sq_strings_p);
 
-              charNDArray result (dv, Vstring_fill_char);
-
-              single_type_concat<charNDArray> (result, args, dim);
+              charNDArray result =  do_single_type_concat<charNDArray> (args, dim);
 
               retval = octave_value (result, type);
             }
           else if (result_type == "logical")
             {
               if (any_sparse_p)
-                retval = do_single_type_concat<SparseBoolMatrix> (dv, args, dim);
+                retval = do_single_type_concat<SparseBoolMatrix> (args, dim);
               else
-                retval = do_single_type_concat<boolNDArray> (dv, args, dim);
+                retval = do_single_type_concat<boolNDArray> (args, dim);
             }
           else if (result_type == "int8")
-            retval = do_single_type_concat<int8NDArray> (dv, args, dim);
+            retval = do_single_type_concat<int8NDArray> (args, dim);
           else if (result_type == "int16")
-            retval = do_single_type_concat<int16NDArray> (dv, args, dim);
+            retval = do_single_type_concat<int16NDArray> (args, dim);
           else if (result_type == "int32")
-            retval = do_single_type_concat<int32NDArray> (dv, args, dim);
+            retval = do_single_type_concat<int32NDArray> (args, dim);
           else if (result_type == "int64")
-            retval = do_single_type_concat<int64NDArray> (dv, args, dim);
+            retval = do_single_type_concat<int64NDArray> (args, dim);
           else if (result_type == "uint8")
-            retval = do_single_type_concat<uint8NDArray> (dv, args, dim);
+            retval = do_single_type_concat<uint8NDArray> (args, dim);
           else if (result_type == "uint16")
-            retval = do_single_type_concat<uint16NDArray> (dv, args, dim);
+            retval = do_single_type_concat<uint16NDArray> (args, dim);
           else if (result_type == "uint32")
-            retval = do_single_type_concat<uint32NDArray> (dv, args, dim);
+            retval = do_single_type_concat<uint32NDArray> (args, dim);
           else if (result_type == "uint64")
-            retval = do_single_type_concat<uint64NDArray> (dv, args, dim);
+            retval = do_single_type_concat<uint64NDArray> (args, dim);
           else
             {
+              dim_vector  dv = args(1).dims ();
+
+              for (int i = 2; i < args.length (); i++)
+                {
+                  if (! dv.concat (args(i).dims (), dim))
+                    {
+                      // Dimensions do not match. 
+                      error ("cat: dimension mismatch");
+                      return retval;
+                    }
+                }
+              
               // The lines below might seem crazy, since we take a copy
               // of the first argument, resize it to be empty and then resize
               // it to be full. This is done since it means that there is no