GBCDSolver.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. /**
  2. * @file GBCDSolver.cpp
  3. * @brief Greedy Block Coordinate Descent Algorithm
  4. * @author Erik Rodner
  5. * @date 01/26/2012
  6. */
  7. #include <iostream>
  8. #include <core/basics/Timer.h>
  9. #include "GBCDSolver.h"
  10. using namespace NICE;
  11. using namespace std;
  12. GBCDSolver::GBCDSolver( uint randomSetSize, uint stepComponents, bool verbose, uint maxIterations, double minDelta )
  13. {
  14. this->verbose = verbose;
  15. this->maxIterations = maxIterations;
  16. this->minDelta = minDelta;
  17. this->stepComponents = stepComponents;
  18. this->randomSetSize = randomSetSize;
  19. this->timeAnalysis = false;
  20. }
  21. void GBCDSolver::setTimeAnalysis(bool timeAnalysis)
  22. {
  23. this->timeAnalysis = timeAnalysis;
  24. }
  25. GBCDSolver::~GBCDSolver()
  26. {
  27. }
  28. void GBCDSolver::greedyApproximation ( const PartialGenericMatrix & gm, const Vector & b, const Vector & grad,
  29. PartialGenericMatrix::SetType & B, Vector & deltaAlpha )
  30. {
  31. uint t = 0;
  32. uint n = b.size();
  33. Vector e ( grad );
  34. // start with an empty set
  35. B.clear();
  36. //if ( verbose )
  37. // cerr << "GBCDSolver::greedyApproximation: size of the problem is " << n << endl;
  38. PartialGenericMatrix::SetType O;
  39. uint elementsN = n;
  40. bool *N = new bool [ n ];
  41. for ( uint i = 0 ; i < n ; i++ )
  42. {
  43. O.push_back(i);
  44. N[i] = true; //N.insert(i);
  45. }
  46. Matrix R ( stepComponents, stepComponents, 0.0 );
  47. deltaAlpha.resize ( stepComponents );
  48. do {
  49. // step (3) of Algorithm 2 in the paper
  50. // determine the index s
  51. int s = *(O.begin());
  52. double min_expr = numeric_limits<double>::max();
  53. for ( PartialGenericMatrix::SetType::const_iterator i = O.begin(); i != O.end(); i++ )
  54. {
  55. double evalue = e(*i);
  56. double expr = - evalue*evalue / ( 2 * gm.getDiagonalElement(*i) );
  57. if ( expr < min_expr )
  58. {
  59. min_expr = expr;
  60. s = *i;
  61. }
  62. }
  63. //if ( verbose )
  64. // cerr << "GBCDSolver: greedy selection of element " << s << endl;
  65. Vector deltaAlphaTmp ( t + 1 );
  66. // step (4) of Algorithm 2 in the paper
  67. if ( t == 0 ) {
  68. R(0,0) = 1 / ( gm.getDiagonalElement(s) );
  69. deltaAlpha[0] = - e(s) / ( gm.getDiagonalElement(s) );
  70. deltaAlphaTmp[0] = deltaAlpha[0];
  71. } else {
  72. Vector beta ( t );
  73. Vector rvector ( t );
  74. Vector tmpScalar ( 1 );
  75. PartialGenericMatrix::SetType sset;
  76. sset.push_back(s);
  77. // ---- calculation of beta
  78. // beta = R * A(B,s)
  79. // beta^T = A(s,B) * R^T
  80. // beta^T = A(B,s) * R^T (because we assume symmetry)
  81. for ( uint i = 0 ; i < beta.size() ; i++ )
  82. {
  83. for ( uint j = 0 ; j < rvector.size(); j++ )
  84. rvector[j] = R( i , j );
  85. gm.multiply ( sset, B, tmpScalar, rvector );
  86. beta[i] = tmpScalar[0];
  87. }
  88. // ---- calculation of nu
  89. gm.multiply ( sset, B, tmpScalar, beta );
  90. double nu = 1.0 / ( gm.getDiagonalElement ( s ) - tmpScalar[0] );
  91. // ---- update our R
  92. for ( uint i = 0 ; i < t ; i++ ) {
  93. for ( uint j = 0 ; j < t ; j++ ) {
  94. R(i,j) += nu * beta(i) * beta(j);
  95. }
  96. R(i,t) = nu * beta(i) * (-1);
  97. R(t,i) = nu * beta(i) * (-1);
  98. }
  99. R(t, t) = nu;
  100. // ---- compute our deltaAlpha update
  101. Vector gradSub ( t+1 );
  102. uint ii = 0;
  103. for ( PartialGenericMatrix::SetType::const_iterator i = B.begin(); i != B.end(); i++,ii++ )
  104. gradSub[ii] = grad[*i];
  105. gradSub[t] = grad[s];
  106. // this statement uses a copy of the R sub-matrix, which
  107. // might be not beneficial for performance!!!
  108. deltaAlphaTmp = (-1.0) * R(0,0,t,t) * gradSub;
  109. for ( uint i = 0 ; i < t+1; i++ )
  110. deltaAlpha[i] = deltaAlphaTmp[i];
  111. }
  112. // step 5 of algorithm 2
  113. B.push_back(s);
  114. // N.erase ( s );
  115. N[s] = false;
  116. elementsN--;
  117. if ( elementsN == 0 ) {
  118. cerr << "Unable to select more elements! Adjust your parameters!" << endl;
  119. break;
  120. }
  121. // step 6 of algorithm 2
  122. // choose a subset O of size kappa = randomSetSize
  123. O.clear();
  124. set<int> selectedElements;
  125. for ( uint i = 0 ; i < randomSetSize ; i++ )
  126. {
  127. int selectedElement;
  128. do {
  129. selectedElement = rand() % n;
  130. } while ( // I have selected this element as the optimal element in a previous step
  131. !N[selectedElement] ||
  132. // I selected this element already for the set O
  133. (selectedElements.find(selectedElement) != selectedElements.end()) );
  134. //if ( verbose )
  135. // cerr << "GBCDSolver: selecting " << selectedElement << " for the set O" << endl;
  136. selectedElements.insert( selectedElement );
  137. O.push_back ( selectedElement );
  138. }
  139. Vector eSub;
  140. gm.multiply ( O, B, eSub, deltaAlphaTmp );
  141. if ( eSub.size() != O.size() )
  142. fthrow(Exception, "The matrix interface did not return a vector of a proper size!");
  143. uint ii = 0;
  144. for ( PartialGenericMatrix::SetType::const_iterator i = O.begin(); i != O.end(); i++,ii++ )
  145. e[ *i ] = eSub[ ii ] + grad[ *i ];
  146. // increment our iteration counter
  147. t++;
  148. } while ( t < this->stepComponents );
  149. delete [] N;
  150. }
  151. int GBCDSolver::solveLin ( const PartialGenericMatrix & gm, const Vector & b, Vector & x )
  152. {
  153. // FIXME: check for quadratic matrix
  154. uint iteration = 0;
  155. Vector grad;
  156. Vector Ax;
  157. if ( x.size() != gm.cols() ) {
  158. // use a simple initial solution, x = 0
  159. x.resize( gm.cols() );
  160. x.set(0.0);
  161. grad = (-1.0) * b;
  162. } else {
  163. gm.multiply ( Ax, x );
  164. grad = Ax - b;
  165. }
  166. PartialGenericMatrix::SetType wholeSet;
  167. for ( uint i = 0 ; i < b.size() ; i++ )
  168. wholeSet.push_back(i);
  169. Timer t;
  170. if ( timeAnalysis )
  171. t.start();
  172. // start with our iterations
  173. do {
  174. // Although the objective of the corresponding quadratic program decreases,
  175. // this is not necessarily true for the residual. We know that at the bottom we get a zero
  176. // gradient (and a residual) but we can not prove anything about the development of it during
  177. // optimization.
  178. double residualNorm = grad.normInf();
  179. if ( verbose )
  180. cerr << "GBCDSolver: [ " << iteration << " / " << maxIterations << " ] " << residualNorm << endl;
  181. Vector deltaAlpha;
  182. PartialGenericMatrix::SetType B;
  183. // -------- the main part: solve the sub-problem of finding a good search direction
  184. greedyApproximation ( gm, b, grad, B, deltaAlpha );
  185. // --------
  186. if ( verbose && b.size() <= 10 )
  187. cerr << "GBCDSolver: " << deltaAlpha << endl;
  188. uint ii = 0;
  189. for ( PartialGenericMatrix::SetType::const_iterator i = B.begin(); i != B.end(); i++, ii++)
  190. {
  191. // update our current estimate, but only at certain positions
  192. x[ *i ] += deltaAlpha[ ii ];
  193. }
  194. double deltaNorm = deltaAlpha.normL2();
  195. if ( verbose )
  196. cerr << "GBCDSolver: delta = " << deltaNorm << endl;
  197. if ( deltaNorm < minDelta ) {
  198. if ( verbose )
  199. cerr << "GBCDSolver: minimum delta reached" << endl;
  200. return iteration;
  201. }
  202. Vector A_deltaAlpha;
  203. gm.multiply ( wholeSet, B, A_deltaAlpha, deltaAlpha );
  204. grad += A_deltaAlpha;
  205. if ( timeAnalysis )
  206. {
  207. t.stop();
  208. cerr << "GBCDSolver: TIME " << t.getSum() << " " << grad.normL2() << " " << grad.normInf() << endl;
  209. t.start();
  210. }
  211. iteration++;
  212. } while ( iteration < maxIterations );
  213. return iteration;
  214. }