sparse_AtA_fast.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. // This file is part of libigl, a simple c++ geometry processing library.
  2. //
  3. // Copyright (C) 2017 Daniele Panozzo <daniele.panozzo@gmail.com>
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public License
  6. // v. 2.0. If a copy of the MPL was not distributed with this file, You can
  7. // obtain one at http://mozilla.org/MPL/2.0/.
  8. #include "sparse_AtA_fast.h"
  9. #include <iostream>
  10. #include <vector>
  11. #include <unordered_map>
  12. #include <map>
  13. #include <utility>
  14. IGL_INLINE void igl::sparse_AtA_fast_precompute(
  15. const Eigen::SparseMatrix<double>& A,
  16. Eigen::SparseMatrix<double>& AtA,
  17. igl::sparse_AtA_fast_data& data)
  18. {
  19. // 1 Compute At (this could be avoided, but performance-wise it will not make a difference)
  20. std::vector<std::vector<int> > Col_RowPtr;
  21. std::vector<std::vector<int> > Col_IndexPtr;
  22. Col_RowPtr.resize(A.cols());
  23. Col_IndexPtr.resize(A.cols());
  24. for (unsigned k=0; k<A.outerSize(); ++k)
  25. {
  26. unsigned outer_index = *(A.outerIndexPtr()+k);
  27. unsigned next_outer_index = (k+1 == A.outerSize()) ? A.nonZeros() : *(A.outerIndexPtr()+k+1);
  28. for (unsigned l=outer_index; l<next_outer_index; ++l)
  29. {
  30. int col = k;
  31. int row = *(A.innerIndexPtr()+l);
  32. int value_index = l;
  33. assert(col < A.cols());
  34. assert(col >= 0);
  35. assert(row < A.rows());
  36. assert(row >= 0);
  37. assert(value_index >= 0);
  38. assert(value_index < A.nonZeros());
  39. Col_RowPtr[col].push_back(row);
  40. Col_IndexPtr[col].push_back(value_index);
  41. }
  42. }
  43. Eigen::SparseMatrix<double> At = A.transpose();
  44. At.makeCompressed();
  45. AtA = At * A;
  46. AtA.makeCompressed();
  47. assert(AtA.isCompressed());
  48. // If weights are not provided, use 1
  49. if (data.W.size() == 0)
  50. data.W = Eigen::VectorXd::Ones(A.rows());
  51. assert(data.W.size() == A.rows());
  52. data.I_outer.reserve(AtA.outerSize());
  53. data.I_row.reserve(2*AtA.nonZeros());
  54. data.I_col.reserve(2*AtA.nonZeros());
  55. data.I_w.reserve(2*AtA.nonZeros());
  56. // 2 Construct the rules
  57. for (unsigned k=0; k<AtA.outerSize(); ++k)
  58. {
  59. unsigned outer_index = *(AtA.outerIndexPtr()+k);
  60. unsigned next_outer_index = (k+1 == AtA.outerSize()) ? AtA.nonZeros() : *(AtA.outerIndexPtr()+k+1);
  61. for (unsigned l=outer_index; l<next_outer_index; ++l)
  62. {
  63. int col = k;
  64. int row = *(AtA.innerIndexPtr()+l);
  65. int value_index = l;
  66. assert(col < AtA.cols());
  67. assert(col >= 0);
  68. assert(row < AtA.rows());
  69. assert(row >= 0);
  70. assert(value_index >= 0);
  71. assert(value_index < AtA.nonZeros());
  72. data.I_outer.push_back(data.I_row.size());
  73. // Find correspondences
  74. unsigned i=0;
  75. unsigned j=0;
  76. while (i<Col_RowPtr[row].size() && j<Col_RowPtr[col].size())
  77. {
  78. if (Col_RowPtr[row][i] == Col_RowPtr[col][j])
  79. {
  80. data.I_row.push_back(Col_IndexPtr[row][i]);
  81. data.I_col.push_back(Col_IndexPtr[col][j]);
  82. data.I_w.push_back(data.W[Col_RowPtr[col][j]]);
  83. ++i;
  84. ++j;
  85. } else
  86. if (Col_RowPtr[row][i] > Col_RowPtr[col][j])
  87. ++j;
  88. else
  89. ++i;
  90. }
  91. }
  92. }
  93. data.I_outer.push_back(data.I_row.size()); // makes it more efficient to iterate later on
  94. igl::sparse_AtA_fast(A,AtA,data);
  95. }
  96. IGL_INLINE void igl::sparse_AtA_fast(
  97. const Eigen::SparseMatrix<double>& A,
  98. Eigen::SparseMatrix<double>& AtA,
  99. const igl::sparse_AtA_fast_data& data)
  100. {
  101. for (unsigned i=0; i<data.I_outer.size()-1; ++i)
  102. {
  103. *(AtA.valuePtr() + i) = 0;
  104. for (unsigned j=data.I_outer[i]; j<data.I_outer[i+1]; ++j)
  105. *(AtA.valuePtr() + i) += *(A.valuePtr() + data.I_row[j]) * data.I_w[j] * *(A.valuePtr() + data.I_col[j]);
  106. }
  107. }
  108. #ifdef IGL_STATIC_LIBRARY
  109. #endif