comparison src/DLD-FUNCTIONS/bsxfun.cc @ 6869:f9c893831e68

[project @ 2007-09-06 16:38:44 by dbateman]
author dbateman
date Thu, 06 Sep 2007 16:38:44 +0000
parents
children cd2c6a69a70d
comparison
equal deleted inserted replaced
6868:975fcdfb0d2d 6869:f9c893831e68
1 /*
2
3 Copyright (C) 2007 David Bateman
4
5 This file is part of Octave.
6
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 2, or (at your option) any
10 later version.
11
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, write to the Free
19 Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
20 02110-1301, USA.
21
22 */
23
24 #ifdef HAVE_CONFIG_H
25 #include <config.h>
26 #endif
27
28 #include <string>
29 #include <vector>
30 #include <list>
31
32 #include "lo-mappers.h"
33
34 #include "oct-map.h"
35 #include "defun-dld.h"
36 #include "parse.h"
37 #include "variables.h"
38 #include "ov-colon.h"
39 #include "unwind-prot.h"
40
41 static bool
42 maybe_update_column (octave_value& Ac, const octave_value& A,
43 const dim_vector& dva, const dim_vector& dvc,
44 octave_idx_type i, octave_value_list &idx)
45 {
46 octave_idx_type nd = dva.length ();
47
48 if (i == 0)
49 {
50 idx(0) = octave_value (':');
51 for (octave_idx_type j = 1; j < nd; j++)
52 {
53 if (dva (j) == 1)
54 idx (j) = octave_value (1);
55 else
56 idx (j) = octave_value ((i % dvc(j)) + 1);
57
58 i = i / dvc (j);
59 }
60
61 Ac = A;
62 Ac = Ac.single_subsref ("(", idx);
63 return true;
64 }
65 else
66 {
67 bool is_changed = false;
68 octave_idx_type k = i;
69 octave_idx_type k1 = i - 1;
70 for (octave_idx_type j = 1; j < nd; j++)
71 {
72 if (dva(j) != 1 && k % dvc (j) != k1 % dvc (j))
73 {
74 idx (j) = octave_value ((k % dvc(j)) + 1);
75 is_changed = true;
76 }
77
78 k = k / dvc (j);
79 k1 = k1 / dvc (j);
80 }
81
82 if (is_changed)
83 {
84 Ac = A;
85 Ac = Ac.single_subsref ("(", idx);
86 return true;
87 }
88 else
89 return false;
90 }
91 }
92
93 static void
94 update_index (octave_value_list& idx, const dim_vector& dv, octave_idx_type i)
95 {
96 octave_idx_type nd = dv.length ();
97
98 if (i == 0)
99 {
100 for (octave_idx_type j = nd - 1; j > 0; j--)
101 idx(j) = octave_value (static_cast<double>(1));
102 idx(0) = octave_value (':');
103 }
104 else
105 {
106 for (octave_idx_type j = 1; j < nd; j++)
107 {
108 idx (j) = octave_value (i % dv (j) + 1);
109 i = i / dv (j);
110 }
111 }
112 }
113
114 static void
115 update_index (Array<int>& idx, const dim_vector& dv, octave_idx_type i)
116 {
117 octave_idx_type nd = dv.length ();
118
119 idx(0) = 0;
120 for (octave_idx_type j = 1; j < nd; j++)
121 {
122 idx (j) = i % dv (j);
123 i = i / dv (j);
124 }
125 }
126
127 DEFUN_DLD (bsxfun, args, nargout,
128 " -*- texinfo -*-\n\
129 @deftypefn {Lodable Function} {} bsxfun (@var{f}, @var{a}, @var{b})\n\
130 Applies a binary function @var{f} element-wise to two matrix arguments\n\
131 @var{a} and @var{b}. The function @var{f} must be capable of accepting\n\
132 two column vector arguments of equal length, or one column vector\n\
133 argument and a scalar.\n\
134 \n\
135 The dimensions of @var{a} and @var{b} must be equal or singleton. The\n\
136 singleton dimensions a the matirces will be expanded to the same\n\
137 dimensioanlity as the other matrix.\n\
138 \n\
139 @seealso{arrayfun, cellfun}\n\
140 @end deftypefn")
141 {
142 int nargin = args.length ();
143 octave_value_list retval;
144
145 if (nargin != 3)
146 print_usage ();
147 else
148 {
149 octave_function *func = 0;
150 std::string name;
151 std::string fcn_name;
152
153 if (args(0).is_function_handle () || args(0).is_inline_function ())
154 func = args(0).function_value ();
155 else if (args(0).is_string ())
156 {
157 name = args(0).string_value ();
158 fcn_name = unique_symbol_name ("__bsxfun_fcn_");
159 std::string fname = "function y = ";
160 fname.append (fcn_name);
161 fname.append ("(x) y = ");
162 func = extract_function (args(0), "bsxfun", fcn_name, fname,
163 "; endfunction");
164 }
165 else
166 error ("bsxfun: first argument must be a string or function handle");
167
168 if (! error_state)
169 {
170 const octave_value A = args (1);
171 dim_vector dva = A.dims ();
172 octave_idx_type nda = dva.length ();
173 const octave_value B = args (2);
174 dim_vector dvb = B.dims ();
175 octave_idx_type ndb = dvb.length ();
176 octave_idx_type nd = nda;
177
178 if (nda > ndb)
179 dvb.resize (nda, 1);
180 else if (nda < ndb)
181 {
182 dva.resize (ndb, 1);
183 nd = ndb;
184 }
185
186 for (octave_idx_type i = 0; i < nd; i++)
187 if (dva (i) != dvb (i) && dva (i) != 1 && dvb (i) != 1)
188 {
189 error ("bsxfun: dimensions don't match");
190 break;
191 }
192
193 if (!error_state)
194 {
195 // Find the size of the output
196 dim_vector dvc;
197 dvc.resize (nd);
198
199 for (octave_idx_type i = 0; i < nd; i++)
200 dvc (i) = (dva (i) < 1 ? dva (i) : (dvb (i) < 1 ? dvb (i) :
201 (dva (i) > dvb (i) ? dva (i) : dvb (i))));
202
203 if (dva == dvb || dva.numel () == 1 || dvb.numel () == 1)
204 {
205 octave_value_list inputs;
206 inputs (0) = A;
207 inputs (1) = B;
208 retval = feval (func, inputs, 1);
209 }
210 else if (dvc.numel () < 1)
211 {
212 octave_value_list inputs;
213 inputs (0) = A.resize (dvc);
214 inputs (1) = B.resize (dvc);
215 retval = feval (func, inputs, 1);
216 }
217 else
218 {
219 octave_idx_type ncount = 1;
220 for (octave_idx_type i = 1; i < nd; i++)
221 ncount *= dvc (i);
222
223 #define BSXDEF(T) \
224 T result_ ## T; \
225 bool have_ ## T = false;
226
227 BSXDEF(NDArray);
228 BSXDEF(ComplexNDArray);
229 BSXDEF(boolNDArray);
230 BSXDEF(int8NDArray);
231 BSXDEF(int16NDArray);
232 BSXDEF(int32NDArray);
233 BSXDEF(int64NDArray);
234 BSXDEF(uint8NDArray);
235 BSXDEF(uint16NDArray);
236 BSXDEF(uint32NDArray);
237 BSXDEF(uint64NDArray);
238
239 octave_value Ac ;
240 octave_value_list idxA;
241 octave_value Bc;
242 octave_value_list idxB;
243 octave_value C;
244 octave_value_list inputs;
245 Array<int> ra_idx (dvc.length(), 0);
246
247
248 for (octave_idx_type i = 0; i < ncount; i++)
249 {
250 if (maybe_update_column (Ac, A, dva, dvc, i, idxA))
251 inputs (0) = Ac;
252
253 if (maybe_update_column (Bc, B, dvb, dvc, i, idxB))
254 inputs (1) = Bc;
255
256 octave_value_list tmp = feval (func, inputs, 1);
257
258 if (error_state)
259 break;
260
261 #define BSXINIT(T, CLS, EXTRACTOR) \
262 (result_type == CLS) \
263 { \
264 have_ ## T = true; \
265 result_ ## T = \
266 tmp (0). EXTRACTOR ## _array_value (); \
267 result_ ## T .resize (dvc); \
268 }
269
270 if (i == 0)
271 {
272 if (! tmp(0).is_sparse_type ())
273 {
274 std::string result_type = tmp(0).class_name ();
275 if (result_type == "double")
276 {
277 if (tmp(0).is_real_type ())
278 {
279 have_NDArray = true;
280 result_NDArray = tmp(0).array_value ();
281 result_NDArray.resize (dvc);
282 }
283 else
284 {
285 have_ComplexNDArray = true;
286 result_ComplexNDArray =
287 tmp(0).complex_array_value ();
288 result_ComplexNDArray.resize (dvc);
289 }
290 }
291 else if BSXINIT(boolNDArray, "logical", bool)
292 else if BSXINIT(int8NDArray, "int8", int8)
293 else if BSXINIT(int16NDArray, "int16", int16)
294 else if BSXINIT(int32NDArray, "int32", int32)
295 else if BSXINIT(int64NDArray, "int64", int64)
296 else if BSXINIT(uint8NDArray, "uint8", uint8)
297 else if BSXINIT(uint16NDArray, "uint16", uint16)
298 else if BSXINIT(uint32NDArray, "uint32", uint32)
299 else if BSXINIT(uint64NDArray, "uint64", uint64)
300 else
301 {
302 C = tmp (0);
303 C = C.resize (dvc);
304 }
305 }
306 }
307 else
308 {
309 update_index (ra_idx, dvc, i);
310
311 if (have_NDArray)
312 {
313 if (tmp(0).class_name () != "double")
314 {
315 have_NDArray = false;
316 C = result_NDArray;
317 C = do_cat_op (C, tmp(0), ra_idx);
318 }
319 else if (tmp(0).is_real_type ())
320 result_NDArray.insert (tmp(0).array_value(),
321 ra_idx);
322 else
323 {
324 result_ComplexNDArray =
325 ComplexNDArray (result_NDArray);
326 result_ComplexNDArray.insert
327 (tmp(0).complex_array_value(), ra_idx);
328 have_NDArray = false;
329 have_ComplexNDArray = true;
330 }
331 }
332
333 #define BSXLOOP(T, CLS, EXTRACTOR) \
334 (have_ ## T) \
335 { \
336 if (tmp (0).class_name () != CLS) \
337 { \
338 have_ ## T = false; \
339 C = result_ ## T; \
340 C = do_cat_op (C, tmp (0), ra_idx); \
341 } \
342 else \
343 result_ ## T .insert \
344 (tmp(0). EXTRACTOR ## _array_value (), \
345 ra_idx); \
346 }
347
348 else if BSXLOOP(ComplexNDArray, "double", complex)
349 else if BSXLOOP(boolNDArray, "logical", bool)
350 else if BSXLOOP(int8NDArray, "int8", int8)
351 else if BSXLOOP(int16NDArray, "int16", int16)
352 else if BSXLOOP(int32NDArray, "int32", int32)
353 else if BSXLOOP(int64NDArray, "int64", int64)
354 else if BSXLOOP(uint8NDArray, "uint8", uint8)
355 else if BSXLOOP(uint16NDArray, "uint16", uint16)
356 else if BSXLOOP(uint32NDArray, "uint32", uint32)
357 else if BSXLOOP(uint64NDArray, "uint64", uint64)
358 else
359 C = do_cat_op (C, tmp(0), ra_idx);
360 }
361 }
362
363 #define BSXEND(T) \
364 (have_ ## T) \
365 retval (0) = result_ ## T;
366
367 if BSXEND(NDArray)
368 else if BSXEND(ComplexNDArray)
369 else if BSXEND(boolNDArray)
370 else if BSXEND(int8NDArray)
371 else if BSXEND(int16NDArray)
372 else if BSXEND(int32NDArray)
373 else if BSXEND(int64NDArray)
374 else if BSXEND(uint8NDArray)
375 else if BSXEND(uint16NDArray)
376 else if BSXEND(uint32NDArray)
377 else if BSXEND(uint64NDArray)
378 else
379 retval(0) = C;
380 }
381 }
382 }
383
384 if (! fcn_name.empty ())
385 clear_function (fcn_name);
386 }
387
388 return retval;
389 }
390
391 /*
392
393 %!shared a, b, c, f
394 %! a = randn (4, 4);
395 %! b = mean (a, 1);
396 %! c = mean (a, 2);
397 %! f = @minus;
398 %!error(bsxfun (f));
399 %!error(bsxfun (f, a));
400 %!error(bsxfun (a, b));
401 %!error(bsxfun (a, b, c));
402 %!error(bsxfun (f, a, b, c));
403 %!error(bsxfun (f, ones(4, 0), ones(4, 4)))
404 %!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0));
405 %!assert(bsxfun (f, ones(1, 4), ones(4, 1)), zeros(4, 4));
406 %!assert(bsxfun (f, a, b), a - repmat(b, 4, 1));
407 %!assert(bsxfun (f, a, c), a - repmat(c, 1, 4));
408 %!assert(bsxfun ("minus", ones(1, 4), ones(4, 1)), zeros(4, 4));
409
410 %!shared a, b, c, f
411 %! a = randn (4, 4);
412 %! a(1) *= 1i;
413 %! b = mean (a, 1);
414 %! c = mean (a, 2);
415 %! f = @minus;
416 %!error(bsxfun (f));
417 %!error(bsxfun (f, a));
418 %!error(bsxfun (a, b));
419 %!error(bsxfun (a, b, c));
420 %!error(bsxfun (f, a, b, c));
421 %!error(bsxfun (f, ones(4, 0), ones(4, 4)))
422 %!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0));
423 %!assert(bsxfun (f, ones(1, 4), ones(4, 1)), zeros(4, 4));
424 %!assert(bsxfun (f, a, b), a - repmat(b, 4, 1));
425 %!assert(bsxfun (f, a, c), a - repmat(c, 1, 4));
426 %!assert(bsxfun ("minus", ones(1, 4), ones(4, 1)), zeros(4, 4));
427
428 %!shared a, b, c, f
429 %! a = randn (4, 4);
430 %! a(end) *= 1i;
431 %! b = mean (a, 1);
432 %! c = mean (a, 2);
433 %! f = @minus;
434 %!error(bsxfun (f));
435 %!error(bsxfun (f, a));
436 %!error(bsxfun (a, b));
437 %!error(bsxfun (a, b, c));
438 %!error(bsxfun (f, a, b, c));
439 %!error(bsxfun (f, ones(4, 0), ones(4, 4)))
440 %!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0));
441 %!assert(bsxfun (f, ones(1, 4), ones(4, 1)), zeros(4, 4));
442 %!assert(bsxfun (f, a, b), a - repmat(b, 4, 1));
443 %!assert(bsxfun (f, a, c), a - repmat(c, 1, 4));
444 %!assert(bsxfun ("minus", ones(1, 4), ones(4, 1)), zeros(4, 4));
445
446 %!shared a, b, c, f
447 %! a = randn (4, 4);
448 %! b = a (1, :);
449 %! c = a (:, 1);
450 %! f = @(x, y) x == y;
451 %!error(bsxfun (f));
452 %!error(bsxfun (f, a));
453 %!error(bsxfun (a, b));
454 %!error(bsxfun (a, b, c));
455 %!error(bsxfun (f, a, b, c));
456 %!error(bsxfun (f, ones(4, 0), ones(4, 4)))
457 %!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0, "logical"));
458 %!assert(bsxfun (f, ones(1, 4), ones(4, 1)), ones(4, 4, "logical"));
459 %!assert(bsxfun (f, a, b), a == repmat(b, 4, 1));
460 %!assert(bsxfun (f, a, c), a == repmat(c, 1, 4));
461
462 %!shared a, b, c, d, f
463 %! a = randn (4, 4, 4);
464 %! b = mean (a, 1);
465 %! c = mean (a, 2);
466 %! d = mean (a, 3);
467 %! f = @minus;
468 %!error(bsxfun (f, ones([4, 0, 4]), ones([4, 4, 4])));
469 %!assert(bsxfun (f, ones([4, 0, 4]), ones([4, 1, 4])), zeros([4, 0, 4]));
470 %!assert(bsxfun (f, ones([4, 4, 0]), ones([4, 1, 1])), zeros([4, 4, 0]));
471 %!assert(bsxfun (f, ones([1, 4, 4]), ones([4, 1, 4])), zeros([4, 4, 4]));
472 %!assert(bsxfun (f, ones([4, 4, 1]), ones([4, 1, 4])), zeros([4, 4, 4]));
473 %!assert(bsxfun (f, ones([4, 1, 4]), ones([1, 4, 4])), zeros([4, 4, 4]));
474 %!assert(bsxfun (f, ones([4, 1, 4]), ones([1, 4, 1])), zeros([4, 4, 4]));
475 %!assert(bsxfun (f, a, b), a - repmat(b, [4, 1, 1]));
476 %!assert(bsxfun (f, a, c), a - repmat(c, [1, 4, 1]));
477 %!assert(bsxfun (f, a, d), a - repmat(d, [1, 1, 4]));
478 %!assert(bsxfun ("minus", ones([4, 0, 4]), ones([4, 1, 4])), zeros([4, 0, 4]));
479
480 %% The below is a very hard case to treat
481 %!assert(bsxfun (f, ones([4, 1, 4, 1]), ones([1, 4, 1, 4])), zeros([4, 4, 4, 4]));
482
483 */