Эх сурвалжийг харах

Make igl/randperm to receive UniformRandomBitGenerator as argument

The reason behind this change is UniformRandomBitGenerator that
std::shuffle takes require max() and min() to be accessible
in compile-time. Therefore, template specializations for igl::randperm
will require max() and min() values as template variables, which is not
desirable. Instead, we declare template specializations for
pre-defined UniformRandomBitGenerator specified in c++11 standard.
Yucheol Jung 6 жил өмнө
parent
commit
0b6f488f28

+ 21 - 33
include/igl/randperm.cpp

@@ -9,49 +9,37 @@
 #include "colon.h"
 #include <algorithm>
 
-template <typename DerivedI>
+template <typename DerivedI, typename URBG>
 IGL_INLINE void igl::randperm(
   const int n,
   Eigen::PlainObjectBase<DerivedI> & I,
-  const int64_t rng_min,
-  const int64_t rng_max,
-  const std::function<int64_t()> &rng)
+  URBG urbg)
 {
   Eigen::VectorXi II;
   igl::colon(0,1,n-1,II);
   I = II;
 
-  // C++ named requirement : UniformRandomBitGenerator
-  // This signature is required for the third parameter of
-  // std::shuffle
-  struct RandPermURBG {
-  public:
-    using result_type = int64_t;
-    RandPermURBG(const result_type min,
-                 const result_type max,
-                 const std::function<result_type()> int_gen) :
-      m_min(min),
-      m_max(max),
-      m_int_gen(std::move(int_gen)) {}
-
-    result_type min() const noexcept { return m_min; }
-    result_type max() const noexcept { return m_max; }
-    result_type operator()() const { return m_int_gen(); };
-  private:
-    result_type m_min;
-    result_type m_max;
-    std::function<result_type()> m_int_gen;
-  };
-
-  const auto int_gen = (nullptr != rng) ?
-    rng :
-    []()->int64_t { return std::rand(); };
-
-  std::shuffle(I.data(),I.data()+n, RandPermURBG(rng_min, rng_max, int_gen));
+  std::shuffle(I.data(),I.data()+n, urbg);
 }
 
 #ifdef IGL_STATIC_LIBRARY
 // Explicit template instantiation
-template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1> >(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, const int64_t, const int64_t, const std::function<int64_t()>&);
-template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1> >(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, const int64_t, const int64_t, const std::function<int64_t()>&);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1>, std::minstd_rand0>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, std::minstd_rand0);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1>, std::minstd_rand0>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, std::minstd_rand0);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1>, std::minstd_rand>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, std::minstd_rand);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1>, std::minstd_rand>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, std::minstd_rand);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1>, std::mt19937>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, std::mt19937);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1>, std::mt19937>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, std::mt19937);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1>, std::mt19937_64>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, std::mt19937_64);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1>, std::mt19937_64>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, std::mt19937_64);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1>, std::ranlux24_base>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, std::ranlux24_base);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1>, std::ranlux24_base>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, std::ranlux24_base);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1>, std::ranlux48_base>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, std::ranlux48_base);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1>, std::ranlux48_base>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, std::ranlux48_base);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1>, std::ranlux24>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, std::ranlux24);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1>, std::ranlux24>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, std::ranlux24);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1>, std::ranlux48>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, std::ranlux48);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1>, std::ranlux48>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, std::ranlux48);
+template void igl::randperm<Eigen::Matrix<int, -1, 1, 0, -1, 1>, std::knuth_b>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, 1, 0, -1, 1> >&, std::knuth_b);
+template void igl::randperm<Eigen::Matrix<int, -1, -1, 0, -1, -1>, std::knuth_b>(int, Eigen::PlainObjectBase<Eigen::Matrix<int, -1, -1, 0, -1, -1> >&, std::knuth_b);
 #endif

+ 6 - 9
include/igl/randperm.h

@@ -9,26 +9,23 @@
 #define IGL_RANDPERM_H
 #include "igl_inline.h"
 #include <Eigen/Core>
-#include <functional>
+#include <random>
 namespace igl
 {
   // Like matlab's randperm(n) but minus 1
   //
   // Inputs:
   //   n  number of elements
-  //   rng_min  the minimum value of rng()
-  //   rng_max  the maximum value of rng()
-  //   rng random number generator. When not given the value,
-  //       randperm will use default random number generator std::rand()
+  //   urbg An instance of UnformRandomBitGenerator. When not given,
+  //        randperm will use default random bit generator std::minstd_rand
+  //        initialized with random seed generated by std::rand()
   // Outputs:
   //   I  n list of rand permutation of 0:n-1
-  template <typename DerivedI>
+  template <typename DerivedI, typename URBG=std::minstd_rand>
   IGL_INLINE void randperm(
     const int n,
     Eigen::PlainObjectBase<DerivedI> & I,
-    const int64_t rng_min=0,
-    const int64_t rng_max=RAND_MAX,
-    const std::function<int64_t()> &rng=nullptr);
+    URBG urbg=std::minstd_rand(std::rand()));
 }
 #ifndef IGL_STATIC_LIBRARY
 #  include "randperm.cpp"

+ 49 - 10
tests/include/igl/randperm.cpp

@@ -15,17 +15,56 @@ TEST(randperm, default_rng_reproduce_identity)
   test_common::assert_eq(I1, I2);
 }
 
-TEST(randperm, custom_rng_reproduce_identity)
+namespace randperm
 {
-  int n = 100;
-  Eigen::VectorXi I1, I2;
-  std::minstd_rand rng1(6);
-  std::minstd_rand rng2(6);
+  template<typename URBG>
+  void test_reproduce()
+  {
+    int n = 100;
+    Eigen::VectorXi I1, I2;
+    URBG rng1(6);
+    URBG rng2(6);
 
-  igl::randperm(100, I1, rng1.min(), rng1.max(),
-                [&rng1]()->int64_t { return rng1(); });
-  igl::randperm(100, I2, rng2.min(), rng2.max(),
-                [&rng2]()->int64_t { return rng2(); });
+    igl::randperm(100, I1, rng1);
+    igl::randperm(100, I2, rng2);
 
-  test_common::assert_eq(I1, I2);
+    test_common::assert_eq(I1, I2);
+  }
+}
+
+TEST(randperm, minstd_rand0_reproduce_identity)
+{
+  randperm::test_reproduce<std::minstd_rand0>();
+}
+TEST(randperm, minstd_rand_reproduce_identity)
+{
+  randperm::test_reproduce<std::minstd_rand>();
+}
+TEST(randperm, mt19937_reproduce_identity)
+{
+  randperm::test_reproduce<std::mt19937>();
+}
+TEST(randperm, mt19937_64_reproduce_identity)
+{
+  randperm::test_reproduce<std::mt19937_64>();
+}
+TEST(randperm, ranlux24_base_reproduce_identity)
+{
+  randperm::test_reproduce<std::ranlux24_base>();
+}
+TEST(randperm, ranlux48_base_reproduce_identity)
+{
+  randperm::test_reproduce<std::ranlux48_base>();
+}
+TEST(randperm, ranlux24_reproduce_identity)
+{
+  randperm::test_reproduce<std::ranlux24>();
+}
+TEST(randperm, ranlux48_reproduce_identity)
+{
+  randperm::test_reproduce<std::ranlux48>();
+}
+TEST(randperm, knuth_b_reproduce_identity)
+{
+  randperm::test_reproduce<std::knuth_b>();
 }