changeset 7301:485d1594d155 octave-forge

addresses undesired side-effect off in-place sorting of data
author schloegl
date Wed, 28 Jul 2010 09:28:54 +0000
parents 169f6e9d6bdd
children 63bc534a005b
files extra/NaN/inst/median.m extra/NaN/src/kth_element.cpp extra/NaN/src/sumskipnan_mex.cpp
diffstat 3 files changed, 101 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/extra/NaN/inst/median.m	Wed Jul 28 00:40:10 2010 +0000
+++ b/extra/NaN/inst/median.m	Wed Jul 28 09:28:54 2010 +0000
@@ -63,17 +63,18 @@
         t = t(~isnan(t));
         n = length(t);
 
-        if flag_MexKthElement, 
+        if n==0,
+                y(xo) = nan;
+        elseif flag_MexKthElement, 
+	        if (D1==1) t = t+0.0; end;	% make sure a real copy (not just a reference to x) is used
        	flag_KthE = 0; % fast kth_element can be used, because t does not contain any NaN and there is need to care about in-place sorting
                 if ~rem(n,2),
-                        y(xo) = sum( kth_element( double(t), n/2 + [0,1] ) ) / 2;
+                        y(xo) = sum( kth_element( double(t), n/2 + [0,1], flag_KthE) ) / 2;
                 elseif rem(n,2),
-                        y(xo) = kth_element(double(t), (n+1)/2 );
+                        y(xo) = kth_element(double(t), (n+1)/2, flag_KthE);
                 end;
         else
                 t = sort(t);
-                if n==0,
-                        y(xo) = nan;
-                elseif ~rem(n,2),
+                if ~rem(n,2),
                         y(xo) = (t(n/2) + t(n/2+1)) / 2;
                 elseif rem(n,2),
                         y(xo) = t((n+1)/2);
--- a/extra/NaN/src/kth_element.cpp	Wed Jul 28 00:40:10 2010 +0000
+++ b/extra/NaN/src/kth_element.cpp	Wed Jul 28 09:28:54 2010 +0000
@@ -1,10 +1,11 @@
 //-------------------------------------------------------------------
 //   C-MEX implementation of kth element - this function is part of the NaN-toolbox. 
 //
-//   usage: x = kth_element(X,k)
+//   usage: x = kth_element(X,k [,flag])
 //          returns sort(X)(k)
 //
-//   
+//   References: 
+//   [1] https://secure.wikimedia.org/wikipedia/en/wiki/Selection_algorithm
 //
 //   This program is free software; you can redistribute it and/or modify
 //   it under the terms of the GNU General Public License as published by
@@ -21,9 +22,17 @@
 //
 //
 // Input:
-// - X  data vector, must be double/real 
-//        data might be reorded (partially sorted) in place, and NaN's are removed. 
-// - k  which element should be selected
+//   X    data vector, must be double/real 
+//   k    which element should be selected
+//   flag [optional]: 
+//	 0: data in X might be reorded (partially sorted) in-place and 
+//	    is slightly faster because no local copy is generated
+//	    data with NaN is not correctly handled. 	
+//       1: data in X is never modified in-place, but a local copy is used.    
+//	    data with NaN is not correctly handled. 	
+//	 2: copies data and excludes all NaN's, the copying might be slower 
+//	    than 1, but it enables a faster selection algorithm. 
+//	    This is the save but slowest option	
 //
 // Output:
 //      x = sort(X)(k)  
@@ -53,7 +62,7 @@
 
 #define SWAP(a,b) {temp = a; a=b; b=temp;} 
  
