Forráskód Böngészése

Make igl/randperm accept random number generator as argument

Yucheol Jung 6 éve
szülő
commit
e686604f4f
3 módosított fájl, 75 hozzáadás és 9 törlés
  1. 35 8
      include/igl/randperm.cpp
  2. 9 1
      include/igl/randperm.h
  3. 31 0
      tests/include/igl/randperm.cpp

+ 35 - 8
include/igl/randperm.cpp

@@ -7,24 +7,51 @@
 // obtain one at http://mozilla.org/MPL/2.0/.
 #include "randperm.h"
 #include "colon.h"
-#include <algorithm> 
-#include <random>
+#include <algorithm>
 
 template <typename DerivedI>
 IGL_INLINE void igl::randperm(
   const int n,
-  Eigen::PlainObjectBase<DerivedI> & I)
+  Eigen::PlainObjectBase<DerivedI> & I,
+  const int64_t rng_min,
+  const int64_t rng_max,
+  const std::function<int64_t()> &rng)
 {
   Eigen::VectorXi II;
   igl::colon(0,1,n-1,II);
   I = II;
-  std::random_device rd;
-  std::mt19937 mt(rd());
-  std::shuffle(I.data(),I.data()+n, mt);
+
+  // C++ named requirement : UniformRandomBitGenerator
+  // This signature is required for the third parameter of
+  // std::shuffle
+  struct RandPermURBG {
+  public:
+    using result_type = int64_t;
+    RandPermURBG(const result_type min,
+                 const result_type max,
+                 const std::function<result_type()> int_gen) :
+      m_min(min),
+      m_max(max),
+      m_int_gen(std::move(int_gen)) {}
+
+    result_type min() const noexcept { return m_min; }
+    result_type max() const noexcept { return m_max; }
+    result_type operator()() const { return m_int_gen(); };
+  private:
+    result_type m_min;
+    result_type m_max;
+    std::function<result_type()> m_int_gen;
+  };
+
+  const auto int_gen = (nullptr != rng) ?
+    rng :
+    []()->int64_t { return std::rand(); };
+
+  std::shuffle(I.data(),I.data()+n, RandPermURBG(rng_min, rng_max, int_gen));
 }
 
 #ifdef IGL_STATIC_LIBRARY
 // Explicit template instantiation
-template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1> >(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&);
-template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1> >(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1> >(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, const int64_t, const int64_t, const std::function<int64_t()>&);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1> >(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, const int64_t, const int64_t, const std::function<int64_t()>&);
 #endif

+ 9 - 1
include/igl/randperm.h

@@ -9,18 +9,26 @@
 #define IGL_RANDPERM_H
 #include "igl_inline.h"
 #include <Eigen/Core>
+#include <functional>
 namespace igl
 {
   // Like matlab's randperm(n) but minus 1
   //
   // Inputs:
   //   n  number of elements
+  //   rng_min  the minimum value of rng()
+  //   rng_max  the maximum value of rng()
+  //   rng random number generator. When not given the value,
+  //       randperm will use default random number generator std::rand()
   // Outputs:
   //   I  n list of rand permutation of 0:n-1
   template <typename DerivedI>
   IGL_INLINE void randperm(
     const int n,
-    Eigen::PlainObjectBase<DerivedI> & I);
+    Eigen::PlainObjectBase<DerivedI> & I,
+    const int64_t rng_min=0,
+    const int64_t rng_max=RAND_MAX,
+    const std::function<int64_t()> &rng=nullptr);
 }
 #ifndef IGL_STATIC_LIBRARY
 #  include "randperm.cpp"

+ 31 - 0
tests/include/igl/randperm.cpp

@@ -0,0 +1,31 @@
+#include <test_common.h>
+#include <igl/randperm.h>
+#include <random>
+
+TEST(randperm, default_rng_reproduce_identity)
+{
+  int n = 100;
+  Eigen::VectorXi I1, I2;
+
+  std::srand(6);
+  igl::randperm(100, I1);
+  std::srand(6);
+  igl::randperm(100, I2);
+
+  test_common::assert_eq(I1, I2);
+}
+
+TEST(randperm, custom_rng_reproduce_identity)
+{
+  int n = 100;
+  Eigen::VectorXi I1, I2;
+  std::minstd_rand rng1(6);
+  std::minstd_rand rng2(6);
+
+  igl::randperm(100, I1, rng1.min(), rng1.max(),
+                [&rng1]()->int64_t { return rng1(); });
+  igl::randperm(100, I2, rng2.min(), rng2.max(),
+                [&rng2]()->int64_t { return rng2(); });
+
+  test_common::assert_eq(I1, I2);
+}