浏览代码

special case for 3 cols

Former-commit-id: ae7db8663e8b8cb8cfa872afa61a739cd77f8273
Alec Jacobson 9 年之前
父节点
当前提交
8ff85292c2
共有 1 个文件被更改,包括 115 次插入4 次删除
  1. 115 4
      include/igl/sort.cpp

+ 115 - 4
include/igl/sort.cpp

@@ -15,6 +15,8 @@
 #include <cassert>
 #include <algorithm>
 #include <iostream>
+#include <thread>
+#include <functional>
 
 template <typename DerivedX, typename DerivedY, typename DerivedIX>
 IGL_INLINE void igl::sort(
@@ -27,9 +29,14 @@ IGL_INLINE void igl::sort(
   // get number of rows (or columns)
   int num_inner = (dim == 1 ? X.rows() : X.cols() );
   // Special case for swapping
-  if(num_inner == 2)
+  switch(num_inner)
   {
-    return igl::sort2(X,dim,ascending,Y,IX);
+    default:
+      break;
+    case 2:
+      return igl::sort2(X,dim,ascending,Y,IX);
+    case 3:
+      return igl::sort3(X,dim,ascending,Y,IX);
   }
   using namespace Eigen;
   // get number of columns (or rows)
@@ -85,9 +92,14 @@ IGL_INLINE void igl::sort_new(
   // get number of rows (or columns)
   int num_inner = (dim == 1 ? X.rows() : X.cols() );
   // Special case for swapping
-  if(num_inner == 2)
+  switch(num_inner)
   {
-    return igl::sort2(X,dim,ascending,Y,IX);
+    default:
+      break;
+    case 2:
+      return igl::sort2(X,dim,ascending,Y,IX);
+    case 3:
+      return igl::sort3(X,dim,ascending,Y,IX);
   }
   using namespace Eigen;
   // get number of columns (or rows)
@@ -180,6 +192,105 @@ IGL_INLINE void igl::sort2(
   }
 }
 
+template <typename DerivedX, typename DerivedY, typename DerivedIX>
+IGL_INLINE void igl::sort3(
+  const Eigen::PlainObjectBase<DerivedX>& X,
+  const int dim,
+  const bool ascending,
+  Eigen::PlainObjectBase<DerivedY>& Y,
+  Eigen::PlainObjectBase<DerivedIX>& IX)
+{
+  using namespace Eigen;
+  using namespace std;
+  typedef typename Eigen::PlainObjectBase<DerivedY>::Scalar YScalar;
+  Y = X.template cast<YScalar>();
+  // get number of columns (or rows)
+  int num_outer = (dim == 1 ? X.cols() : X.rows() );
+  // get number of rows (or columns)
+  int num_inner = (dim == 1 ? X.rows() : X.cols() );
+  assert(num_inner == 3);(void)num_inner;
+  typedef typename Eigen::PlainObjectBase<DerivedIX>::Scalar Index;
+  IX.resize(X.rows(),X.cols());
+  if(dim==1)
+  {
+    IX.row(0).setConstant(0);// = Eigen::PlainObjectBase<DerivedIX>::Zero(1,IX.cols());
+    IX.row(1).setConstant(1);// = Eigen::PlainObjectBase<DerivedIX>::Ones (1,IX.cols());
+    IX.row(2).setConstant(2);// = Eigen::PlainObjectBase<DerivedIX>::Ones (1,IX.cols());
+  }else
+  {
+    IX.col(0).setConstant(0);// = Eigen::PlainObjectBase<DerivedIX>::Zero(IX.rows(),1);
+    IX.col(1).setConstant(1);// = Eigen::PlainObjectBase<DerivedIX>::Ones (IX.rows(),1);
+    IX.col(2).setConstant(2);// = Eigen::PlainObjectBase<DerivedIX>::Ones (IX.rows(),1);
+  }
+
+  const int n = num_outer;
+  const size_t nthreads = n<8000?1:std::thread::hardware_concurrency();
+  {
+    std::vector<std::thread> threads(nthreads);
+    for(int t = 0;t<nthreads;t++)
+    {
+      threads[t] = std::thread(std::bind(
+        [&X,&Y,&IX,&dim,&ascending](const int bi, const int ei, const int t)
+      {
+        // loop over columns (or rows)
+        for(int i = bi;i<ei;i++)
+        {
+          YScalar & a = (dim==1 ? Y(0,i) : Y(i,0));
+          YScalar & b = (dim==1 ? Y(1,i) : Y(i,1));
+          YScalar & c = (dim==1 ? Y(2,i) : Y(i,2));
+          Index & ai = (dim==1 ? IX(0,i) : IX(i,0));
+          Index & bi = (dim==1 ? IX(1,i) : IX(i,1));
+          Index & ci = (dim==1 ? IX(2,i) : IX(i,2));
+          if(ascending)
+          {
+            // 123 132 213 231 312 321
+            if(a > b)
+            {
+              std::swap(a,b);
+              std::swap(ai,bi);
+            }
+            // 123 132 123 231 132 231
+            if(b > c)
+            {
+              std::swap(b,c);
+              std::swap(bi,ci);
+              // 123 123 123 213 123 213
+              if(a > b)
+              {
+                std::swap(a,b);
+                std::swap(ai,bi);
+              }
+              // 123 123 123 123 123 123
+            }
+          }else
+          {
+            // 123 132 213 231 312 321
+            if(a < b)
+            {
+              std::swap(a,b);
+              std::swap(ai,bi);
+            }
+            // 213 312 213 321 312 321
+            if(b < c)
+            {
+              std::swap(b,c);
+              std::swap(bi,ci);
+              // 231 321 231 321 321 321
+              if(a < b)
+              {
+                std::swap(a,b);
+                std::swap(ai,bi);
+              }
+              // 321 321 321 321 321 321
+            }
+          }
+        }
+      }, t*n/nthreads, (t+1)==nthreads?n:(t+1)*n/nthreads,t));
+    }
+    std::for_each(threads.begin(),threads.end(),[](std::thread& x){x.join();});
+  }
+}
+
 template <class T>
 IGL_INLINE void igl::sort(
   const std::vector<T> & unsorted,