changeset 12592:d03ad555e14e octave-forge

[nan] fix liblinear - use train/predict from liblinear/matlab
author schloegl
date Sun, 12 Apr 2015 19:00:36 +0000
parents 3aeba3530595
children 0605cb0434ff
files extra/NaN/src/predict.c extra/NaN/src/train.c
diffstat 2 files changed, 472 insertions(+), 369 deletions(-) [+]
line wrap: on
line diff
--- a/extra/NaN/src/predict.c	Sun Apr 12 17:52:15 2015 +0000
+++ b/extra/NaN/src/predict.c	Sun Apr 12 19:00:36 2015 +0000
@@ -24,220 +24,305 @@
 */
 
 #include <stdio.h>
-#include <ctype.h>
 #include <stdlib.h>
 #include <string.h>
-#include <errno.h>
 #include "linear.h"
 
-struct feature_node *x;
-int max_nr_attr = 64;
+#include "mex.h"
+#include "linear_model_matlab.h"
+
+#ifdef tmwtypes_h
+  #if (MX_API_VER<=0x07020000)
+    typedef int mwSize;
+  #endif 
+#endif 
+
+
+#define CMD_LEN 2048
+
+#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
+
+int col_format_flag;
 
-struct model* model_;
-int flag_predict_probability=0;
+void read_sparse_instance(const mxArray *prhs, int index, struct feature_node *x, int feature_number, double bias)
+{
+	int i, j, low, high;
+	mwIndex *ir, *jc;
+	double *samples;
+
+	ir = mxGetIr(prhs);
+	jc = mxGetJc(prhs);
+	samples = mxGetPr(prhs);
 
-void exit_input_error(int line_num)
-{
-	fprintf(stderr,"Wrong input format at line %d\n", line_num);
-	exit(1);
+	// each column is one instance
+	j = 0;
+	low = (int) jc[index], high = (int) jc[index+1];
+	for(i=low; i<high && (int) (ir[i])<feature_number; i++)
+	{
+		x[j].index = (int) ir[i]+1;
+		x[j].value = samples[i];
+		j++;
+	}
+	if(bias>=0)
+	{
+		x[j].index = feature_number+1;
+		x[j].value = bias;
+		j++;
+	}
+	x[j].index = -1;
 }
 
-static char *line = NULL;
-static int max_line_len;
-
-static char* readline(FILE *input)
+static void fake_answer(mxArray *plhs[])
 {
-	int len;
-	
-	if(fgets(line,max_line_len,input) == NULL)
-		return NULL;
-
-	while(strrchr(line,'\n') == NULL)
-	{
-		max_line_len *= 2;
-		line = (char *) realloc(line,max_line_len);
-		len = (int) strlen(line);
-		if(fgets(line+len,max_line_len-len,input) == NULL)
-			break;
-	}
-	return line;
+	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
+	plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
+	plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
 }
 
-void do_predict(FILE *input, FILE *output, struct model* model_)
+void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, const int predict_probability_flag)
 {
+	int label_vector_row_num, label_vector_col_num;
+	int feature_number, testing_instance_number;
+	int instance_index;
+	double *ptr_instance, *ptr_label, *ptr_predict_label;
+	double *ptr_prob_estimates, *ptr_dec_values, *ptr;
+	struct feature_node *x;
+	mxArray *pplhs[1]; // instance sparse matrix in row format
+
 	int correct = 0;
 	int total = 0;
 
 	int nr_class=get_nr_class(model_);
+	int nr_w;
 	double *prob_estimates=NULL;
-	int j, n;
-	int nr_feature=get_nr_feature(model_);
-	if(model_->bias>=0)
-		n=nr_feature+1;
+
+	if(nr_class==2 && model_->param.solver_type!=MCSVM_CS)
+		nr_w=1;
 	else
-		n=nr_feature;
-
-	if(flag_predict_probability)
-	{
-		int *labels;
+		nr_w=nr_class;
 
-		if(!check_probability_model(model_))
-		{
-			fprintf(stderr, "probability output is only supported for logistic regression\n");
-			exit(1);
-		}
-
-		labels=(int *) malloc(nr_class*sizeof(int));
-		get_labels(model_,labels);
-		prob_estimates = (double *) malloc(nr_class*sizeof(double));
-		fprintf(output,"labels");		
-		for(j=0;j<nr_class;j++)
-			fprintf(output," %d",labels[j]);
-		fprintf(output,"\n");
-		free(labels);
+	// prhs[1] = testing instance matrix
+	feature_number = get_nr_feature(model_);
+	testing_instance_number = (int) mxGetM(prhs[1]);
+	if(col_format_flag)
+	{
+		feature_number = (int) mxGetM(prhs[1]);
+		testing_instance_number = (int) mxGetN(prhs[1]);
 	}
 
-	max_line_len = 1024;
-	line = (char *)malloc(max_line_len*sizeof(char));
-	while(readline(input) != NULL)
-	{
-		int i = 0;
-		int target_label, predict_label;
-		char *idx, *val, *label, *endptr;
-		int inst_max_index = 0; // strtol gives 0 if wrong format
-
-		label = strtok(line," \t\n");
-		if(label == NULL) // empty line
-			exit_input_error(total+1);
-
-		target_label = (int) strtol(label,&endptr,10);
-		if(endptr == label || *endptr != '\0')
-			exit_input_error(total+1);
-
-		while(1)
-		{
-			if(i>=max_nr_attr-2)	// need one more for index = -1
-			{
-				max_nr_attr *= 2;
-				x = (struct feature_node *) realloc(x,max_nr_attr*sizeof(struct feature_node));
-			}
-
-			idx = strtok(NULL,":");
-			val = strtok(NULL," \t");
+	label_vector_row_num = (int) mxGetM(prhs[0]);
+	label_vector_col_num = (int) mxGetN(prhs[0]);
 
-			if(val == NULL)
-				break;
-			errno = 0;
-			x[i].index = (int) strtol(idx,&endptr,10);
-			if(endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index)
-				exit_input_error(total+1);
-			else
-				inst_max_index = x[i].index;
-
-			errno = 0;
-			x[i].value = strtod(val,&endptr);
-			if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
-				exit_input_error(total+1);
+	if(label_vector_row_num!=testing_instance_number)
+	{
+		mexPrintf("Length of label vector does not match # of instances.\n");
+		fake_answer(plhs);
+		return;
+	}
+	if(label_vector_col_num!=1)
+	{
+		mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
+		fake_answer(plhs);
+		return;
+	}
 
-			// feature indices larger than those in training are not used
-			if(x[i].index <= nr_feature)
-				++i;
-		}
+	ptr_instance = mxGetPr(prhs[1]);
+	ptr_label    = mxGetPr(prhs[0]);
 
-		if(model_->bias>=0)
+	// transpose instance matrix
+	if(mxIsSparse(prhs[1]))
+	{
+		if(col_format_flag)
 		{
-			x[i].index = n;
-			x[i].value = model_->bias;
-			i++;
-		}
-		x[i].index = -1;
-
-		if(flag_predict_probability)
-		{
-			int j;
-			predict_label = predict_probability(model_,x,prob_estimates);
-			fprintf(output,"%d",predict_label);
-			for(j=0;j<model_->nr_class;j++)
-				fprintf(output," %g",prob_estimates[j]);
-			fprintf(output,"\n");
+			pplhs[0] = (mxArray *)prhs[1];
 		}
 		else
 		{
-			predict_label = predict(model_,x);
-			fprintf(output,"%d\n",predict_label);
+			mxArray *pprhs[1];
+			pprhs[0] = mxDuplicateArray(prhs[1]);
+			if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
+			{
+				mexPrintf("Error: cannot transpose testing instance matrix\n");
+				fake_answer(plhs);
+				return;
+			}
+		}
+	}
+	else
+		mexPrintf("Testing_instance_matrix must be sparse\n");
+
+
+	prob_estimates = Malloc(double, nr_class);
+
+	plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
+	if(predict_probability_flag)
+		plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
+	else
+		plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_w, mxREAL);
+
+	ptr_predict_label = mxGetPr(plhs[0]);
+	ptr_prob_estimates = mxGetPr(plhs[2]);
+	ptr_dec_values = mxGetPr(plhs[2]);
+	x = Malloc(struct feature_node, feature_number+2);
+	for(instance_index=0;instance_index<testing_instance_number;instance_index++)
+	{
+		int i;
+		double target,v;
+
+		target = ptr_label[instance_index];
+
+		// prhs[1] and prhs[1]^T are sparse
+		read_sparse_instance(pplhs[0], instance_index, x, feature_number, model_->bias);
+
+		if(predict_probability_flag)
+		{
+			v = predict_probability(model_, x, prob_estimates);
+			ptr_predict_label[instance_index] = v;
+			for(i=0;i<nr_class;i++)
+				ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
+		}
+		else
+		{
+			double *dec_values = Malloc(double, nr_class);
+			v = predict(model_, x);
+			ptr_predict_label[instance_index] = v;
+
+			predict_values(model_, x, dec_values);
+			for(i=0;i<nr_w;i++)
+				ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
+			free(dec_values);
 		}
 
-		if(predict_label == target_label)
+		if(v == target)
 			++correct;
 		++total;
 	}
-	printf("Accuracy = %g%% (%d/%d)\n",(double) correct/total*100,correct,total);
-	if(flag_predict_probability)
+	mexPrintf("Accuracy = %g%% (%d/%d)\n", (double) correct/total*100,correct,total);
+
+	// return accuracy, mean squared error, squared correlation coefficient
+	plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL);
+	ptr = mxGetPr(plhs[1]);
+	ptr[0] = (double) correct/total*100;
+
+	free(x);
+	if(prob_estimates != NULL)
 		free(prob_estimates);
 }
 
 void exit_with_help()
 {
-	printf(
-	"Usage: predict [options] test_file model_file output_file\n"
-	"options:\n"
-	"-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0)\n"
-	);
-	exit(1);
+	mexPrintf(
+			"Usage: [predicted_label, accuracy, decision_values/prob_estimates] = predict(testing_label_vector, testing_instance_matrix, model, 'liblinear_options','col')\n"
+			"liblinear_options:\n"
+			"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0)\n"
+			"col:\n"
+			"	if 'col' is setted testing_instance_matrix is parsed in column format, otherwise is in row format"
+			);
 }
 
