active_set.cpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. #include "active_set.h"
  2. #include "min_quad_with_fixed.h"
  3. #include "slice.h"
  4. #include "cat.h"
  5. #include "matlab_format.h"
  6. #include <iostream>
  7. #include <limits>
  8. #include <algorithm>
  9. template <
  10. typename AT,
  11. typename DerivedB,
  12. typename Derivedknown,
  13. typename DerivedY,
  14. typename AeqT,
  15. typename DerivedBeq,
  16. typename AieqT,
  17. typename DerivedBieq,
  18. typename Derivedlx,
  19. typename Derivedux,
  20. typename DerivedZ
  21. >
  22. IGL_INLINE igl::SolverStatus igl::active_set(
  23. const Eigen::SparseMatrix<AT>& A,
  24. const Eigen::PlainObjectBase<DerivedB> & B,
  25. const Eigen::PlainObjectBase<Derivedknown> & known,
  26. const Eigen::PlainObjectBase<DerivedY> & Y,
  27. const Eigen::SparseMatrix<AeqT>& Aeq,
  28. const Eigen::PlainObjectBase<DerivedBeq> & Beq,
  29. const Eigen::SparseMatrix<AieqT>& Aieq,
  30. const Eigen::PlainObjectBase<DerivedBieq> & Bieq,
  31. const Eigen::PlainObjectBase<Derivedlx> & lx,
  32. const Eigen::PlainObjectBase<Derivedux> & ux,
  33. const igl::active_set_params & params,
  34. Eigen::PlainObjectBase<DerivedZ> & Z
  35. )
  36. {
  37. using namespace igl;
  38. using namespace Eigen;
  39. using namespace std;
  40. SolverStatus ret = SOLVER_STATUS_ERROR;
  41. const int n = A.rows();
  42. assert(n == A.cols());
  43. // Discard const qualifiers
  44. //if(B.size() == 0)
  45. //{
  46. // B = Eigen::PlainObjectBase<DerivedB>::Zero(n,1);
  47. //}
  48. assert(n == B.rows());
  49. assert(B.cols() == 1);
  50. assert(Y.cols() == 1);
  51. assert((Aeq.size() == 0 && Beq.size() == 0) || Aeq.cols() == n);
  52. assert((Aeq.size() == 0 && Beq.size() == 0) || Aeq.rows() == Beq.rows());
  53. assert((Aeq.size() == 0 && Beq.size() == 0) || Beq.cols() == 1);
  54. assert((Aieq.size() == 0 && Bieq.size() == 0) || Aieq.cols() == n);
  55. assert((Aieq.size() == 0 && Bieq.size() == 0) || Aieq.rows() == Bieq.rows());
  56. assert((Aieq.size() == 0 && Bieq.size() == 0) || Bieq.cols() == 1);
  57. // Discard const qualifiers
  58. //if(lx.size() == 0)
  59. //{
  60. // lx = Eigen::PlainObjectBase<Derivedlx>::Constant(
  61. // n,1,numeric_limits<typename Derivedlx::Scalar>::min());
  62. //}
  63. //if(ux.size() == 0)
  64. //{
  65. // ux = Eigen::PlainObjectBase<Derivedux>::Constant(
  66. // n,1,numeric_limits<typename Derivedux::Scalar>::max());
  67. //}
  68. assert(lx.rows() == n);
  69. assert(ux.rows() == n);
  70. assert(ux.cols() == 1);
  71. assert(lx.cols() == 1);
  72. assert((ux.array()-lx.array()).minCoeff() > 0);
  73. if(Z.size() != 0)
  74. {
  75. // Initial guess should have correct size
  76. assert(Z.rows() == n);
  77. assert(Z.cols() == 1);
  78. }
  79. assert(known.cols() == 1);
  80. // Number of knowns
  81. const int nk = known.size();
  82. // Initialize active sets
  83. typedef bool BOOL;
  84. #define TRUE true
  85. #define FALSE false
  86. Matrix<BOOL,Dynamic,1> as_lx = Matrix<BOOL,Dynamic,1>::Constant(n,1,FALSE);
  87. Matrix<BOOL,Dynamic,1> as_ux = Matrix<BOOL,Dynamic,1>::Constant(n,1,FALSE);
  88. Matrix<BOOL,Dynamic,1> as_ieq(Aieq.rows(),1);
  89. // Keep track of previous Z for comparison
  90. PlainObjectBase<DerivedZ> old_Z;
  91. old_Z = PlainObjectBase<DerivedZ>::Constant(
  92. n,1,numeric_limits<typename DerivedZ::Scalar>::max());
  93. int iter = 0;
  94. while(true)
  95. {
  96. // FIND BREACHES OF CONSTRAINTS
  97. int new_as_lx = 0;
  98. int new_as_ux = 0;
  99. int new_as_ieq = 0;
  100. if(Z.size() > 0)
  101. {
  102. for(int z = 0;z < n;z++)
  103. {
  104. if(Z(z) < lx(z))
  105. {
  106. new_as_lx += (as_lx(z)?0:1);
  107. //new_as_lx++;
  108. as_lx(z) = TRUE;
  109. }
  110. if(Z(z) > ux(z))
  111. {
  112. new_as_ux += (as_ux(z)?0:1);
  113. //new_as_ux++;
  114. as_ux(z) = TRUE;
  115. }
  116. }
  117. PlainObjectBase<DerivedZ> AieqZ;
  118. AieqZ = Aieq*Z;
  119. for(int a = 0;a<Aieq.rows();a++)
  120. {
  121. if(AieqZ(a) > Bieq(a))
  122. {
  123. new_as_ieq += (as_ieq(a)?0:1);
  124. as_ieq(a) = TRUE;
  125. }
  126. }
  127. //cout<<"new_as_lx: "<<new_as_lx<<endl;
  128. //cout<<"new_as_ux: "<<new_as_ux<<endl;
  129. const double diff = (Z-old_Z).squaredNorm();
  130. //cout<<"diff: "<<diff<<endl;
  131. if(diff < params.solution_diff_threshold)
  132. {
  133. ret = SOLVER_STATUS_CONVERGED;
  134. break;
  135. }
  136. old_Z = Z;
  137. }
  138. const int as_lx_count = count(as_lx.data(),as_lx.data()+n,TRUE);
  139. const int as_ux_count = count(as_ux.data(),as_ux.data()+n,TRUE);
  140. const int as_ieq_count = count(as_ieq.data(),as_ieq.data()+n,TRUE);
  141. // PREPARE FIXED VALUES
  142. PlainObjectBase<Derivedknown> known_i;
  143. known_i.resize(nk + as_lx_count + as_ux_count,1);
  144. PlainObjectBase<DerivedY> Y_i;
  145. Y_i.resize(nk + as_lx_count + as_ux_count,1);
  146. {
  147. known_i.block(0,0,known.rows(),known.cols()) = known;
  148. Y_i.block(0,0,Y.rows(),Y.cols()) = Y;
  149. int k = nk;
  150. // Then all lx
  151. for(int z = 0;z < n;z++)
  152. {
  153. if(as_lx(z))
  154. {
  155. known_i(k) = z;
  156. Y_i(k) = lx(z);
  157. k++;
  158. }
  159. }
  160. // Finally all ux
  161. for(int z = 0;z < n;z++)
  162. {
  163. if(as_ux(z))
  164. {
  165. known_i(k) = z;
  166. Y_i(k) = ux(z);
  167. k++;
  168. }
  169. }
  170. assert(k==Y_i.size());
  171. assert(k==known_i.size());
  172. }
  173. //cout<<matlab_format((known_i.array()+1).eval(),"known_i")<<endl;
  174. // PREPARE EQUALITY CONSTRAINTS
  175. VectorXi as_ieq_list(as_ieq_count,1);
  176. // Gather active constraints and resp. rhss
  177. PlainObjectBase<DerivedBeq> Beq_i;
  178. Beq_i.resize(Beq.rows()+as_ieq_count,1);
  179. {
  180. int k =0;
  181. for(int a=0;a<as_ieq.size();a++)
  182. {
  183. if(a)
  184. {
  185. as_ieq_list(k)=a;
  186. Beq_i(Beq.rows()+k,1) = Bieq(k,1);
  187. k++;
  188. }
  189. }
  190. assert(k == as_ieq_count);
  191. }
  192. // extract active constraint rows
  193. SparseMatrix<AeqT> Aeq_i,Aieq_i;
  194. slice(Aieq,as_ieq_list,1,Aieq_i);
  195. // Append to equality constraints
  196. cat(1,Aeq,Aieq_i,Aeq_i);
  197. min_quad_with_fixed_data<AT> data;
  198. if(!min_quad_with_fixed_precompute(A,known_i,Aeq_i,params.Auu_pd,data))
  199. {
  200. cerr<<"Error: min_quad_with_fixed precomputation failed."<<endl;
  201. ret = SOLVER_STATUS_ERROR;
  202. break;
  203. }
  204. Eigen::PlainObjectBase<DerivedZ> sol;
  205. if(!min_quad_with_fixed_solve(data,B,Y_i,Beq_i,Z,sol))
  206. {
  207. cerr<<"Error: min_quad_with_fixed solve failed."<<endl;
  208. ret = SOLVER_STATUS_ERROR;
  209. break;
  210. }
  211. // Compute Lagrange multiplier values for known_i
  212. // This needs to be adjusted slightly if A is not symmetric
  213. assert(data.Auu_sym);
  214. SparseMatrix<AT> Ak;
  215. // Slow
  216. slice(A,known_i,1,Ak);
  217. Eigen::PlainObjectBase<DerivedB> Bk;
  218. slice(B,known_i,Bk);
  219. MatrixXd Lambda_known_i = -(Ak*Z + 0.5*Bk);
  220. // reverse the lambda values for lx
  221. Lambda_known_i.block(nk,0,as_lx_count,1) =
  222. (-1*Lambda_known_i.block(nk,0,as_lx_count,1)).eval();
  223. // Extract Lagrange multipliers for Aieq_i (always at back of sol)
  224. VectorXd Lambda_Aieq_i(Aieq_i.rows(),1);
  225. for(int l = 0;l<Aieq_i.rows();l++)
  226. {
  227. Lambda_Aieq_i(Aieq_i.rows()-1-l) = sol(sol.rows()-1-l);
  228. }
  229. // Remove from active set
  230. for(int l = 0;l<as_lx_count;l++)
  231. {
  232. if(Lambda_known_i(nk + l) < params.inactive_threshold)
  233. {
  234. as_lx(known_i(nk + l)) = FALSE;
  235. }
  236. }
  237. for(int u = 0;u<as_ux_count;u++)
  238. {
  239. if(Lambda_known_i(nk + as_lx_count + u) <
  240. params.inactive_threshold)
  241. {
  242. as_ux(known_i(nk + as_lx_count + u)) = FALSE;
  243. }
  244. }
  245. for(int a = 0;a<as_ieq_count;a++)
  246. {
  247. if(Lambda_Aieq_i(a) < params.inactive_threshold)
  248. {
  249. as_ieq(as_ieq_list(a)) = FALSE;
  250. }
  251. }
  252. iter++;
  253. //cout<<iter<<endl;
  254. if(params.max_iter>0 && iter>=params.max_iter)
  255. {
  256. ret = SOLVER_STATUS_MAX_ITER;
  257. break;
  258. }
  259. }
  260. return ret;
  261. }
  262. #ifndef IGL_HEADER_ONLY
  263. // Explicit template specialization
  264. template igl::SolverStatus igl::active_set<double, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<int, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, double, Eigen::Matrix<double, -1, 1, 0, -1, 1>, double, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::Matrix<double, -1, 1, 0, -1, 1> >(Eigen::SparseMatrix<double, 0, int> const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::SparseMatrix<double, 0, int> const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::SparseMatrix<double, 0, int> const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> > const&, igl::active_set_params const&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> >&);
  265. #endif