active_set.cpp 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. #include "active_set.h"
  2. #include "min_quad_with_fixed.h"
  3. #include "slice.h"
  4. #include "cat.h"
  5. #include "matlab_format.h"
  6. #include <iostream>
  7. #include <limits>
  8. #include <algorithm>
  9. template <
  10. typename AT,
  11. typename DerivedB,
  12. typename Derivedknown,
  13. typename DerivedY,
  14. typename AeqT,
  15. typename DerivedBeq,
  16. typename AieqT,
  17. typename DerivedBieq,
  18. typename Derivedlx,
  19. typename Derivedux,
  20. typename DerivedZ
  21. >
  22. IGL_INLINE igl::SolverStatus igl::active_set(
  23. const Eigen::SparseMatrix<AT>& A,
  24. const Eigen::PlainObjectBase<DerivedB> & B,
  25. const Eigen::PlainObjectBase<Derivedknown> & known,
  26. const Eigen::PlainObjectBase<DerivedY> & Y,
  27. const Eigen::SparseMatrix<AeqT>& Aeq,
  28. const Eigen::PlainObjectBase<DerivedBeq> & Beq,
  29. const Eigen::SparseMatrix<AieqT>& Aieq,
  30. const Eigen::PlainObjectBase<DerivedBieq> & Bieq,
  31. const Eigen::PlainObjectBase<Derivedlx> & p_lx,
  32. const Eigen::PlainObjectBase<Derivedux> & p_ux,
  33. const igl::active_set_params & params,
  34. Eigen::PlainObjectBase<DerivedZ> & Z
  35. )
  36. {
  37. using namespace igl;
  38. using namespace Eigen;
  39. using namespace std;
  40. SolverStatus ret = SOLVER_STATUS_ERROR;
  41. const int n = A.rows();
  42. assert(n == A.cols() && "A must be square");
  43. // Discard const qualifiers
  44. //if(B.size() == 0)
  45. //{
  46. // B = Eigen::PlainObjectBase<DerivedB>::Zero(n,1);
  47. //}
  48. assert(n == B.rows() && "B.rows() must match A.rows()");
  49. assert(B.cols() == 1 && "B must be a column vector");
  50. assert(Y.cols() == 1 && "Y must be a column vector");
  51. assert((Aeq.size() == 0 && Beq.size() == 0) || Aeq.cols() == n);
  52. assert((Aeq.size() == 0 && Beq.size() == 0) || Aeq.rows() == Beq.rows());
  53. assert((Aeq.size() == 0 && Beq.size() == 0) || Beq.cols() == 1);
  54. assert((Aieq.size() == 0 && Bieq.size() == 0) || Aieq.cols() == n);
  55. assert((Aieq.size() == 0 && Bieq.size() == 0) || Aieq.rows() == Bieq.rows());
  56. assert((Aieq.size() == 0 && Bieq.size() == 0) || Bieq.cols() == 1);
  57. Eigen::PlainObjectBase<Derivedlx> lx;
  58. Eigen::PlainObjectBase<Derivedux> ux;
  59. if(p_lx.size() == 0)
  60. {
  61. lx = Eigen::PlainObjectBase<Derivedlx>::Constant(
  62. n,1,-numeric_limits<typename Derivedlx::Scalar>::max());
  63. }else
  64. {
  65. lx = p_lx;
  66. }
  67. if(ux.size() == 0)
  68. {
  69. ux = Eigen::PlainObjectBase<Derivedux>::Constant(
  70. n,1,numeric_limits<typename Derivedux::Scalar>::max());
  71. }else
  72. {
  73. ux = p_ux;
  74. }
  75. assert(lx.rows() == n && "lx must have n rows");
  76. assert(ux.rows() == n && "ux must have n rows");
  77. assert(ux.cols() == 1 && "lx must be a column vector");
  78. assert(lx.cols() == 1 && "ux must be a column vector");
  79. assert((ux.array()-lx.array()).minCoeff() > 0 && "ux(i) must be > lx(i)");
  80. if(Z.size() != 0)
  81. {
  82. // Initial guess should have correct size
  83. assert(Z.rows() == n && "Z must have n rows");
  84. assert(Z.cols() == 1 && "Z must be a column vector");
  85. }
  86. assert(known.cols() == 1 && "known must be a column vector");
  87. // Number of knowns
  88. const int nk = known.size();
  89. // Initialize active sets
  90. typedef int BOOL;
  91. #define TRUE 1
  92. #define FALSE 0
  93. Matrix<BOOL,Dynamic,1> as_lx = Matrix<BOOL,Dynamic,1>::Constant(n,1,FALSE);
  94. Matrix<BOOL,Dynamic,1> as_ux = Matrix<BOOL,Dynamic,1>::Constant(n,1,FALSE);
  95. Matrix<BOOL,Dynamic,1> as_ieq = Matrix<BOOL,Dynamic,1>::Constant(Aieq.rows(),1,FALSE);
  96. // Keep track of previous Z for comparison
  97. PlainObjectBase<DerivedZ> old_Z;
  98. old_Z = PlainObjectBase<DerivedZ>::Constant(
  99. n,1,numeric_limits<typename DerivedZ::Scalar>::max());
  100. int iter = 0;
  101. while(true)
  102. {
  103. //cout<<iter<<":"<<endl;
  104. //cout<<" pre"<<endl;
  105. // FIND BREACHES OF CONSTRAINTS
  106. int new_as_lx = 0;
  107. int new_as_ux = 0;
  108. int new_as_ieq = 0;
  109. if(Z.size() > 0)
  110. {
  111. for(int z = 0;z < n;z++)
  112. {
  113. if(Z(z) < lx(z))
  114. {
  115. new_as_lx += (as_lx(z)?0:1);
  116. //new_as_lx++;
  117. as_lx(z) = TRUE;
  118. }
  119. if(Z(z) > ux(z))
  120. {
  121. new_as_ux += (as_ux(z)?0:1);
  122. //new_as_ux++;
  123. as_ux(z) = TRUE;
  124. }
  125. }
  126. PlainObjectBase<DerivedZ> AieqZ;
  127. AieqZ = Aieq*Z;
  128. for(int a = 0;a<Aieq.rows();a++)
  129. {
  130. if(AieqZ(a) > Bieq(a))
  131. {
  132. new_as_ieq += (as_ieq(a)?0:1);
  133. as_ieq(a) = TRUE;
  134. }
  135. }
  136. //cout<<"new_as_lx: "<<new_as_lx<<endl;
  137. //cout<<"new_as_ux: "<<new_as_ux<<endl;
  138. const double diff = (Z-old_Z).squaredNorm();
  139. //cout<<"diff: "<<diff<<endl;
  140. if(diff < params.solution_diff_threshold)
  141. {
  142. ret = SOLVER_STATUS_CONVERGED;
  143. break;
  144. }
  145. old_Z = Z;
  146. }
  147. const int as_lx_count = count(as_lx.data(),as_lx.data()+n,TRUE);
  148. const int as_ux_count = count(as_ux.data(),as_ux.data()+n,TRUE);
  149. const int as_ieq_count =
  150. count(as_ieq.data(),as_ieq.data()+as_ieq.size(),TRUE);
  151. #ifndef NDEBUG
  152. {
  153. int count = 0;
  154. for(int a = 0;a<as_ieq.size();a++)
  155. {
  156. if(as_ieq(a))
  157. {
  158. assert(as_ieq(a) == TRUE);
  159. count++;
  160. }
  161. }
  162. assert(as_ieq_count == count);
  163. }
  164. #endif
  165. // PREPARE FIXED VALUES
  166. PlainObjectBase<Derivedknown> known_i;
  167. known_i.resize(nk + as_lx_count + as_ux_count,1);
  168. PlainObjectBase<DerivedY> Y_i;
  169. Y_i.resize(nk + as_lx_count + as_ux_count,1);
  170. {
  171. known_i.block(0,0,known.rows(),known.cols()) = known;
  172. Y_i.block(0,0,Y.rows(),Y.cols()) = Y;
  173. int k = nk;
  174. // Then all lx
  175. for(int z = 0;z < n;z++)
  176. {
  177. if(as_lx(z))
  178. {
  179. known_i(k) = z;
  180. Y_i(k) = lx(z);
  181. k++;
  182. }
  183. }
  184. // Finally all ux
  185. for(int z = 0;z < n;z++)
  186. {
  187. if(as_ux(z))
  188. {
  189. known_i(k) = z;
  190. Y_i(k) = ux(z);
  191. k++;
  192. }
  193. }
  194. assert(k==Y_i.size());
  195. assert(k==known_i.size());
  196. }
  197. //cout<<matlab_format((known_i.array()+1).eval(),"known_i")<<endl;
  198. // PREPARE EQUALITY CONSTRAINTS
  199. VectorXi as_ieq_list(as_ieq_count,1);
  200. // Gather active constraints and resp. rhss
  201. PlainObjectBase<DerivedBeq> Beq_i;
  202. Beq_i.resize(Beq.rows()+as_ieq_count,1);
  203. {
  204. int k =0;
  205. for(int a=0;a<as_ieq.size();a++)
  206. {
  207. if(as_ieq(a))
  208. {
  209. assert(k<as_ieq_list.size());
  210. as_ieq_list(k)=a;
  211. Beq_i(Beq.rows()+k,0) = Bieq(k,0);
  212. k++;
  213. }
  214. }
  215. assert(k == as_ieq_count);
  216. }
  217. // extract active constraint rows
  218. SparseMatrix<AeqT> Aeq_i,Aieq_i;
  219. slice(Aieq,as_ieq_list,1,Aieq_i);
  220. // Append to equality constraints
  221. cat(1,Aeq,Aieq_i,Aeq_i);
  222. min_quad_with_fixed_data<AT> data;
  223. #ifndef NDEBUG
  224. {
  225. // NO DUPES!
  226. Matrix<BOOL,Dynamic,1> fixed = Matrix<BOOL,Dynamic,1>::Constant(n,1,FALSE);
  227. for(int k = 0;k<known_i.size();k++)
  228. {
  229. assert(!fixed[known_i(k)]);
  230. fixed[known_i(k)] = TRUE;
  231. }
  232. }
  233. #endif
  234. //cout<<" min_quad_with_fixed_precompute"<<endl;
  235. if(!min_quad_with_fixed_precompute(A,known_i,Aeq_i,params.Auu_pd,data))
  236. {
  237. cerr<<"Error: min_quad_with_fixed precomputation failed."<<endl;
  238. if(iter > 0 && Aeq_i.rows() > Aeq.rows())
  239. {
  240. cerr<<" *Are you sure rows of [Aeq;Aieq] are linearly independent?*"<<
  241. endl;
  242. }
  243. ret = SOLVER_STATUS_ERROR;
  244. break;
  245. }
  246. //cout<<" min_quad_with_fixed_solve"<<endl;
  247. Eigen::PlainObjectBase<DerivedZ> sol;
  248. if(!min_quad_with_fixed_solve(data,B,Y_i,Beq_i,Z,sol))
  249. {
  250. cerr<<"Error: min_quad_with_fixed solve failed."<<endl;
  251. ret = SOLVER_STATUS_ERROR;
  252. break;
  253. }
  254. //cout<<" post"<<endl;
  255. // Compute Lagrange multiplier values for known_i
  256. // This needs to be adjusted slightly if A is not symmetric
  257. assert(data.Auu_sym);
  258. SparseMatrix<AT> Ak;
  259. // Slow
  260. slice(A,known_i,1,Ak);
  261. Eigen::PlainObjectBase<DerivedB> Bk;
  262. slice(B,known_i,Bk);
  263. MatrixXd Lambda_known_i = -(Ak*Z + 0.5*Bk);
  264. // reverse the lambda values for lx
  265. Lambda_known_i.block(nk,0,as_lx_count,1) =
  266. (-1*Lambda_known_i.block(nk,0,as_lx_count,1)).eval();
  267. // Extract Lagrange multipliers for Aieq_i (always at back of sol)
  268. VectorXd Lambda_Aieq_i(Aieq_i.rows(),1);
  269. for(int l = 0;l<Aieq_i.rows();l++)
  270. {
  271. Lambda_Aieq_i(Aieq_i.rows()-1-l) = sol(sol.rows()-1-l);
  272. }
  273. // Remove from active set
  274. for(int l = 0;l<as_lx_count;l++)
  275. {
  276. if(Lambda_known_i(nk + l) < params.inactive_threshold)
  277. {
  278. as_lx(known_i(nk + l)) = FALSE;
  279. }
  280. }
  281. for(int u = 0;u<as_ux_count;u++)
  282. {
  283. if(Lambda_known_i(nk + as_lx_count + u) <
  284. params.inactive_threshold)
  285. {
  286. as_ux(known_i(nk + as_lx_count + u)) = FALSE;
  287. }
  288. }
  289. for(int a = 0;a<as_ieq_count;a++)
  290. {
  291. if(Lambda_Aieq_i(a) < params.inactive_threshold)
  292. {
  293. as_ieq(as_ieq_list(a)) = FALSE;
  294. }
  295. }
  296. iter++;
  297. //cout<<iter<<endl;
  298. if(params.max_iter>0 && iter>=params.max_iter)
  299. {
  300. ret = SOLVER_STATUS_MAX_ITER;
  301. break;
  302. }
  303. }
  304. return ret;
  305. }
  306. #ifndef IGL_HEADER_ONLY
  307. // Explicit template specialization
  308. template igl::SolverStatus igl::active_set<double, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<int, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, double, Eigen::Matrix<double, -1, 1, 0, -1, 1>, double, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1> >(Eigen::SparseMatrix<double, 0, int> const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::SparseMatrix<double, 0, int> const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::SparseMatrix<double, 0, int> const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, igl::active_set_params const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> >&);
  309. #endif