Mercurial > octave
comparison src/DLD-FUNCTIONS/bsxfun.cc @ 6869:f9c893831e68
[project @ 2007-09-06 16:38:44 by dbateman]
author | dbateman |
---|---|
date | Thu, 06 Sep 2007 16:38:44 +0000 |
parents | |
children | cd2c6a69a70d |
comparison
equal
deleted
inserted
replaced
6868:975fcdfb0d2d | 6869:f9c893831e68 |
---|---|
1 /* | |
2 | |
3 Copyright (C) 2007 David Bateman | |
4 | |
5 This file is part of Octave. | |
6 | |
7 Octave is free software; you can redistribute it and/or modify it | |
8 under the terms of the GNU General Public License as published by the | |
9 Free Software Foundation; either version 2, or (at your option) any | |
10 later version. | |
11 | |
12 Octave is distributed in the hope that it will be useful, but WITHOUT | |
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or | |
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License | |
15 for more details. | |
16 | |
17 You should have received a copy of the GNU General Public License | |
18 along with Octave; see the file COPYING. If not, write to the Free | |
19 Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA | |
20 02110-1301, USA. | |
21 | |
22 */ | |
23 | |
24 #ifdef HAVE_CONFIG_H | |
25 #include <config.h> | |
26 #endif | |
27 | |
28 #include <string> | |
29 #include <vector> | |
30 #include <list> | |
31 | |
32 #include "lo-mappers.h" | |
33 | |
34 #include "oct-map.h" | |
35 #include "defun-dld.h" | |
36 #include "parse.h" | |
37 #include "variables.h" | |
38 #include "ov-colon.h" | |
39 #include "unwind-prot.h" | |
40 | |
41 static bool | |
42 maybe_update_column (octave_value& Ac, const octave_value& A, | |
43 const dim_vector& dva, const dim_vector& dvc, | |
44 octave_idx_type i, octave_value_list &idx) | |
45 { | |
46 octave_idx_type nd = dva.length (); | |
47 | |
48 if (i == 0) | |
49 { | |
50 idx(0) = octave_value (':'); | |
51 for (octave_idx_type j = 1; j < nd; j++) | |
52 { | |
53 if (dva (j) == 1) | |
54 idx (j) = octave_value (1); | |
55 else | |
56 idx (j) = octave_value ((i % dvc(j)) + 1); | |
57 | |
58 i = i / dvc (j); | |
59 } | |
60 | |
61 Ac = A; | |
62 Ac = Ac.single_subsref ("(", idx); | |
63 return true; | |
64 } | |
65 else | |
66 { | |
67 bool is_changed = false; | |
68 octave_idx_type k = i; | |
69 octave_idx_type k1 = i - 1; | |
70 for (octave_idx_type j = 1; j < nd; j++) | |
71 { | |
72 if (dva(j) != 1 && k % dvc (j) != k1 % dvc (j)) | |
73 { | |
74 idx (j) = octave_value ((k % dvc(j)) + 1); | |
75 is_changed = true; | |
76 } | |
77 | |
78 k = k / dvc (j); | |
79 k1 = k1 / dvc (j); | |
80 } | |
81 | |
82 if (is_changed) | |
83 { | |
84 Ac = A; | |
85 Ac = Ac.single_subsref ("(", idx); | |
86 return true; | |
87 } | |
88 else | |
89 return false; | |
90 } | |
91 } | |
92 | |
93 static void | |
94 update_index (octave_value_list& idx, const dim_vector& dv, octave_idx_type i) | |
95 { | |
96 octave_idx_type nd = dv.length (); | |
97 | |
98 if (i == 0) | |
99 { | |
100 for (octave_idx_type j = nd - 1; j > 0; j--) | |
101 idx(j) = octave_value (static_cast<double>(1)); | |
102 idx(0) = octave_value (':'); | |
103 } | |
104 else | |
105 { | |
106 for (octave_idx_type j = 1; j < nd; j++) | |
107 { | |
108 idx (j) = octave_value (i % dv (j) + 1); | |
109 i = i / dv (j); | |
110 } | |
111 } | |
112 } | |
113 | |
114 static void | |
115 update_index (Array<int>& idx, const dim_vector& dv, octave_idx_type i) | |
116 { | |
117 octave_idx_type nd = dv.length (); | |
118 | |
119 idx(0) = 0; | |
120 for (octave_idx_type j = 1; j < nd; j++) | |
121 { | |
122 idx (j) = i % dv (j); | |
123 i = i / dv (j); | |
124 } | |
125 } | |
126 | |
127 DEFUN_DLD (bsxfun, args, nargout, | |
128 " -*- texinfo -*-\n\ | |
129 @deftypefn {Lodable Function} {} bsxfun (@var{f}, @var{a}, @var{b})\n\ | |
130 Applies a binary function @var{f} element-wise to two matrix arguments\n\ | |
131 @var{a} and @var{b}. The function @var{f} must be capable of accepting\n\ | |
132 two column vector arguments of equal length, or one column vector\n\ | |
133 argument and a scalar.\n\ | |
134 \n\ | |
135 The dimensions of @var{a} and @var{b} must be equal or singleton. The\n\ | |
136 singleton dimensions a the matirces will be expanded to the same\n\ | |
137 dimensioanlity as the other matrix.\n\ | |
138 \n\ | |
139 @seealso{arrayfun, cellfun}\n\ | |
140 @end deftypefn") | |
141 { | |
142 int nargin = args.length (); | |
143 octave_value_list retval; | |
144 | |
145 if (nargin != 3) | |
146 print_usage (); | |
147 else | |
148 { | |
149 octave_function *func = 0; | |
150 std::string name; | |
151 std::string fcn_name; | |
152 | |
153 if (args(0).is_function_handle () || args(0).is_inline_function ()) | |
154 func = args(0).function_value (); | |
155 else if (args(0).is_string ()) | |
156 { | |
157 name = args(0).string_value (); | |
158 fcn_name = unique_symbol_name ("__bsxfun_fcn_"); | |
159 std::string fname = "function y = "; | |
160 fname.append (fcn_name); | |
161 fname.append ("(x) y = "); | |
162 func = extract_function (args(0), "bsxfun", fcn_name, fname, | |
163 "; endfunction"); | |
164 } | |
165 else | |
166 error ("bsxfun: first argument must be a string or function handle"); | |
167 | |
168 if (! error_state) | |
169 { | |
170 const octave_value A = args (1); | |
171 dim_vector dva = A.dims (); | |
172 octave_idx_type nda = dva.length (); | |
173 const octave_value B = args (2); | |
174 dim_vector dvb = B.dims (); | |
175 octave_idx_type ndb = dvb.length (); | |
176 octave_idx_type nd = nda; | |
177 | |
178 if (nda > ndb) | |
179 dvb.resize (nda, 1); | |
180 else if (nda < ndb) | |
181 { | |
182 dva.resize (ndb, 1); | |
183 nd = ndb; | |
184 } | |
185 | |
186 for (octave_idx_type i = 0; i < nd; i++) | |
187 if (dva (i) != dvb (i) && dva (i) != 1 && dvb (i) != 1) | |
188 { | |
189 error ("bsxfun: dimensions don't match"); | |
190 break; | |
191 } | |
192 | |
193 if (!error_state) | |
194 { | |
195 // Find the size of the output | |
196 dim_vector dvc; | |
197 dvc.resize (nd); | |
198 | |
199 for (octave_idx_type i = 0; i < nd; i++) | |
200 dvc (i) = (dva (i) < 1 ? dva (i) : (dvb (i) < 1 ? dvb (i) : | |
201 (dva (i) > dvb (i) ? dva (i) : dvb (i)))); | |
202 | |
203 if (dva == dvb || dva.numel () == 1 || dvb.numel () == 1) | |
204 { | |
205 octave_value_list inputs; | |
206 inputs (0) = A; | |
207 inputs (1) = B; | |
208 retval = feval (func, inputs, 1); | |
209 } | |
210 else if (dvc.numel () < 1) | |
211 { | |
212 octave_value_list inputs; | |
213 inputs (0) = A.resize (dvc); | |
214 inputs (1) = B.resize (dvc); | |
215 retval = feval (func, inputs, 1); | |
216 } | |
217 else | |
218 { | |
219 octave_idx_type ncount = 1; | |
220 for (octave_idx_type i = 1; i < nd; i++) | |
221 ncount *= dvc (i); | |
222 | |
223 #define BSXDEF(T) \ | |
224 T result_ ## T; \ | |
225 bool have_ ## T = false; | |
226 | |
227 BSXDEF(NDArray); | |
228 BSXDEF(ComplexNDArray); | |
229 BSXDEF(boolNDArray); | |
230 BSXDEF(int8NDArray); | |
231 BSXDEF(int16NDArray); | |
232 BSXDEF(int32NDArray); | |
233 BSXDEF(int64NDArray); | |
234 BSXDEF(uint8NDArray); | |
235 BSXDEF(uint16NDArray); | |
236 BSXDEF(uint32NDArray); | |
237 BSXDEF(uint64NDArray); | |
238 | |
239 octave_value Ac ; | |
240 octave_value_list idxA; | |
241 octave_value Bc; | |
242 octave_value_list idxB; | |
243 octave_value C; | |
244 octave_value_list inputs; | |
245 Array<int> ra_idx (dvc.length(), 0); | |
246 | |
247 | |
248 for (octave_idx_type i = 0; i < ncount; i++) | |
249 { | |
250 if (maybe_update_column (Ac, A, dva, dvc, i, idxA)) | |
251 inputs (0) = Ac; | |
252 | |
253 if (maybe_update_column (Bc, B, dvb, dvc, i, idxB)) | |
254 inputs (1) = Bc; | |
255 | |
256 octave_value_list tmp = feval (func, inputs, 1); | |
257 | |
258 if (error_state) | |
259 break; | |
260 | |
261 #define BSXINIT(T, CLS, EXTRACTOR) \ | |
262 (result_type == CLS) \ | |
263 { \ | |
264 have_ ## T = true; \ | |
265 result_ ## T = \ | |
266 tmp (0). EXTRACTOR ## _array_value (); \ | |
267 result_ ## T .resize (dvc); \ | |
268 } | |
269 | |
270 if (i == 0) | |
271 { | |
272 if (! tmp(0).is_sparse_type ()) | |
273 { | |
274 std::string result_type = tmp(0).class_name (); | |
275 if (result_type == "double") | |
276 { | |
277 if (tmp(0).is_real_type ()) | |
278 { | |
279 have_NDArray = true; | |
280 result_NDArray = tmp(0).array_value (); | |
281 result_NDArray.resize (dvc); | |
282 } | |
283 else | |
284 { | |
285 have_ComplexNDArray = true; | |
286 result_ComplexNDArray = | |
287 tmp(0).complex_array_value (); | |
288 result_ComplexNDArray.resize (dvc); | |
289 } | |
290 } | |
291 else if BSXINIT(boolNDArray, "logical", bool) | |
292 else if BSXINIT(int8NDArray, "int8", int8) | |
293 else if BSXINIT(int16NDArray, "int16", int16) | |
294 else if BSXINIT(int32NDArray, "int32", int32) | |
295 else if BSXINIT(int64NDArray, "int64", int64) | |
296 else if BSXINIT(uint8NDArray, "uint8", uint8) | |
297 else if BSXINIT(uint16NDArray, "uint16", uint16) | |
298 else if BSXINIT(uint32NDArray, "uint32", uint32) | |
299 else if BSXINIT(uint64NDArray, "uint64", uint64) | |
300 else | |
301 { | |
302 C = tmp (0); | |
303 C = C.resize (dvc); | |
304 } | |
305 } | |
306 } | |
307 else | |
308 { | |
309 update_index (ra_idx, dvc, i); | |
310 | |
311 if (have_NDArray) | |
312 { | |
313 if (tmp(0).class_name () != "double") | |
314 { | |
315 have_NDArray = false; | |
316 C = result_NDArray; | |
317 C = do_cat_op (C, tmp(0), ra_idx); | |
318 } | |
319 else if (tmp(0).is_real_type ()) | |
320 result_NDArray.insert (tmp(0).array_value(), | |
321 ra_idx); | |
322 else | |
323 { | |
324 result_ComplexNDArray = | |
325 ComplexNDArray (result_NDArray); | |
326 result_ComplexNDArray.insert | |
327 (tmp(0).complex_array_value(), ra_idx); | |
328 have_NDArray = false; | |
329 have_ComplexNDArray = true; | |
330 } | |
331 } | |
332 | |
333 #define BSXLOOP(T, CLS, EXTRACTOR) \ | |
334 (have_ ## T) \ | |
335 { \ | |
336 if (tmp (0).class_name () != CLS) \ | |
337 { \ | |
338 have_ ## T = false; \ | |
339 C = result_ ## T; \ | |
340 C = do_cat_op (C, tmp (0), ra_idx); \ | |
341 } \ | |
342 else \ | |
343 result_ ## T .insert \ | |
344 (tmp(0). EXTRACTOR ## _array_value (), \ | |
345 ra_idx); \ | |
346 } | |
347 | |
348 else if BSXLOOP(ComplexNDArray, "double", complex) | |
349 else if BSXLOOP(boolNDArray, "logical", bool) | |
350 else if BSXLOOP(int8NDArray, "int8", int8) | |
351 else if BSXLOOP(int16NDArray, "int16", int16) | |
352 else if BSXLOOP(int32NDArray, "int32", int32) | |
353 else if BSXLOOP(int64NDArray, "int64", int64) | |
354 else if BSXLOOP(uint8NDArray, "uint8", uint8) | |
355 else if BSXLOOP(uint16NDArray, "uint16", uint16) | |
356 else if BSXLOOP(uint32NDArray, "uint32", uint32) | |
357 else if BSXLOOP(uint64NDArray, "uint64", uint64) | |
358 else | |
359 C = do_cat_op (C, tmp(0), ra_idx); | |
360 } | |
361 } | |
362 | |
363 #define BSXEND(T) \ | |
364 (have_ ## T) \ | |
365 retval (0) = result_ ## T; | |
366 | |
367 if BSXEND(NDArray) | |
368 else if BSXEND(ComplexNDArray) | |
369 else if BSXEND(boolNDArray) | |
370 else if BSXEND(int8NDArray) | |
371 else if BSXEND(int16NDArray) | |
372 else if BSXEND(int32NDArray) | |
373 else if BSXEND(int64NDArray) | |
374 else if BSXEND(uint8NDArray) | |
375 else if BSXEND(uint16NDArray) | |
376 else if BSXEND(uint32NDArray) | |
377 else if BSXEND(uint64NDArray) | |
378 else | |
379 retval(0) = C; | |
380 } | |
381 } | |
382 } | |
383 | |
384 if (! fcn_name.empty ()) | |
385 clear_function (fcn_name); | |
386 } | |
387 | |
388 return retval; | |
389 } | |
390 | |
391 /* | |
392 | |
393 %!shared a, b, c, f | |
394 %! a = randn (4, 4); | |
395 %! b = mean (a, 1); | |
396 %! c = mean (a, 2); | |
397 %! f = @minus; | |
398 %!error(bsxfun (f)); | |
399 %!error(bsxfun (f, a)); | |
400 %!error(bsxfun (a, b)); | |
401 %!error(bsxfun (a, b, c)); | |
402 %!error(bsxfun (f, a, b, c)); | |
403 %!error(bsxfun (f, ones(4, 0), ones(4, 4))) | |
404 %!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0)); | |
405 %!assert(bsxfun (f, ones(1, 4), ones(4, 1)), zeros(4, 4)); | |
406 %!assert(bsxfun (f, a, b), a - repmat(b, 4, 1)); | |
407 %!assert(bsxfun (f, a, c), a - repmat(c, 1, 4)); | |
408 %!assert(bsxfun ("minus", ones(1, 4), ones(4, 1)), zeros(4, 4)); | |
409 | |
410 %!shared a, b, c, f | |
411 %! a = randn (4, 4); | |
412 %! a(1) *= 1i; | |
413 %! b = mean (a, 1); | |
414 %! c = mean (a, 2); | |
415 %! f = @minus; | |
416 %!error(bsxfun (f)); | |
417 %!error(bsxfun (f, a)); | |
418 %!error(bsxfun (a, b)); | |
419 %!error(bsxfun (a, b, c)); | |
420 %!error(bsxfun (f, a, b, c)); | |
421 %!error(bsxfun (f, ones(4, 0), ones(4, 4))) | |
422 %!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0)); | |
423 %!assert(bsxfun (f, ones(1, 4), ones(4, 1)), zeros(4, 4)); | |
424 %!assert(bsxfun (f, a, b), a - repmat(b, 4, 1)); | |
425 %!assert(bsxfun (f, a, c), a - repmat(c, 1, 4)); | |
426 %!assert(bsxfun ("minus", ones(1, 4), ones(4, 1)), zeros(4, 4)); | |
427 | |
428 %!shared a, b, c, f | |
429 %! a = randn (4, 4); | |
430 %! a(end) *= 1i; | |
431 %! b = mean (a, 1); | |
432 %! c = mean (a, 2); | |
433 %! f = @minus; | |
434 %!error(bsxfun (f)); | |
435 %!error(bsxfun (f, a)); | |
436 %!error(bsxfun (a, b)); | |
437 %!error(bsxfun (a, b, c)); | |
438 %!error(bsxfun (f, a, b, c)); | |
439 %!error(bsxfun (f, ones(4, 0), ones(4, 4))) | |
440 %!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0)); | |
441 %!assert(bsxfun (f, ones(1, 4), ones(4, 1)), zeros(4, 4)); | |
442 %!assert(bsxfun (f, a, b), a - repmat(b, 4, 1)); | |
443 %!assert(bsxfun (f, a, c), a - repmat(c, 1, 4)); | |
444 %!assert(bsxfun ("minus", ones(1, 4), ones(4, 1)), zeros(4, 4)); | |
445 | |
446 %!shared a, b, c, f | |
447 %! a = randn (4, 4); | |
448 %! b = a (1, :); | |
449 %! c = a (:, 1); | |
450 %! f = @(x, y) x == y; | |
451 %!error(bsxfun (f)); | |
452 %!error(bsxfun (f, a)); | |
453 %!error(bsxfun (a, b)); | |
454 %!error(bsxfun (a, b, c)); | |
455 %!error(bsxfun (f, a, b, c)); | |
456 %!error(bsxfun (f, ones(4, 0), ones(4, 4))) | |
457 %!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0, "logical")); | |
458 %!assert(bsxfun (f, ones(1, 4), ones(4, 1)), ones(4, 4, "logical")); | |
459 %!assert(bsxfun (f, a, b), a == repmat(b, 4, 1)); | |
460 %!assert(bsxfun (f, a, c), a == repmat(c, 1, 4)); | |
461 | |
462 %!shared a, b, c, d, f | |
463 %! a = randn (4, 4, 4); | |
464 %! b = mean (a, 1); | |
465 %! c = mean (a, 2); | |
466 %! d = mean (a, 3); | |
467 %! f = @minus; | |
468 %!error(bsxfun (f, ones([4, 0, 4]), ones([4, 4, 4]))); | |
469 %!assert(bsxfun (f, ones([4, 0, 4]), ones([4, 1, 4])), zeros([4, 0, 4])); | |
470 %!assert(bsxfun (f, ones([4, 4, 0]), ones([4, 1, 1])), zeros([4, 4, 0])); | |
471 %!assert(bsxfun (f, ones([1, 4, 4]), ones([4, 1, 4])), zeros([4, 4, 4])); | |
472 %!assert(bsxfun (f, ones([4, 4, 1]), ones([4, 1, 4])), zeros([4, 4, 4])); | |
473 %!assert(bsxfun (f, ones([4, 1, 4]), ones([1, 4, 4])), zeros([4, 4, 4])); | |
474 %!assert(bsxfun (f, ones([4, 1, 4]), ones([1, 4, 1])), zeros([4, 4, 4])); | |
475 %!assert(bsxfun (f, a, b), a - repmat(b, [4, 1, 1])); | |
476 %!assert(bsxfun (f, a, c), a - repmat(c, [1, 4, 1])); | |
477 %!assert(bsxfun (f, a, d), a - repmat(d, [1, 1, 4])); | |
478 %!assert(bsxfun ("minus", ones([4, 0, 4]), ones([4, 1, 4])), zeros([4, 0, 4])); | |
479 | |
480 %% The below is a very hard case to treat | |
481 %!assert(bsxfun (f, ones([4, 1, 4, 1]), ones([1, 4, 1, 4])), zeros([4, 4, 4, 4])); | |
482 | |
483 */ |