-int main(int argc, char **argv)
+void mexFunction( int nlhs, mxArray *plhs[],
+		int nrhs, const mxArray *prhs[] )
 {
-	FILE *input, *output;
-	int i;
+	int prob_estimate_flag = 0;
+	struct model *model_;
+	char cmd[CMD_LEN];
+	col_format_flag = 0;
 
-	// parse options
-	for(i=1;i<argc;i++)
+	if(nrhs > 5 || nrhs < 3)
 	{
-		if(argv[i][0] != '-') break;
-		++i;
-		switch(argv[i-1][1])
-		{
-			case 'b':
-				flag_predict_probability = atoi(argv[i]);
-				break;
-
-			default:
-				fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]);
-				exit_with_help();
-				break;
+		exit_with_help();
+		fake_answer(plhs);
+		return;
+	}
+	if(nrhs == 5)
+	{
+		mxGetString(prhs[4], cmd, mxGetN(prhs[4])+1);
+		if(strcmp(cmd, "col") == 0)
+		{			
+			col_format_flag = 1;
 		}
 	}
-	if(i>=argc)
-		exit_with_help();
 
-	input = fopen(argv[i],"r");
-	if(input == NULL)
-	{
-		fprintf(stderr,"can't open input file %s\n",argv[i]);
-		exit(1);
-	}
-
-	output = fopen(argv[i+2],"w");
-	if(output == NULL)
-	{
-		fprintf(stderr,"can't open output file %s\n",argv[i+2]);
-		exit(1);
+	if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
+		mexPrintf("Error: label vector and instance matrix must be double\n");
+		fake_answer(plhs);
+		return;
 	}
 
