SVD.h 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /*
  2. * NICE-Core - efficient algebra and computer vision methods
  3. * - libbasicvector - A simple vector library
  4. * See file License for license information.
  5. */
  6. #ifndef BASICVECTOR_SVD_H
  7. #define BASICVECTOR_SVD_H
  8. #include "core/vector/VectorT.h"
  9. #include "core/vector/MatrixT.h"
  10. #include "core/vector/RowMatrixT.h"
  11. #ifdef NICE_USELIB_LINAL
  12. #include <LinAl/algorithms.h>
  13. #endif
  14. namespace NICE {
  15. #ifdef NICE_USELIB_LINAL
  16. template<class T>
  17. class SVD {
  18. public:
  19. /**
  20. * Constructor
  21. * @param matrix Input matrix
  22. */
  23. inline SVD(const MatrixT<T>& matrix) {
  24. LinAl::MatrixCF<double> _u;
  25. LinAl::MatrixCF<double> _vt;
  26. LinAl::VectorCC<double> _s;
  27. LinAl::svdfull(matrix.linal(), _u, _s, _vt);
  28. u = _u;
  29. vt = _vt;
  30. s = _s;
  31. }
  32. /**
  33. * Get the first factor U.
  34. * @return
  35. */
  36. inline MatrixT<T> getU() {
  37. return u;
  38. }
  39. /**
  40. * Get the third factor V.
  41. * @return
  42. */
  43. inline MatrixT<T> getV() {
  44. return vt.transpose();
  45. }
  46. /**
  47. * Get the transposed third factor V^T.
  48. * @return
  49. */
  50. inline MatrixT<T> getVt() {
  51. return vt;
  52. }
  53. /**
  54. * Get the singular values as a vector.
  55. * @return
  56. */
  57. inline VectorT<T> getSingularValues() {
  58. return s;
  59. }
  60. /**
  61. * Get the singular values as a diagonal matrix.
  62. * @return
  63. */
  64. inline MatrixT<T> getS() {
  65. MatrixT<T> result(s.size(), s.size());
  66. result = 0.0;
  67. for (uint i = 0; i < s.size(); ++i) {
  68. result(i,i) = s[i];
  69. }
  70. return result;
  71. }
  72. // double norm2 ()
  73. // double cond ()
  74. // int rank ()
  75. private:
  76. MatrixT<T> u;
  77. VectorT<T> s;
  78. MatrixT<T> vt;
  79. };
  80. /** enforce predefined singular values of a square matrix */
  81. template<class T>
  82. inline void enforceSingularValues(MatrixT<T>& m, const VectorT<T>& sNew) {
  83. SVD<T> svd(m);
  84. MatrixT<T> u = svd.getU();
  85. MatrixT<T> s = svd.getS();
  86. MatrixT<T> vt = svd.getVt();
  87. for (unsigned int i = 0; i < sNew.size(); i++) {
  88. s(i, i) = sNew[i];
  89. }
  90. MatrixT<T> us;
  91. us.multiply(u, s);
  92. m.multiply(us, vt);
  93. }
  94. template<class T>
  95. inline void enforceRankDefect(MatrixT<T>& m, const uint defect) {
  96. SVD<T> svd(m);
  97. MatrixT<T> u = svd.getU();
  98. MatrixT<T> s = svd.getS();
  99. MatrixT<T> vt = svd.getVt();
  100. for (unsigned int i = 0; i < s.rows(); i++) {
  101. s(i, i) = ((int)i < (int)s.rows() - (int)defect ? s(i,i) : 0.0);
  102. }
  103. MatrixT<T> us;
  104. us.multiply(u, s);
  105. m.multiply(us, vt);
  106. }
  107. #else // no LinAl
  108. #ifndef SVDLINAL_WARNING
  109. #pragma message NICE_WARNING("SVD requires LinAl.")
  110. #define SVDLINAL_WARNING
  111. #endif
  112. #endif
  113. } // namespace
  114. #endif