comparison libinterp/corefcn/perms.cc @ 33005:349b4adf686a

Simplify code complexity of perms.cc (bug #65244) * perms.cc: #include <numeric> at top of file. * perms.cc (is_equal_T): New template function. * perms.cc (is_equal_T<octave_value>): New template specialization function. * perms.cc (GetPerms): Remove third input "do_sort". Merge code from GetPermsNoSort(). Use std::iota rather than hand-rolled for loop. * perms.cc (GetPermsNoSort): Delete function. * perms.cc (Fperms): Rename functions calls to GetPermsNoSort() to GetPerms(). Rewrite input validation error message to refer to input "V".
author Hendrik Koerner <koerhen@web.de>
date Sun, 11 Feb 2024 10:06:09 +0800
parents 9c9f4df5e4c3
children 8adbe07a6835
comparison
equal deleted inserted replaced
33004:9c9f4df5e4c3 33005:349b4adf686a
26 #if defined(HAVE_CONFIG_H) 26 #if defined(HAVE_CONFIG_H)
27 # include "config.h" 27 # include "config.h"
28 #endif 28 #endif
29 29
30 #include <algorithm> 30 #include <algorithm>
31 #include <numeric>
31 32
32 #include "defun.h" 33 #include "defun.h"
33 #include "error.h" 34 #include "error.h"
34 #include "errwarn.h" 35 #include "errwarn.h"
35 #include "ovl.h" 36 #include "ovl.h"
46 } 47 }
47 48
48 // 49 //
49 // Use C++ template to cater for the different octave array classes. 50 // Use C++ template to cater for the different octave array classes.
50 // 51 //
52
53 // FIXME: To allow comparison between all supported template types, we need
54 // to use either "if constexpr" (supported in C++17) or template specialisation
55 // (supported in C++11). Currently (2024), Octave stipulates the usage of
56 // C++11, so the (slightly more complex) template specialization is used.
57 // Once Octave moves to C++17 or beyond, the following code snippet is
58 // preferrable and the comparison templates can be removed:
59 // bool isequal;
60 // if constexpr (std::is_same<T, octave_value>::value)
61 // isEqual = Ar[i].is_equal (Ar[j]);
62 // else
63 // isEqual = (Ar[i] == Ar[j]);
64
65 template <typename T>
66 bool is_equal_T (T a, T b)
67 {
68 return a == b;
69 }
70
71 template <>
72 bool is_equal_T<octave_value> (octave_value a, octave_value b)
73 {
74 return a.is_equal (b);
75 }
76
51 template <typename T> 77 template <typename T>
52 static inline Array<T> 78 static inline Array<T>
53 GetPerms (const Array<T>& ar_in, bool uniq_v, bool do_sort = false) 79 GetPerms (const Array<T>& ar_in, bool uniq_v = false)
54 { 80 {
55 octave_idx_type m = ar_in.numel (); 81 octave_idx_type m = ar_in.numel ();
56 double nr = Factorial (m); 82 double nr; // number of rows of resulting array
57 83
58 // Setup index vector filled from 0..m-1 84 // Setup index vector filled from 0..m-1
59 OCTAVE_LOCAL_BUFFER (int, myvidx, m); 85 OCTAVE_LOCAL_BUFFER (octave_idx_type, myvidx, m);
60 for (int i = 0; i < m; i++) 86 std::iota (&myvidx[0], &myvidx[m], 0);
61 myvidx[i] = i; 87
62 88 const T *Ar = ar_in.data ();
63 // Interim array to sort ar_in for octave sort order and to implement
64 // "unique".
65 Array<T> ar (ar_in);
66 89
67 if (uniq_v) 90 if (uniq_v)
68 { 91 {
69 ar = ar.sort (ar.dims () (1) > ar.dims () (0) ? 1 : 0, ASCENDING); 92 // Mutual Comparison is used to detect duplicated values.
70 const T *Ar = ar.data (); 93 // Using sort would be possible for numerical values and be of
71 int ctr = 0; 94 // O(n*log (n)) complexity instead of O(n*(n -1) / 2). But sort
72 int N_el = 1; 95 // is not supported for the octave_value container (structs/cells).
73 96 // Moreover, sort requires overhead for creating, filling, and sorting
74 // Number of same elements where we need to remove permutations 97 // the intermediate array which would need to be post-processed.
75 // Number of unique permutations is n! / (n_el1! * n_el2! * ...) 98 // In practice, and because n must be very small, mutual comparison is
76 for (octave_idx_type i = 0; i < m - 1; i++) 99 // typically faster and consumes less memory.
77 { 100
78 myvidx[i] = ctr; 101 octave_idx_type N_el = 1;
79 if (Ar[i + 1] != Ar[i]) 102 double denom = 1.0;
80 {
81 nr /= Factorial (N_el);
82 ctr = i + 1; // index of next different element
83 N_el = 1;
84 }
85 else
86 N_el++;
87 }
88 myvidx[m - 1] = ctr;
89 nr /= Factorial (N_el);
90 }
91 else if (do_sort)
92 {
93 ar = ar.sort (ar.dims () (1) > ar.dims () (0) ? 1 : 0, ASCENDING);
94 }
95
96 // Sort vector indices for inverse lexicographic order later.
97 std::sort (myvidx, myvidx + m, std::greater<int> ());
98
99 const T *Ar = ar.data ();
100
101 // Set up result array
102 octave_idx_type n = static_cast<octave_idx_type> (nr);
103 Array<T> res (dim_vector (n, m));
104 T *Res = res.rwdata ();
105
106 // Do the actual job
107 octave_idx_type i = 0;
108 std::sort (myvidx, myvidx + m, std::greater<int> ());
109 do
110 {
111 for (octave_idx_type j = 0; j < m; j++)
112 Res[i + j * n] = Ar[myvidx[j]];
113 i++;
114 }
115 while (std::next_permutation (myvidx, myvidx + m, std::greater<int> ()));
116
117 return res;
118 }
119
120 // Template for non-numerical types (e.g. Cell) without sorting.
121 // The C++ compiler complains as the provided type octave_value does not
122 // support the test of equality via '==' in the above template.
123
124 template <typename T>
125 static inline Array<T>
126 GetPermsNoSort (const Array<T>& ar_in, bool uniq_v = false)
127 {
128 octave_idx_type m = ar_in.numel ();
129 double nr = Factorial (m);
130
131 // Setup index vector filled from 0..m-1
132 OCTAVE_LOCAL_BUFFER (int, myvidx, m);
133 for (int i = 0; i < m; i++)
134 myvidx[i] = i;
135
136 const T *Ar = ar_in.data ();
137
138 if (uniq_v)
139 {
140 // Mutual Comparison using is_equal to detect duplicated values
141 int N_el = 1;
142 // Number of unique permutations is n! / (n_el1! * n_el2! * ...) 103 // Number of unique permutations is n! / (n_el1! * n_el2! * ...)
143 for (octave_idx_type i = 0; i < m - 1; i++) 104 for (octave_idx_type i = 0; i < m - 1; i++)
144 { 105 {
145 for (octave_idx_type j = i + 1; j < m; j++) 106 for (octave_idx_type j = i + 1; j < m; j++)
146 { 107 {
147 if (myvidx[j] > myvidx[i] && Ar[i].is_equal (Ar[j])) 108 bool isequal = is_equal_T<T>(Ar[i], Ar[j]);
109 if (myvidx[j] > myvidx[i] && isequal)
148 { 110 {
149 myvidx[j] = myvidx[i]; // not yet processed... 111 myvidx[j] = myvidx[i]; // not yet processed...
150 N_el++; 112 N_el++;
151 } 113 }
152 else 114 else
153 { 115 {
154 nr /= Factorial (N_el); 116 denom *= Factorial (N_el);
155 N_el = 1; 117 N_el = 1;
156 } 118 }
157 } 119 }
158 } 120 }
159 nr /= Factorial (N_el); 121 denom *= Factorial (N_el);
160 } 122 nr = Factorial (m) / denom;
123 }
124 else
125 nr = Factorial (m);
161 126
162 // Sort vector indices for inverse lexicographic order later. 127 // Sort vector indices for inverse lexicographic order later.
163 std::sort (myvidx, myvidx + m, std::greater<int> ()); 128 std::sort (myvidx, myvidx + m, std::greater<octave_idx_type> ());
164 129
165 // Set up result array 130 // Set up result array
166 octave_idx_type n = static_cast<octave_idx_type> (nr); 131 octave_idx_type n = static_cast<octave_idx_type> (nr);
167 Array<T> res (dim_vector (n, m)); 132 Array<T> res (dim_vector (n, m));
168 T *Res = res.rwdata (); 133 T *Res = res.rwdata ();
257 222
258 if (! (args (0).is_matrix_type () || args (0).is_range () 223 if (! (args (0).is_matrix_type () || args (0).is_range ()
259 || args (0).iscell () || args (0).is_scalar_type () 224 || args (0).iscell () || args (0).is_scalar_type ()
260 || args (0).isstruct ())) 225 || args (0).isstruct ()))
261 { 226 {
262 error ("perms: INPUT must be a matrix, a range, a cell array, " 227 error ("perms: V must be a matrix, range, cell array, "
263 "a struct or a scalar."); 228 "struct, or a scalar.");
264 } 229 }
265 230
266 std::string clname = args (0).class_name (); 231 std::string clname = args (0).class_name ();
267 232
268 // Execute main permutation code for the different classes 233 // Execute main permutation code for the different classes
289 else if (clname == "uint32") 254 else if (clname == "uint32")
290 retval = GetPerms<octave_uint32> (args (0).uint32_array_value (), uniq_v); 255 retval = GetPerms<octave_uint32> (args (0).uint32_array_value (), uniq_v);
291 else if (clname == "uint64") 256 else if (clname == "uint64")
292 retval = GetPerms<octave_uint64> (args (0).uint64_array_value (), uniq_v); 257 retval = GetPerms<octave_uint64> (args (0).uint64_array_value (), uniq_v);
293 else if (clname == "cell") 258 else if (clname == "cell")
294 retval = GetPermsNoSort<octave_value> (args (0).cell_value (), uniq_v); 259 retval = GetPerms<octave_value> (args (0).cell_value (), uniq_v);
295 else if (clname == "struct") 260 else if (clname == "struct")
296 { 261 {
297 const octave_map map_in (args (0).map_value ()); 262 const octave_map map_in (args (0).map_value ());
298 string_vector fn = map_in.fieldnames (); 263 string_vector fn = map_in.fieldnames ();
299 if (fn.numel () == 0 && map_in.numel () != 0) 264 if (fn.numel () == 0 && map_in.numel () != 0)
310 } 275 }
311 else 276 else
312 { 277 {
313 for (octave_idx_type i = 0; i < fn.numel (); i++) 278 for (octave_idx_type i = 0; i < fn.numel (); i++)
314 { 279 {
315 out.assign (fn (i), GetPermsNoSort<octave_value> 280 out.assign (fn (i), GetPerms<octave_value>
316 (map_in.contents (fn (i)), uniq_v)); 281 (map_in.contents (fn (i)), uniq_v));
317 } 282 }
318 } 283 }
319 retval = out; 284 retval = out;
320 } 285 }