Mercurial > octave
comparison libinterp/corefcn/perms.cc @ 33005:349b4adf686a
Simplify code complexity of perms.cc (bug #65244)
* perms.cc: #include <numeric> at top of file.
* perms.cc (is_equal_T): New template function.
* perms.cc (is_equal_T<octave_value>): New template specialization function.
* perms.cc (GetPerms): Remove third input "do_sort". Merge code from
GetPermsNoSort(). Use std::iota rather than hand-rolled for loop.
* perms.cc (GetPermsNoSort): Delete function.
* perms.cc (Fperms): Rename functions calls to GetPermsNoSort() to GetPerms().
Rewrite input validation error message to refer to input "V".
author | Hendrik Koerner <koerhen@web.de> |
---|---|
date | Sun, 11 Feb 2024 10:06:09 +0800 |
parents | 9c9f4df5e4c3 |
children | 8adbe07a6835 |
comparison
equal
deleted
inserted
replaced
33004:9c9f4df5e4c3 | 33005:349b4adf686a |
---|---|
26 #if defined(HAVE_CONFIG_H) | 26 #if defined(HAVE_CONFIG_H) |
27 # include "config.h" | 27 # include "config.h" |
28 #endif | 28 #endif |
29 | 29 |
30 #include <algorithm> | 30 #include <algorithm> |
31 #include <numeric> | |
31 | 32 |
32 #include "defun.h" | 33 #include "defun.h" |
33 #include "error.h" | 34 #include "error.h" |
34 #include "errwarn.h" | 35 #include "errwarn.h" |
35 #include "ovl.h" | 36 #include "ovl.h" |
46 } | 47 } |
47 | 48 |
48 // | 49 // |
49 // Use C++ template to cater for the different octave array classes. | 50 // Use C++ template to cater for the different octave array classes. |
50 // | 51 // |
52 | |
53 // FIXME: To allow comparison between all supported template types, we need | |
54 // to use either "if constexpr" (supported in C++17) or template specialisation | |
55 // (supported in C++11). Currently (2024), Octave stipulates the usage of | |
56 // C++11, so the (slightly more complex) template specialization is used. | |
57 // Once Octave moves to C++17 or beyond, the following code snippet is | |
58 // preferrable and the comparison templates can be removed: | |
59 // bool isequal; | |
60 // if constexpr (std::is_same<T, octave_value>::value) | |
61 // isEqual = Ar[i].is_equal (Ar[j]); | |
62 // else | |
63 // isEqual = (Ar[i] == Ar[j]); | |
64 | |
65 template <typename T> | |
66 bool is_equal_T (T a, T b) | |
67 { | |
68 return a == b; | |
69 } | |
70 | |
71 template <> | |
72 bool is_equal_T<octave_value> (octave_value a, octave_value b) | |
73 { | |
74 return a.is_equal (b); | |
75 } | |
76 | |
51 template <typename T> | 77 template <typename T> |
52 static inline Array<T> | 78 static inline Array<T> |
53 GetPerms (const Array<T>& ar_in, bool uniq_v, bool do_sort = false) | 79 GetPerms (const Array<T>& ar_in, bool uniq_v = false) |
54 { | 80 { |
55 octave_idx_type m = ar_in.numel (); | 81 octave_idx_type m = ar_in.numel (); |
56 double nr = Factorial (m); | 82 double nr; // number of rows of resulting array |
57 | 83 |
58 // Setup index vector filled from 0..m-1 | 84 // Setup index vector filled from 0..m-1 |
59 OCTAVE_LOCAL_BUFFER (int, myvidx, m); | 85 OCTAVE_LOCAL_BUFFER (octave_idx_type, myvidx, m); |
60 for (int i = 0; i < m; i++) | 86 std::iota (&myvidx[0], &myvidx[m], 0); |
61 myvidx[i] = i; | 87 |
62 | 88 const T *Ar = ar_in.data (); |
63 // Interim array to sort ar_in for octave sort order and to implement | |
64 // "unique". | |
65 Array<T> ar (ar_in); | |
66 | 89 |
67 if (uniq_v) | 90 if (uniq_v) |
68 { | 91 { |
69 ar = ar.sort (ar.dims () (1) > ar.dims () (0) ? 1 : 0, ASCENDING); | 92 // Mutual Comparison is used to detect duplicated values. |
70 const T *Ar = ar.data (); | 93 // Using sort would be possible for numerical values and be of |
71 int ctr = 0; | 94 // O(n*log (n)) complexity instead of O(n*(n -1) / 2). But sort |
72 int N_el = 1; | 95 // is not supported for the octave_value container (structs/cells). |
73 | 96 // Moreover, sort requires overhead for creating, filling, and sorting |
74 // Number of same elements where we need to remove permutations | 97 // the intermediate array which would need to be post-processed. |
75 // Number of unique permutations is n! / (n_el1! * n_el2! * ...) | 98 // In practice, and because n must be very small, mutual comparison is |
76 for (octave_idx_type i = 0; i < m - 1; i++) | 99 // typically faster and consumes less memory. |
77 { | 100 |
78 myvidx[i] = ctr; | 101 octave_idx_type N_el = 1; |
79 if (Ar[i + 1] != Ar[i]) | 102 double denom = 1.0; |
80 { | |
81 nr /= Factorial (N_el); | |
82 ctr = i + 1; // index of next different element | |
83 N_el = 1; | |
84 } | |
85 else | |
86 N_el++; | |
87 } | |
88 myvidx[m - 1] = ctr; | |
89 nr /= Factorial (N_el); | |
90 } | |
91 else if (do_sort) | |
92 { | |
93 ar = ar.sort (ar.dims () (1) > ar.dims () (0) ? 1 : 0, ASCENDING); | |
94 } | |
95 | |
96 // Sort vector indices for inverse lexicographic order later. | |
97 std::sort (myvidx, myvidx + m, std::greater<int> ()); | |
98 | |
99 const T *Ar = ar.data (); | |
100 | |
101 // Set up result array | |
102 octave_idx_type n = static_cast<octave_idx_type> (nr); | |
103 Array<T> res (dim_vector (n, m)); | |
104 T *Res = res.rwdata (); | |
105 | |
106 // Do the actual job | |
107 octave_idx_type i = 0; | |
108 std::sort (myvidx, myvidx + m, std::greater<int> ()); | |
109 do | |
110 { | |
111 for (octave_idx_type j = 0; j < m; j++) | |
112 Res[i + j * n] = Ar[myvidx[j]]; | |
113 i++; | |
114 } | |
115 while (std::next_permutation (myvidx, myvidx + m, std::greater<int> ())); | |
116 | |
117 return res; | |
118 } | |
119 | |
120 // Template for non-numerical types (e.g. Cell) without sorting. | |
121 // The C++ compiler complains as the provided type octave_value does not | |
122 // support the test of equality via '==' in the above template. | |
123 | |
124 template <typename T> | |
125 static inline Array<T> | |
126 GetPermsNoSort (const Array<T>& ar_in, bool uniq_v = false) | |
127 { | |
128 octave_idx_type m = ar_in.numel (); | |
129 double nr = Factorial (m); | |
130 | |
131 // Setup index vector filled from 0..m-1 | |
132 OCTAVE_LOCAL_BUFFER (int, myvidx, m); | |
133 for (int i = 0; i < m; i++) | |
134 myvidx[i] = i; | |
135 | |
136 const T *Ar = ar_in.data (); | |
137 | |
138 if (uniq_v) | |
139 { | |
140 // Mutual Comparison using is_equal to detect duplicated values | |
141 int N_el = 1; | |
142 // Number of unique permutations is n! / (n_el1! * n_el2! * ...) | 103 // Number of unique permutations is n! / (n_el1! * n_el2! * ...) |
143 for (octave_idx_type i = 0; i < m - 1; i++) | 104 for (octave_idx_type i = 0; i < m - 1; i++) |
144 { | 105 { |
145 for (octave_idx_type j = i + 1; j < m; j++) | 106 for (octave_idx_type j = i + 1; j < m; j++) |
146 { | 107 { |
147 if (myvidx[j] > myvidx[i] && Ar[i].is_equal (Ar[j])) | 108 bool isequal = is_equal_T<T>(Ar[i], Ar[j]); |
109 if (myvidx[j] > myvidx[i] && isequal) | |
148 { | 110 { |
149 myvidx[j] = myvidx[i]; // not yet processed... | 111 myvidx[j] = myvidx[i]; // not yet processed... |
150 N_el++; | 112 N_el++; |
151 } | 113 } |
152 else | 114 else |
153 { | 115 { |
154 nr /= Factorial (N_el); | 116 denom *= Factorial (N_el); |
155 N_el = 1; | 117 N_el = 1; |
156 } | 118 } |
157 } | 119 } |
158 } | 120 } |
159 nr /= Factorial (N_el); | 121 denom *= Factorial (N_el); |
160 } | 122 nr = Factorial (m) / denom; |
123 } | |
124 else | |
125 nr = Factorial (m); | |
161 | 126 |
162 // Sort vector indices for inverse lexicographic order later. | 127 // Sort vector indices for inverse lexicographic order later. |
163 std::sort (myvidx, myvidx + m, std::greater<int> ()); | 128 std::sort (myvidx, myvidx + m, std::greater<octave_idx_type> ()); |
164 | 129 |
165 // Set up result array | 130 // Set up result array |
166 octave_idx_type n = static_cast<octave_idx_type> (nr); | 131 octave_idx_type n = static_cast<octave_idx_type> (nr); |
167 Array<T> res (dim_vector (n, m)); | 132 Array<T> res (dim_vector (n, m)); |
168 T *Res = res.rwdata (); | 133 T *Res = res.rwdata (); |
257 | 222 |
258 if (! (args (0).is_matrix_type () || args (0).is_range () | 223 if (! (args (0).is_matrix_type () || args (0).is_range () |
259 || args (0).iscell () || args (0).is_scalar_type () | 224 || args (0).iscell () || args (0).is_scalar_type () |
260 || args (0).isstruct ())) | 225 || args (0).isstruct ())) |
261 { | 226 { |
262 error ("perms: INPUT must be a matrix, a range, a cell array, " | 227 error ("perms: V must be a matrix, range, cell array, " |
263 "a struct or a scalar."); | 228 "struct, or a scalar."); |
264 } | 229 } |
265 | 230 |
266 std::string clname = args (0).class_name (); | 231 std::string clname = args (0).class_name (); |
267 | 232 |
268 // Execute main permutation code for the different classes | 233 // Execute main permutation code for the different classes |
289 else if (clname == "uint32") | 254 else if (clname == "uint32") |
290 retval = GetPerms<octave_uint32> (args (0).uint32_array_value (), uniq_v); | 255 retval = GetPerms<octave_uint32> (args (0).uint32_array_value (), uniq_v); |
291 else if (clname == "uint64") | 256 else if (clname == "uint64") |
292 retval = GetPerms<octave_uint64> (args (0).uint64_array_value (), uniq_v); | 257 retval = GetPerms<octave_uint64> (args (0).uint64_array_value (), uniq_v); |
293 else if (clname == "cell") | 258 else if (clname == "cell") |
294 retval = GetPermsNoSort<octave_value> (args (0).cell_value (), uniq_v); | 259 retval = GetPerms<octave_value> (args (0).cell_value (), uniq_v); |
295 else if (clname == "struct") | 260 else if (clname == "struct") |
296 { | 261 { |
297 const octave_map map_in (args (0).map_value ()); | 262 const octave_map map_in (args (0).map_value ()); |
298 string_vector fn = map_in.fieldnames (); | 263 string_vector fn = map_in.fieldnames (); |
299 if (fn.numel () == 0 && map_in.numel () != 0) | 264 if (fn.numel () == 0 && map_in.numel () != 0) |
310 } | 275 } |
311 else | 276 else |
312 { | 277 { |
313 for (octave_idx_type i = 0; i < fn.numel (); i++) | 278 for (octave_idx_type i = 0; i < fn.numel (); i++) |
314 { | 279 { |
315 out.assign (fn (i), GetPermsNoSort<octave_value> | 280 out.assign (fn (i), GetPerms<octave_value> |
316 (map_in.contents (fn (i)), uniq_v)); | 281 (map_in.contents (fn (i)), uniq_v)); |
317 } | 282 } |
318 } | 283 } |
319 retval = out; | 284 retval = out; |
320 } | 285 } |