comparison liboctave/fDiagMatrix.cc @ 7789:82be108cc558

First attempt at single precision tyeps * * * corrections to qrupdate single precision routines * * * prefer demotion to single over promotion to double * * * Add single precision support to log2 function * * * Trivial PROJECT file update * * * Cache optimized hermitian/transpose methods * * * Add tests for tranpose/hermitian and ChangeLog entry for new transpose code
author David Bateman <dbateman@free.fr>
date Sun, 27 Apr 2008 22:34:17 +0200
parents
children 4976f66d469b
comparison
equal deleted inserted replaced
7788:45f5faba05a2 7789:82be108cc558
1 // FloatDiagMatrix manipulations.
2 /*
3
4 Copyright (C) 1994, 1995, 1996, 1997, 2000, 2001, 2002, 2003, 2004,
5 2005, 2007 John W. Eaton
6
7 This file is part of Octave.
8
9 Octave is free software; you can redistribute it and/or modify it
10 under the terms of the GNU General Public License as published by the
11 Free Software Foundation; either version 3 of the License, or (at your
12 option) any later version.
13
14 Octave is distributed in the hope that it will be useful, but WITHOUT
15 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
16 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
17 for more details.
18
19 You should have received a copy of the GNU General Public License
20 along with Octave; see the file COPYING. If not, see
21 <http://www.gnu.org/licenses/>.
22
23 */
24
25 #ifdef HAVE_CONFIG_H
26 #include <config.h>
27 #endif
28
29 #include <iostream>
30
31 #include "Array-util.h"
32 #include "lo-error.h"
33 #include "mx-base.h"
34 #include "mx-inlines.cc"
35 #include "oct-cmplx.h"
36
37 // Diagonal Matrix class.
38
39 bool
40 FloatDiagMatrix::operator == (const FloatDiagMatrix& a) const
41 {
42 if (rows () != a.rows () || cols () != a.cols ())
43 return 0;
44
45 return mx_inline_equal (data (), a.data (), length ());
46 }
47
48 bool
49 FloatDiagMatrix::operator != (const FloatDiagMatrix& a) const
50 {
51 return !(*this == a);
52 }
53
54 FloatDiagMatrix&
55 FloatDiagMatrix::fill (float val)
56 {
57 for (octave_idx_type i = 0; i < length (); i++)
58 elem (i, i) = val;
59 return *this;
60 }
61
62 FloatDiagMatrix&
63 FloatDiagMatrix::fill (float val, octave_idx_type beg, octave_idx_type end)
64 {
65 if (beg < 0 || end >= length () || end < beg)
66 {
67 (*current_liboctave_error_handler) ("range error for fill");
68 return *this;
69 }
70
71 for (octave_idx_type i = beg; i <= end; i++)
72 elem (i, i) = val;
73
74 return *this;
75 }
76
77 FloatDiagMatrix&
78 FloatDiagMatrix::fill (const FloatColumnVector& a)
79 {
80 octave_idx_type len = length ();
81 if (a.length () != len)
82 {
83 (*current_liboctave_error_handler) ("range error for fill");
84 return *this;
85 }
86
87 for (octave_idx_type i = 0; i < len; i++)
88 elem (i, i) = a.elem (i);
89
90 return *this;
91 }
92
93 FloatDiagMatrix&
94 FloatDiagMatrix::fill (const FloatRowVector& a)
95 {
96 octave_idx_type len = length ();
97 if (a.length () != len)
98 {
99 (*current_liboctave_error_handler) ("range error for fill");
100 return *this;
101 }
102
103 for (octave_idx_type i = 0; i < len; i++)
104 elem (i, i) = a.elem (i);
105
106 return *this;
107 }
108
109 FloatDiagMatrix&
110 FloatDiagMatrix::fill (const FloatColumnVector& a, octave_idx_type beg)
111 {
112 octave_idx_type a_len = a.length ();
113 if (beg < 0 || beg + a_len >= length ())
114 {
115 (*current_liboctave_error_handler) ("range error for fill");
116 return *this;
117 }
118
119 for (octave_idx_type i = 0; i < a_len; i++)
120 elem (i+beg, i+beg) = a.elem (i);
121
122 return *this;
123 }
124
125 FloatDiagMatrix&
126 FloatDiagMatrix::fill (const FloatRowVector& a, octave_idx_type beg)
127 {
128 octave_idx_type a_len = a.length ();
129 if (beg < 0 || beg + a_len >= length ())
130 {
131 (*current_liboctave_error_handler) ("range error for fill");
132 return *this;
133 }
134
135 for (octave_idx_type i = 0; i < a_len; i++)
136 elem (i+beg, i+beg) = a.elem (i);
137
138 return *this;
139 }
140
141 FloatDiagMatrix
142 real (const FloatComplexDiagMatrix& a)
143 {
144 FloatDiagMatrix retval;
145 octave_idx_type a_len = a.length ();
146 if (a_len > 0)
147 retval = FloatDiagMatrix (mx_inline_real_dup (a.data (), a_len), a.rows (),
148 a.cols ());
149 return retval;
150 }
151
152 FloatDiagMatrix
153 imag (const FloatComplexDiagMatrix& a)
154 {
155 FloatDiagMatrix retval;
156 octave_idx_type a_len = a.length ();
157 if (a_len > 0)
158 retval = FloatDiagMatrix (mx_inline_imag_dup (a.data (), a_len), a.rows (),
159 a.cols ());
160 return retval;
161 }
162
163 FloatMatrix
164 FloatDiagMatrix::extract (octave_idx_type r1, octave_idx_type c1, octave_idx_type r2, octave_idx_type c2) const
165 {
166 if (r1 > r2) { octave_idx_type tmp = r1; r1 = r2; r2 = tmp; }
167 if (c1 > c2) { octave_idx_type tmp = c1; c1 = c2; c2 = tmp; }
168
169 octave_idx_type new_r = r2 - r1 + 1;
170 octave_idx_type new_c = c2 - c1 + 1;
171
172 FloatMatrix result (new_r, new_c);
173
174 for (octave_idx_type j = 0; j < new_c; j++)
175 for (octave_idx_type i = 0; i < new_r; i++)
176 result.elem (i, j) = elem (r1+i, c1+j);
177
178 return result;
179 }
180
181 // extract row or column i.
182
183 FloatRowVector
184 FloatDiagMatrix::row (octave_idx_type i) const
185 {
186 octave_idx_type r = rows ();
187 octave_idx_type c = cols ();
188 if (i < 0 || i >= r)
189 {
190 (*current_liboctave_error_handler) ("invalid row selection");
191 return FloatRowVector ();
192 }
193
194 FloatRowVector retval (c, 0.0);
195 if (r <= c || (r > c && i < c))
196 retval.elem (i) = elem (i, i);
197
198 return retval;
199 }
200
201 FloatRowVector
202 FloatDiagMatrix::row (char *s) const
203 {
204 if (! s)
205 {
206 (*current_liboctave_error_handler) ("invalid row selection");
207 return FloatRowVector ();
208 }
209
210 char c = *s;
211 if (c == 'f' || c == 'F')
212 return row (static_cast<octave_idx_type>(0));
213 else if (c == 'l' || c == 'L')
214 return row (rows () - 1);
215 else
216 {
217 (*current_liboctave_error_handler) ("invalid row selection");
218 return FloatRowVector ();
219 }
220 }
221
222 FloatColumnVector
223 FloatDiagMatrix::column (octave_idx_type i) const
224 {
225 octave_idx_type r = rows ();
226 octave_idx_type c = cols ();
227 if (i < 0 || i >= c)
228 {
229 (*current_liboctave_error_handler) ("invalid column selection");
230 return FloatColumnVector ();
231 }
232
233 FloatColumnVector retval (r, 0.0);
234 if (r >= c || (r < c && i < r))
235 retval.elem (i) = elem (i, i);
236
237 return retval;
238 }
239
240 FloatColumnVector
241 FloatDiagMatrix::column (char *s) const
242 {
243 if (! s)
244 {
245 (*current_liboctave_error_handler) ("invalid column selection");
246 return FloatColumnVector ();
247 }
248
249 char c = *s;
250 if (c == 'f' || c == 'F')
251 return column (static_cast<octave_idx_type>(0));
252 else if (c == 'l' || c == 'L')
253 return column (cols () - 1);
254 else
255 {
256 (*current_liboctave_error_handler) ("invalid column selection");
257 return FloatColumnVector ();
258 }
259 }
260
261 FloatDiagMatrix
262 FloatDiagMatrix::inverse (void) const
263 {
264 int info;
265 return inverse (info);
266 }
267
268 FloatDiagMatrix
269 FloatDiagMatrix::inverse (int &info) const
270 {
271 octave_idx_type r = rows ();
272 octave_idx_type c = cols ();
273 octave_idx_type len = length ();
274 if (r != c)
275 {
276 (*current_liboctave_error_handler) ("inverse requires square matrix");
277 return FloatDiagMatrix ();
278 }
279
280 FloatDiagMatrix retval (r, c);
281
282 info = 0;
283 for (octave_idx_type i = 0; i < len; i++)
284 {
285 if (elem (i, i) == 0.0)
286 {
287 info = -1;
288 return *this;
289 }
290 else
291 retval.elem (i, i) = 1.0 / elem (i, i);
292 }
293
294 return retval;
295 }
296
297 // diagonal matrix by diagonal matrix -> diagonal matrix operations
298
299 // diagonal matrix by diagonal matrix -> diagonal matrix operations
300
301 FloatDiagMatrix
302 operator * (const FloatDiagMatrix& a, const FloatDiagMatrix& b)
303 {
304 octave_idx_type a_nr = a.rows ();
305 octave_idx_type a_nc = a.cols ();
306
307 octave_idx_type b_nr = b.rows ();
308 octave_idx_type b_nc = b.cols ();
309
310 if (a_nc != b_nr)
311 {
312 gripe_nonconformant ("operaotr *", a_nr, a_nc, b_nr, b_nc);
313 return FloatDiagMatrix ();
314 }
315
316 if (a_nr == 0 || a_nc == 0 || b_nc == 0)
317 return FloatDiagMatrix (a_nr, a_nc, 0.0);
318
319 FloatDiagMatrix c (a_nr, b_nc);
320
321 octave_idx_type len = a_nr < b_nc ? a_nr : b_nc;
322
323 for (octave_idx_type i = 0; i < len; i++)
324 {
325 float a_element = a.elem (i, i);
326 float b_element = b.elem (i, i);
327
328 if (a_element == 0.0 || b_element == 0.0)
329 c.elem (i, i) = 0.0;
330 else if (a_element == 1.0)
331 c.elem (i, i) = b_element;
332 else if (b_element == 1.0)
333 c.elem (i, i) = a_element;
334 else
335 c.elem (i, i) = a_element * b_element;
336 }
337
338 return c;
339 }
340
341 // other operations
342
343 FloatColumnVector
344 FloatDiagMatrix::diag (octave_idx_type k) const
345 {
346 octave_idx_type nnr = rows ();
347 octave_idx_type nnc = cols ();
348
349 if (nnr == 0 || nnc == 0)
350
351 if (k > 0)
352 nnc -= k;
353 else if (k < 0)
354 nnr += k;
355
356 FloatColumnVector d;
357
358 if (nnr > 0 && nnc > 0)
359 {
360 octave_idx_type ndiag = (nnr < nnc) ? nnr : nnc;
361
362 d.resize (ndiag);
363
364 if (k > 0)
365 {
366 for (octave_idx_type i = 0; i < ndiag; i++)
367 d.elem (i) = elem (i, i+k);
368 }
369 else if ( k < 0)
370 {
371 for (octave_idx_type i = 0; i < ndiag; i++)
372 d.elem (i) = elem (i-k, i);
373 }
374 else
375 {
376 for (octave_idx_type i = 0; i < ndiag; i++)
377 d.elem (i) = elem (i, i);
378 }
379 }
380 else
381 (*current_liboctave_error_handler)
382 ("diag: requested diagonal out of range");
383
384 return d;
385 }
386
387 std::ostream&
388 operator << (std::ostream& os, const FloatDiagMatrix& a)
389 {
390 // int field_width = os.precision () + 7;
391
392 for (octave_idx_type i = 0; i < a.rows (); i++)
393 {
394 for (octave_idx_type j = 0; j < a.cols (); j++)
395 {
396 if (i == j)
397 os << " " /* setw (field_width) */ << a.elem (i, i);
398 else
399 os << " " /* setw (field_width) */ << 0.0;
400 }
401 os << "\n";
402 }
403 return os;
404 }
405
406 /*
407 ;;; Local Variables: ***
408 ;;; mode: C++ ***
409 ;;; End: ***
410 */