Browse Source

support for cross validation, consistent naming

Alexander Freytag 8 years ago
parent
commit
86b76dbb45
4 changed files with 155 additions and 9 deletions
  1. 27 0
      binary/do_binary_cross_validation.m
  2. 6 0
      binary/do_binary_predict.m
  3. 97 0
      binary/validation_function.m
  4. 25 9
      liblinear_train.m

+ 27 - 0
binary/do_binary_cross_validation.m

@@ -0,0 +1,27 @@
+function f_accuracy = do_binary_cross_validation(y, x, param, nr_fold)
+    len        = length(y);
+    rand_ind   = randperm(len);
+    dec_values = [];
+    labels     = [];
+    
+    % Cross training : folding
+    for i = 1:nr_fold 
+      test_ind = rand_ind([floor((i-1)*len/nr_fold)+1:floor(i*len/nr_fold)]');
+      train_ind = [1:len]';
+      train_ind(test_ind) = [];
+      model = train(y(train_ind),x(train_ind,:),param);
+      [pred, acc, dec] = predict(y(test_ind),x(test_ind,:),model,'-q');
+      
+      if model.Label(1) < 0;
+        dec = dec * -1;
+      end
+      
+      dec_values = vertcat(dec_values, dec);
+      labels     = vertcat(labels, y(test_ind));
+    end
+    
+    % final evaluation
+    f_accuracy = validation_function(dec_values, labels);
+    disp(sprintf('Cross Validation: %f', f_accuracy));
+
+end

+ 6 - 0
binary/do_binary_predict.m

@@ -0,0 +1,6 @@
+function [pred ret dec] = do_binary_predict(y, x, model)
+[pred acc dec] = predict(y, x, model);
+if model.Label(1) < 0;
+  dec = dec * -1;
+end
+ret = validation_function(dec, y);

+ 97 - 0
binary/validation_function.m

@@ -0,0 +1,97 @@
+function ret = validation_function(dec, labels);
+labels = (labels >= 0) - (labels < 0);
+valid_function = @(dec, labels) auc(dec, labels);
+ret = valid_function(dec, labels);
+%precision(dec, labels);
+%recall(dec, labels);
+%fscore(dec, labels);
+%bac(dec, labels);
+%auc(dec, labels);
+%accuracy(dec, labels);
+
+function ret = precision(dec, label)
+tp = sum(label == 1 & dec >= 0);
+tp_fp = sum(dec >= 0);
+if tp_fp == 0;
+  disp(sprintf('warning: No positive predict label.'));
+  ret = 0;
+else
+  ret = tp / tp_fp;
+end
+disp(sprintf('Precision = %g%% (%d/%d)', 100.0 * ret, tp, tp_fp));
+
+function ret = recall(dec, label)
+tp = sum(label == 1 & dec >= 0);
+tp_fn = sum(label == 1);
+if tp_fn == 0;
+  disp(sprintf('warning: No postive true label.'));
+  ret = 0;
+else
+  ret = tp / tp_fn;
+end
+disp(sprintf('Recall = %g%% (%d/%d)', 100.0 * ret, tp, tp_fn));
+
+function ret = fscore(dec, label)
+tp = sum(label == 1 & dec >= 0);
+tp_fp = sum(dec >= 0);
+tp_fn = sum(label == 1);
+if tp_fp == 0;
+  disp(sprintf('warning: No positive predict label.'));
+  precision = 0;
+else
+  precision = tp / tp_fp;
+end
+if tp_fn == 0;
+  disp(sprintf('warning: No postive true label.'));
+  recall = 0;
+else
+  recall = tp / tp_fn;
+end
+if precision + recall == 0;
+  disp(sprintf('warning: precision + recall = 0.'));
+  ret = 0;
+else
+  ret = 2 * precision * recall / (precision + recall);
+end
+disp(sprintf('F-score = %g', ret));
+
+function ret = bac(dec, label)
+tp = sum(label == 1 & dec >= 0);
+tn = sum(label == -1 & dec < 0);
+tp_fn = sum(label == 1);
+tn_fp = sum(label == -1);
+if tp_fn == 0;
+  disp(sprintf('warning: No positive true label.'));
+  sensitivity = 0;
+else
+  sensitivity = tp / tp_fn;
+end
+if tn_fp == 0;
+  disp(sprintf('warning: No negative true label.'));
+  specificity = 0;
+else
+  specificity = tn / tn_fp;
+end
+ret = (sensitivity + specificity) / 2;
+disp(sprintf('BAC = %g', ret));
+
+function ret = auc(dec, label)
+[dec idx] = sort(dec, 'descend');
+label = label(idx);
+tp = cumsum(label == 1);
+fp = sum(label == -1);
+ret = sum(tp(label == -1));
+if tp == 0 | fp == 0;
+  disp(sprintf('warning: Too few postive true labels or negative true labels'));
+  ret = 0;
+else
+  ret = ret / tp(end) / fp;
+end
+  %disp(sprintf('AUC = %g', ret));
+
+function ret = accuracy(dec, label)
+  correct = sum(dec .* label >= 0);
+  total = length(dec);
+  ret = correct / total;
+  disp(sprintf('Accuracy = %g%% (%d/%d)', 100.0 * ret, correct, total));
+

+ 25 - 9
liblinear_train.m

@@ -22,20 +22,20 @@ function svmmodel = liblinear_train ( labels, feat, settings )
     end
     
     
-    libsvm_options = '';
+    liblinear_options = '';
     
     % outputs for training
     if ( ~ getFieldWithDefault ( settings, 'b_verbose', false ) )
-        libsvm_options = sprintf('%s -q', libsvm_options);
+        liblinear_options = sprintf('%s -q', liblinear_options);
     end
     
     % cost parameter
     f_svm_C = getFieldWithDefault ( settings, 'f_svm_C', 1);
-    libsvm_options = sprintf('%s -c %f', libsvm_options, f_svm_C);    
+    liblinear_options = sprintf('%s -c %f', liblinear_options, f_svm_C);    
     
     % do we want to use an offset for the hyperplane?
     if ( getFieldWithDefault ( settings, 'b_addOffset', false) )
-        libsvm_options = sprintf('%s -B 1', libsvm_options);    
+        liblinear_options = sprintf('%s -B 1', liblinear_options);    
     end
     
     % which solver to use
@@ -50,7 +50,7 @@ function svmmodel = liblinear_train ( labels, feat, settings )
 %          6 -- L1-regularized logistic regression
 %          7 -- L2-regularized logistic regression (dual)    
     i_svmSolver = getFieldWithDefault ( settings, 'i_svmSolver', 1);
-    libsvm_options = sprintf('%s -s %d', libsvm_options, i_svmSolver);    
+    liblinear_options = sprintf('%s -s %d', liblinear_options, i_svmSolver);    
 
     
     % increase penalty for positive samples according to invers ratio of
@@ -59,6 +59,16 @@ function svmmodel = liblinear_train ( labels, feat, settings )
     % 
     b_weightBalancing = getFieldWithDefault ( settings, 'b_weightBalancing', false);
     
+    % increase penalty for positive samples according to invers ratio of
+    % their number, i.e., if 1/3 is ratio of positive to negative samples, then
+    % impact of positives is 3 the times of negatives
+    % 
+    b_cross_val = getFieldWithDefault ( settings, 'b_cross_val', false);   
+    if ( b_cross_val && (length(unique(labels)) ~=2 ) )
+        i_num_folds = getFieldWithDefault ( settings, 'i_num_folds', 10);  
+        liblinear_options = sprintf('%s -v %d', liblinear_options, i_num_folds ); 
+    end
+    
     
   
     uniqueLabels = unique ( labels );
@@ -68,19 +78,25 @@ function svmmodel = liblinear_train ( labels, feat, settings )
     %# train one-against-all models
     
     if ( ~b_weightBalancing)    
-        svmmodel = train( labels, feat, libsvm_options );
+        if ( b_cross_val && (length(unique(labels)) ==2 ) )
+            
+            % measure of accuracy during cross validation is auc   
+            svmmodel = do_binary_cross_validation( labels, feat, liblinear_options, getFieldWithDefault ( settings, 'i_num_folds', 10) );
+        else
+            svmmodel = train( labels, feat, liblinear_options );
+        end
     else
         svmmodel = cell( i_numClasses,1);
         for k=1:length(i_classesToRun)
             yBin        = 2*double( labels == uniqueLabels( k ) )-1;
             
             fraction = double(sum(yBin==1))/double(numel(yBin));
-            libsvm_optionsLocal = sprintf('%s -w1 %f', libsvm_options, 1.0/fraction);
-            svmmodel{ k } = train( yBin, feat, libsvm_optionsLocal );
+            liblinear_optionsLocal = sprintf('%s -w1 %f', liblinear_options, 1.0/fraction);
+            svmmodel{ k } = train( yBin, feat, liblinear_optionsLocal );
             
             %store the unique class label for later evaluations.
             svmmodel{ k }.uniqueLabel = uniqueLabels( k );
         end         
     end
-    
+       
 end