Mercurial > forge
changeset 12592:d03ad555e14e octave-forge
[nan] fix liblinear - use train/predict from liblinear/matlab
author | schloegl |
---|---|
date | Sun, 12 Apr 2015 19:00:36 +0000 |
parents | 3aeba3530595 |
children | 0605cb0434ff |
files | extra/NaN/src/predict.c extra/NaN/src/train.c |
diffstat | 2 files changed, 472 insertions(+), 369 deletions(-) [+] |
line wrap: on
line diff
--- a/extra/NaN/src/predict.c Sun Apr 12 17:52:15 2015 +0000 +++ b/extra/NaN/src/predict.c Sun Apr 12 19:00:36 2015 +0000 @@ -24,220 +24,305 @@ */ #include <stdio.h> -#include <ctype.h> #include <stdlib.h> #include <string.h> -#include <errno.h> #include "linear.h" -struct feature_node *x; -int max_nr_attr = 64; +#include "mex.h" +#include "linear_model_matlab.h" + +#ifdef tmwtypes_h + #if (MX_API_VER<=0x07020000) + typedef int mwSize; + #endif +#endif + + +#define CMD_LEN 2048 + +#define Malloc(type,n) (type *)malloc((n)*sizeof(type)) + +int col_format_flag; -struct model* model_; -int flag_predict_probability=0; +void read_sparse_instance(const mxArray *prhs, int index, struct feature_node *x, int feature_number, double bias) +{ + int i, j, low, high; + mwIndex *ir, *jc; + double *samples; + + ir = mxGetIr(prhs); + jc = mxGetJc(prhs); + samples = mxGetPr(prhs); -void exit_input_error(int line_num) -{ - fprintf(stderr,"Wrong input format at line %d\n", line_num); - exit(1); + // each column is one instance + j = 0; + low = (int) jc[index], high = (int) jc[index+1]; + for(i=low; i<high && (int) (ir[i])<feature_number; i++) + { + x[j].index = (int) ir[i]+1; + x[j].value = samples[i]; + j++; + } + if(bias>=0) + { + x[j].index = feature_number+1; + x[j].value = bias; + j++; + } + x[j].index = -1; } -static char *line = NULL; -static int max_line_len; - -static char* readline(FILE *input) +static void fake_answer(mxArray *plhs[]) { - int len; - - if(fgets(line,max_line_len,input) == NULL) - return NULL; - - while(strrchr(line,'\n') == NULL) - { - max_line_len *= 2; - line = (char *) realloc(line,max_line_len); - len = (int) strlen(line); - if(fgets(line+len,max_line_len-len,input) == NULL) - break; - } - return line; + plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL); + plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL); + plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL); } -void do_predict(FILE *input, FILE *output, struct model* model_) +void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, const int predict_probability_flag) { + int label_vector_row_num, label_vector_col_num; + int feature_number, testing_instance_number; + int instance_index; + double *ptr_instance, *ptr_label, *ptr_predict_label; + double *ptr_prob_estimates, *ptr_dec_values, *ptr; + struct feature_node *x; + mxArray *pplhs[1]; // instance sparse matrix in row format + int correct = 0; int total = 0; int nr_class=get_nr_class(model_); + int nr_w; double *prob_estimates=NULL; - int j, n; - int nr_feature=get_nr_feature(model_); - if(model_->bias>=0) - n=nr_feature+1; + + if(nr_class==2 && model_->param.solver_type!=MCSVM_CS) + nr_w=1; else - n=nr_feature; - - if(flag_predict_probability) - { - int *labels; + nr_w=nr_class; - if(!check_probability_model(model_)) - { - fprintf(stderr, "probability output is only supported for logistic regression\n"); - exit(1); - } - - labels=(int *) malloc(nr_class*sizeof(int)); - get_labels(model_,labels); - prob_estimates = (double *) malloc(nr_class*sizeof(double)); - fprintf(output,"labels"); - for(j=0;j<nr_class;j++) - fprintf(output," %d",labels[j]); - fprintf(output,"\n"); - free(labels); + // prhs[1] = testing instance matrix + feature_number = get_nr_feature(model_); + testing_instance_number = (int) mxGetM(prhs[1]); + if(col_format_flag) + { + feature_number = (int) mxGetM(prhs[1]); + testing_instance_number = (int) mxGetN(prhs[1]); } - max_line_len = 1024; - line = (char *)malloc(max_line_len*sizeof(char)); - while(readline(input) != NULL) - { - int i = 0; - int target_label, predict_label; - char *idx, *val, *label, *endptr; - int inst_max_index = 0; // strtol gives 0 if wrong format - - label = strtok(line," \t\n"); - if(label == NULL) // empty line - exit_input_error(total+1); - - target_label = (int) strtol(label,&endptr,10); - if(endptr == label || *endptr != '\0') - exit_input_error(total+1); - - while(1) - { - if(i>=max_nr_attr-2) // need one more for index = -1 - { - max_nr_attr *= 2; - x = (struct feature_node *) realloc(x,max_nr_attr*sizeof(struct feature_node)); - } - - idx = strtok(NULL,":"); - val = strtok(NULL," \t"); + label_vector_row_num = (int) mxGetM(prhs[0]); + label_vector_col_num = (int) mxGetN(prhs[0]); - if(val == NULL) - break; - errno = 0; - x[i].index = (int) strtol(idx,&endptr,10); - if(endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index) - exit_input_error(total+1); - else - inst_max_index = x[i].index; - - errno = 0; - x[i].value = strtod(val,&endptr); - if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr))) - exit_input_error(total+1); + if(label_vector_row_num!=testing_instance_number) + { + mexPrintf("Length of label vector does not match # of instances.\n"); + fake_answer(plhs); + return; + } + if(label_vector_col_num!=1) + { + mexPrintf("label (1st argument) should be a vector (# of column is 1).\n"); + fake_answer(plhs); + return; + } - // feature indices larger than those in training are not used - if(x[i].index <= nr_feature) - ++i; - } + ptr_instance = mxGetPr(prhs[1]); + ptr_label = mxGetPr(prhs[0]); - if(model_->bias>=0) + // transpose instance matrix + if(mxIsSparse(prhs[1])) + { + if(col_format_flag) { - x[i].index = n; - x[i].value = model_->bias; - i++; - } - x[i].index = -1; - - if(flag_predict_probability) - { - int j; - predict_label = predict_probability(model_,x,prob_estimates); - fprintf(output,"%d",predict_label); - for(j=0;j<model_->nr_class;j++) - fprintf(output," %g",prob_estimates[j]); - fprintf(output,"\n"); + pplhs[0] = (mxArray *)prhs[1]; } else { - predict_label = predict(model_,x); - fprintf(output,"%d\n",predict_label); + mxArray *pprhs[1]; + pprhs[0] = mxDuplicateArray(prhs[1]); + if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose")) + { + mexPrintf("Error: cannot transpose testing instance matrix\n"); + fake_answer(plhs); + return; + } + } + } + else + mexPrintf("Testing_instance_matrix must be sparse\n"); + + + prob_estimates = Malloc(double, nr_class); + + plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL); + if(predict_probability_flag) + plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL); + else + plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_w, mxREAL); + + ptr_predict_label = mxGetPr(plhs[0]); + ptr_prob_estimates = mxGetPr(plhs[2]); + ptr_dec_values = mxGetPr(plhs[2]); + x = Malloc(struct feature_node, feature_number+2); + for(instance_index=0;instance_index<testing_instance_number;instance_index++) + { + int i; + double target,v; + + target = ptr_label[instance_index]; + + // prhs[1] and prhs[1]^T are sparse + read_sparse_instance(pplhs[0], instance_index, x, feature_number, model_->bias); + + if(predict_probability_flag) + { + v = predict_probability(model_, x, prob_estimates); + ptr_predict_label[instance_index] = v; + for(i=0;i<nr_class;i++) + ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i]; + } + else + { + double *dec_values = Malloc(double, nr_class); + v = predict(model_, x); + ptr_predict_label[instance_index] = v; + + predict_values(model_, x, dec_values); + for(i=0;i<nr_w;i++) + ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i]; + free(dec_values); } - if(predict_label == target_label) + if(v == target) ++correct; ++total; } - printf("Accuracy = %g%% (%d/%d)\n",(double) correct/total*100,correct,total); - if(flag_predict_probability) + mexPrintf("Accuracy = %g%% (%d/%d)\n", (double) correct/total*100,correct,total); + + // return accuracy, mean squared error, squared correlation coefficient + plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL); + ptr = mxGetPr(plhs[1]); + ptr[0] = (double) correct/total*100; + + free(x); + if(prob_estimates != NULL) free(prob_estimates); } void exit_with_help() { - printf( - "Usage: predict [options] test_file model_file output_file\n" - "options:\n" - "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0)\n" - ); - exit(1); + mexPrintf( + "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = predict(testing_label_vector, testing_instance_matrix, model, 'liblinear_options','col')\n" + "liblinear_options:\n" + "-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0)\n" + "col:\n" + " if 'col' is setted testing_instance_matrix is parsed in column format, otherwise is in row format" + ); } -int main(int argc, char **argv) +void mexFunction( int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] ) { - FILE *input, *output; - int i; + int prob_estimate_flag = 0; + struct model *model_; + char cmd[CMD_LEN]; + col_format_flag = 0; - // parse options - for(i=1;i<argc;i++) + if(nrhs > 5 || nrhs < 3) { - if(argv[i][0] != '-') break; - ++i; - switch(argv[i-1][1]) - { - case 'b': - flag_predict_probability = atoi(argv[i]); - break; - - default: - fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]); - exit_with_help(); - break; + exit_with_help(); + fake_answer(plhs); + return; + } + if(nrhs == 5) + { + mxGetString(prhs[4], cmd, mxGetN(prhs[4])+1); + if(strcmp(cmd, "col") == 0) + { + col_format_flag = 1; } } - if(i>=argc) - exit_with_help(); - input = fopen(argv[i],"r"); - if(input == NULL) - { - fprintf(stderr,"can't open input file %s\n",argv[i]); - exit(1); - } - - output = fopen(argv[i+2],"w"); - if(output == NULL) - { - fprintf(stderr,"can't open output file %s\n",argv[i+2]); - exit(1); + if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) { + mexPrintf("Error: label vector and instance matrix must be double\n"); + fake_answer(plhs); + return; } - if((model_=load_model(argv[i+1]))==0) + if(mxIsStruct(prhs[2])) { - fprintf(stderr,"can't open model file %s\n",argv[i+1]); - exit(1); + const char *error_msg; + + // parse options + if(nrhs>=4) + { + int i, argc = 1; + char *argv[CMD_LEN/2]; + + // put options in argv[] + mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1); + if((argv[argc] = strtok(cmd, " ")) != NULL) + while((argv[++argc] = strtok(NULL, " ")) != NULL) + ; + + for(i=1;i<argc;i++) + { + if(argv[i][0] != '-') break; + if(++i>=argc) + { + exit_with_help(); + fake_answer(plhs); + return; + } + switch(argv[i-1][1]) + { + case 'b': + prob_estimate_flag = atoi(argv[i]); + break; + default: + mexPrintf("unknown option\n"); + exit_with_help(); + fake_answer(plhs); + return; + } + } + } + + model_ = Malloc(struct model, 1); + error_msg = matlab_matrix_to_model(model_, prhs[2]); + if(error_msg) + { + mexPrintf("Error: can't read model: %s\n", error_msg); + free_and_destroy_model(&model_); + fake_answer(plhs); + return; + } + + if(prob_estimate_flag) + { + if(!check_probability_model(model_)) + { + mexPrintf("probability output is only supported for logistic regression\n"); + prob_estimate_flag=0; + } + } + + if(mxIsSparse(prhs[1])) + do_predict(plhs, prhs, model_, prob_estimate_flag); + else + { + mexPrintf("Testing_instance_matrix must be sparse\n"); + fake_answer(plhs); + } + + // destroy model_ + free_and_destroy_model(&model_); + } + else + { + mexPrintf("model file should be a struct array\n"); + fake_answer(plhs); } - x = (struct feature_node *) malloc(max_nr_attr*sizeof(struct feature_node)); - do_predict(input, output, model_); - free_and_destroy_model(&model_); - free(line); - free(x); - fclose(input); - fclose(output); - return 0; + return; } -
--- a/extra/NaN/src/train.c Sun Apr 12 17:52:15 2015 +0000 +++ b/extra/NaN/src/train.c Sun Apr 12 19:00:36 2015 +0000 @@ -28,18 +28,29 @@ #include <stdlib.h> #include <string.h> #include <ctype.h> -#include <errno.h> #include "linear.h" + +#include "mex.h" +#include "linear_model_matlab.h" + +#ifdef tmwtypes_h + #if (MX_API_VER<=0x07020000) + typedef int mwIndex; + #endif +#endif + +#define CMD_LEN 2048 #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) #define INF HUGE_VAL void print_null(const char *s) {} +void print_string_matlab(const char *s) {mexPrintf(s);} void exit_with_help() { - printf( - "Usage: train [options] training_set_file [model_file]\n" - "options:\n" + mexPrintf( + "Usage: model = train(training_label_vector, training_instance_matrix, 'liblinear_options', 'col');\n" + "liblinear_options:\n" "-s type : set type of solver (default 1)\n" " 0 -- L2-regularized logistic regression (primal)\n" " 1 -- L2-regularized L2-loss support vector classification (dual)\n" @@ -64,108 +75,47 @@ "-wi weight: weights adjust the parameter C of different classes (see README for details)\n" "-v n: n-fold cross validation mode\n" "-q : quiet mode (no outputs)\n" + "col:\n" + " if 'col' is setted, training_instance_matrix is parsed in column format, otherwise is in row format\n" ); - exit(1); -} - -void exit_input_error(int line_num) -{ - fprintf(stderr,"Wrong input format at line %d\n", line_num); - exit(1); } -static char *line = NULL; -static int max_line_len; - -static char* readline(FILE *input) -{ - int len; - - if(fgets(line,max_line_len,input) == NULL) - return NULL; - - while(strrchr(line,'\n') == NULL) - { - max_line_len *= 2; - line = (char *) realloc(line,max_line_len); - len = (int) strlen(line); - if(fgets(line+len,max_line_len-len,input) == NULL) - break; - } - return line; -} - -void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name); -void read_problem(const char *filename); -void do_cross_validation(); - +// liblinear arguments +struct parameter param; // set by parse_command_line +struct problem prob; // set by read_problem +struct model *model_; struct feature_node *x_space; -struct parameter param; -struct problem prob; -struct model* model_; -int flag_cross_validation; +int cross_validation_flag; +int col_format_flag; int nr_fold; double bias; -int main(int argc, char **argv) -{ - char input_file_name[1024]; - char model_file_name[1024]; - const char *error_msg; - - parse_command_line(argc, argv, input_file_name, model_file_name); - read_problem(input_file_name); - error_msg = check_parameter(&prob,¶m); - - if(error_msg) - { - fprintf(stderr,"Error: %s\n",error_msg); - exit(1); - } - - if(flag_cross_validation) - { - do_cross_validation(); - } - else - { - model_=train(&prob, ¶m); - if(save_model(model_file_name, model_)) - { - fprintf(stderr,"can't save model to file %s\n",model_file_name); - exit(1); - } - free_and_destroy_model(&model_); - } - destroy_param(¶m); - free(prob.y); - free(prob.x); - free(x_space); - free(line); - - return 0; -} - -void do_cross_validation() +double do_cross_validation() { int i; int total_correct = 0; - int *target = Malloc(int, prob.l); + int *target = Malloc(int,prob.l); + double retval = 0.0; cross_validation(&prob,¶m,nr_fold,target); for(i=0;i<prob.l;i++) if(target[i] == prob.y[i]) ++total_correct; - printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l); + mexPrintf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l); + retval = 100.0*total_correct/prob.l; free(target); + return retval; } -void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name) +// nrhs should be 3 +int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name) { - int i; - void (*print_func)(const char*) = NULL; // default printing to stdout + int i, argc = 1; + char cmd[CMD_LEN]; + char *argv[CMD_LEN/2]; + void (*print_func)(const char *) = print_string_matlab; // default printing to matlab display // default values param.solver_type = L2R_L2LOSS_SVC_DUAL; @@ -174,33 +124,60 @@ param.nr_weight = 0; param.weight_label = NULL; param.weight = NULL; - flag_cross_validation = 0; + cross_validation_flag = 0; + col_format_flag = 0; bias = -1; + + if(nrhs <= 1) + return 1; + + if(nrhs == 4) + { + mxGetString(prhs[3], cmd, mxGetN(prhs[3])+1); + if(strcmp(cmd, "col") == 0) + col_format_flag = 1; + } + + // put options in argv[] + if(nrhs > 2) + { + mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1); + if((argv[argc] = strtok(cmd, " ")) != NULL) + while((argv[++argc] = strtok(NULL, " ")) != NULL) + ; + } + // parse options for(i=1;i<argc;i++) { if(argv[i][0] != '-') break; - if(++i>=argc) - exit_with_help(); + ++i; + if(i>=argc && argv[i-1][1] != 'q') // since option -q has no parameter + return 1; switch(argv[i-1][1]) { case 's': param.solver_type = atoi(argv[i]); break; - case 'c': param.C = atof(argv[i]); break; - case 'e': param.eps = atof(argv[i]); break; - case 'B': bias = atof(argv[i]); break; - + case 'v': + cross_validation_flag = 1; + nr_fold = atoi(argv[i]); + if(nr_fold < 2) + { + mexPrintf("n-fold cross validation: n must >= 2\n"); + return 1; + } + break; case 'w': ++param.nr_weight; param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight); @@ -208,49 +185,18 @@ param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]); param.weight[param.nr_weight-1] = atof(argv[i]); break; - - case 'v': - flag_cross_validation = 1; - nr_fold = atoi(argv[i]); - if(nr_fold < 2) - { - fprintf(stderr,"n-fold cross validation: n must >= 2\n"); - exit_with_help(); - } - break; - case 'q': print_func = &print_null; i--; break; - default: - fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]); - exit_with_help(); - break; + mexPrintf("unknown option\n"); + return 1; } } set_print_string_function(print_func); - // determine filenames - if(i>=argc) - exit_with_help(); - - strcpy(input_file_name, argv[i]); - - if(i<argc-1) - strcpy(model_file_name,argv[i+1]); - else - { - char *p = strrchr(argv[i],'/'); - if(p==NULL) - p = argv[i]; - else - ++p; - sprintf(model_file_name,"%s.model",p); - } - if(param.eps == INF) { if(param.solver_type == L2R_LR || param.solver_type == L2R_L2LOSS_SVC) @@ -260,106 +206,178 @@ else if(param.solver_type == L1R_L2LOSS_SVC || param.solver_type == L1R_LR) param.eps = 0.01; } + return 0; +} + +static void fake_answer(mxArray *plhs[]) +{ + plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL); } -// read in a problem (in libsvm format) -void read_problem(const char *filename) +int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat) { - int max_index, inst_max_index, i; - long int elements, j; - FILE *fp = fopen(filename,"r"); - char *endptr; - char *idx, *val, *label; + int i, j, k, low, high; + mwIndex *ir, *jc; + int elements, max_index, num_samples, label_vector_row_num; + double *samples, *labels; + mxArray *instance_mat_col; // instance sparse matrix in column format + + prob.x = NULL; + prob.y = NULL; + x_space = NULL; - if(fp == NULL) + if(col_format_flag) + instance_mat_col = (mxArray *)instance_mat; + else { - fprintf(stderr,"can't open input file %s\n",filename); - exit(1); + // transpose instance matrix + mxArray *prhs[1], *plhs[1]; + prhs[0] = mxDuplicateArray(instance_mat); + if(mexCallMATLAB(1, plhs, 1, prhs, "transpose")) + { + mexPrintf("Error: cannot transpose training instance matrix\n"); + return -1; + } + instance_mat_col = plhs[0]; + mxDestroyArray(prhs[0]); } - prob.l = 0; - elements = 0; - max_line_len = 1024; - line = Malloc(char,max_line_len); - while(readline(fp)!=NULL) - { - char *p = strtok(line," \t"); // label + // the number of instance + prob.l = (int) mxGetN(instance_mat_col); + label_vector_row_num = (int) mxGetM(label_vec); - // features - while(1) - { - p = strtok(NULL," \t"); - if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature - break; - elements++; - } - elements++; // for bias term - prob.l++; + if(label_vector_row_num!=prob.l) + { + mexPrintf("Length of label vector does not match # of instances.\n"); + return -1; } - rewind(fp); + + // each column is one instance + labels = mxGetPr(label_vec); + samples = mxGetPr(instance_mat_col); + ir = mxGetIr(instance_mat_col); + jc = mxGetJc(instance_mat_col); + + num_samples = (int) mxGetNzmax(instance_mat_col); + + elements = num_samples + prob.l*2; + max_index = (int) mxGetM(instance_mat_col); + + prob.y = Malloc(int, prob.l); + prob.x = Malloc(struct feature_node*, prob.l); + x_space = Malloc(struct feature_node, elements); prob.bias=bias; - prob.y = Malloc(int,prob.l); - prob.x = Malloc(struct feature_node *,prob.l); - x_space = Malloc(struct feature_node,elements+prob.l); - - max_index = 0; - j=0; + j = 0; for(i=0;i<prob.l;i++) { - inst_max_index = 0; // strtol gives 0 if wrong format - readline(fp); prob.x[i] = &x_space[j]; - label = strtok(line," \t\n"); - if(label == NULL) // empty line - exit_input_error(i+1); - - prob.y[i] = (int) strtol(label,&endptr,10); - if(endptr == label || *endptr != '\0') - exit_input_error(i+1); - - while(1) + prob.y[i] = (int) labels[i]; + low = (int) jc[i], high = (int) jc[i+1]; + for(k=low;k<high;k++) { - idx = strtok(NULL,":"); - val = strtok(NULL," \t"); - - if(val == NULL) - break; - - errno = 0; - x_space[j].index = (int) strtol(idx,&endptr,10); - if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index) - exit_input_error(i+1); - else - inst_max_index = x_space[j].index; - - errno = 0; - x_space[j].value = strtod(val,&endptr); - if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr))) - exit_input_error(i+1); - - ++j; + x_space[j].index = (int) ir[k]+1; + x_space[j].value = samples[k]; + j++; + } + if(prob.bias>=0) + { + x_space[j].index = max_index+1; + x_space[j].value = prob.bias; + j++; } - - if(inst_max_index > max_index) - max_index = inst_max_index; - - if(prob.bias >= 0) - x_space[j++].value = prob.bias; - x_space[j++].index = -1; } - if(prob.bias >= 0) + if(prob.bias>=0) + prob.n = max_index+1; + else + prob.n = max_index; + + return 0; +} + +// Interface function of matlab +// now assume prhs[0]: label prhs[1]: features +void mexFunction( int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] ) +{ + const char *error_msg; + // fix random seed to have same results for each run + // (for cross validation) + srand(1); + + // Transform the input Matrix to libsvm format + if(nrhs > 1 && nrhs < 5) { - prob.n=max_index+1; - for(i=1;i<prob.l;i++) - (prob.x[i]-2)->index = prob.n; - x_space[j-2].index = prob.n; + int err=0; + + if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) { + mexPrintf("Error: label vector and instance matrix must be double\n"); + fake_answer(plhs); + return; + } + + if(parse_command_line(nrhs, prhs, NULL)) + { + exit_with_help(); + destroy_param(¶m); + fake_answer(plhs); + return; + } + + if(mxIsSparse(prhs[1])) + err = read_problem_sparse(prhs[0], prhs[1]); + else + { + mexPrintf("Training_instance_matrix must be sparse\n"); + destroy_param(¶m); + fake_answer(plhs); + return; + } + + // train's original code + error_msg = check_parameter(&prob, ¶m); + + if(err || error_msg) + { + if (error_msg != NULL) + mexPrintf("Error: %s\n", error_msg); + destroy_param(¶m); + free(prob.y); + free(prob.x); + free(x_space); + fake_answer(plhs); + return; + } + + if(cross_validation_flag) + { + double *ptr; + plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL); + ptr = mxGetPr(plhs[0]); + ptr[0] = do_cross_validation(); + } + else + { + const char *error_msg; + + model_ = train(&prob, ¶m); + error_msg = model_to_matlab_structure(plhs, model_); + if(error_msg) + mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg); + free_and_destroy_model(&model_); + } + destroy_param(¶m); + free(prob.y); + free(prob.x); + free(x_space); } else - prob.n=max_index; - - fclose(fp); + { + exit_with_help(); + fake_answer(plhs); + return; + } }