LinRegression.h 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. /**
  2. * @file LinRegression.h
  3. * @brief Algorithm for linear regression
  4. * @author Frank Prüfer
  5. * @date 08/13/2013
  6. */
  7. #ifndef LINREGRESSIONINCLUDE
  8. #define LINREGRESSIONINCLUDE
  9. #include "vislearning/regression/regressionbase/RegressionAlgorithm.h"
  10. #include <vector>
  11. #include "core/vector/VectorT.h"
  12. #include "core/vector/MatrixT.h"
  13. namespace OBJREC
  14. {
  15. class LinRegression : public RegressionAlgorithm
  16. {
  17. protected:
  18. /** vector containing all model parameters */
  19. std::vector<double> modelParams;
  20. /** dimensionality of the model (i.e. number of model parameters) */
  21. uint dim;
  22. public:
  23. /** simple constructor */
  24. LinRegression();
  25. /** constructor, specifying the dimensionality of the model*/
  26. LinRegression(uint dimension);
  27. /** copy constructor */
  28. LinRegression ( const LinRegression & src );
  29. /** simple destructor */
  30. virtual ~LinRegression();
  31. /** clone function */
  32. LinRegression* clone (void) const;
  33. /** method to learn model parameters */
  34. void teach ( const NICE::VVector & x, const NICE::Vector & y );
  35. /** returns model parameters as a vector */
  36. std::vector<double> getModelParams();
  37. /** method to predict function value */
  38. double predict ( const NICE::Vector & x );
  39. };
  40. } //namespace
  41. #endif