-	if((model_=load_model(argv[i+1]))==0)
+	if(mxIsStruct(prhs[2]))
 	{
-		fprintf(stderr,"can't open model file %s\n",argv[i+1]);
-		exit(1);
+		const char *error_msg;
+
+		// parse options
+		if(nrhs>=4)
+		{
+			int i, argc = 1;
+			char *argv[CMD_LEN/2];
+
+			// put options in argv[]
+			mxGetString(prhs[3], cmd,  mxGetN(prhs[3]) + 1);
+			if((argv[argc] = strtok(cmd, " ")) != NULL)
+				while((argv[++argc] = strtok(NULL, " ")) != NULL)
+					;
+
+			for(i=1;i<argc;i++)
+			{
+				if(argv[i][0] != '-') break;
+				if(++i>=argc)
+				{
+					exit_with_help();
+					fake_answer(plhs);
+					return;
+				}
+				switch(argv[i-1][1])
+				{
+					case 'b':
+						prob_estimate_flag = atoi(argv[i]);
+						break;
+					default:
+						mexPrintf("unknown option\n");
+						exit_with_help();
+						fake_answer(plhs);
+						return;
+				}
+			}
+		}
+
+		model_ = Malloc(struct model, 1);
+		error_msg = matlab_matrix_to_model(model_, prhs[2]);
+		if(error_msg)
+		{
+			mexPrintf("Error: can't read model: %s\n", error_msg);
+			free_and_destroy_model(&model_);
+			fake_answer(plhs);
+			return;
+		}
+
+		if(prob_estimate_flag)
+		{
+			if(!check_probability_model(model_))
+			{
+				mexPrintf("probability output is only supported for logistic regression\n");
+				prob_estimate_flag=0;
+			}
+		}
+
+		if(mxIsSparse(prhs[1]))
+			do_predict(plhs, prhs, model_, prob_estimate_flag);
+		else
+		{
+			mexPrintf("Testing_instance_matrix must be sparse\n");
+			fake_answer(plhs);
+		}
+
+		// destroy model_
+		free_and_destroy_model(&model_);
+	}
+	else
+	{
+		mexPrintf("model file should be a struct array\n");
+		fake_answer(plhs);
 	}
 
