changeset 29086:d0fe364977c1

zeros, ones, NaN, Inf, NA: Add "like" syntax (bug #50865). * data.cc (fill_matrix): Change several overloads to accept "like" and a variable as the last two arguments. (fzeros, fones, fNaN, fInf, fNA, ffalse, ftrue): Add BISTs for new syntax.
author Markus Mützel <markus.muetzel@gmx.de>
date Fri, 24 Apr 2020 20:03:50 +0200
parents 3b29d72645a9
children 944fd6fca864
files libinterp/corefcn/data.cc
diffstat 1 files changed, 170 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/data.cc	Fri Nov 20 23:25:42 2020 +0100
+++ b/libinterp/corefcn/data.cc	Fri Apr 24 20:03:50 2020 +0200
@@ -3944,6 +3944,8 @@
   oct_data_conv::data_type dt = oct_data_conv::dt_double;
 
   dim_vector dims (1, 1);
+  bool issparse = false;
+  bool iscomplex = false;
 
   if (nargin > 0 && args(nargin-1).is_string ())
     {
@@ -3953,6 +3955,16 @@
       dt = oct_data_conv::string_to_data_type (nm);
     }
 
+  if (nargin > 1 && args(nargin-2).is_string ()
+      && args(nargin-2).string_value () == "like")
+    {
+      std::string nm = args(nargin-1).class_name ();
+      issparse = args(nargin-1).issparse ();
+      iscomplex = args(nargin-1).iscomplex ();
+      nargin -= 2;
+      dt = oct_data_conv::string_to_data_type (nm);
+    }
+
   switch (nargin)
     {
     case 0:
@@ -3982,6 +3994,33 @@
   // Note that automatic narrowing will handle conversion from
   // NDArray to scalar.
 
+  if (issparse)
+    {
+      if (dims.ndims () > 2)
+        error ("%s: sparse ND arrays not supported.", fcn);
+
+      switch (dt)
+        {
+        case oct_data_conv::dt_double:
+          if (iscomplex)
+            retval = SparseComplexMatrix (dims(0), dims(1), Complex (val, 0));
+          else
+            retval = SparseMatrix (dims(0), dims(1), static_cast<double> (val));
+          break;
+
+        case oct_data_conv::dt_logical:
+          retval = SparseBoolMatrix (dims(0), dims(1), static_cast<bool> (val));
+          break;
+
+        default:
+          // FIXME: It shouldn't be possible to ever reach this.
+          error ("%s: invalid class name for sparse", fcn);
+          break;
+        }
+
+      return retval;
+    }
+
   switch (dt)
     {
     case oct_data_conv::dt_int8:
@@ -4017,22 +4056,25 @@
       break;
 
     case oct_data_conv::dt_single:
-      retval = FloatNDArray (dims, val);
+      if (iscomplex)
+        retval = FloatComplexNDArray (dims, val);
+      else
+        retval = FloatNDArray (dims, val);
       break;
 
     case oct_data_conv::dt_double:
-      {
-        if (dims.ndims () == 2 && dims(0) == 1)
-          {
-            // FIXME: If this optimization provides a significant
-            // benefit, then maybe there should be a special storage
-            // type for constant value arrays.
-            double dval = static_cast<double> (val);
-            retval = octave::range<double>::make_constant (dval, dims(1));
-          }
-        else
-          retval = NDArray (dims, val);
-      }
+      if (iscomplex)
+        retval = ComplexNDArray (dims, Complex (val, 0));
+      else if (dims.ndims () == 2 && dims(0) == 1)
+        {
+          // FIXME: If this optimization provides a significant
+          // benefit, then maybe there should be a special storage
+          // type for constant value arrays.
+          double dval = static_cast<double> (val);
+          retval = octave::range<double>::make_constant (dval, dims(1));
+        }
+      else
+        retval = NDArray (dims, val);
       break;
 
     case oct_data_conv::dt_logical:
@@ -4058,6 +4100,8 @@
   oct_data_conv::data_type dt = oct_data_conv::dt_double;
 
   dim_vector dims (1, 1);
+  bool issparse = false;
+  bool iscomplex = false;
 
   if (nargin > 0 && args(nargin-1).is_string ())
     {
@@ -4067,6 +4111,20 @@
       dt = oct_data_conv::string_to_data_type (nm);
     }
 
+  if (nargin > 1 && args(nargin-2).is_string ()
+      && args(nargin-2).string_value () == "like"
+      && (std::string(fcn) ==  "Inf"
+          || std::string(fcn) == "NaN" || std::string(fcn) == "NA"))
+    {
+      if (! args(nargin-1).isfloat ())
+        error ("%s: input followed by 'like' must be floating point", fcn);
+      std::string nm = args(nargin-1).class_name ();
+      issparse = args(nargin-1).issparse ();
+      iscomplex = args(nargin-1).iscomplex ();
+      nargin -= 2;
+      dt = oct_data_conv::string_to_data_type (nm);
+    }
+
   switch (nargin)
     {
     case 0:
@@ -4093,14 +4151,33 @@
   // Note that automatic narrowing will handle conversion from
   // NDArray to scalar.
 
+  if (issparse)
+    {
+      if (dims.ndims () > 2)
+        error ("%s: sparse ND arrays not supported", fcn);
+
+      if (iscomplex)
+        retval = SparseComplexMatrix (dims(0), dims(1), Complex (val, 0));
+      else
+        retval = SparseMatrix (dims(0), dims(1), static_cast<double> (val));
+
+      return retval;
+    }
+
   switch (dt)
     {
     case oct_data_conv::dt_single:
-      retval = FloatNDArray (dims, fval);
+      if (iscomplex)
+        retval = FloatComplexNDArray (dims, fval);
+      else
+        retval = FloatNDArray (dims, fval);
       break;
 
     case oct_data_conv::dt_double:
-      if (dims.ndims () == 2 && dims(0) == 1 && octave::math::isfinite (val))
+      if (iscomplex)
+        retval = ComplexNDArray (dims, Complex (val, 0));
+      else if (dims.ndims () == 2 && dims(0) == 1
+               && octave::math::isfinite (val))
         // FIXME: If this optimization provides a significant benefit,
         // then maybe there should be a special storage type for
         // constant value arrays.
@@ -4272,6 +4349,18 @@
         error ("%s: invalid data type '%s'", fcn, nm.c_str ());
     }
 
+  bool issparse = false;
+
+  if (nargin > 1 && args(nargin-2).is_string ()
+      && args(nargin-2).string_value () == "like")
+  {
+    if (! args(nargin-1).islogical ())
+      error (R"(%s: input followed by "like" must be logical)", fcn);
+
+    issparse = args(nargin-1).issparse ();
+    nargin -= 2;
+  }
+
   switch (nargin)
     {
     case 0:
@@ -4298,7 +4387,15 @@
   // Note that automatic narrowing will handle conversion from
   // NDArray to scalar.
 
-  retval = boolNDArray (dims, val);
+  if (issparse)
+    {
+      if (dims.ndims () > 2)
+        error ("%s: sparse ND arrays not supported", fcn);
+
+      retval = SparseBoolMatrix (dims(0), dims(1), val);
+    }
+  else
+    retval = boolNDArray (dims, val);
 
   return retval;
 }
@@ -4353,12 +4450,21 @@
 %!assert (ones (3, 2, "int8"), int8 ([1, 1; 1, 1; 1, 1]))
 %!assert (size (ones (3, 4, 5, "int8")), [3, 4, 5])
 
+%!assert (ones (2, 2, "like", double (1)), double ([1, 1; 1, 1]))
+%!assert (ones (2, 2, "like", complex (ones (2, 2))), [1, 1; 1, 1])
+%!assert (ones (1, 2, "like", single (1)), single ([1, 1]))
+%!assert (ones (1, "like", single (1i)), single (1))
+%!assert (ones (2, 2, "like", uint8 (8)), uint8 ([1, 1; 1, 1]))
+%!assert (ones (2, "like", speye (2)), sparse ([1, 1; 1, 1]))
+%!assert (ones (2, "like", sparse (1i)), sparse (complex ([1, 1; 1, 1])))
+
 %!assert (size (ones (1, -2, 2)), [1, 0, 2])
 
 ## Test input validation
 %!error <conversion of 1.1 .*failed> ones (1.1)
 %!error <conversion of 1.1 .*failed> ones (1, 1.1)
 %!error <conversion of 1.1 .*failed> ones ([1, 1.1])
+%!error <sparse ND .* not supported> ones (3, 3, 3, "like", speye (1))
 */
 
 /*
@@ -4411,6 +4517,13 @@
 %!assert (zeros (3, 2), [0, 0; 0, 0; 0, 0])
 %!assert (size (zeros (3, 4, 5)), [3, 4, 5])
 
+%!assert (zeros (2, 2, "like", double (1)), double ([0, 0; 0, 0]))
+%!assert (zeros (2, 2, "like", complex (ones (2, 2))), [0, 0; 0, 0])
+%!assert (zeros (1, 2, "like", single (1)), single ([0, 0]))
+%!assert (zeros (1, 2, "like", single (1i)), single ([0, 0]))
+%!assert (zeros (2, 2, "like", uint8 (8)), uint8 ([0, 0; 0, 0]))
+%!assert (zeros (2, "like", speye (2)), sparse ([0, 0; 0, 0]))
+
 %!assert (zeros (3, "single"), single ([0, 0, 0; 0, 0, 0; 0, 0, 0]))
 %!assert (zeros (2, 3, "single"), single ([0, 0, 0; 0, 0, 0]))
 %!assert (zeros (3, 2, "single"), single ([0, 0; 0, 0; 0, 0]))
@@ -4428,6 +4541,7 @@
 %!error <conversion of 1.1 .*failed> zeros ([1, 1.1])
 %!error <conversion of 1.1 .*failed> zeros (1, 1.1, 2)
 %!error <conversion of 1.1 .*failed> zeros ([1, 1.1, 2])
+%!error <sparse ND .* not supported> zeros (3, 3, 3, "like", speye (1))
 */
 
 DEFUN (Inf, args, ,
@@ -4475,20 +4589,30 @@
 DEFALIAS (inf, Inf);
 
 /*
-%!assert (inf (3), [Inf, Inf, Inf; Inf, Inf, Inf; Inf, Inf, Inf])
-%!assert (inf (2, 3), [Inf, Inf, Inf; Inf, Inf, Inf])
-%!assert (inf (3, 2), [Inf, Inf; Inf, Inf; Inf, Inf])
-%!assert (size (inf (3, 4, 5)), [3, 4, 5])
-
-%!assert (inf (3, "single"), single ([Inf, Inf, Inf; Inf, Inf, Inf; Inf, Inf, Inf]))
-%!assert (inf (2, 3, "single"), single ([Inf, Inf, Inf; Inf, Inf, Inf]))
-%!assert (inf (3, 2, "single"), single ([Inf, Inf; Inf, Inf; Inf, Inf]))
+%!assert (Inf (3), [Inf, Inf, Inf; Inf, Inf, Inf; Inf, Inf, Inf])
+%!assert (Inf (2, 3), [Inf, Inf, Inf; Inf, Inf, Inf])
+%!assert (Inf (3, 2), [Inf, Inf; Inf, Inf; Inf, Inf])
+%!assert (size (Inf (3, 4, 5)), [3, 4, 5])
+
+%!assert (Inf (3, "single"), single ([Inf, Inf, Inf; Inf, Inf, Inf; Inf, Inf, Inf]))
+%!assert (Inf (2, 3, "single"), single ([Inf, Inf, Inf; Inf, Inf, Inf]))
+%!assert (Inf (3, 2, "single"), single ([Inf, Inf; Inf, Inf; Inf, Inf]))
 %!assert (size (inf (3, 4, 5, "single")), [3, 4, 5])
 
-%!error inf (3, "int8")
-%!error inf (2, 3, "int8")
-%!error inf (3, 2, "int8")
-%!error inf (3, 4, 5, "int8")
+%!assert (Inf (2, 2, "like", speye (2)), sparse ([Inf, Inf; Inf, Inf]))
+%!assert (Inf (2, 2, "like", complex (ones (2, 2))), [Inf, Inf; Inf, Inf])
+%!assert (Inf (2, 2, "like", double (1)), double ([Inf, Inf; Inf, Inf]))
+%!assert (Inf (3, 3, "like", single (1)), single ([Inf, Inf, Inf; Inf, Inf, Inf; Inf, Inf, Inf]))
+%!assert (Inf (2, "like", single (1i)), single ([Inf, Inf; Inf, Inf]))
+
+%!error Inf (3, "like", int8 (1))
+
+%!error Inf (3, "int8")
+%!error Inf (2, 3, "int8")
+%!error Inf (3, 2, "int8")
+%!error Inf (3, 4, 5, "int8")
+%!error <input .* floating> Inf (3, 3, "like", true)
+%!error <input .* floating> Inf (2, "like", uint8 (1))
 */
 
 DEFUN (NaN, args, ,
@@ -4550,10 +4674,20 @@
 %!assert (NaN (3, 2, "single"), single ([NaN, NaN; NaN, NaN; NaN, NaN]))
 %!assert (size (NaN (3, 4, 5, "single")), [3, 4, 5])
 
+%!assert (NaN (2, 2, "like", double (1)), double ([NaN, NaN; NaN, NaN]))
+%!assert (NaN (2, 2, "like", complex (ones(2, 2))), [NaN, NaN; NaN, NaN])
+%!assert (NaN (3, 3, "like", single (1)), single ([NaN, NaN, NaN; NaN, NaN, NaN; NaN, NaN, NaN]))
+%!assert (NaN (2, "like", single (1i)), single ([NaN, NaN; NaN, NaN]))
+%!assert (NaN (2, 2, "like", speye (2)), sparse ([NaN, NaN; NaN, NaN]))
+
+%!error NaN (3, 'like', int8 (1))
+
 %!error NaN (3, "int8")
 %!error NaN (2, 3, "int8")
 %!error NaN (3, 2, "int8")
 %!error NaN (3, 4, 5, "int8")
+%!error <input .* floating> NaN (3, 3, "like", true)
+%!error <input .* floating> NaN (2, "like", uint8 (1))
 */
 
 DEFUN (e, args, ,
@@ -4935,7 +5069,11 @@
 /*
 %!assert (false (2, 3), logical (zeros (2, 3)))
 %!assert (false (2, 3, "logical"), logical (zeros (2, 3)))
+%!assert (false (2, 1, "like", true), [false; false])
+%!assert (false (2, 1, "like", sparse (true)), sparse ([false; false]))
+
 %!error false (2, 3, "double")
+%!error <input .* logical> false (2, 1, "like", sparse (1))
 */
 
 DEFUN (true, args, ,
@@ -4959,7 +5097,11 @@
 /*
 %!assert (true (2, 3), logical (ones (2, 3)))
 %!assert (true (2, 3, "logical"), logical (ones (2, 3)))
+%!assert (true (2, 1, "like", false), [true; true])
+%!assert (true (2, 1, "like", sparse (true)), sparse ([true; true]))
+
 %!error true (2, 3, "double")
+%!error <input .* logical> true (2, 1, "like", double (1))
 */
 
 template <typename MT>