Mercurial > forge
view extra/NaN/src/train.c @ 12685:f26b1170ea90 octave-forge
resulting values should be really converted to output data type
author | schloegl |
---|---|
date | Sat, 12 Sep 2015 07:15:01 +0000 |
parents | 0605cb0434ff |
children |
line wrap: on
line source
/* $Id$ Copyright (c) 2007-2009 The LIBLINEAR Project. Copyright (c) 2010 Alois Schloegl <alois.schloegl@gmail.com> This function is part of the NaN-toolbox http://pub.ist.ac.at/~schloegl/matlab/NaN/ This code was extracted from liblinear-1.51 in Jan 2010 and modified for the use with Octave This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, see <http://www.gnu.org/licenses/>. */ #include <stdio.h> #include <math.h> #include <stdlib.h> #include <string.h> #include <ctype.h> #include "linear.h" #include "mex.h" #include "linear_model_matlab.h" #ifdef tmwtypes_h #if (MX_API_VER<=0x07020000) typedef int mwSize; typedef int mwIndex; #endif #endif #define CMD_LEN 2048 #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) #define INF HUGE_VAL void print_null(const char *s){} void (*liblinear_default_print_string) (const char *); void exit_with_help() { mexPrintf( "Usage: model = train(weight_vector, training_label_vector, training_instance_matrix, 'liblinear_options', 'col');\n" "liblinear_options:\n" "-s type : set type of solver (default 1)\n" " 0 -- L2-regularized logistic regression\n" " 1 -- L2-regularized L2-loss support vector classification (dual)\n" " 2 -- L2-regularized L2-loss support vector classification (primal)\n" " 3 -- L2-regularized L1-loss support vector classification (dual)\n" " 4 -- multi-class support vector classification by Crammer and Singer\n" " 5 -- L1-regularized L2-loss support vector classification\n" " 6 -- L1-regularized logistic regression\n" "-c cost : set the parameter C (default 1)\n" "-e epsilon : set tolerance of termination criterion\n" " -s 0 and 2\n" " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n" " where f is the primal function, (default 0.01)\n" " -s 1, 3, and 4\n" " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n" " -s 5 and 6\n" " |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf,\n" " where f is the primal function (default 0.01)\n" "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n" "-wi weight: weights adjust the parameter C of different classes (see README for details)\n" "-v n: n-fold cross validation mode\n" "-q : quiet mode (no outputs)\n" "col:\n" " if 'col' is setted, training_instance_matrix is parsed in column format, otherwise is in row format\n" ); } // liblinear arguments struct parameter param; // set by parse_command_line struct problem prob; // set by read_problem struct model *model_; struct feature_node *x_space; int cross_validation_flag; int col_format_flag; int nr_fold; double bias; double do_cross_validation() { int i; int total_correct = 0; int *target = Malloc(int,prob.l); double retval = 0.0; cross_validation(&prob,¶m,nr_fold,target); for(i=0;i<prob.l;i++) if(target[i] == prob.y[i]) ++total_correct; mexPrintf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l); retval = 100.0*total_correct/prob.l; free(target); return retval; } // nrhs should be 4 int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name) { int i, argc = 1; char cmd[CMD_LEN]; char *argv[CMD_LEN/2]; // default values param.solver_type = L2R_L2LOSS_SVC_DUAL; param.C = 1; param.eps = INF; // see setting below param.nr_weight = 0; param.weight_label = NULL; param.weight = NULL; cross_validation_flag = 0; col_format_flag = 0; bias = -1; // train loaded only once under matlab if(liblinear_default_print_string == NULL) liblinear_default_print_string = liblinear_print_string; else liblinear_print_string = liblinear_default_print_string; if(nrhs <= 2) return 1; if(nrhs == 5) { mxGetString(prhs[4], cmd, mxGetN(prhs[4])+1); if(strcmp(cmd, "col") == 0) col_format_flag = 1; } // put options in argv[] if(nrhs > 3) { mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1); if((argv[argc] = strtok(cmd, " ")) != NULL) while((argv[++argc] = strtok(NULL, " ")) != NULL) ; } // parse options for(i=1;i<argc;i++) { if(argv[i][0] != '-') break; ++i; if(i>=argc && argv[i-1][1] != 'q') // since option -q has no parameter return 1; switch(argv[i-1][1]) { case 's': param.solver_type = atoi(argv[i]); break; case 'c': param.C = atof(argv[i]); break; case 'e': param.eps = atof(argv[i]); break; case 'B': bias = atof(argv[i]); break; case 'v': cross_validation_flag = 1; nr_fold = atoi(argv[i]); if(nr_fold < 2) { mexPrintf("n-fold cross validation: n must >= 2\n"); return 1; } break; case 'w': ++param.nr_weight; param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight); param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight); param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]); param.weight[param.nr_weight-1] = atof(argv[i]); break; case 'q': liblinear_print_string = &print_null; i--; break; default: mexPrintf("unknown option\n"); return 1; } } if(param.eps == INF) { if(param.solver_type == L2R_LR || param.solver_type == L2R_L2LOSS_SVC) param.eps = 0.01; else if(param.solver_type == L2R_L2LOSS_SVC_DUAL || param.solver_type == L2R_L1LOSS_SVC_DUAL || param.solver_type == MCSVM_CS) param.eps = 0.1; else if(param.solver_type == L1R_L2LOSS_SVC || param.solver_type == L1R_LR) param.eps = 0.01; } return 0; } static void fake_answer(mxArray *plhs[]) { plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL); } int read_problem_sparse(const mxArray *weight_vec, const mxArray *label_vec, const mxArray *instance_mat) { int i, j, k, low, high; mwIndex *ir, *jc; int elements, max_index, num_samples, label_vector_row_num, weight_vector_row_num; double *samples, *labels, *weights; mxArray *instance_mat_col; // instance sparse matrix in column format prob.x = NULL; prob.y = NULL; prob.W = NULL; x_space = NULL; if(col_format_flag) instance_mat_col = (mxArray *)instance_mat; else { // transpose instance matrix mxArray *prhs[1], *plhs[1]; prhs[0] = mxDuplicateArray(instance_mat); if(mexCallMATLAB(1, plhs, 1, prhs, "transpose")) { mexPrintf("Error: cannot transpose training instance matrix\n"); return -1; } instance_mat_col = plhs[0]; mxDestroyArray(prhs[0]); } // the number of instance prob.l = (int) mxGetN(instance_mat_col); weight_vector_row_num = (int) mxGetM(weight_vec); label_vector_row_num = (int) mxGetM(label_vec); if(weight_vector_row_num == 0) ;//mexPrintf("Warning: treat each instance with weight 1.0\n"); else if(weight_vector_row_num!=prob.l) { mexPrintf("Length of weight vector does not match # of instances.\n"); return -1; } if(label_vector_row_num!=prob.l) { mexPrintf("Length of label vector does not match # of instances.\n"); return -1; } // each column is one instance weights = mxGetPr(weight_vec); labels = mxGetPr(label_vec); samples = mxGetPr(instance_mat_col); ir = mxGetIr(instance_mat_col); jc = mxGetJc(instance_mat_col); num_samples = (int) mxGetNzmax(instance_mat_col); elements = num_samples + prob.l*2; max_index = (int) mxGetM(instance_mat_col); prob.y = Malloc(int, prob.l); prob.W = Malloc(double,prob.l); prob.x = Malloc(struct feature_node*, prob.l); x_space = Malloc(struct feature_node, elements); prob.bias=bias; j = 0; for(i=0;i<prob.l;i++) { prob.x[i] = &x_space[j]; prob.y[i] = (int) labels[i]; prob.W[i] = 1; if(weight_vector_row_num > 0) prob.W[i] *= (double) weights[i]; low = (int) jc[i], high = (int) jc[i+1]; for(k=low;k<high;k++) { x_space[j].index = (int) ir[k]+1; x_space[j].value = samples[k]; j++; } if(prob.bias>=0) { x_space[j].index = max_index+1; x_space[j].value = prob.bias; j++; } x_space[j++].index = -1; } if(prob.bias>=0) prob.n = max_index+1; else prob.n = max_index; return 0; } // Interface function of matlab // now assume prhs[0]: label prhs[1]: features void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) { const char *error_msg; // fix random seed to have same results for each run // (for cross validation) srand(1); // Transform the input Matrix to libsvm format if(nrhs > 2 && nrhs < 6) { int err=0; if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]) || !mxIsDouble(prhs[2])) { mexPrintf("Error: weight vector, label vector and instance matrix must be double\n"); fake_answer(plhs); return; } if(parse_command_line(nrhs, prhs, NULL)) { exit_with_help(); destroy_param(¶m); fake_answer(plhs); return; } if(mxIsSparse(prhs[2])) err = read_problem_sparse(prhs[0], prhs[1], prhs[2]); else { mexPrintf("Training_instance_matrix must be sparse\n"); destroy_param(¶m); fake_answer(plhs); return; } // train's original code error_msg = check_parameter(¶m); if(err || error_msg) { if (error_msg != NULL) mexPrintf("Error: %s\n", error_msg); destroy_param(¶m); free(prob.y); free(prob.x); free(x_space); fake_answer(plhs); return; } if(cross_validation_flag) { double *ptr; plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL); ptr = mxGetPr(plhs[0]); ptr[0] = do_cross_validation(); } else { const char *error_msg; model_ = train(&prob, ¶m); error_msg = model_to_matlab_structure(plhs, model_); if(error_msg) mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg); destroy_model(model_); } destroy_param(¶m); free(prob.y); free(prob.x); free(prob.W); free(x_space); } else { exit_with_help(); fake_answer(plhs); return; } }