redux.h 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #ifndef IGL_REDUX_H
  2. #define IGL_REDUX_H
  3. #include <Eigen/Core>
  4. #include <Eigen/Sparse>
  5. namespace igl
  6. {
  7. // REDUX Perform reductions on the rows or columns of a SparseMatrix. This is
  8. // _similar_ to DenseBase::redux, but different in two important ways:
  9. // 1. (unstored) Zeros are **not** "visited", however if the first element
  10. // in the column/row does not appear in the first row/column then the
  11. // reduction is assumed to start with zero. In this way, "any", "all",
  12. // "count"(non-zeros) work as expected. This means it is **not** possible
  13. // to use this to count (implicit) zeros.
  14. // 2. This redux is more powerful in the sense that A and B may have
  15. // different types. This makes it possible to count the number of
  16. // non-zeros in a SparseMatrix<bool> A into a VectorXi B.
  17. //
  18. // Inputs:
  19. // A m by n sparse matrix
  20. // dim dimension along which to sum (1 or 2)
  21. // func function handle with the prototype `X(Y a, I i, J j, Z b)` where a
  22. // is the running value, b is A(i,j)
  23. // Output:
  24. // S n-long sparse vector (if dim == 1)
  25. // or
  26. // S m-long sparse vector (if dim == 2)
  27. template <typename AType, typename Func, typename DerivedB>
  28. inline void redux(
  29. const Eigen::SparseMatrix<AType> & A,
  30. const int dim,
  31. const Func & func,
  32. Eigen::PlainObjectBase<DerivedB> & B);
  33. }
  34. // Implementation
  35. #include "redux.h"
  36. #include "for_each.h"
  37. template <typename AType, typename Func, typename DerivedB>
  38. inline void igl::redux(
  39. const Eigen::SparseMatrix<AType> & A,
  40. const int dim,
  41. const Func & func,
  42. Eigen::PlainObjectBase<DerivedB> & B)
  43. {
  44. typedef typename Eigen::SparseMatrix<AType>::StorageIndex Index;
  45. assert((dim == 1 || dim == 2) && "dim must be 2 or 1");
  46. // Get size of input
  47. int m = A.rows();
  48. int n = A.cols();
  49. // resize output
  50. B = DerivedB::Zero(dim==1?n:m);
  51. const auto func_wrap = [&func,&B,&dim](const Index i, const Index j, const AType v)
  52. {
  53. if(dim == 1)
  54. {
  55. B(j) = i == 0? v : func(B(j),v);
  56. }else
  57. {
  58. B(i) = j == 0? v : func(B(i),v);
  59. }
  60. };
  61. for_each(A,func_wrap);
  62. }
  63. //#ifndef IGL_STATIC_LIBRARY
  64. //# include "redux.cpp"
  65. //#endif
  66. #endif