-	x = (struct feature_node *) malloc(max_nr_attr*sizeof(struct feature_node));
-	do_predict(input, output, model_);
-	free_and_destroy_model(&model_);
-	free(line);
-	free(x);
-	fclose(input);
-	fclose(output);
-	return 0;
+	return;
 }
-
--- a/extra/NaN/src/train.c	Sun Apr 12 17:52:15 2015 +0000
+++ b/extra/NaN/src/train.c	Sun Apr 12 19:00:36 2015 +0000
@@ -28,18 +28,29 @@
 #include <stdlib.h>
 #include <string.h>
 #include <ctype.h>
-#include <errno.h>
 #include "linear.h"
+
+#include "mex.h"
+#include "linear_model_matlab.h"
+
+#ifdef tmwtypes_h
+  #if (MX_API_VER<=0x07020000)
+    typedef int mwIndex;
+  #endif 
+#endif 
+
+#define CMD_LEN 2048
 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
 #define INF HUGE_VAL
 
 void print_null(const char *s) {}
+void print_string_matlab(const char *s) {mexPrintf(s);}
 
 void exit_with_help()
 {
-	printf(
-	"Usage: train [options] training_set_file [model_file]\n"
-	"options:\n"
+	mexPrintf(
+	"Usage: model = train(training_label_vector, training_instance_matrix, 'liblinear_options', 'col');\n"
+	"liblinear_options:\n"
 	"-s type : set type of solver (default 1)\n"
 	"	0 -- L2-regularized logistic regression (primal)\n"
 	"	1 -- L2-regularized L2-loss support vector classification (dual)\n"	
@@ -64,108 +75,47 @@
 	"-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
 	"-v n: n-fold cross validation mode\n"
 	"-q : quiet mode (no outputs)\n"
+	"col:\n"
+	"	if 'col' is setted, training_instance_matrix is parsed in column format, otherwise is in row format\n"
 	);
-	exit(1);
-}
-
-void exit_input_error(int line_num)
-{
-	fprintf(stderr,"Wrong input format at line %d\n", line_num);
-	exit(1);
 }
 
-static char *line = NULL;
-static int max_line_len;
-
-static char* readline(FILE *input)
-{
-	int len;
-	
-	if(fgets(line,max_line_len,input) == NULL)
-		return NULL;
-
-	while(strrchr(line,'\n') == NULL)
-	{
-		max_line_len *= 2;
-		line = (char *) realloc(line,max_line_len);
-		len = (int) strlen(line);
-		if(fgets(line+len,max_line_len-len,input) == NULL)
-			break;
-	}
-	return line;
-}
-
-void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
-void read_problem(const char *filename);
-void do_cross_validation();
-
+// liblinear arguments
+struct parameter param;		// set by parse_command_line
+struct problem prob;		// set by read_problem
+struct model *model_;
 struct feature_node *x_space;
-struct parameter param;
-struct problem prob;
-struct model* model_;
-int flag_cross_validation;
+int cross_validation_flag;
+int col_format_flag;
 int nr_fold;
 double bias;
 
-int main(int argc, char **argv)
-{
-	char input_file_name[1024];
-	char model_file_name[1024];
-	const char *error_msg;
-
-	parse_command_line(argc, argv, input_file_name, model_file_name);
-	read_problem(input_file_name);
-	error_msg = check_parameter(&prob,&param);
-
-	if(error_msg)
-	{
-		fprintf(stderr,"Error: %s\n",error_msg);
-		exit(1);
-	}
-
-	if(flag_cross_validation)
-	{
-		do_cross_validation();
-	}
-	else
-	{
-		model_=train(&prob, &param);
-		if(save_model(model_file_name, model_))
-		{
-			fprintf(stderr,"can't save model to file %s\n",model_file_name);
-			exit(1);
-		}
-		free_and_destroy_model(&model_);
-	}
-	destroy_param(&param);
-	free(prob.y);
-	free(prob.x);
-	free(x_space);
-	free(line);
-
-	return 0;
-}
-
-void do_cross_validation()
+double do_cross_validation()
 {
 	int i;
 	int total_correct = 0;
-	int *target = Malloc(int, prob.l);
+	int *target = Malloc(int,prob.l);
+	double retval = 0.0;
 
 	cross_validation(&prob,&param,nr_fold,target);
 
 	for(i=0;i<prob.l;i++)
 		if(target[i] == prob.y[i])
 			++total_correct;
-	printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
+	mexPrintf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
+	retval = 100.0*total_correct/prob.l;
 
 	free(target);
+	return retval;
 }
 
-void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
+// nrhs should be 3
+int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
 {
-	int i;
-	void (*print_func)(const char*) = NULL;	// default printing to stdout
+	int i, argc = 1;
+	char cmd[CMD_LEN];
+	char *argv[CMD_LEN/2];
+	void (*print_func)(const char *) = print_string_matlab;	// default printing to matlab display
 
 	// default values
 	param.solver_type = L2R_L2LOSS_SVC_DUAL;
@@ -174,33 +124,60 @@
 	param.nr_weight = 0;
 	param.weight_label = NULL;
 	param.weight = NULL;
-	flag_cross_validation = 0;
+	cross_validation_flag = 0;
+	col_format_flag = 0;
 	bias = -1;
 
+
+	if(nrhs <= 1)
+		return 1;
+
+	if(nrhs == 4)
+	{
+		mxGetString(prhs[3], cmd, mxGetN(prhs[3])+1);
+		if(strcmp(cmd, "col") == 0)
+			col_format_flag = 1;
+	}
+
+	// put options in argv[]
+	if(nrhs > 2)
+	{
+		mxGetString(prhs[2], cmd,  mxGetN(prhs[2]) + 1);
+		if((argv[argc] = strtok(cmd, " ")) != NULL)
+			while((argv[++argc] = strtok(NULL, " ")) != NULL)
+				;
+	}
+
 	// parse options
 	for(i=1;i<argc;i++)
 	{
 		if(argv[i][0] != '-') break;
-		if(++i>=argc)
-			exit_with_help();
+		++i;
+		if(i>=argc && argv[i-1][1] != 'q') // since option -q has no parameter
+			return 1;
 		switch(argv[i-1][1])
 		{
 			case 's':
 				param.solver_type = atoi(argv[i]);
 				break;
-
 			case 'c':
 				param.C = atof(argv[i]);
 				break;
-
 			case 'e':
 				param.eps = atof(argv[i]);
 				break;
-
 			case 'B':
 				bias = atof(argv[i]);
 				break;
-
+			case 'v':
+				cross_validation_flag = 1;
+				nr_fold = atoi(argv[i]);
+				if(nr_fold < 2)
+				{
+					mexPrintf("n-fold cross validation: n must >= 2\n");
+					return 1;
+				}
+				break;
 			case 'w':
 				++param.nr_weight;
 				param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight);
@@ -208,49 +185,18 @@
 				param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
 				param.weight[param.nr_weight-1] = atof(argv[i]);
 				break;
-
-			case 'v':
-				flag_cross_validation = 1;
-				nr_fold = atoi(argv[i]);
-				if(nr_fold < 2)
-				{
-					fprintf(stderr,"n-fold cross validation: n must >= 2\n");
-					exit_with_help();
-				}
-				break;
-
 			case 'q':
 				print_func = &print_null;
 				i--;
 				break;
-
 			default:
-				fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]);
-				exit_with_help();
-				break;
+				mexPrintf("unknown option\n");
+				return 1;
 		}
 	}
 
 	set_print_string_function(print_func);
 
