comparison liboctave/CmplxSVD.cc @ 15887:8ced82e96b48 stable

Fix segfaults with gesdd driver for svd (bug #37998). * liboctave/CmplxSVD.cc(init): Correctly size rwork array for gesdd driver. * liboctave/fCmplxSVD.cc(init): Correctly size rwork array for gesdd driver. * liboctave/dbleSVD.cc(init): Tweak coding style to match CmplxSVD.cc. * liboctave/floatSVD.cc(init): Tweak coding style to match fCmplxSVD.cc. * src/DLD-FUNCTIONS/svd.cc: Add %!test for gesdd driver and complex matrices.
author Rik <rik@octave.org>
date Thu, 03 Jan 2013 10:05:03 -0800
parents 72c96de7a403
children
comparison
equal deleted inserted replaced
15885:065bc7944335 15887:8ced82e96b48
116 // the singular values of [eye(3), eye(3)]. The result is 116 // the singular values of [eye(3), eye(3)]. The result is
117 // [-sqrt(2), -sqrt(2), -sqrt(2)]. 117 // [-sqrt(2), -sqrt(2), -sqrt(2)].
118 // 118 //
119 // For Lapack 3.0, this problem seems to be fixed. 119 // For Lapack 3.0, this problem seems to be fixed.
120 120
121 jobu = 'N'; 121 jobu = jobv = 'N';
122 jobv = 'N';
123 ncol_u = nrow_vt = 1; 122 ncol_u = nrow_vt = 1;
124 break; 123 break;
125 124
126 default: 125 default:
127 break; 126 break;
140 if (! (jobv == 'N' || jobv == 'O')) 139 if (! (jobv == 'N' || jobv == 'O'))
141 right_sm.resize (nrow_vt, n); 140 right_sm.resize (nrow_vt, n);
142 141
143 Complex *vt = right_sm.fortran_vec (); 142 Complex *vt = right_sm.fortran_vec ();
144 143
145 octave_idx_type lrwork = 5*max_mn; 144 // Query ZGESVD for the correct dimension of WORK.
146
147 Array<double> rwork (dim_vector (lrwork, 1));
148
149 // Ask ZGESVD what the dimension of WORK should be.
150 145
151 octave_idx_type lwork = -1; 146 octave_idx_type lwork = -1;
152 147
153 Array<Complex> work (dim_vector (1, 1)); 148 Array<Complex> work (dim_vector (1, 1));
154 149
155 octave_idx_type one = 1; 150 octave_idx_type one = 1;
156 octave_idx_type m1 = std::max (m, one), nrow_vt1 = std::max (nrow_vt, one); 151 octave_idx_type m1 = std::max (m, one);
152 octave_idx_type nrow_vt1 = std::max (nrow_vt, one);
157 153
158 if (svd_driver == SVD::GESVD) 154 if (svd_driver == SVD::GESVD)
159 { 155 {
156 octave_idx_type lrwork = 5*max_mn;
157 Array<double> rwork (dim_vector (lrwork, 1));
158
160 F77_XFCN (zgesvd, ZGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1), 159 F77_XFCN (zgesvd, ZGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
161 F77_CONST_CHAR_ARG2 (&jobv, 1), 160 F77_CONST_CHAR_ARG2 (&jobv, 1),
162 m, n, tmp_data, m1, s_vec, u, m1, vt, 161 m, n, tmp_data, m1, s_vec, u, m1, vt,
163 nrow_vt1, work.fortran_vec (), lwork, 162 nrow_vt1, work.fortran_vec (), lwork,
164 rwork.fortran_vec (), info 163 rwork.fortran_vec (), info
178 } 177 }
179 else if (svd_driver == SVD::GESDD) 178 else if (svd_driver == SVD::GESDD)
180 { 179 {
181 assert (jobu == jobv); 180 assert (jobu == jobv);
182 char jobz = jobu; 181 char jobz = jobu;
182
183 octave_idx_type lrwork;
184 if (jobz == 'N')
185 lrwork = 7*min_mn;
186 else
187 lrwork = 5*min_mn*min_mn + 5*min_mn;
188 Array<double> rwork (dim_vector (lrwork, 1));
189
183 OCTAVE_LOCAL_BUFFER (octave_idx_type, iwork, 8*min_mn); 190 OCTAVE_LOCAL_BUFFER (octave_idx_type, iwork, 8*min_mn);
184 191
185 F77_XFCN (zgesdd, ZGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1), 192 F77_XFCN (zgesdd, ZGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1),
186 m, n, tmp_data, m1, s_vec, u, m1, vt, 193 m, n, tmp_data, m1, s_vec, u, m1, vt,
187 nrow_vt1, work.fortran_vec (), lwork, 194 nrow_vt1, work.fortran_vec (), lwork,