active_set.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. #include "active_set.h"
  2. #include "min_quad_with_fixed.h"
  3. #include "slice.h"
  4. #include "cat.h"
  5. #include "matlab_format.h"
  6. #include <iostream>
  7. #include <limits>
  8. #include <algorithm>
  9. template <
  10. typename AT,
  11. typename DerivedB,
  12. typename Derivedknown,
  13. typename DerivedY,
  14. typename AeqT,
  15. typename DerivedBeq,
  16. typename AieqT,
  17. typename DerivedBieq,
  18. typename Derivedlx,
  19. typename Derivedux,
  20. typename DerivedZ
  21. >
  22. IGL_INLINE igl::SolverStatus igl::active_set(
  23. const Eigen::SparseMatrix<AT>& A,
  24. const Eigen::PlainObjectBase<DerivedB> & B,
  25. const Eigen::PlainObjectBase<Derivedknown> & known,
  26. const Eigen::PlainObjectBase<DerivedY> & Y,
  27. const Eigen::SparseMatrix<AeqT>& Aeq,
  28. const Eigen::PlainObjectBase<DerivedBeq> & Beq,
  29. const Eigen::SparseMatrix<AieqT>& Aieq,
  30. const Eigen::PlainObjectBase<DerivedBieq> & Bieq,
  31. const Eigen::PlainObjectBase<Derivedlx> & p_lx,
  32. const Eigen::PlainObjectBase<Derivedux> & p_ux,
  33. const igl::active_set_params & params,
  34. Eigen::PlainObjectBase<DerivedZ> & Z
  35. )
  36. {
  37. #define ACTIVE_SET_CPP_DEBUG
  38. using namespace igl;
  39. using namespace Eigen;
  40. using namespace std;
  41. SolverStatus ret = SOLVER_STATUS_ERROR;
  42. const int n = A.rows();
  43. assert(n == A.cols() && "A must be square");
  44. // Discard const qualifiers
  45. //if(B.size() == 0)
  46. //{
  47. // B = Eigen::PlainObjectBase<DerivedB>::Zero(n,1);
  48. //}
  49. assert(n == B.rows() && "B.rows() must match A.rows()");
  50. assert(B.cols() == 1 && "B must be a column vector");
  51. assert(Y.cols() == 1 && "Y must be a column vector");
  52. assert((Aeq.size() == 0 && Beq.size() == 0) || Aeq.cols() == n);
  53. assert((Aeq.size() == 0 && Beq.size() == 0) || Aeq.rows() == Beq.rows());
  54. assert((Aeq.size() == 0 && Beq.size() == 0) || Beq.cols() == 1);
  55. assert((Aieq.size() == 0 && Bieq.size() == 0) || Aieq.cols() == n);
  56. assert((Aieq.size() == 0 && Bieq.size() == 0) || Aieq.rows() == Bieq.rows());
  57. assert((Aieq.size() == 0 && Bieq.size() == 0) || Bieq.cols() == 1);
  58. Eigen::PlainObjectBase<Derivedlx> lx;
  59. Eigen::PlainObjectBase<Derivedux> ux;
  60. if(p_lx.size() == 0)
  61. {
  62. lx = Eigen::PlainObjectBase<Derivedlx>::Constant(
  63. n,1,-numeric_limits<typename Derivedlx::Scalar>::max());
  64. }else
  65. {
  66. lx = p_lx;
  67. }
  68. if(ux.size() == 0)
  69. {
  70. ux = Eigen::PlainObjectBase<Derivedux>::Constant(
  71. n,1,numeric_limits<typename Derivedux::Scalar>::max());
  72. }else
  73. {
  74. ux = p_ux;
  75. }
  76. assert(lx.rows() == n && "lx must have n rows");
  77. assert(ux.rows() == n && "ux must have n rows");
  78. assert(ux.cols() == 1 && "lx must be a column vector");
  79. assert(lx.cols() == 1 && "ux must be a column vector");
  80. assert((ux.array()-lx.array()).minCoeff() > 0 && "ux(i) must be > lx(i)");
  81. if(Z.size() != 0)
  82. {
  83. // Initial guess should have correct size
  84. assert(Z.rows() == n && "Z must have n rows");
  85. assert(Z.cols() == 1 && "Z must be a column vector");
  86. }
  87. assert(known.cols() == 1 && "known must be a column vector");
  88. // Number of knowns
  89. const int nk = known.size();
  90. // Initialize active sets
  91. typedef int BOOL;
  92. #define TRUE 1
  93. #define FALSE 0
  94. Matrix<BOOL,Dynamic,1> as_lx = Matrix<BOOL,Dynamic,1>::Constant(n,1,FALSE);
  95. Matrix<BOOL,Dynamic,1> as_ux = Matrix<BOOL,Dynamic,1>::Constant(n,1,FALSE);
  96. Matrix<BOOL,Dynamic,1> as_ieq = Matrix<BOOL,Dynamic,1>::Constant(Aieq.rows(),1,FALSE);
  97. // Keep track of previous Z for comparison
  98. PlainObjectBase<DerivedZ> old_Z;
  99. old_Z = PlainObjectBase<DerivedZ>::Constant(
  100. n,1,numeric_limits<typename DerivedZ::Scalar>::max());
  101. int iter = 0;
  102. while(true)
  103. {
  104. #ifdef ACTIVE_SET_CPP_DEBUG
  105. cout<<iter<<":"<<endl;
  106. cout<<" pre"<<endl;
  107. #endif
  108. // FIND BREACHES OF CONSTRAINTS
  109. int new_as_lx = 0;
  110. int new_as_ux = 0;
  111. int new_as_ieq = 0;
  112. if(Z.size() > 0)
  113. {
  114. for(int z = 0;z < n;z++)
  115. {
  116. if(Z(z) < lx(z))
  117. {
  118. new_as_lx += (as_lx(z)?0:1);
  119. //new_as_lx++;
  120. as_lx(z) = TRUE;
  121. }
  122. if(Z(z) > ux(z))
  123. {
  124. new_as_ux += (as_ux(z)?0:1);
  125. //new_as_ux++;
  126. as_ux(z) = TRUE;
  127. }
  128. }
  129. if(Aieq.rows() > 0)
  130. {
  131. PlainObjectBase<DerivedZ> AieqZ;
  132. AieqZ = Aieq*Z;
  133. for(int a = 0;a<Aieq.rows();a++)
  134. {
  135. if(AieqZ(a) > Bieq(a))
  136. {
  137. new_as_ieq += (as_ieq(a)?0:1);
  138. as_ieq(a) = TRUE;
  139. }
  140. }
  141. }
  142. #ifdef ACTIVE_SET_CPP_DEBUG
  143. cout<<" new_as_lx: "<<new_as_lx<<endl;
  144. cout<<" new_as_ux: "<<new_as_ux<<endl;
  145. #endif
  146. const double diff = (Z-old_Z).squaredNorm();
  147. #ifdef ACTIVE_SET_CPP_DEBUG
  148. cout<<"diff: "<<diff<<endl;
  149. #endif
  150. if(diff < params.solution_diff_threshold)
  151. {
  152. ret = SOLVER_STATUS_CONVERGED;
  153. break;
  154. }
  155. old_Z = Z;
  156. }
  157. const int as_lx_count = count(as_lx.data(),as_lx.data()+n,TRUE);
  158. const int as_ux_count = count(as_ux.data(),as_ux.data()+n,TRUE);
  159. const int as_ieq_count =
  160. count(as_ieq.data(),as_ieq.data()+as_ieq.size(),TRUE);
  161. #ifndef NDEBUG
  162. {
  163. int count = 0;
  164. for(int a = 0;a<as_ieq.size();a++)
  165. {
  166. if(as_ieq(a))
  167. {
  168. assert(as_ieq(a) == TRUE);
  169. count++;
  170. }
  171. }
  172. assert(as_ieq_count == count);
  173. }
  174. #endif
  175. // PREPARE FIXED VALUES
  176. PlainObjectBase<Derivedknown> known_i;
  177. known_i.resize(nk + as_lx_count + as_ux_count,1);
  178. PlainObjectBase<DerivedY> Y_i;
  179. Y_i.resize(nk + as_lx_count + as_ux_count,1);
  180. {
  181. known_i.block(0,0,known.rows(),known.cols()) = known;
  182. Y_i.block(0,0,Y.rows(),Y.cols()) = Y;
  183. int k = nk;
  184. // Then all lx
  185. for(int z = 0;z < n;z++)
  186. {
  187. if(as_lx(z))
  188. {
  189. known_i(k) = z;
  190. Y_i(k) = lx(z);
  191. k++;
  192. }
  193. }
  194. // Finally all ux
  195. for(int z = 0;z < n;z++)
  196. {
  197. if(as_ux(z))
  198. {
  199. known_i(k) = z;
  200. Y_i(k) = ux(z);
  201. k++;
  202. }
  203. }
  204. assert(k==Y_i.size());
  205. assert(k==known_i.size());
  206. }
  207. //cout<<matlab_format((known_i.array()+1).eval(),"known_i")<<endl;
  208. // PREPARE EQUALITY CONSTRAINTS
  209. VectorXi as_ieq_list(as_ieq_count,1);
  210. // Gather active constraints and resp. rhss
  211. PlainObjectBase<DerivedBeq> Beq_i;
  212. Beq_i.resize(Beq.rows()+as_ieq_count,1);
  213. {
  214. int k =0;
  215. for(int a=0;a<as_ieq.size();a++)
  216. {
  217. if(as_ieq(a))
  218. {
  219. assert(k<as_ieq_list.size());
  220. as_ieq_list(k)=a;
  221. Beq_i(Beq.rows()+k,0) = Bieq(k,0);
  222. k++;
  223. }
  224. }
  225. assert(k == as_ieq_count);
  226. }
  227. // extract active constraint rows
  228. SparseMatrix<AeqT> Aeq_i,Aieq_i;
  229. slice(Aieq,as_ieq_list,1,Aieq_i);
  230. // Append to equality constraints
  231. cat(1,Aeq,Aieq_i,Aeq_i);
  232. min_quad_with_fixed_data<AT> data;
  233. #ifndef NDEBUG
  234. {
  235. // NO DUPES!
  236. Matrix<BOOL,Dynamic,1> fixed = Matrix<BOOL,Dynamic,1>::Constant(n,1,FALSE);
  237. for(int k = 0;k<known_i.size();k++)
  238. {
  239. assert(!fixed[known_i(k)]);
  240. fixed[known_i(k)] = TRUE;
  241. }
  242. }
  243. #endif
  244. #ifdef ACTIVE_SET_CPP_DEBUG
  245. cout<<" min_quad_with_fixed_precompute"<<endl;
  246. #endif
  247. if(!min_quad_with_fixed_precompute(A,known_i,Aeq_i,params.Auu_pd,data))
  248. {
  249. cerr<<"Error: min_quad_with_fixed precomputation failed."<<endl;
  250. if(iter > 0 && Aeq_i.rows() > Aeq.rows())
  251. {
  252. cerr<<" *Are you sure rows of [Aeq;Aieq] are linearly independent?*"<<
  253. endl;
  254. }
  255. ret = SOLVER_STATUS_ERROR;
  256. break;
  257. }
  258. #ifdef ACTIVE_SET_CPP_DEBUG
  259. cout<<" min_quad_with_fixed_solve"<<endl;
  260. #endif
  261. Eigen::PlainObjectBase<DerivedZ> sol;
  262. if(!min_quad_with_fixed_solve(data,B,Y_i,Beq_i,Z,sol))
  263. {
  264. cerr<<"Error: min_quad_with_fixed solve failed."<<endl;
  265. ret = SOLVER_STATUS_ERROR;
  266. break;
  267. }
  268. //cout<<matlab_format(Z,"Z")<<endl;
  269. #ifdef ACTIVE_SET_CPP_DEBUG
  270. cout<<" post"<<endl;
  271. #endif
  272. // Compute Lagrange multiplier values for known_i
  273. // This needs to be adjusted slightly if A is not symmetric
  274. assert(data.Auu_sym);
  275. SparseMatrix<AT> Ak;
  276. // Slow
  277. slice(A,known_i,1,Ak);
  278. Eigen::PlainObjectBase<DerivedB> Bk;
  279. slice(B,known_i,Bk);
  280. MatrixXd Lambda_known_i = -(Ak*Z + 0.5*Bk);
  281. // reverse the lambda values for lx
  282. Lambda_known_i.block(nk,0,as_lx_count,1) =
  283. (-1*Lambda_known_i.block(nk,0,as_lx_count,1)).eval();
  284. // Extract Lagrange multipliers for Aieq_i (always at back of sol)
  285. VectorXd Lambda_Aieq_i(Aieq_i.rows(),1);
  286. for(int l = 0;l<Aieq_i.rows();l++)
  287. {
  288. Lambda_Aieq_i(Aieq_i.rows()-1-l) = sol(sol.rows()-1-l);
  289. }
  290. // Remove from active set
  291. for(int l = 0;l<as_lx_count;l++)
  292. {
  293. if(Lambda_known_i(nk + l) < params.inactive_threshold)
  294. {
  295. as_lx(known_i(nk + l)) = FALSE;
  296. }
  297. }
  298. for(int u = 0;u<as_ux_count;u++)
  299. {
  300. if(Lambda_known_i(nk + as_lx_count + u) <
  301. params.inactive_threshold)
  302. {
  303. as_ux(known_i(nk + as_lx_count + u)) = FALSE;
  304. }
  305. }
  306. for(int a = 0;a<as_ieq_count;a++)
  307. {
  308. if(Lambda_Aieq_i(a) < params.inactive_threshold)
  309. {
  310. as_ieq(as_ieq_list(a)) = FALSE;
  311. }
  312. }
  313. iter++;
  314. //cout<<iter<<endl;
  315. if(params.max_iter>0 && iter>=params.max_iter)
  316. {
  317. ret = SOLVER_STATUS_MAX_ITER;
  318. break;
  319. }
  320. }
  321. return ret;
  322. }
  323. #ifndef IGL_HEADER_ONLY
  324. // Explicit template specialization
  325. template igl::SolverStatus igl::active_set<double, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<int, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, double, Eigen::Matrix<double, -1, 1, 0, -1, 1>, double, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1> >(Eigen::SparseMatrix<double, 0, int> const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::SparseMatrix<double, 0, int> const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::SparseMatrix<double, 0, int> const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, igl::active_set_params const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> >&);
  326. #endif