-	// determine filenames
-	if(i>=argc)
-		exit_with_help();
-
-	strcpy(input_file_name, argv[i]);
-
-	if(i<argc-1)
-		strcpy(model_file_name,argv[i+1]);
-	else
-	{
-		char *p = strrchr(argv[i],'/');
-		if(p==NULL)
-			p = argv[i];
-		else
-			++p;
-		sprintf(model_file_name,"%s.model",p);
-	}
-
 	if(param.eps == INF)
 	{
 		if(param.solver_type == L2R_LR || param.solver_type == L2R_L2LOSS_SVC)
@@ -260,106 +206,178 @@
 		else if(param.solver_type == L1R_L2LOSS_SVC || param.solver_type == L1R_LR)
 			param.eps = 0.01;
 	}
+	return 0;
+}
+
+static void fake_answer(mxArray *plhs[])
+{
+	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
 }
 
-// read in a problem (in libsvm format)
-void read_problem(const char *filename)
+int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
 {
-	int max_index, inst_max_index, i;
-	long int elements, j;
-	FILE *fp = fopen(filename,"r");
-	char *endptr;
-	char *idx, *val, *label;
+	int i, j, k, low, high;
+	mwIndex *ir, *jc;
+	int elements, max_index, num_samples, label_vector_row_num;
+	double *samples, *labels;
+	mxArray *instance_mat_col; // instance sparse matrix in column format
+
+	prob.x = NULL;
+	prob.y = NULL;
+	x_space = NULL;
 
-	if(fp == NULL)
+	if(col_format_flag)
+		instance_mat_col = (mxArray *)instance_mat;
+	else
 	{
-		fprintf(stderr,"can't open input file %s\n",filename);
-		exit(1);
+		// transpose instance matrix
+		mxArray *prhs[1], *plhs[1];
+		prhs[0] = mxDuplicateArray(instance_mat);
+		if(mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
+		{
+			mexPrintf("Error: cannot transpose training instance matrix\n");
+			return -1;
+		}
+		instance_mat_col = plhs[0];
+		mxDestroyArray(prhs[0]);
 	}
 
-	prob.l = 0;
-	elements = 0;
-	max_line_len = 1024;
-	line = Malloc(char,max_line_len);
-	while(readline(fp)!=NULL)
-	{
-		char *p = strtok(line," \t"); // label
+	// the number of instance
+	prob.l = (int) mxGetN(instance_mat_col);
+	label_vector_row_num = (int) mxGetM(label_vec);
 
-		// features
-		while(1)
-		{
-			p = strtok(NULL," \t");
-			if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
-				break;
-			elements++;
-		}
-		elements++; // for bias term
-		prob.l++;
+	if(label_vector_row_num!=prob.l)
+	{
+		mexPrintf("Length of label vector does not match # of instances.\n");
+		return -1;
 	}
-	rewind(fp);
+	
+	// each column is one instance
+	labels = mxGetPr(label_vec);
+	samples = mxGetPr(instance_mat_col);
+	ir = mxGetIr(instance_mat_col);
+	jc = mxGetJc(instance_mat_col);
+
+	num_samples = (int) mxGetNzmax(instance_mat_col);
+
+	elements = num_samples + prob.l*2;
+	max_index = (int) mxGetM(instance_mat_col);
+
+	prob.y = Malloc(int, prob.l);
+	prob.x = Malloc(struct feature_node*, prob.l);
+	x_space = Malloc(struct feature_node, elements);
 
 	prob.bias=bias;
 
-	prob.y = Malloc(int,prob.l);
-	prob.x = Malloc(struct feature_node *,prob.l);
-	x_space = Malloc(struct feature_node,elements+prob.l);
-
-	max_index = 0;
-	j=0;
+	j = 0;
 	for(i=0;i<prob.l;i++)
 	{
-		inst_max_index = 0; // strtol gives 0 if wrong format
-		readline(fp);
 		prob.x[i] = &x_space[j];
-		label = strtok(line," \t\n");
-		if(label == NULL) // empty line
-			exit_input_error(i+1);
-
-		prob.y[i] = (int) strtol(label,&endptr,10);
-		if(endptr == label || *endptr != '\0')
-			exit_input_error(i+1);
-
-		while(1)
+		prob.y[i] = (int) labels[i];
+		low = (int) jc[i], high = (int) jc[i+1];
+		for(k=low;k<high;k++)
 		{
-			idx = strtok(NULL,":");
-			val = strtok(NULL," \t");
-
-			if(val == NULL)
-				break;
-
-			errno = 0;
-			x_space[j].index = (int) strtol(idx,&endptr,10);
-			if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
-				exit_input_error(i+1);
-			else
-				inst_max_index = x_space[j].index;
-
-			errno = 0;
-			x_space[j].value = strtod(val,&endptr);
-			if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
-				exit_input_error(i+1);
-
-			++j;
+			x_space[j].index = (int) ir[k]+1;
+			x_space[j].value = samples[k];
+			j++;
+	 	}
+		if(prob.bias>=0)
+		{
+			x_space[j].index = max_index+1;
+			x_space[j].value = prob.bias;
+			j++;
 		}
-
-		if(inst_max_index > max_index)
-			max_index = inst_max_index;
-
-		if(prob.bias >= 0)
-			x_space[j++].value = prob.bias;
-
 		x_space[j++].index = -1;
 	}
 
-	if(prob.bias >= 0)
+	if(prob.bias>=0)
+		prob.n = max_index+1;
+	else
+		prob.n = max_index;
+
+	return 0;
+}
+
+// Interface function of matlab
+// now assume prhs[0]: label prhs[1]: features
+void mexFunction( int nlhs, mxArray *plhs[],
+		int nrhs, const mxArray *prhs[] )
+{
+	const char *error_msg;
+	// fix random seed to have same results for each run
+	// (for cross validation)
+	srand(1);
+
+	// Transform the input Matrix to libsvm format
+	if(nrhs > 1 && nrhs < 5)
 	{
-		prob.n=max_index+1;
-		for(i=1;i<prob.l;i++)
-			(prob.x[i]-2)->index = prob.n; 
-		x_space[j-2].index = prob.n;
+		int err=0;
+
+		if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
+			mexPrintf("Error: label vector and instance matrix must be double\n");
+			fake_answer(plhs);
+			return;
+		}
+
+		if(parse_command_line(nrhs, prhs, NULL))
+		{
+			exit_with_help();
+			destroy_param(&param);
+			fake_answer(plhs);
+			return;
+		}
+
+		if(mxIsSparse(prhs[1]))
+			err = read_problem_sparse(prhs[0], prhs[1]);
+		else
+		{
+			mexPrintf("Training_instance_matrix must be sparse\n");
+			destroy_param(&param);
+			fake_answer(plhs);
+			return;
+		}
+
+		// train's original code
+		error_msg = check_parameter(&prob, &param);
+
+		if(err || error_msg)
+		{
+			if (error_msg != NULL)
+				mexPrintf("Error: %s\n", error_msg);
+			destroy_param(&param);
+			free(prob.y);
+			free(prob.x);
+			free(x_space);
+			fake_answer(plhs);
+			return;
+		}
+
+		if(cross_validation_flag)
+		{
+			double *ptr;
+			plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
+			ptr = mxGetPr(plhs[0]);
+			ptr[0] = do_cross_validation();
+		}
+		else
+		{
+			const char *error_msg;
+
+			model_ = train(&prob, &param);
+			error_msg = model_to_matlab_structure(plhs, model_);
+			if(error_msg)
+				mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg);
+			free_and_destroy_model(&model_);
+		}
+		destroy_param(&param);
+		free(prob.y);
+		free(prob.x);
+		free(x_space);
 	}
 	else
-		prob.n=max_index;
-
-	fclose(fp);
+	{
+		exit_with_help();
+		fake_answer(plhs);
+		return;
+	}
 }