comparison extra/NaN/src/svm.cpp @ 12589:06a805605e9a octave-forge

[nan] upgrade libsvm to v3.12
author schloegl
date Sun, 12 Apr 2015 14:37:46 +0000
parents 6a419bec96bb
children
comparison
equal deleted inserted replaced
12588:3f24658504ab 12589:06a805605e9a
1 /* 1 /*
2 2
3 3 Copyright (c) 2000-2012 Chih-Chung Chang and Chih-Jen Lin
4 $Id$ 4 Copyright (c) 2010,2011,2015 Alois Schloegl <alois.schloegl@ist.ac.at>
5 Copyright (c) 2000-2009 Chih-Chung Chang and Chih-Jen Lin
6 Copyright (c) 2010 Alois Schloegl <alois.schloegl@gmail.com>
7 This function is part of the NaN-toolbox 5 This function is part of the NaN-toolbox
8 http://pub.ist.ac.at/~schloegl/matlab/NaN/ 6 http://pub.ist.ac.at/~schloegl/matlab/NaN/
9 7
10 This code was extracted from libsvm-mat-2.9-1 in Jan 2010 and 8 This code was extracted from libsvm-3.12 in Apr 2015 and
11 modified for the use with Octave 9 modified for the use with Octave
12 10
13 This program is free software; you can redistribute it and/or modify 11 This program is free software; you can redistribute it and/or modify
14 it under the terms of the GNU General Public License as published by 12 it under the terms of the GNU General Public License as published by
15 the Free Software Foundation; either version 3 of the License, or 13 the Free Software Foundation; either version 3 of the License, or
30 #include <stdlib.h> 28 #include <stdlib.h>
31 #include <ctype.h> 29 #include <ctype.h>
32 #include <float.h> 30 #include <float.h>
33 #include <string.h> 31 #include <string.h>
34 #include <stdarg.h> 32 #include <stdarg.h>
33 #include <limits.h>
34 #include <locale.h>
35 #include "svm.h" 35 #include "svm.h"
36 36
37 int libsvm_version = LIBSVM_VERSION; 37 int libsvm_version = LIBSVM_VERSION;
38 typedef float Qfloat; 38 typedef float Qfloat;
39 typedef signed char schar; 39 typedef signed char schar;
67 static void print_string_stdout(const char *s) 67 static void print_string_stdout(const char *s)
68 { 68 {
69 fputs(s,stdout); 69 fputs(s,stdout);
70 fflush(stdout); 70 fflush(stdout);
71 } 71 }
72 void (*svm_print_string) (const char *) = &print_string_stdout; 72 static void (*svm_print_string) (const char *) = &print_string_stdout;
73 #if 1 73 #if 1
74 static void info(const char *fmt,...) 74 static void info(const char *fmt,...)
75 { 75 {
76 char buf[BUFSIZ]; 76 char buf[BUFSIZ];
77 va_list ap; 77 va_list ap;
218 // the member function get_Q is for getting one column from the Q Matrix 218 // the member function get_Q is for getting one column from the Q Matrix
219 // 219 //
220 class QMatrix { 220 class QMatrix {
221 public: 221 public:
222 virtual Qfloat *get_Q(int column, int len) const = 0; 222 virtual Qfloat *get_Q(int column, int len) const = 0;
223 virtual Qfloat *get_QD() const = 0; 223 virtual double *get_QD() const = 0;
224 virtual void swap_index(int i, int j) const = 0; 224 virtual void swap_index(int i, int j) const = 0;
225 virtual ~QMatrix() {} 225 virtual ~QMatrix() {}
226 }; 226 };
227 227
228 class Kernel: public QMatrix { 228 class Kernel: public QMatrix {
231 virtual ~Kernel(); 231 virtual ~Kernel();
232 232
233 static double k_function(const svm_node *x, const svm_node *y, 233 static double k_function(const svm_node *x, const svm_node *y,
234 const svm_parameter& param); 234 const svm_parameter& param);
235 virtual Qfloat *get_Q(int column, int len) const = 0; 235 virtual Qfloat *get_Q(int column, int len) const = 0;
236 virtual Qfloat *get_QD() const = 0; 236 virtual double *get_QD() const = 0;
237 virtual void swap_index(int i, int j) const // no so const... 237 virtual void swap_index(int i, int j) const // no so const...
238 { 238 {
239 swap(x[i],x[j]); 239 swap(x[i],x[j]);
240 if(x_square) swap(x_square[i],x_square[j]); 240 if(x_square) swap(x_square[i],x_square[j]);
241 } 241 }
438 double *G; // gradient of objective function 438 double *G; // gradient of objective function
439 enum { LOWER_BOUND, UPPER_BOUND, FREE }; 439 enum { LOWER_BOUND, UPPER_BOUND, FREE };
440 char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE 440 char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
441 double *alpha; 441 double *alpha;
442 const QMatrix *Q; 442 const QMatrix *Q;
443 const Qfloat *QD; 443 const double *QD;
444 double eps; 444 double eps;
445 double Cp,Cn; 445 double Cp,Cn;
446 double *p; 446 double *p;
447 int *active_set; 447 int *active_set;
448 double *G_bar; // gradient, if we treat free variables as 0 448 double *G_bar; // gradient, if we treat free variables as 0
500 for(j=0;j<active_size;j++) 500 for(j=0;j<active_size;j++)
501 if(is_free(j)) 501 if(is_free(j))
502 nr_free++; 502 nr_free++;
503 503
504 if(2*nr_free < active_size) 504 if(2*nr_free < active_size)
505 info("\nWarning: using -h 0 may be faster\n"); 505 info("\nWARNING: using -h 0 may be faster\n");
506 506
507 if (nr_free*l > 2*active_size*(l-active_size)) 507 if (nr_free*l > 2*active_size*(l-active_size))
508 { 508 {
509 for(i=active_size;i<l;i++) 509 for(i=active_size;i<l;i++)
510 { 510 {
582 } 582 }
583 583
584 // optimization step 584 // optimization step
585 585
586 int iter = 0; 586 int iter = 0;
587 int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l);
587 int counter = min(l,1000)+1; 588 int counter = min(l,1000)+1;
588 589
589 while(1) 590 while(iter < max_iter)
590 { 591 {
591 // show progress and do shrinking 592 // show progress and do shrinking
592 593
593 if(--counter == 0) 594 if(--counter == 0)
594 { 595 {
624 double old_alpha_i = alpha[i]; 625 double old_alpha_i = alpha[i];
625 double old_alpha_j = alpha[j]; 626 double old_alpha_j = alpha[j];
626 627
627 if(y[i]!=y[j]) 628 if(y[i]!=y[j])
628 { 629 {
629 double quad_coef = Q_i[i]+Q_j[j]+2*Q_i[j]; 630 double quad_coef = QD[i]+QD[j]+2*Q_i[j];
630 if (quad_coef <= 0) 631 if (quad_coef <= 0)
631 quad_coef = TAU; 632 quad_coef = TAU;
632 double delta = (-G[i]-G[j])/quad_coef; 633 double delta = (-G[i]-G[j])/quad_coef;
633 double diff = alpha[i] - alpha[j]; 634 double diff = alpha[i] - alpha[j];
634 alpha[i] += delta; 635 alpha[i] += delta;
667 } 668 }
668 } 669 }
669 } 670 }
670 else 671 else
671 { 672 {
672 double quad_coef = Q_i[i]+Q_j[j]-2*Q_i[j]; 673 double quad_coef = QD[i]+QD[j]-2*Q_i[j];
673 if (quad_coef <= 0) 674 if (quad_coef <= 0)
674 quad_coef = TAU; 675 quad_coef = TAU;
675 double delta = (G[i]-G[j])/quad_coef; 676 double delta = (G[i]-G[j])/quad_coef;
676 double sum = alpha[i] + alpha[j]; 677 double sum = alpha[i] + alpha[j];
677 alpha[i] -= delta; 678 alpha[i] -= delta;
749 else 750 else
750 for(k=0;k<l;k++) 751 for(k=0;k<l;k++)
751 G_bar[k] += C_j * Q_j[k]; 752 G_bar[k] += C_j * Q_j[k];
752 } 753 }
753 } 754 }
755 }
756
757 if(iter >= max_iter)
758 {
759 if(active_size < l)
760 {
761 // reconstruct the whole gradient to calculate objective value
762 reconstruct_gradient();
763 active_size = l;
764 info("*");
765 }
766 info("\nWARNING: reaching max number of iterations");
754 } 767 }
755 768
756 // calculate rho 769 // calculate rho
757 770
758 si->rho = calculate_rho(); 771 si->rho = calculate_rho();
845 if (G[j] >= Gmax2) 858 if (G[j] >= Gmax2)
846 Gmax2 = G[j]; 859 Gmax2 = G[j];
847 if (grad_diff > 0) 860 if (grad_diff > 0)
848 { 861 {
849 double obj_diff; 862 double obj_diff;
850 double quad_coef=Q_i[i]+QD[j]-2.0*y[i]*Q_i[j]; 863 double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
851 if (quad_coef > 0) 864 if (quad_coef > 0)
852 obj_diff = -(grad_diff*grad_diff)/quad_coef; 865 obj_diff = -(grad_diff*grad_diff)/quad_coef;
853 else 866 else
854 obj_diff = -(grad_diff*grad_diff)/TAU; 867 obj_diff = -(grad_diff*grad_diff)/TAU;
855 868
869 if (-G[j] >= Gmax2) 882 if (-G[j] >= Gmax2)
870 Gmax2 = -G[j]; 883 Gmax2 = -G[j];
871 if (grad_diff > 0) 884 if (grad_diff > 0)
872 { 885 {
873 double obj_diff; 886 double obj_diff;
874 double quad_coef=Q_i[i]+QD[j]+2.0*y[i]*Q_i[j]; 887 double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
875 if (quad_coef > 0) 888 if (quad_coef > 0)
876 obj_diff = -(grad_diff*grad_diff)/quad_coef; 889 obj_diff = -(grad_diff*grad_diff)/quad_coef;
877 else 890 else
878 obj_diff = -(grad_diff*grad_diff)/TAU; 891 obj_diff = -(grad_diff*grad_diff)/TAU;
879 892
1097 if (G[j] >= Gmaxp2) 1110 if (G[j] >= Gmaxp2)
1098 Gmaxp2 = G[j]; 1111 Gmaxp2 = G[j];
1099 if (grad_diff > 0) 1112 if (grad_diff > 0)
1100 { 1113 {
1101 double obj_diff; 1114 double obj_diff;
1102 double quad_coef = Q_ip[ip]+QD[j]-2*Q_ip[j]; 1115 double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
1103 if (quad_coef > 0) 1116 if (quad_coef > 0)
1104 obj_diff = -(grad_diff*grad_diff)/quad_coef; 1117 obj_diff = -(grad_diff*grad_diff)/quad_coef;
1105 else 1118 else
1106 obj_diff = -(grad_diff*grad_diff)/TAU; 1119 obj_diff = -(grad_diff*grad_diff)/TAU;
1107 1120
1121 if (-G[j] >= Gmaxn2) 1134 if (-G[j] >= Gmaxn2)
1122 Gmaxn2 = -G[j]; 1135 Gmaxn2 = -G[j];
1123 if (grad_diff > 0) 1136 if (grad_diff > 0)
1124 { 1137 {
1125 double obj_diff; 1138 double obj_diff;
1126 double quad_coef = Q_in[in]+QD[j]-2*Q_in[j]; 1139 double quad_coef = QD[in]+QD[j]-2*Q_in[j];
1127 if (quad_coef > 0) 1140 if (quad_coef > 0)
1128 obj_diff = -(grad_diff*grad_diff)/quad_coef; 1141 obj_diff = -(grad_diff*grad_diff)/quad_coef;
1129 else 1142 else
1130 obj_diff = -(grad_diff*grad_diff)/TAU; 1143 obj_diff = -(grad_diff*grad_diff)/TAU;
1131 1144
1282 SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_) 1295 SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
1283 :Kernel(prob.l, prob.x, param) 1296 :Kernel(prob.l, prob.x, param)
1284 { 1297 {
1285 clone(y,y_,prob.l); 1298 clone(y,y_,prob.l);
1286 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); 1299 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1287 QD = new Qfloat[prob.l]; 1300 QD = new double[prob.l];
1288 for(int i=0;i<prob.l;i++) 1301 for(int i=0;i<prob.l;i++)
1289 QD[i]= (Qfloat)(this->*kernel_function)(i,i); 1302 QD[i] = (this->*kernel_function)(i,i);
1290 } 1303 }
1291 1304
1292 Qfloat *get_Q(int i, int len) const 1305 Qfloat *get_Q(int i, int len) const
1293 { 1306 {
1294 Qfloat *data; 1307 Qfloat *data;
1299 data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j)); 1312 data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
1300 } 1313 }
1301 return data; 1314 return data;
1302 } 1315 }
1303 1316
1304 Qfloat *get_QD() const 1317 double *get_QD() const
1305 { 1318 {
1306 return QD; 1319 return QD;
1307 } 1320 }
1308 1321
1309 void swap_index(int i, int j) const 1322 void swap_index(int i, int j) const
1321 delete[] QD; 1334 delete[] QD;
1322 } 1335 }
1323 private: 1336 private:
1324 schar *y; 1337 schar *y;
1325 Cache *cache; 1338 Cache *cache;
1326 Qfloat *QD; 1339 double *QD;
1327 }; 1340 };
1328 1341
1329 class ONE_CLASS_Q: public Kernel 1342 class ONE_CLASS_Q: public Kernel
1330 { 1343 {
1331 public: 1344 public:
1332 ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param) 1345 ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
1333 :Kernel(prob.l, prob.x, param) 1346 :Kernel(prob.l, prob.x, param)
1334 { 1347 {
1335 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); 1348 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1336 QD = new Qfloat[prob.l]; 1349 QD = new double[prob.l];
1337 for(int i=0;i<prob.l;i++) 1350 for(int i=0;i<prob.l;i++)
1338 QD[i]= (Qfloat)(this->*kernel_function)(i,i); 1351 QD[i] = (this->*kernel_function)(i,i);
1339 } 1352 }
1340 1353
1341 Qfloat *get_Q(int i, int len) const 1354 Qfloat *get_Q(int i, int len) const
1342 { 1355 {
1343 Qfloat *data; 1356 Qfloat *data;
1348 data[j] = (Qfloat)(this->*kernel_function)(i,j); 1361 data[j] = (Qfloat)(this->*kernel_function)(i,j);
1349 } 1362 }
1350 return data; 1363 return data;
1351 } 1364 }
1352 1365
1353 Qfloat *get_QD() const 1366 double *get_QD() const
1354 { 1367 {
1355 return QD; 1368 return QD;
1356 } 1369 }
1357 1370
1358 void swap_index(int i, int j) const 1371 void swap_index(int i, int j) const
1367 delete cache; 1380 delete cache;
1368 delete[] QD; 1381 delete[] QD;
1369 } 1382 }
1370 private: 1383 private:
1371 Cache *cache; 1384 Cache *cache;
1372 Qfloat *QD; 1385 double *QD;
1373 }; 1386 };
1374 1387
1375 class SVR_Q: public Kernel 1388 class SVR_Q: public Kernel
1376 { 1389 {
1377 public: 1390 public:
1378 SVR_Q(const svm_problem& prob, const svm_parameter& param) 1391 SVR_Q(const svm_problem& prob, const svm_parameter& param)
1379 :Kernel(prob.l, prob.x, param) 1392 :Kernel(prob.l, prob.x, param)
1380 { 1393 {
1381 l = prob.l; 1394 l = prob.l;
1382 cache = new Cache(l,(long int)(param.cache_size*(1<<20))); 1395 cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
1383 QD = new Qfloat[2*l]; 1396 QD = new double[2*l];
1384 sign = new schar[2*l]; 1397 sign = new schar[2*l];
1385 index = new int[2*l]; 1398 index = new int[2*l];
1386 for(int k=0;k<l;k++) 1399 for(int k=0;k<l;k++)
1387 { 1400 {
1388 sign[k] = 1; 1401 sign[k] = 1;
1389 sign[k+l] = -1; 1402 sign[k+l] = -1;
1390 index[k] = k; 1403 index[k] = k;
1391 index[k+l] = k; 1404 index[k+l] = k;
1392 QD[k]= (Qfloat)(this->*kernel_function)(k,k); 1405 QD[k] = (this->*kernel_function)(k,k);
1393 QD[k+l]=QD[k]; 1406 QD[k+l] = QD[k];
1394 } 1407 }
1395 buffer[0] = new Qfloat[2*l]; 1408 buffer[0] = new Qfloat[2*l];
1396 buffer[1] = new Qfloat[2*l]; 1409 buffer[1] = new Qfloat[2*l];
1397 next_buffer = 0; 1410 next_buffer = 0;
1398 } 1411 }
1421 for(j=0;j<len;j++) 1434 for(j=0;j<len;j++)
1422 buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]]; 1435 buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]];
1423 return buf; 1436 return buf;
1424 } 1437 }
1425 1438
1426 Qfloat *get_QD() const 1439 double *get_QD() const
1427 { 1440 {
1428 return QD; 1441 return QD;
1429 } 1442 }
1430 1443
1431 ~SVR_Q() 1444 ~SVR_Q()
1442 Cache *cache; 1455 Cache *cache;
1443 schar *sign; 1456 schar *sign;
1444 int *index; 1457 int *index;
1445 mutable int next_buffer; 1458 mutable int next_buffer;
1446 Qfloat *buffer[2]; 1459 Qfloat *buffer[2];
1447 Qfloat *QD; 1460 double *QD;
1448 }; 1461 };
1449 1462
1450 // 1463 //
1451 // construct and solve various formulations 1464 // construct and solve various formulations
1452 // 1465 //
1462 1475
1463 for(i=0;i<l;i++) 1476 for(i=0;i<l;i++)
1464 { 1477 {
1465 alpha[i] = 0; 1478 alpha[i] = 0;
1466 minus_ones[i] = -1; 1479 minus_ones[i] = -1;
1467 if(prob->y[i] > 0) y[i] = +1; else y[i]=-1; 1480 if(prob->y[i] > 0) y[i] = +1; else y[i] = -1;
1468 } 1481 }
1469 1482
1470 Solver s; 1483 Solver s;
1471 s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y, 1484 s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
1472 alpha, Cp, Cn, param->eps, si, param->shrinking); 1485 alpha, Cp, Cn, param->eps, si, param->shrinking);
1712 f.alpha = alpha; 1725 f.alpha = alpha;
1713 f.rho = si.rho; 1726 f.rho = si.rho;
1714 return f; 1727 return f;
1715 } 1728 }
1716 1729
1717 //
1718 // svm_model
1719 //
1720 /*
1721 struct svm_model
1722 {
1723 svm_parameter param; // parameter
1724 int nr_class; // number of classes, = 2 in regression/one class svm
1725 int l; // total #SV
1726 svm_node **SV; // SVs (SV[l])
1727 double **sv_coef; // coefficients for SVs in decision functions (sv_coef[k-1][l])
1728 double *rho; // constants in decision functions (rho[k*(k-1)/2])
1729 double *probA; // pariwise probability information
1730 double *probB;
1731
1732 // for classification only
1733
1734 int *label; // label of each class (label[k])
1735 int *nSV; // number of SVs for each class (nSV[k])
1736 // nSV[0] + nSV[1] + ... + nSV[k-1] = l
1737 // XXX
1738 int free_sv; // 1 if svm_model is created by svm_load_model
1739 // 0 if svm_model is created by svm_train
1740 };
1741 */
1742 // Platt's binary SVM Probablistic Output: an improvement from Lin et al. 1730 // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
1743 static void sigmoid_train( 1731 static void sigmoid_train(
1744 int l, const double *dec_values, const double *labels, 1732 int l, const double *dec_values, const double *labels,
1745 double& A, double& B) 1733 double& A, double& B)
1746 { 1734 {
1854 } 1842 }
1855 1843
1856 static double sigmoid_predict(double decision_value, double A, double B) 1844 static double sigmoid_predict(double decision_value, double A, double B)
1857 { 1845 {
1858 double fApB = decision_value*A+B; 1846 double fApB = decision_value*A+B;
1847 // 1-p used later; avoid catastrophic cancellation
1859 if (fApB >= 0) 1848 if (fApB >= 0)
1860 return exp(-fApB)/(1.0+exp(-fApB)); 1849 return exp(-fApB)/(1.0+exp(-fApB));
1861 else 1850 else
1862 return 1.0/(1+exp(fApB)) ; 1851 return 1.0/(1+exp(fApB)) ;
1863 } 1852 }
2000 { 1989 {
2001 svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); 1990 svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]]));
2002 // ensure +1 -1 order; reason not using CV subroutine 1991 // ensure +1 -1 order; reason not using CV subroutine
2003 dec_values[perm[j]] *= submodel->label[0]; 1992 dec_values[perm[j]] *= submodel->label[0];
2004 } 1993 }
2005 svm_destroy_model(submodel); 1994 svm_free_and_destroy_model(&submodel);
2006 svm_destroy_param(&subparam); 1995 svm_destroy_param(&subparam);
2007 } 1996 }
2008 free(subprob.x); 1997 free(subprob.x);
2009 free(subprob.y); 1998 free(subprob.y);
2010 } 1999 }
2164 int *start = NULL; 2153 int *start = NULL;
2165 int *count = NULL; 2154 int *count = NULL;
2166 int *perm = Malloc(int,l); 2155 int *perm = Malloc(int,l);
2167 2156
2168 // group training data of the same class 2157 // group training data of the same class
2169 svm_group_classes(prob,&nr_class,&label,&start,&count,perm); 2158 svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2159 if(nr_class == 1)
2160 info("WARNING: training data in only one class. See README for details.\n");
2161
2170 svm_node **x = Malloc(svm_node *,l); 2162 svm_node **x = Malloc(svm_node *,l);
2171 int i; 2163 int i;
2172 for(i=0;i<l;i++) 2164 for(i=0;i<l;i++)
2173 x[i] = prob->x[perm[i]]; 2165 x[i] = prob->x[perm[i]];
2174 2166
2182 int j; 2174 int j;
2183 for(j=0;j<nr_class;j++) 2175 for(j=0;j<nr_class;j++)
2184 if(param->weight_label[i] == label[j]) 2176 if(param->weight_label[i] == label[j])
2185 break; 2177 break;
2186 if(j == nr_class) 2178 if(j == nr_class)
2187 fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]); 2179 fprintf(stderr,"WARNING: class label %d specified in weight is not found\n", param->weight_label[i]);
2188 else 2180 else
2189 weighted_C[j] *= param->weight[i]; 2181 weighted_C[j] *= param->weight[i];
2190 } 2182 }
2191 2183
2192 // train k*(k-1)/2 models 2184 // train k*(k-1)/2 models
2450 free(prob_estimates); 2442 free(prob_estimates);
2451 } 2443 }
2452 else 2444 else
2453 for(j=begin;j<end;j++) 2445 for(j=begin;j<end;j++)
2454 target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]); 2446 target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
2455 svm_destroy_model(submodel); 2447 svm_free_and_destroy_model(&submodel);
2456 free(subprob.x); 2448 free(subprob.x);
2457 free(subprob.y); 2449 free(subprob.y);
2458 } 2450 }
2459 free(fold_start); 2451 free(fold_start);
2460 free(perm); 2452 free(perm);
2488 fprintf(stderr,"Model doesn't contain information for SVR probability inference\n"); 2480 fprintf(stderr,"Model doesn't contain information for SVR probability inference\n");
2489 return 0; 2481 return 0;
2490 } 2482 }
2491 } 2483 }
2492 2484
2493 void svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values) 2485 double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
2494 { 2486 {
2487 int i;
2495 if(model->param.svm_type == ONE_CLASS || 2488 if(model->param.svm_type == ONE_CLASS ||
2496 model->param.svm_type == EPSILON_SVR || 2489 model->param.svm_type == EPSILON_SVR ||
2497 model->param.svm_type == NU_SVR) 2490 model->param.svm_type == NU_SVR)
2498 { 2491 {
2499 double *sv_coef = model->sv_coef[0]; 2492 double *sv_coef = model->sv_coef[0];
2500 double sum = 0; 2493 double sum = 0;
2501 for(int i=0;i<model->l;i++) 2494 for(i=0;i<model->l;i++)
2502 sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param); 2495 sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
2503 sum -= model->rho[0]; 2496 sum -= model->rho[0];
2504 *dec_values = sum; 2497 *dec_values = sum;
2498
2499 if(model->param.svm_type == ONE_CLASS)
2500 return (sum>0)?1:-1;
2501 else
2502 return sum;
2505 } 2503 }
2506 else 2504 else
2507 { 2505 {
2508 int i;
2509 int nr_class = model->nr_class; 2506 int nr_class = model->nr_class;
2510 int l = model->l; 2507 int l = model->l;
2511 2508
2512 double *kvalue = Malloc(double,l); 2509 double *kvalue = Malloc(double,l);
2513 for(i=0;i<l;i++) 2510 for(i=0;i<l;i++)
2515 2512
2516 int *start = Malloc(int,nr_class); 2513 int *start = Malloc(int,nr_class);
2517 start[0] = 0; 2514 start[0] = 0;
2518 for(i=1;i<nr_class;i++) 2515 for(i=1;i<nr_class;i++)
2519 start[i] = start[i-1]+model->nSV[i-1]; 2516 start[i] = start[i-1]+model->nSV[i-1];
2517
2518 int *vote = Malloc(int,nr_class);
2519 for(i=0;i<nr_class;i++)
2520 vote[i] = 0;
2520 2521
2521 int p=0; 2522 int p=0;
2522 for(i=0;i<nr_class;i++) 2523 for(i=0;i<nr_class;i++)
2523 for(int j=i+1;j<nr_class;j++) 2524 for(int j=i+1;j<nr_class;j++)
2524 { 2525 {
2535 sum += coef1[si+k] * kvalue[si+k]; 2536 sum += coef1[si+k] * kvalue[si+k];
2536 for(k=0;k<cj;k++) 2537 for(k=0;k<cj;k++)
2537 sum += coef2[sj+k] * kvalue[sj+k]; 2538 sum += coef2[sj+k] * kvalue[sj+k];
2538 sum -= model->rho[p]; 2539 sum -= model->rho[p];
2539 dec_values[p] = sum; 2540 dec_values[p] = sum;
2540 p++; 2541
2541 } 2542 if(dec_values[p] > 0)
2542
2543 free(kvalue);
2544 free(start);
2545 }
2546 }
2547
2548 double svm_predict(const svm_model *model, const svm_node *x)
2549 {
2550 if(model->param.svm_type == ONE_CLASS ||
2551 model->param.svm_type == EPSILON_SVR ||
2552 model->param.svm_type == NU_SVR)
2553 {
2554 double res;
2555 svm_predict_values(model, x, &res);
2556
2557 if(model->param.svm_type == ONE_CLASS)
2558 return (res>0)?1:-1;
2559 else
2560 return res;
2561 }
2562 else
2563 {
2564 int i;
2565 int nr_class = model->nr_class;
2566 double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2567 svm_predict_values(model, x, dec_values);
2568
2569 int *vote = Malloc(int,nr_class);
2570 for(i=0;i<nr_class;i++)
2571 vote[i] = 0;
2572 int pos=0;
2573 for(i=0;i<nr_class;i++)
2574 for(int j=i+1;j<nr_class;j++)
2575 {
2576 if(dec_values[pos++] > 0)
2577 ++vote[i]; 2543 ++vote[i];
2578 else 2544 else
2579 ++vote[j]; 2545 ++vote[j];
2546 p++;
2580 } 2547 }
2581 2548
2582 int vote_max_idx = 0; 2549 int vote_max_idx = 0;
2583 for(i=1;i<nr_class;i++) 2550 for(i=1;i<nr_class;i++)
2584 if(vote[i] > vote[vote_max_idx]) 2551 if(vote[i] > vote[vote_max_idx])
2585 vote_max_idx = i; 2552 vote_max_idx = i;
2553
2554 free(kvalue);
2555 free(start);
2586 free(vote); 2556 free(vote);
2587 free(dec_values);
2588 return model->label[vote_max_idx]; 2557 return model->label[vote_max_idx];
2589 } 2558 }
2559 }
2560
2561 double svm_predict(const svm_model *model, const svm_node *x)
2562 {
2563 int nr_class = model->nr_class;
2564 double *dec_values;
2565 if(model->param.svm_type == ONE_CLASS ||
2566 model->param.svm_type == EPSILON_SVR ||
2567 model->param.svm_type == NU_SVR)
2568 dec_values = Malloc(double, 1);
2569 else
2570 dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2571 double pred_result = svm_predict_values(model, x, dec_values);
2572 free(dec_values);
2573 return pred_result;
2590 } 2574 }
2591 2575
2592 double svm_predict_probability( 2576 double svm_predict_probability(
2593 const svm_model *model, const svm_node *x, double *prob_estimates) 2577 const svm_model *model, const svm_node *x, double *prob_estimates)
2594 { 2578 {
2641 int svm_save_model(const char *model_file_name, const svm_model *model) 2625 int svm_save_model(const char *model_file_name, const svm_model *model)
2642 { 2626 {
2643 FILE *fp = fopen(model_file_name,"w"); 2627 FILE *fp = fopen(model_file_name,"w");
2644 if(fp==NULL) return -1; 2628 if(fp==NULL) return -1;
2645 2629
2630 char *old_locale = strdup(setlocale(LC_ALL, NULL));
2631 setlocale(LC_ALL, "C");
2632
2646 const svm_parameter& param = model->param; 2633 const svm_parameter& param = model->param;
2647 2634
2648 fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]); 2635 fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
2649 fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]); 2636 fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
2650 2637
2719 fprintf(fp,"%d:%.8g ",p->index,p->value); 2706 fprintf(fp,"%d:%.8g ",p->index,p->value);
2720 p++; 2707 p++;
2721 } 2708 }
2722 fprintf(fp, "\n"); 2709 fprintf(fp, "\n");
2723 } 2710 }
2711
2712 setlocale(LC_ALL, old_locale);
2713 free(old_locale);
2714
2724 if (ferror(fp) != 0 || fclose(fp) != 0) return -1; 2715 if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
2725 else return 0; 2716 else return 0;
2726 } 2717 }
2727 2718
2728 static char *line = NULL; 2719 static char *line = NULL;
2748 2739
2749 svm_model *svm_load_model(const char *model_file_name) 2740 svm_model *svm_load_model(const char *model_file_name)
2750 { 2741 {
2751 FILE *fp = fopen(model_file_name,"rb"); 2742 FILE *fp = fopen(model_file_name,"rb");
2752 if(fp==NULL) return NULL; 2743 if(fp==NULL) return NULL;
2753 2744
2745 char *old_locale = strdup(setlocale(LC_ALL, NULL));
2746 setlocale(LC_ALL, "C");
2747
2754 // read parameters 2748 // read parameters
2755 2749
2756 svm_model *model = Malloc(svm_model,1); 2750 svm_model *model = Malloc(svm_model,1);
2757 svm_parameter& param = model->param; 2751 svm_parameter& param = model->param;
2758 model->rho = NULL; 2752 model->rho = NULL;
2779 } 2773 }
2780 } 2774 }
2781 if(svm_type_table[i] == NULL) 2775 if(svm_type_table[i] == NULL)
2782 { 2776 {
2783 fprintf(stderr,"unknown svm type.\n"); 2777 fprintf(stderr,"unknown svm type.\n");
2778
2779 setlocale(LC_ALL, old_locale);
2780 free(old_locale);
2784 free(model->rho); 2781 free(model->rho);
2785 free(model->label); 2782 free(model->label);
2786 free(model->nSV); 2783 free(model->nSV);
2787 free(model); 2784 free(model);
2788 return NULL; 2785 return NULL;
2801 } 2798 }
2802 } 2799 }
2803 if(kernel_type_table[i] == NULL) 2800 if(kernel_type_table[i] == NULL)
2804 { 2801 {
2805 fprintf(stderr,"unknown kernel function.\n"); 2802 fprintf(stderr,"unknown kernel function.\n");
2803
2804 setlocale(LC_ALL, old_locale);
2805 free(old_locale);
2806 free(model->rho); 2806 free(model->rho);
2807 free(model->label); 2807 free(model->label);
2808 free(model->nSV); 2808 free(model->nSV);
2809 free(model); 2809 free(model);
2810 return NULL; 2810 return NULL;
2865 break; 2865 break;
2866 } 2866 }
2867 else 2867 else
2868 { 2868 {
2869 fprintf(stderr,"unknown text in model file: [%s]\n",cmd); 2869 fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
2870
2871 setlocale(LC_ALL, old_locale);
2872 free(old_locale);
2870 free(model->rho); 2873 free(model->rho);
2871 free(model->label); 2874 free(model->label);
2872 free(model->nSV); 2875 free(model->nSV);
2873 free(model); 2876 free(model);
2874 return NULL; 2877 return NULL;
2937 } 2940 }
2938 x_space[j++].index = -1; 2941 x_space[j++].index = -1;
2939 } 2942 }
2940 free(line); 2943 free(line);
2941 2944
2945 setlocale(LC_ALL, old_locale);
2946 free(old_locale);
2947
2942 if (ferror(fp) != 0 || fclose(fp) != 0) 2948 if (ferror(fp) != 0 || fclose(fp) != 0)
2943 return NULL; 2949 return NULL;
2944 2950
2945 model->free_sv = 1; // XXX 2951 model->free_sv = 1; // XXX
2946 return model; 2952 return model;
2947 } 2953 }
2948 2954
2949 void svm_destroy_model(svm_model* model) 2955 void svm_free_model_content(svm_model* model_ptr)
2950 { 2956 {
2951 if(model->free_sv && model->l > 0) 2957 if(model_ptr->free_sv && model_ptr->l > 0 && model_ptr->SV != NULL)
2952 free((void *)(model->SV[0])); 2958 free((void *)(model_ptr->SV[0]));
2953 for(int i=0;i<model->nr_class-1;i++) 2959 if(model_ptr->sv_coef)
2954 free(model->sv_coef[i]); 2960 {
2955 free(model->SV); 2961 for(int i=0;i<model_ptr->nr_class-1;i++)
2956 free(model->sv_coef); 2962 free(model_ptr->sv_coef[i]);
2957 free(model->rho); 2963 }
2958 free(model->label); 2964
2959 free(model->probA); 2965 free(model_ptr->SV);
2960 free(model->probB); 2966 model_ptr->SV = NULL;
2961 free(model->nSV); 2967
2962 free(model); 2968 free(model_ptr->sv_coef);
2969 model_ptr->sv_coef = NULL;
2970
2971 free(model_ptr->rho);
2972 model_ptr->rho = NULL;
2973
2974 free(model_ptr->label);
2975 model_ptr->label= NULL;
2976
2977 free(model_ptr->probA);
2978 model_ptr->probA = NULL;
2979
2980 free(model_ptr->probB);
2981 model_ptr->probB= NULL;
2982
2983 free(model_ptr->nSV);
2984 model_ptr->nSV = NULL;
2985 }
2986
2987 void svm_free_and_destroy_model(svm_model** model_ptr_ptr)
2988 {
2989 if(model_ptr_ptr != NULL && *model_ptr_ptr != NULL)
2990 {
2991 svm_free_model_content(*model_ptr_ptr);
2992 free(*model_ptr_ptr);
2993 *model_ptr_ptr = NULL;
2994 }
2963 } 2995 }
2964 2996
2965 void svm_destroy_param(svm_parameter* param) 2997 void svm_destroy_param(svm_parameter* param)
2966 { 2998 {
2967 free(param->weight_label); 2999 free(param->weight_label);
3094 return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && 3126 return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
3095 model->probA!=NULL && model->probB!=NULL) || 3127 model->probA!=NULL && model->probB!=NULL) ||
3096 ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && 3128 ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
3097 model->probA!=NULL); 3129 model->probA!=NULL);
3098 } 3130 }
3131
3132 void svm_set_print_string_function(void (*print_func)(const char *))
3133 {
3134 if(print_func == NULL)
3135 svm_print_string = &print_string_stdout;
3136 else
3137 svm_print_string = print_func;
3138 }