Mercurial > forge
changeset 7301:485d1594d155 octave-forge
addresses undesired side-effect off in-place sorting of data
author | schloegl |
---|---|
date | Wed, 28 Jul 2010 09:28:54 +0000 |
parents | 169f6e9d6bdd |
children | 63bc534a005b |
files | extra/NaN/inst/median.m extra/NaN/src/kth_element.cpp extra/NaN/src/sumskipnan_mex.cpp |
diffstat | 3 files changed, 101 insertions(+), 24 deletions(-) [+] |
line wrap: on
line diff
--- a/extra/NaN/inst/median.m Wed Jul 28 00:40:10 2010 +0000 +++ b/extra/NaN/inst/median.m Wed Jul 28 09:28:54 2010 +0000 @@ -63,17 +63,18 @@ t = t(~isnan(t)); n = length(t); - if flag_MexKthElement, + if n==0, + y(xo) = nan; + elseif flag_MexKthElement, + if (D1==1) t = t+0.0; end; % make sure a real copy (not just a reference to x) is used flag_KthE = 0; % fast kth_element can be used, because t does not contain any NaN and there is need to care about in-place sorting if ~rem(n,2), - y(xo) = sum( kth_element( double(t), n/2 + [0,1] ) ) / 2; + y(xo) = sum( kth_element( double(t), n/2 + [0,1], flag_KthE) ) / 2; elseif rem(n,2), - y(xo) = kth_element(double(t), (n+1)/2 ); + y(xo) = kth_element(double(t), (n+1)/2, flag_KthE); end; else t = sort(t); - if n==0, - y(xo) = nan; - elseif ~rem(n,2), + if ~rem(n,2), y(xo) = (t(n/2) + t(n/2+1)) / 2; elseif rem(n,2), y(xo) = t((n+1)/2);
--- a/extra/NaN/src/kth_element.cpp Wed Jul 28 00:40:10 2010 +0000 +++ b/extra/NaN/src/kth_element.cpp Wed Jul 28 09:28:54 2010 +0000 @@ -1,10 +1,11 @@ //------------------------------------------------------------------- // C-MEX implementation of kth element - this function is part of the NaN-toolbox. // -// usage: x = kth_element(X,k) +// usage: x = kth_element(X,k [,flag]) // returns sort(X)(k) // -// +// References: +// [1] https://secure.wikimedia.org/wikipedia/en/wiki/Selection_algorithm // // This program is free software; you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by @@ -21,9 +22,17 @@ // // // Input: -// - X data vector, must be double/real -// data might be reorded (partially sorted) in place, and NaN's are removed. -// - k which element should be selected +// X data vector, must be double/real +// k which element should be selected +// flag [optional]: +// 0: data in X might be reorded (partially sorted) in-place and +// is slightly faster because no local copy is generated +// data with NaN is not correctly handled. +// 1: data in X is never modified in-place, but a local copy is used. +// data with NaN is not correctly handled. +// 2: copies data and excludes all NaN's, the copying might be slower +// than 1, but it enables a faster selection algorithm. +// This is the save but slowest option // // Output: // x = sort(X)(k) @@ -53,7 +62,7 @@ #define SWAP(a,b) {temp = a; a=b; b=temp;} -static void findFirstK(double array[], size_t left, size_t right, size_t k) +static void findFirstK(double *array, size_t left, size_t right, size_t k) { while (right > left) { mwIndex pivotIndex = (left + right) / 2; @@ -64,7 +73,9 @@ SWAP(array[pivotIndex], array[right]); pivotIndex = left; for (mwIndex i = left; i <= right - 1; ++i ) { - if (array[i] <= pivotValue || isnan(pivotValue)) { + // if (array[i] <= pivotValue || isnan(pivotValue)) // needed if data contains NaN's + if (array[i] <= pivotValue) + { SWAP(array[i], array[pivotIndex]); ++pivotIndex; } @@ -83,23 +94,70 @@ void mexFunction(int POutputCount, mxArray* POutput[], int PInputCount, const mxArray *PInputs[]) { mwIndex k, n; // running indices - mwSize szK, szX; - double *Y,*X,*K; + mwSize szK, szX; + double *T,*X,*Y,*K; + char flag = 0; // default value // check for proper number of input and output arguments - if (PInputCount != 2) { + if ( PInputCount < 2 || PInputCount > 3 ) { mexPrintf("KTH_ELEMENT returns the K-th smallest element of vector X\n"); mexPrintf("\nusage:\tx = kth_element(X,k)\n"); - mexPrintf("\nNote, the elements in X are modified in place. Do not use kth_element directely unless you know what you do. You are warned.\n"); + mexPrintf("\nusage:\tx = kth_element(X,k,flag)\n"); + mexPrintf("\nflag=0: the elements in X can be modified in-place, and data with NaN's is not correctly handled. This can be useful for performance reasons, but it might modify data in-place and is not save for data with NaN's. You are warned.\n"); + mexPrintf("flag=1: prevents in-place modification of X using a local copy of the data, but does not handle data with NaN in the correct way.\n"); + mexPrintf("flag=2: prevents in-place modification of X using a local copy of the data and handles NaN's correctly. This is the save but slowest option.\n"); mexPrintf("\nsee also: median, quantile\n\n"); - mexErrMsgTxt("KTH_ELEMENT requires 2 input arguments\n"); - } + mexErrMsgTxt("KTH_ELEMENT requires two or three input arguments\n"); + } + else if (PInputCount == 3) { + // check value of flag + mwSize N = mxGetNumberOfElements(PInputs[2]); + if (N>1) + mexErrMsgTxt("KTH_ELEMENT: flag argument must be scalar\n"); + else if (N==1) { + switch (mxGetClassID(PInputs[2])) { + case mxLOGICAL_CLASS: + case mxCHAR_CLASS: + case mxINT8_CLASS: + case mxUINT8_CLASS: + flag = (char)*(uint8_t*)mxGetData(PInputs[2]); + break; + case mxDOUBLE_CLASS: + flag = (char)*(double*)mxGetData(PInputs[2]); + break; + case mxSINGLE_CLASS: + flag = (char)*(float*)mxGetData(PInputs[2]); + break; + case mxINT16_CLASS: + case mxUINT16_CLASS: + flag = (char)*(uint16_t*)mxGetData(PInputs[2]); + break; + case mxINT32_CLASS: + case mxUINT32_CLASS: + flag = (char)*(uint32_t*)mxGetData(PInputs[2]); + break; + case mxINT64_CLASS: + case mxUINT64_CLASS: + flag = (char)*(uint64_t*)mxGetData(PInputs[2]); + break; + case mxFUNCTION_CLASS: + case mxUNKNOWN_CLASS: + case mxCELL_CLASS: + case mxSTRUCT_CLASS: + default: + mexErrMsgTxt("KTH_ELEMENT: Type of 3rd input argument not supported."); + } + } + // else flag = default value + } + // else flag = default value + if (POutputCount > 2) - mexErrMsgTxt("KTH_ELEMENT has 1 output arguments."); + mexErrMsgTxt("KTH_ELEMENT has only one output arguments."); // get 1st argument - if (mxIsComplex(PInputs[0])) + if (mxIsComplex(PInputs[0]) || mxIsComplex(PInputs[1])) mexErrMsgTxt("complex argument not supported (yet). "); if (!mxIsDouble(PInputs[0]) || !mxIsDouble(PInputs[1])) mexErrMsgTxt("input arguments must be of type double . "); @@ -112,8 +170,23 @@ szX = mxGetNumberOfElements(PInputs[0]); X = (double*)mxGetData(PInputs[0]); + if (flag==0) + T = X; + else { + //***** create temporary copy for avoiding unintended side effects (in-place sort of input data) */ + T = (double*)mxMalloc(szX*sizeof(double)); + if (flag==1) + memcpy(T,X,szX*sizeof(double)); + else { + /* do not copy NaN's */ + for (k=0,n=0; k < szX; k++) { + if (!isnan(X[k])) T[n++]=X[k]; + } + szX = n; + } + } + /*********** create output arguments *****************/ - POutput[0] = mxCreateDoubleMatrix(mxGetM(PInputs[1]),mxGetN(PInputs[1]),mxREAL); Y = (double*) mxGetData(POutput[0]); for (k=0; k < szK; k++) { @@ -121,11 +194,13 @@ if (n >= szX || n < 0) Y[k] = 0.0/0.0; // NaN: result undefined else { - findFirstK(X, 0, szX-1, n); - Y[k] = X[n]; + findFirstK(T, 0, szX-1, n); + Y[k] = T[n]; } } + if (flag) mxFree(T); + return; }