Mercurial > forge
comparison extra/NaN/src/svm.cpp @ 12589:06a805605e9a octave-forge
[nan] upgrade libsvm to v3.12
author | schloegl |
---|---|
date | Sun, 12 Apr 2015 14:37:46 +0000 |
parents | 6a419bec96bb |
children |
comparison
equal
deleted
inserted
replaced
12588:3f24658504ab | 12589:06a805605e9a |
---|---|
1 /* | 1 /* |
2 | 2 |
3 | 3 Copyright (c) 2000-2012 Chih-Chung Chang and Chih-Jen Lin |
4 $Id$ | 4 Copyright (c) 2010,2011,2015 Alois Schloegl <alois.schloegl@ist.ac.at> |
5 Copyright (c) 2000-2009 Chih-Chung Chang and Chih-Jen Lin | |
6 Copyright (c) 2010 Alois Schloegl <alois.schloegl@gmail.com> | |
7 This function is part of the NaN-toolbox | 5 This function is part of the NaN-toolbox |
8 http://pub.ist.ac.at/~schloegl/matlab/NaN/ | 6 http://pub.ist.ac.at/~schloegl/matlab/NaN/ |
9 | 7 |
10 This code was extracted from libsvm-mat-2.9-1 in Jan 2010 and | 8 This code was extracted from libsvm-3.12 in Apr 2015 and |
11 modified for the use with Octave | 9 modified for the use with Octave |
12 | 10 |
13 This program is free software; you can redistribute it and/or modify | 11 This program is free software; you can redistribute it and/or modify |
14 it under the terms of the GNU General Public License as published by | 12 it under the terms of the GNU General Public License as published by |
15 the Free Software Foundation; either version 3 of the License, or | 13 the Free Software Foundation; either version 3 of the License, or |
30 #include <stdlib.h> | 28 #include <stdlib.h> |
31 #include <ctype.h> | 29 #include <ctype.h> |
32 #include <float.h> | 30 #include <float.h> |
33 #include <string.h> | 31 #include <string.h> |
34 #include <stdarg.h> | 32 #include <stdarg.h> |
33 #include <limits.h> | |
34 #include <locale.h> | |
35 #include "svm.h" | 35 #include "svm.h" |
36 | 36 |
37 int libsvm_version = LIBSVM_VERSION; | 37 int libsvm_version = LIBSVM_VERSION; |
38 typedef float Qfloat; | 38 typedef float Qfloat; |
39 typedef signed char schar; | 39 typedef signed char schar; |
67 static void print_string_stdout(const char *s) | 67 static void print_string_stdout(const char *s) |
68 { | 68 { |
69 fputs(s,stdout); | 69 fputs(s,stdout); |
70 fflush(stdout); | 70 fflush(stdout); |
71 } | 71 } |
72 void (*svm_print_string) (const char *) = &print_string_stdout; | 72 static void (*svm_print_string) (const char *) = &print_string_stdout; |
73 #if 1 | 73 #if 1 |
74 static void info(const char *fmt,...) | 74 static void info(const char *fmt,...) |
75 { | 75 { |
76 char buf[BUFSIZ]; | 76 char buf[BUFSIZ]; |
77 va_list ap; | 77 va_list ap; |
218 // the member function get_Q is for getting one column from the Q Matrix | 218 // the member function get_Q is for getting one column from the Q Matrix |
219 // | 219 // |
220 class QMatrix { | 220 class QMatrix { |
221 public: | 221 public: |
222 virtual Qfloat *get_Q(int column, int len) const = 0; | 222 virtual Qfloat *get_Q(int column, int len) const = 0; |
223 virtual Qfloat *get_QD() const = 0; | 223 virtual double *get_QD() const = 0; |
224 virtual void swap_index(int i, int j) const = 0; | 224 virtual void swap_index(int i, int j) const = 0; |
225 virtual ~QMatrix() {} | 225 virtual ~QMatrix() {} |
226 }; | 226 }; |
227 | 227 |
228 class Kernel: public QMatrix { | 228 class Kernel: public QMatrix { |
231 virtual ~Kernel(); | 231 virtual ~Kernel(); |
232 | 232 |
233 static double k_function(const svm_node *x, const svm_node *y, | 233 static double k_function(const svm_node *x, const svm_node *y, |
234 const svm_parameter& param); | 234 const svm_parameter& param); |
235 virtual Qfloat *get_Q(int column, int len) const = 0; | 235 virtual Qfloat *get_Q(int column, int len) const = 0; |
236 virtual Qfloat *get_QD() const = 0; | 236 virtual double *get_QD() const = 0; |
237 virtual void swap_index(int i, int j) const // no so const... | 237 virtual void swap_index(int i, int j) const // no so const... |
238 { | 238 { |
239 swap(x[i],x[j]); | 239 swap(x[i],x[j]); |
240 if(x_square) swap(x_square[i],x_square[j]); | 240 if(x_square) swap(x_square[i],x_square[j]); |
241 } | 241 } |
438 double *G; // gradient of objective function | 438 double *G; // gradient of objective function |
439 enum { LOWER_BOUND, UPPER_BOUND, FREE }; | 439 enum { LOWER_BOUND, UPPER_BOUND, FREE }; |
440 char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE | 440 char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE |
441 double *alpha; | 441 double *alpha; |
442 const QMatrix *Q; | 442 const QMatrix *Q; |
443 const Qfloat *QD; | 443 const double *QD; |
444 double eps; | 444 double eps; |
445 double Cp,Cn; | 445 double Cp,Cn; |
446 double *p; | 446 double *p; |
447 int *active_set; | 447 int *active_set; |
448 double *G_bar; // gradient, if we treat free variables as 0 | 448 double *G_bar; // gradient, if we treat free variables as 0 |
500 for(j=0;j<active_size;j++) | 500 for(j=0;j<active_size;j++) |
501 if(is_free(j)) | 501 if(is_free(j)) |
502 nr_free++; | 502 nr_free++; |
503 | 503 |
504 if(2*nr_free < active_size) | 504 if(2*nr_free < active_size) |
505 info("\nWarning: using -h 0 may be faster\n"); | 505 info("\nWARNING: using -h 0 may be faster\n"); |
506 | 506 |
507 if (nr_free*l > 2*active_size*(l-active_size)) | 507 if (nr_free*l > 2*active_size*(l-active_size)) |
508 { | 508 { |
509 for(i=active_size;i<l;i++) | 509 for(i=active_size;i<l;i++) |
510 { | 510 { |
582 } | 582 } |
583 | 583 |
584 // optimization step | 584 // optimization step |
585 | 585 |
586 int iter = 0; | 586 int iter = 0; |
587 int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l); | |
587 int counter = min(l,1000)+1; | 588 int counter = min(l,1000)+1; |
588 | 589 |
589 while(1) | 590 while(iter < max_iter) |
590 { | 591 { |
591 // show progress and do shrinking | 592 // show progress and do shrinking |
592 | 593 |
593 if(--counter == 0) | 594 if(--counter == 0) |
594 { | 595 { |
624 double old_alpha_i = alpha[i]; | 625 double old_alpha_i = alpha[i]; |
625 double old_alpha_j = alpha[j]; | 626 double old_alpha_j = alpha[j]; |
626 | 627 |
627 if(y[i]!=y[j]) | 628 if(y[i]!=y[j]) |
628 { | 629 { |
629 double quad_coef = Q_i[i]+Q_j[j]+2*Q_i[j]; | 630 double quad_coef = QD[i]+QD[j]+2*Q_i[j]; |
630 if (quad_coef <= 0) | 631 if (quad_coef <= 0) |
631 quad_coef = TAU; | 632 quad_coef = TAU; |
632 double delta = (-G[i]-G[j])/quad_coef; | 633 double delta = (-G[i]-G[j])/quad_coef; |
633 double diff = alpha[i] - alpha[j]; | 634 double diff = alpha[i] - alpha[j]; |
634 alpha[i] += delta; | 635 alpha[i] += delta; |
667 } | 668 } |
668 } | 669 } |
669 } | 670 } |
670 else | 671 else |
671 { | 672 { |
672 double quad_coef = Q_i[i]+Q_j[j]-2*Q_i[j]; | 673 double quad_coef = QD[i]+QD[j]-2*Q_i[j]; |
673 if (quad_coef <= 0) | 674 if (quad_coef <= 0) |
674 quad_coef = TAU; | 675 quad_coef = TAU; |
675 double delta = (G[i]-G[j])/quad_coef; | 676 double delta = (G[i]-G[j])/quad_coef; |
676 double sum = alpha[i] + alpha[j]; | 677 double sum = alpha[i] + alpha[j]; |
677 alpha[i] -= delta; | 678 alpha[i] -= delta; |
749 else | 750 else |
750 for(k=0;k<l;k++) | 751 for(k=0;k<l;k++) |
751 G_bar[k] += C_j * Q_j[k]; | 752 G_bar[k] += C_j * Q_j[k]; |
752 } | 753 } |
753 } | 754 } |
755 } | |
756 | |
757 if(iter >= max_iter) | |
758 { | |
759 if(active_size < l) | |
760 { | |
761 // reconstruct the whole gradient to calculate objective value | |
762 reconstruct_gradient(); | |
763 active_size = l; | |
764 info("*"); | |
765 } | |
766 info("\nWARNING: reaching max number of iterations"); | |
754 } | 767 } |
755 | 768 |
756 // calculate rho | 769 // calculate rho |
757 | 770 |
758 si->rho = calculate_rho(); | 771 si->rho = calculate_rho(); |
845 if (G[j] >= Gmax2) | 858 if (G[j] >= Gmax2) |
846 Gmax2 = G[j]; | 859 Gmax2 = G[j]; |
847 if (grad_diff > 0) | 860 if (grad_diff > 0) |
848 { | 861 { |
849 double obj_diff; | 862 double obj_diff; |
850 double quad_coef=Q_i[i]+QD[j]-2.0*y[i]*Q_i[j]; | 863 double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j]; |
851 if (quad_coef > 0) | 864 if (quad_coef > 0) |
852 obj_diff = -(grad_diff*grad_diff)/quad_coef; | 865 obj_diff = -(grad_diff*grad_diff)/quad_coef; |
853 else | 866 else |
854 obj_diff = -(grad_diff*grad_diff)/TAU; | 867 obj_diff = -(grad_diff*grad_diff)/TAU; |
855 | 868 |
869 if (-G[j] >= Gmax2) | 882 if (-G[j] >= Gmax2) |
870 Gmax2 = -G[j]; | 883 Gmax2 = -G[j]; |
871 if (grad_diff > 0) | 884 if (grad_diff > 0) |
872 { | 885 { |
873 double obj_diff; | 886 double obj_diff; |
874 double quad_coef=Q_i[i]+QD[j]+2.0*y[i]*Q_i[j]; | 887 double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j]; |
875 if (quad_coef > 0) | 888 if (quad_coef > 0) |
876 obj_diff = -(grad_diff*grad_diff)/quad_coef; | 889 obj_diff = -(grad_diff*grad_diff)/quad_coef; |
877 else | 890 else |
878 obj_diff = -(grad_diff*grad_diff)/TAU; | 891 obj_diff = -(grad_diff*grad_diff)/TAU; |
879 | 892 |
1097 if (G[j] >= Gmaxp2) | 1110 if (G[j] >= Gmaxp2) |
1098 Gmaxp2 = G[j]; | 1111 Gmaxp2 = G[j]; |
1099 if (grad_diff > 0) | 1112 if (grad_diff > 0) |
1100 { | 1113 { |
1101 double obj_diff; | 1114 double obj_diff; |
1102 double quad_coef = Q_ip[ip]+QD[j]-2*Q_ip[j]; | 1115 double quad_coef = QD[ip]+QD[j]-2*Q_ip[j]; |
1103 if (quad_coef > 0) | 1116 if (quad_coef > 0) |
1104 obj_diff = -(grad_diff*grad_diff)/quad_coef; | 1117 obj_diff = -(grad_diff*grad_diff)/quad_coef; |
1105 else | 1118 else |
1106 obj_diff = -(grad_diff*grad_diff)/TAU; | 1119 obj_diff = -(grad_diff*grad_diff)/TAU; |
1107 | 1120 |
1121 if (-G[j] >= Gmaxn2) | 1134 if (-G[j] >= Gmaxn2) |
1122 Gmaxn2 = -G[j]; | 1135 Gmaxn2 = -G[j]; |
1123 if (grad_diff > 0) | 1136 if (grad_diff > 0) |
1124 { | 1137 { |
1125 double obj_diff; | 1138 double obj_diff; |
1126 double quad_coef = Q_in[in]+QD[j]-2*Q_in[j]; | 1139 double quad_coef = QD[in]+QD[j]-2*Q_in[j]; |
1127 if (quad_coef > 0) | 1140 if (quad_coef > 0) |
1128 obj_diff = -(grad_diff*grad_diff)/quad_coef; | 1141 obj_diff = -(grad_diff*grad_diff)/quad_coef; |
1129 else | 1142 else |
1130 obj_diff = -(grad_diff*grad_diff)/TAU; | 1143 obj_diff = -(grad_diff*grad_diff)/TAU; |
1131 | 1144 |
1282 SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_) | 1295 SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_) |
1283 :Kernel(prob.l, prob.x, param) | 1296 :Kernel(prob.l, prob.x, param) |
1284 { | 1297 { |
1285 clone(y,y_,prob.l); | 1298 clone(y,y_,prob.l); |
1286 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); | 1299 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); |
1287 QD = new Qfloat[prob.l]; | 1300 QD = new double[prob.l]; |
1288 for(int i=0;i<prob.l;i++) | 1301 for(int i=0;i<prob.l;i++) |
1289 QD[i]= (Qfloat)(this->*kernel_function)(i,i); | 1302 QD[i] = (this->*kernel_function)(i,i); |
1290 } | 1303 } |
1291 | 1304 |
1292 Qfloat *get_Q(int i, int len) const | 1305 Qfloat *get_Q(int i, int len) const |
1293 { | 1306 { |
1294 Qfloat *data; | 1307 Qfloat *data; |
1299 data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j)); | 1312 data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j)); |
1300 } | 1313 } |
1301 return data; | 1314 return data; |
1302 } | 1315 } |
1303 | 1316 |
1304 Qfloat *get_QD() const | 1317 double *get_QD() const |
1305 { | 1318 { |
1306 return QD; | 1319 return QD; |
1307 } | 1320 } |
1308 | 1321 |
1309 void swap_index(int i, int j) const | 1322 void swap_index(int i, int j) const |
1321 delete[] QD; | 1334 delete[] QD; |
1322 } | 1335 } |
1323 private: | 1336 private: |
1324 schar *y; | 1337 schar *y; |
1325 Cache *cache; | 1338 Cache *cache; |
1326 Qfloat *QD; | 1339 double *QD; |
1327 }; | 1340 }; |
1328 | 1341 |
1329 class ONE_CLASS_Q: public Kernel | 1342 class ONE_CLASS_Q: public Kernel |
1330 { | 1343 { |
1331 public: | 1344 public: |
1332 ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param) | 1345 ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param) |
1333 :Kernel(prob.l, prob.x, param) | 1346 :Kernel(prob.l, prob.x, param) |
1334 { | 1347 { |
1335 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); | 1348 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); |
1336 QD = new Qfloat[prob.l]; | 1349 QD = new double[prob.l]; |
1337 for(int i=0;i<prob.l;i++) | 1350 for(int i=0;i<prob.l;i++) |
1338 QD[i]= (Qfloat)(this->*kernel_function)(i,i); | 1351 QD[i] = (this->*kernel_function)(i,i); |
1339 } | 1352 } |
1340 | 1353 |
1341 Qfloat *get_Q(int i, int len) const | 1354 Qfloat *get_Q(int i, int len) const |
1342 { | 1355 { |
1343 Qfloat *data; | 1356 Qfloat *data; |
1348 data[j] = (Qfloat)(this->*kernel_function)(i,j); | 1361 data[j] = (Qfloat)(this->*kernel_function)(i,j); |
1349 } | 1362 } |
1350 return data; | 1363 return data; |
1351 } | 1364 } |
1352 | 1365 |
1353 Qfloat *get_QD() const | 1366 double *get_QD() const |
1354 { | 1367 { |
1355 return QD; | 1368 return QD; |
1356 } | 1369 } |
1357 | 1370 |
1358 void swap_index(int i, int j) const | 1371 void swap_index(int i, int j) const |
1367 delete cache; | 1380 delete cache; |
1368 delete[] QD; | 1381 delete[] QD; |
1369 } | 1382 } |
1370 private: | 1383 private: |
1371 Cache *cache; | 1384 Cache *cache; |
1372 Qfloat *QD; | 1385 double *QD; |
1373 }; | 1386 }; |
1374 | 1387 |
1375 class SVR_Q: public Kernel | 1388 class SVR_Q: public Kernel |
1376 { | 1389 { |
1377 public: | 1390 public: |
1378 SVR_Q(const svm_problem& prob, const svm_parameter& param) | 1391 SVR_Q(const svm_problem& prob, const svm_parameter& param) |
1379 :Kernel(prob.l, prob.x, param) | 1392 :Kernel(prob.l, prob.x, param) |
1380 { | 1393 { |
1381 l = prob.l; | 1394 l = prob.l; |
1382 cache = new Cache(l,(long int)(param.cache_size*(1<<20))); | 1395 cache = new Cache(l,(long int)(param.cache_size*(1<<20))); |
1383 QD = new Qfloat[2*l]; | 1396 QD = new double[2*l]; |
1384 sign = new schar[2*l]; | 1397 sign = new schar[2*l]; |
1385 index = new int[2*l]; | 1398 index = new int[2*l]; |
1386 for(int k=0;k<l;k++) | 1399 for(int k=0;k<l;k++) |
1387 { | 1400 { |
1388 sign[k] = 1; | 1401 sign[k] = 1; |
1389 sign[k+l] = -1; | 1402 sign[k+l] = -1; |
1390 index[k] = k; | 1403 index[k] = k; |
1391 index[k+l] = k; | 1404 index[k+l] = k; |
1392 QD[k]= (Qfloat)(this->*kernel_function)(k,k); | 1405 QD[k] = (this->*kernel_function)(k,k); |
1393 QD[k+l]=QD[k]; | 1406 QD[k+l] = QD[k]; |
1394 } | 1407 } |
1395 buffer[0] = new Qfloat[2*l]; | 1408 buffer[0] = new Qfloat[2*l]; |
1396 buffer[1] = new Qfloat[2*l]; | 1409 buffer[1] = new Qfloat[2*l]; |
1397 next_buffer = 0; | 1410 next_buffer = 0; |
1398 } | 1411 } |
1421 for(j=0;j<len;j++) | 1434 for(j=0;j<len;j++) |
1422 buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]]; | 1435 buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]]; |
1423 return buf; | 1436 return buf; |
1424 } | 1437 } |
1425 | 1438 |
1426 Qfloat *get_QD() const | 1439 double *get_QD() const |
1427 { | 1440 { |
1428 return QD; | 1441 return QD; |
1429 } | 1442 } |
1430 | 1443 |
1431 ~SVR_Q() | 1444 ~SVR_Q() |
1442 Cache *cache; | 1455 Cache *cache; |
1443 schar *sign; | 1456 schar *sign; |
1444 int *index; | 1457 int *index; |
1445 mutable int next_buffer; | 1458 mutable int next_buffer; |
1446 Qfloat *buffer[2]; | 1459 Qfloat *buffer[2]; |
1447 Qfloat *QD; | 1460 double *QD; |
1448 }; | 1461 }; |
1449 | 1462 |
1450 // | 1463 // |
1451 // construct and solve various formulations | 1464 // construct and solve various formulations |
1452 // | 1465 // |
1462 | 1475 |
1463 for(i=0;i<l;i++) | 1476 for(i=0;i<l;i++) |
1464 { | 1477 { |
1465 alpha[i] = 0; | 1478 alpha[i] = 0; |
1466 minus_ones[i] = -1; | 1479 minus_ones[i] = -1; |
1467 if(prob->y[i] > 0) y[i] = +1; else y[i]=-1; | 1480 if(prob->y[i] > 0) y[i] = +1; else y[i] = -1; |
1468 } | 1481 } |
1469 | 1482 |
1470 Solver s; | 1483 Solver s; |
1471 s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y, | 1484 s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y, |
1472 alpha, Cp, Cn, param->eps, si, param->shrinking); | 1485 alpha, Cp, Cn, param->eps, si, param->shrinking); |
1712 f.alpha = alpha; | 1725 f.alpha = alpha; |
1713 f.rho = si.rho; | 1726 f.rho = si.rho; |
1714 return f; | 1727 return f; |
1715 } | 1728 } |
1716 | 1729 |
1717 // | |
1718 // svm_model | |
1719 // | |
1720 /* | |
1721 struct svm_model | |
1722 { | |
1723 svm_parameter param; // parameter | |
1724 int nr_class; // number of classes, = 2 in regression/one class svm | |
1725 int l; // total #SV | |
1726 svm_node **SV; // SVs (SV[l]) | |
1727 double **sv_coef; // coefficients for SVs in decision functions (sv_coef[k-1][l]) | |
1728 double *rho; // constants in decision functions (rho[k*(k-1)/2]) | |
1729 double *probA; // pariwise probability information | |
1730 double *probB; | |
1731 | |
1732 // for classification only | |
1733 | |
1734 int *label; // label of each class (label[k]) | |
1735 int *nSV; // number of SVs for each class (nSV[k]) | |
1736 // nSV[0] + nSV[1] + ... + nSV[k-1] = l | |
1737 // XXX | |
1738 int free_sv; // 1 if svm_model is created by svm_load_model | |
1739 // 0 if svm_model is created by svm_train | |
1740 }; | |
1741 */ | |
1742 // Platt's binary SVM Probablistic Output: an improvement from Lin et al. | 1730 // Platt's binary SVM Probablistic Output: an improvement from Lin et al. |
1743 static void sigmoid_train( | 1731 static void sigmoid_train( |
1744 int l, const double *dec_values, const double *labels, | 1732 int l, const double *dec_values, const double *labels, |
1745 double& A, double& B) | 1733 double& A, double& B) |
1746 { | 1734 { |
1854 } | 1842 } |
1855 | 1843 |
1856 static double sigmoid_predict(double decision_value, double A, double B) | 1844 static double sigmoid_predict(double decision_value, double A, double B) |
1857 { | 1845 { |
1858 double fApB = decision_value*A+B; | 1846 double fApB = decision_value*A+B; |
1847 // 1-p used later; avoid catastrophic cancellation | |
1859 if (fApB >= 0) | 1848 if (fApB >= 0) |
1860 return exp(-fApB)/(1.0+exp(-fApB)); | 1849 return exp(-fApB)/(1.0+exp(-fApB)); |
1861 else | 1850 else |
1862 return 1.0/(1+exp(fApB)) ; | 1851 return 1.0/(1+exp(fApB)) ; |
1863 } | 1852 } |
2000 { | 1989 { |
2001 svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); | 1990 svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); |
2002 // ensure +1 -1 order; reason not using CV subroutine | 1991 // ensure +1 -1 order; reason not using CV subroutine |
2003 dec_values[perm[j]] *= submodel->label[0]; | 1992 dec_values[perm[j]] *= submodel->label[0]; |
2004 } | 1993 } |
2005 svm_destroy_model(submodel); | 1994 svm_free_and_destroy_model(&submodel); |
2006 svm_destroy_param(&subparam); | 1995 svm_destroy_param(&subparam); |
2007 } | 1996 } |
2008 free(subprob.x); | 1997 free(subprob.x); |
2009 free(subprob.y); | 1998 free(subprob.y); |
2010 } | 1999 } |
2164 int *start = NULL; | 2153 int *start = NULL; |
2165 int *count = NULL; | 2154 int *count = NULL; |
2166 int *perm = Malloc(int,l); | 2155 int *perm = Malloc(int,l); |
2167 | 2156 |
2168 // group training data of the same class | 2157 // group training data of the same class |
2169 svm_group_classes(prob,&nr_class,&label,&start,&count,perm); | 2158 svm_group_classes(prob,&nr_class,&label,&start,&count,perm); |
2159 if(nr_class == 1) | |
2160 info("WARNING: training data in only one class. See README for details.\n"); | |
2161 | |
2170 svm_node **x = Malloc(svm_node *,l); | 2162 svm_node **x = Malloc(svm_node *,l); |
2171 int i; | 2163 int i; |
2172 for(i=0;i<l;i++) | 2164 for(i=0;i<l;i++) |
2173 x[i] = prob->x[perm[i]]; | 2165 x[i] = prob->x[perm[i]]; |
2174 | 2166 |
2182 int j; | 2174 int j; |
2183 for(j=0;j<nr_class;j++) | 2175 for(j=0;j<nr_class;j++) |
2184 if(param->weight_label[i] == label[j]) | 2176 if(param->weight_label[i] == label[j]) |
2185 break; | 2177 break; |
2186 if(j == nr_class) | 2178 if(j == nr_class) |
2187 fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]); | 2179 fprintf(stderr,"WARNING: class label %d specified in weight is not found\n", param->weight_label[i]); |
2188 else | 2180 else |
2189 weighted_C[j] *= param->weight[i]; | 2181 weighted_C[j] *= param->weight[i]; |
2190 } | 2182 } |
2191 | 2183 |
2192 // train k*(k-1)/2 models | 2184 // train k*(k-1)/2 models |
2450 free(prob_estimates); | 2442 free(prob_estimates); |
2451 } | 2443 } |
2452 else | 2444 else |
2453 for(j=begin;j<end;j++) | 2445 for(j=begin;j<end;j++) |
2454 target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]); | 2446 target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]); |
2455 svm_destroy_model(submodel); | 2447 svm_free_and_destroy_model(&submodel); |
2456 free(subprob.x); | 2448 free(subprob.x); |
2457 free(subprob.y); | 2449 free(subprob.y); |
2458 } | 2450 } |
2459 free(fold_start); | 2451 free(fold_start); |
2460 free(perm); | 2452 free(perm); |
2488 fprintf(stderr,"Model doesn't contain information for SVR probability inference\n"); | 2480 fprintf(stderr,"Model doesn't contain information for SVR probability inference\n"); |
2489 return 0; | 2481 return 0; |
2490 } | 2482 } |
2491 } | 2483 } |
2492 | 2484 |
2493 void svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values) | 2485 double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values) |
2494 { | 2486 { |
2487 int i; | |
2495 if(model->param.svm_type == ONE_CLASS || | 2488 if(model->param.svm_type == ONE_CLASS || |
2496 model->param.svm_type == EPSILON_SVR || | 2489 model->param.svm_type == EPSILON_SVR || |
2497 model->param.svm_type == NU_SVR) | 2490 model->param.svm_type == NU_SVR) |
2498 { | 2491 { |
2499 double *sv_coef = model->sv_coef[0]; | 2492 double *sv_coef = model->sv_coef[0]; |
2500 double sum = 0; | 2493 double sum = 0; |
2501 for(int i=0;i<model->l;i++) | 2494 for(i=0;i<model->l;i++) |
2502 sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param); | 2495 sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param); |
2503 sum -= model->rho[0]; | 2496 sum -= model->rho[0]; |
2504 *dec_values = sum; | 2497 *dec_values = sum; |
2498 | |
2499 if(model->param.svm_type == ONE_CLASS) | |
2500 return (sum>0)?1:-1; | |
2501 else | |
2502 return sum; | |
2505 } | 2503 } |
2506 else | 2504 else |
2507 { | 2505 { |
2508 int i; | |
2509 int nr_class = model->nr_class; | 2506 int nr_class = model->nr_class; |
2510 int l = model->l; | 2507 int l = model->l; |
2511 | 2508 |
2512 double *kvalue = Malloc(double,l); | 2509 double *kvalue = Malloc(double,l); |
2513 for(i=0;i<l;i++) | 2510 for(i=0;i<l;i++) |
2515 | 2512 |
2516 int *start = Malloc(int,nr_class); | 2513 int *start = Malloc(int,nr_class); |
2517 start[0] = 0; | 2514 start[0] = 0; |
2518 for(i=1;i<nr_class;i++) | 2515 for(i=1;i<nr_class;i++) |
2519 start[i] = start[i-1]+model->nSV[i-1]; | 2516 start[i] = start[i-1]+model->nSV[i-1]; |
2517 | |
2518 int *vote = Malloc(int,nr_class); | |
2519 for(i=0;i<nr_class;i++) | |
2520 vote[i] = 0; | |
2520 | 2521 |
2521 int p=0; | 2522 int p=0; |
2522 for(i=0;i<nr_class;i++) | 2523 for(i=0;i<nr_class;i++) |
2523 for(int j=i+1;j<nr_class;j++) | 2524 for(int j=i+1;j<nr_class;j++) |
2524 { | 2525 { |
2535 sum += coef1[si+k] * kvalue[si+k]; | 2536 sum += coef1[si+k] * kvalue[si+k]; |
2536 for(k=0;k<cj;k++) | 2537 for(k=0;k<cj;k++) |
2537 sum += coef2[sj+k] * kvalue[sj+k]; | 2538 sum += coef2[sj+k] * kvalue[sj+k]; |
2538 sum -= model->rho[p]; | 2539 sum -= model->rho[p]; |
2539 dec_values[p] = sum; | 2540 dec_values[p] = sum; |
2540 p++; | 2541 |
2541 } | 2542 if(dec_values[p] > 0) |
2542 | |
2543 free(kvalue); | |
2544 free(start); | |
2545 } | |
2546 } | |
2547 | |
2548 double svm_predict(const svm_model *model, const svm_node *x) | |
2549 { | |
2550 if(model->param.svm_type == ONE_CLASS || | |
2551 model->param.svm_type == EPSILON_SVR || | |
2552 model->param.svm_type == NU_SVR) | |
2553 { | |
2554 double res; | |
2555 svm_predict_values(model, x, &res); | |
2556 | |
2557 if(model->param.svm_type == ONE_CLASS) | |
2558 return (res>0)?1:-1; | |
2559 else | |
2560 return res; | |
2561 } | |
2562 else | |
2563 { | |
2564 int i; | |
2565 int nr_class = model->nr_class; | |
2566 double *dec_values = Malloc(double, nr_class*(nr_class-1)/2); | |
2567 svm_predict_values(model, x, dec_values); | |
2568 | |
2569 int *vote = Malloc(int,nr_class); | |
2570 for(i=0;i<nr_class;i++) | |
2571 vote[i] = 0; | |
2572 int pos=0; | |
2573 for(i=0;i<nr_class;i++) | |
2574 for(int j=i+1;j<nr_class;j++) | |
2575 { | |
2576 if(dec_values[pos++] > 0) | |
2577 ++vote[i]; | 2543 ++vote[i]; |
2578 else | 2544 else |
2579 ++vote[j]; | 2545 ++vote[j]; |
2546 p++; | |
2580 } | 2547 } |
2581 | 2548 |
2582 int vote_max_idx = 0; | 2549 int vote_max_idx = 0; |
2583 for(i=1;i<nr_class;i++) | 2550 for(i=1;i<nr_class;i++) |
2584 if(vote[i] > vote[vote_max_idx]) | 2551 if(vote[i] > vote[vote_max_idx]) |
2585 vote_max_idx = i; | 2552 vote_max_idx = i; |
2553 | |
2554 free(kvalue); | |
2555 free(start); | |
2586 free(vote); | 2556 free(vote); |
2587 free(dec_values); | |
2588 return model->label[vote_max_idx]; | 2557 return model->label[vote_max_idx]; |
2589 } | 2558 } |
2559 } | |
2560 | |
2561 double svm_predict(const svm_model *model, const svm_node *x) | |
2562 { | |
2563 int nr_class = model->nr_class; | |
2564 double *dec_values; | |
2565 if(model->param.svm_type == ONE_CLASS || | |
2566 model->param.svm_type == EPSILON_SVR || | |
2567 model->param.svm_type == NU_SVR) | |
2568 dec_values = Malloc(double, 1); | |
2569 else | |
2570 dec_values = Malloc(double, nr_class*(nr_class-1)/2); | |
2571 double pred_result = svm_predict_values(model, x, dec_values); | |
2572 free(dec_values); | |
2573 return pred_result; | |
2590 } | 2574 } |
2591 | 2575 |
2592 double svm_predict_probability( | 2576 double svm_predict_probability( |
2593 const svm_model *model, const svm_node *x, double *prob_estimates) | 2577 const svm_model *model, const svm_node *x, double *prob_estimates) |
2594 { | 2578 { |
2641 int svm_save_model(const char *model_file_name, const svm_model *model) | 2625 int svm_save_model(const char *model_file_name, const svm_model *model) |
2642 { | 2626 { |
2643 FILE *fp = fopen(model_file_name,"w"); | 2627 FILE *fp = fopen(model_file_name,"w"); |
2644 if(fp==NULL) return -1; | 2628 if(fp==NULL) return -1; |
2645 | 2629 |
2630 char *old_locale = strdup(setlocale(LC_ALL, NULL)); | |
2631 setlocale(LC_ALL, "C"); | |
2632 | |
2646 const svm_parameter& param = model->param; | 2633 const svm_parameter& param = model->param; |
2647 | 2634 |
2648 fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]); | 2635 fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]); |
2649 fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]); | 2636 fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]); |
2650 | 2637 |
2719 fprintf(fp,"%d:%.8g ",p->index,p->value); | 2706 fprintf(fp,"%d:%.8g ",p->index,p->value); |
2720 p++; | 2707 p++; |
2721 } | 2708 } |
2722 fprintf(fp, "\n"); | 2709 fprintf(fp, "\n"); |
2723 } | 2710 } |
2711 | |
2712 setlocale(LC_ALL, old_locale); | |
2713 free(old_locale); | |
2714 | |
2724 if (ferror(fp) != 0 || fclose(fp) != 0) return -1; | 2715 if (ferror(fp) != 0 || fclose(fp) != 0) return -1; |
2725 else return 0; | 2716 else return 0; |
2726 } | 2717 } |
2727 | 2718 |
2728 static char *line = NULL; | 2719 static char *line = NULL; |
2748 | 2739 |
2749 svm_model *svm_load_model(const char *model_file_name) | 2740 svm_model *svm_load_model(const char *model_file_name) |
2750 { | 2741 { |
2751 FILE *fp = fopen(model_file_name,"rb"); | 2742 FILE *fp = fopen(model_file_name,"rb"); |
2752 if(fp==NULL) return NULL; | 2743 if(fp==NULL) return NULL; |
2753 | 2744 |
2745 char *old_locale = strdup(setlocale(LC_ALL, NULL)); | |
2746 setlocale(LC_ALL, "C"); | |
2747 | |
2754 // read parameters | 2748 // read parameters |
2755 | 2749 |
2756 svm_model *model = Malloc(svm_model,1); | 2750 svm_model *model = Malloc(svm_model,1); |
2757 svm_parameter& param = model->param; | 2751 svm_parameter& param = model->param; |
2758 model->rho = NULL; | 2752 model->rho = NULL; |
2779 } | 2773 } |
2780 } | 2774 } |
2781 if(svm_type_table[i] == NULL) | 2775 if(svm_type_table[i] == NULL) |
2782 { | 2776 { |
2783 fprintf(stderr,"unknown svm type.\n"); | 2777 fprintf(stderr,"unknown svm type.\n"); |
2778 | |
2779 setlocale(LC_ALL, old_locale); | |
2780 free(old_locale); | |
2784 free(model->rho); | 2781 free(model->rho); |
2785 free(model->label); | 2782 free(model->label); |
2786 free(model->nSV); | 2783 free(model->nSV); |
2787 free(model); | 2784 free(model); |
2788 return NULL; | 2785 return NULL; |
2801 } | 2798 } |
2802 } | 2799 } |
2803 if(kernel_type_table[i] == NULL) | 2800 if(kernel_type_table[i] == NULL) |
2804 { | 2801 { |
2805 fprintf(stderr,"unknown kernel function.\n"); | 2802 fprintf(stderr,"unknown kernel function.\n"); |
2803 | |
2804 setlocale(LC_ALL, old_locale); | |
2805 free(old_locale); | |
2806 free(model->rho); | 2806 free(model->rho); |
2807 free(model->label); | 2807 free(model->label); |
2808 free(model->nSV); | 2808 free(model->nSV); |
2809 free(model); | 2809 free(model); |
2810 return NULL; | 2810 return NULL; |
2865 break; | 2865 break; |
2866 } | 2866 } |
2867 else | 2867 else |
2868 { | 2868 { |
2869 fprintf(stderr,"unknown text in model file: [%s]\n",cmd); | 2869 fprintf(stderr,"unknown text in model file: [%s]\n",cmd); |
2870 | |
2871 setlocale(LC_ALL, old_locale); | |
2872 free(old_locale); | |
2870 free(model->rho); | 2873 free(model->rho); |
2871 free(model->label); | 2874 free(model->label); |
2872 free(model->nSV); | 2875 free(model->nSV); |
2873 free(model); | 2876 free(model); |
2874 return NULL; | 2877 return NULL; |
2937 } | 2940 } |
2938 x_space[j++].index = -1; | 2941 x_space[j++].index = -1; |
2939 } | 2942 } |
2940 free(line); | 2943 free(line); |
2941 | 2944 |
2945 setlocale(LC_ALL, old_locale); | |
2946 free(old_locale); | |
2947 | |
2942 if (ferror(fp) != 0 || fclose(fp) != 0) | 2948 if (ferror(fp) != 0 || fclose(fp) != 0) |
2943 return NULL; | 2949 return NULL; |
2944 | 2950 |
2945 model->free_sv = 1; // XXX | 2951 model->free_sv = 1; // XXX |
2946 return model; | 2952 return model; |
2947 } | 2953 } |
2948 | 2954 |
2949 void svm_destroy_model(svm_model* model) | 2955 void svm_free_model_content(svm_model* model_ptr) |
2950 { | 2956 { |
2951 if(model->free_sv && model->l > 0) | 2957 if(model_ptr->free_sv && model_ptr->l > 0 && model_ptr->SV != NULL) |
2952 free((void *)(model->SV[0])); | 2958 free((void *)(model_ptr->SV[0])); |
2953 for(int i=0;i<model->nr_class-1;i++) | 2959 if(model_ptr->sv_coef) |
2954 free(model->sv_coef[i]); | 2960 { |
2955 free(model->SV); | 2961 for(int i=0;i<model_ptr->nr_class-1;i++) |
2956 free(model->sv_coef); | 2962 free(model_ptr->sv_coef[i]); |
2957 free(model->rho); | 2963 } |
2958 free(model->label); | 2964 |
2959 free(model->probA); | 2965 free(model_ptr->SV); |
2960 free(model->probB); | 2966 model_ptr->SV = NULL; |
2961 free(model->nSV); | 2967 |
2962 free(model); | 2968 free(model_ptr->sv_coef); |
2969 model_ptr->sv_coef = NULL; | |
2970 | |
2971 free(model_ptr->rho); | |
2972 model_ptr->rho = NULL; | |
2973 | |
2974 free(model_ptr->label); | |
2975 model_ptr->label= NULL; | |
2976 | |
2977 free(model_ptr->probA); | |
2978 model_ptr->probA = NULL; | |
2979 | |
2980 free(model_ptr->probB); | |
2981 model_ptr->probB= NULL; | |
2982 | |
2983 free(model_ptr->nSV); | |
2984 model_ptr->nSV = NULL; | |
2985 } | |
2986 | |
2987 void svm_free_and_destroy_model(svm_model** model_ptr_ptr) | |
2988 { | |
2989 if(model_ptr_ptr != NULL && *model_ptr_ptr != NULL) | |
2990 { | |
2991 svm_free_model_content(*model_ptr_ptr); | |
2992 free(*model_ptr_ptr); | |
2993 *model_ptr_ptr = NULL; | |
2994 } | |
2963 } | 2995 } |
2964 | 2996 |
2965 void svm_destroy_param(svm_parameter* param) | 2997 void svm_destroy_param(svm_parameter* param) |
2966 { | 2998 { |
2967 free(param->weight_label); | 2999 free(param->weight_label); |
3094 return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && | 3126 return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && |
3095 model->probA!=NULL && model->probB!=NULL) || | 3127 model->probA!=NULL && model->probB!=NULL) || |
3096 ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && | 3128 ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && |
3097 model->probA!=NULL); | 3129 model->probA!=NULL); |
3098 } | 3130 } |
3131 | |
3132 void svm_set_print_string_function(void (*print_func)(const char *)) | |
3133 { | |
3134 if(print_func == NULL) | |
3135 svm_print_string = &print_string_stdout; | |
3136 else | |
3137 svm_print_string = print_func; | |
3138 } |