testILSConjugateGradients.cpp 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. /**
  2. * @file testILSConjugateGradients.cpp
  3. * @author Paul Bodesheim
  4. * @date 23/01/2012
  5. * @brief test routine for Iterative Linear Solver: Conjugate Gradients Method (CGM)
  6. */
  7. #include "core/vector/MatrixT.h"
  8. #include "core/vector/VectorT.h"
  9. #include <stdio.h>
  10. #include <ctime>
  11. #include "iostream"
  12. #include "core/basics/Exception.h"
  13. #include "core/vector/Algorithms.h"
  14. #include "core/algebra/ILSConjugateGradients.h"
  15. #include "core/algebra/GMStandard.h"
  16. using namespace std;
  17. using namespace NICE;
  18. int main(int argc, char* argv[])
  19. {
  20. int mySize = 20; // number of equations
  21. FILE * logfile;
  22. std::string logfilename;
  23. if ( argc < 2 )
  24. logfilename = "/home/bodesheim/testILS-CGM.log";
  25. else
  26. logfilename = argv[1];
  27. logfile = fopen(logfilename.c_str(), "w");
  28. // generate matrix A
  29. Matrix A(mySize,mySize,0.0);
  30. fprintf(logfile, "A:\n");
  31. for (uint i = 0; i < A.rows(); i++)
  32. {
  33. for (uint j = 0; j < A.cols(); j++)
  34. {
  35. if ( j == i ) A(i,j) = (i+1)+(j+1);
  36. else {
  37. A(i,j) = sqrt((i+1)*(j+1));
  38. }
  39. fprintf(logfile, "%f ",A(i,j));
  40. }
  41. fprintf(logfile, "\n");
  42. }
  43. // generate vector b (RHS of LS)
  44. Vector b(mySize,0.0);
  45. fprintf(logfile, "b:\n");
  46. for (uint i = 0; i < b.size(); i++)
  47. {
  48. b(i) = (i+1)*sqrt(i+1);
  49. fprintf(logfile, "%f ",b(i));
  50. }
  51. fprintf(logfile, "\n");
  52. // solve Ax = b
  53. Vector x(mySize,0.0);
  54. ILSConjugateGradients cgm(true,mySize);
  55. //tic
  56. time_t start = clock();
  57. cgm.solveLin(GMStandard(A),b,x);
  58. //toc
  59. float duration = (float) (clock() - start);
  60. std::cerr << "Time for solveLin: " << duration/CLOCKS_PER_SEC << std::endl;
  61. fprintf(logfile, "x:\n");
  62. for (uint i = 0; i < x.size(); i++)
  63. {
  64. fprintf(logfile, "%f ",x(i));
  65. }
  66. fprintf(logfile, "\n");
  67. // check result
  68. Vector Ax(mySize,0.0);
  69. Ax = A*x;
  70. fprintf(logfile, "A*x:\n");
  71. for (uint i = 0; i < Ax.size(); i++)
  72. {
  73. fprintf(logfile, "%f ",Ax(i));
  74. }
  75. fprintf(logfile, "\n");
  76. fclose(logfile);
  77. return 0;
  78. }