knn_octree.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #include "knn_octree.h"
  2. #include "parallel_for.h"
  3. #include <cmath>
  4. namespace igl {
  5. template <typename DerivedP, typename KType, typename IndexType,
  6. typename CentersType, typename WidthsType, typename DerivedI>
  7. IGL_INLINE void knn_octree(
  8. const Eigen::MatrixBase<DerivedP>& P,
  9. const KType & k,
  10. const std::vector<std::vector<IndexType> > & point_indices,
  11. const std::vector<Eigen::Matrix<IndexType,8,1>,
  12. Eigen::aligned_allocator<Eigen::Matrix<IndexType,8,1> > > & children,
  13. const std::vector<Eigen::Matrix<CentersType,1,3>,
  14. Eigen::aligned_allocator<Eigen::Matrix<CentersType,1,3> > > & centers,
  15. const std::vector<WidthsType> & widths,
  16. Eigen::PlainObjectBase<DerivedI> & I)
  17. {
  18. typedef Eigen::Matrix<typename DerivedP::Scalar, 1, 3> RowVector3PType;
  19. const int n = P.rows();
  20. const KType real_k = std::min(n,k);
  21. auto distance_to_width_one_cube = [](RowVector3PType point){
  22. return std::sqrt(std::pow(std::max(std::abs(point(0))-1,0.0),2)
  23. + std::pow(std::max(std::abs(point(1))-1,0.0),2)
  24. + std::pow(std::max(std::abs(point(2))-1,0.0),2));
  25. };
  26. auto distance_to_cube = [&distance_to_width_one_cube]
  27. (RowVector3PType point,
  28. Eigen::Matrix<CentersType,1,3> cube_center,
  29. WidthsType cube_width){
  30. RowVector3PType transformed_point = (point-cube_center)/cube_width;
  31. return cube_width*distance_to_width_one_cube(transformed_point);
  32. };
  33. I.resize(n,real_k);
  34. igl::parallel_for(n,[&](int i)
  35. {
  36. int points_found = 0;
  37. RowVector3PType point_of_interest = P.row(i);
  38. //To make my priority queue take both points and octree cells,
  39. //I use the indices 0 to n-1 for the n points,
  40. // and the indices n to n+m-1 for the m octree cells
  41. // Using lambda to compare elements.
  42. auto cmp = [&point_of_interest, &P, &centers, &widths,
  43. &n, &distance_to_cube](int left, int right) {
  44. double leftdistance, rightdistance;
  45. if(left < n){ //left is a point index
  46. leftdistance = (P.row(left) - point_of_interest).norm();
  47. } else { //left is an octree cell
  48. leftdistance = distance_to_cube(point_of_interest,
  49. centers.at(left-n),
  50. widths.at(left-n));
  51. }
  52. if(right < n){ //left is a point index
  53. rightdistance = (P.row(right) - point_of_interest).norm();
  54. } else { //left is an octree cell
  55. rightdistance = distance_to_cube(point_of_interest,
  56. centers.at(right-n),
  57. widths.at(right-n));
  58. }
  59. return leftdistance >= rightdistance;
  60. };
  61. std::priority_queue<IndexType, std::vector<IndexType>,
  62. decltype(cmp)> queue(cmp);
  63. queue.push(n); //This is the 0th octree cell (ie the root)
  64. while(points_found < real_k){
  65. IndexType curr_cell_or_point = queue.top();
  66. queue.pop();
  67. if(curr_cell_or_point < n){ //current index is for is a point
  68. I(i,points_found) = curr_cell_or_point;
  69. points_found++;
  70. } else {
  71. IndexType curr_cell = curr_cell_or_point - n;
  72. if(children.at(curr_cell)(0) == -1){ //In the case of a leaf
  73. if(point_indices.at(curr_cell).size() > 0){
  74. //Assumption: Leaves either have one point, or none
  75. queue.push(point_indices.at(curr_cell).at(0));
  76. }
  77. } else { //Not a leaf
  78. for(int j = 0; j < 8; j++){
  79. //+n to adjust for the octree cells
  80. queue.push(children.at(curr_cell)(j)+n);
  81. }
  82. }
  83. }
  84. }
  85. },1000);
  86. }
  87. }