-static void findFirstK(double array[], size_t left, size_t right, size_t k)
+static void findFirstK(double *array, size_t left, size_t right, size_t k)
 {
         while (right > left) {
                 mwIndex pivotIndex = (left + right) / 2;
@@ -64,7 +73,9 @@
         	SWAP(array[pivotIndex], array[right]);
         	pivotIndex = left;
 	        for (mwIndex i = left; i <= right - 1; ++i ) {
-        	        if (array[i] <= pivotValue || isnan(pivotValue)) {
+        	        // if (array[i] <= pivotValue || isnan(pivotValue)) // needed if data contains NaN's 
+        	        if (array[i] <= pivotValue) 
+        	        {
         	        	SWAP(array[i], array[pivotIndex]);
         	                ++pivotIndex;
                 	}
@@ -83,23 +94,70 @@
 void mexFunction(int POutputCount,  mxArray* POutput[], int PInputCount, const mxArray *PInputs[]) 
 {
     	mwIndex k, n;	// running indices 
-    	mwSize  szK, szX; 
-    	double 	*Y,*X,*K; 
+    	mwSize  szK, szX;
+    	double 	*T,*X,*Y,*K; 
+    	char flag = 0; // default value 
 
 	// check for proper number of input and output arguments
-	if (PInputCount != 2) {
+	if ( PInputCount < 2 || PInputCount > 3 ) {
 		mexPrintf("KTH_ELEMENT returns the K-th smallest element of vector X\n");
 		mexPrintf("\nusage:\tx = kth_element(X,k)\n");
-		mexPrintf("\nNote, the elements in X are modified in place. Do not use kth_element directely unless you know what you do. You are warned.\n");
+		mexPrintf("\nusage:\tx = kth_element(X,k,flag)\n");
+		mexPrintf("\nflag=0: the elements in X can be modified in-place, and data with NaN's is not correctly handled. This can be useful for performance reasons, but it might modify data in-place and is not save for data with NaN's. You are warned.\n");
+		mexPrintf("flag=1: prevents in-place modification of X using a local copy of the data, but does not handle data with NaN in the correct way.\n");
+		mexPrintf("flag=2: prevents in-place modification of X using a local copy of the data and handles NaN's correctly. This is the save but slowest option.\n");
 		
 	    	mexPrintf("\nsee also: median, quantile\n\n");
-	        mexErrMsgTxt("KTH_ELEMENT requires 2 input arguments\n");
-	}        
+	        mexErrMsgTxt("KTH_ELEMENT requires two or three input arguments\n");
+	}  
+	else if (PInputCount == 3) {
+		// check value of flag
+		mwSize N = mxGetNumberOfElements(PInputs[2]); 
+		if (N>1)
+		        mexErrMsgTxt("KTH_ELEMENT: flag argument must be scalar\n");
+		else if (N==1) {
+	    		switch (mxGetClassID(PInputs[2])) {
+    			case mxLOGICAL_CLASS:
+    			case mxCHAR_CLASS:
+	    		case mxINT8_CLASS:
+    			case mxUINT8_CLASS:
+    				flag = (char)*(uint8_t*)mxGetData(PInputs[2]);
+    				break; 
+	    		case mxDOUBLE_CLASS:
+    				flag = (char)*(double*)mxGetData(PInputs[2]);
+    				break; 
+	    		case mxSINGLE_CLASS:
+    				flag = (char)*(float*)mxGetData(PInputs[2]);
+    				break; 
+	    		case mxINT16_CLASS:
+    			case mxUINT16_CLASS:
+    				flag = (char)*(uint16_t*)mxGetData(PInputs[2]);
+    				break; 
+	    		case mxINT32_CLASS:
+    			case mxUINT32_CLASS:
+    				flag = (char)*(uint32_t*)mxGetData(PInputs[2]);
+    				break; 
+	    		case mxINT64_CLASS:
+    			case mxUINT64_CLASS:
+    				flag = (char)*(uint64_t*)mxGetData(PInputs[2]);
+    				break; 
+	    		case mxFUNCTION_CLASS:
+    			case mxUNKNOWN_CLASS:
+    			case mxCELL_CLASS:
+	    		case mxSTRUCT_CLASS:
+    			default: 
+    				mexErrMsgTxt("KTH_ELEMENT: Type of 3rd input argument not supported.");
+			}
+		}
+		// else flag = default value 
+	}      
+	// else flag = default value 
+
 	if (POutputCount > 2)
-	        mexErrMsgTxt("KTH_ELEMENT has 1 output arguments.");
+	        mexErrMsgTxt("KTH_ELEMENT has only one output arguments.");
 
 	// get 1st argument
-	if (mxIsComplex(PInputs[0]))
+	if (mxIsComplex(PInputs[0]) || mxIsComplex(PInputs[1]))
 		mexErrMsgTxt("complex argument not supported (yet). ");
 	if (!mxIsDouble(PInputs[0]) || !mxIsDouble(PInputs[1]))
 		mexErrMsgTxt("input arguments must be of type double . ");
@@ -112,8 +170,23 @@
 	szX = mxGetNumberOfElements(PInputs[0]);
 	X = (double*)mxGetData(PInputs[0]);
 
+	if (flag==0)
+		T = X;
+	else {
+		//***** create temporary copy for avoiding unintended side effects (in-place sort of input data) */ 
+		T = (double*)mxMalloc(szX*sizeof(double)); 
+		if (flag==1)
+			memcpy(T,X,szX*sizeof(double));
+		else  {
+			/* do not copy NaN's */
+			for (k=0,n=0; k < szX; k++) {
+				if (!isnan(X[k])) T[n++]=X[k]; 
+			}
+			szX = n; 
+		}	
+	}
+	
 	/*********** create output arguments *****************/
-
 	POutput[0] = mxCreateDoubleMatrix(mxGetM(PInputs[1]),mxGetN(PInputs[1]),mxREAL);
 	Y = (double*) mxGetData(POutput[0]);
 	for (k=0; k < szK; k++) {
@@ -121,11 +194,13 @@
 		if (n >= szX || n < 0)
 			Y[k] = 0.0/0.0;	// NaN: result undefined
 		else {
-        		findFirstK(X, 0, szX-1, n);
-        		Y[k] = X[n];
+        		findFirstK(T, 0, szX-1, n);
+        		Y[k] = T[n];
 		}	
 	}
 
+	if (flag) mxFree(T);
+	
 	return; 
 }
 
--- a/extra/NaN/src/sumskipnan_mex.cpp	Wed Jul 28 00:40:10 2010 +0000
+++ b/extra/NaN/src/sumskipnan_mex.cpp	Wed Jul 28 09:28:54 2010 +0000
@@ -448,6 +448,7 @@
     		case mxUNKNOWN_CLASS:
     		case mxCELL_CLASS:
     		case mxSTRUCT_CLASS:
+    		default: 
     			mexPrintf("Type of 3rd input argument not supported.");
 		}
 	}