redux.h 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #ifndef IGL_REDUX_H
  2. #define IGL_REDUX_H
  3. #include <Eigen/Core>
  4. #include <Eigen/Sparse>
  5. namespace igl
  6. {
  7. // REDUX Perform reductions on the rows or columns of a SparseMatrix. This is
  8. // _similar_ to DenseBase::redux, but different in two important ways:
  9. // 1. (unstored) Zeros are **not** "visited", however if the first element
  10. // in the column/row does not appear in the first row/column then the
  11. // reduction is assumed to start with zero. In this way, "any", "all",
  12. // "count"(non-zeros) work as expected. This means it is **not** possible
  13. // to use this to count (implicit) zeros.
  14. // 2. This redux is more powerful in the sense that A and B may have
  15. // different types. This makes it possible to count the number of
  16. // non-zeros in a SparseMatrix<bool> A into a VectorXi B.
  17. //
  18. // Inputs:
  19. // A m by n sparse matrix
  20. // dim dimension along which to sum (1 or 2)
  21. // func function handle with the prototype `X(Y a, I i, J j, Z b)` where a
  22. // is the running value, b is A(i,j)
  23. // Output:
  24. // S n-long sparse vector (if dim == 1)
  25. // or
  26. // S m-long sparse vector (if dim == 2)
  27. template <typename AType, typename Func, typename DerivedB>
  28. inline void redux(
  29. const Eigen::SparseMatrix<AType> & A,
  30. const int dim,
  31. const Func & func,
  32. Eigen::PlainObjectBase<DerivedB> & B);
  33. }
  34. // Implementation
  35. #include "for_each.h"
  36. template <typename AType, typename Func, typename DerivedB>
  37. inline void igl::redux(
  38. const Eigen::SparseMatrix<AType> & A,
  39. const int dim,
  40. const Func & func,
  41. Eigen::PlainObjectBase<DerivedB> & B)
  42. {
  43. typedef typename Eigen::SparseMatrix<AType>::StorageIndex Index;
  44. assert((dim == 1 || dim == 2) && "dim must be 2 or 1");
  45. // Get size of input
  46. int m = A.rows();
  47. int n = A.cols();
  48. // resize output
  49. B = DerivedB::Zero(dim==1?n:m);
  50. const auto func_wrap = [&func,&B,&dim](const Index i, const Index j, const AType v)
  51. {
  52. if(dim == 1)
  53. {
  54. B(j) = i == 0? v : func(B(j),v);
  55. }else
  56. {
  57. B(i) = j == 0? v : func(B(i),v);
  58. }
  59. };
  60. for_each(A,func_wrap);
  61. }
  62. //#ifndef IGL_STATIC_LIBRARY
  63. //# include "redux.cpp"
  64. //#endif
  65. #endif