GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
RegressionTree.h
Go to the documentation of this file.
1 
31 #ifndef GRT_REGRESSION_TREE_HEADER
32 #define GRT_REGRESSION_TREE_HEADER
33 
34 #include "../../CoreModules/Regressifier.h"
35 #include "../../CoreAlgorithms/Tree/Tree.h"
36 #include "RegressionTreeNode.h"
37 
38 GRT_BEGIN_NAMESPACE
39 
40 class GRT_API RegressionTree : public Tree, public Regressifier
41 {
42 public:
54  RegressionTree(const UINT numSplittingSteps=100,const UINT minNumSamplesPerNode=5,const UINT maxDepth=10,const bool removeFeaturesAtEachSpilt = false,const UINT trainingMode = BEST_ITERATIVE_SPILT,const bool useScaling=false,const Float minRMSErrorPerNode = 0.01);
55 
61  RegressionTree(const RegressionTree &rhs);
62 
66  virtual ~RegressionTree(void);
67 
74  RegressionTree &operator=(const RegressionTree &rhs);
75 
83  virtual bool deepCopyFrom(const Regressifier *regressifier);
84 
92  virtual bool train_(RegressionData &trainingData);
93 
101  virtual bool predict_(VectorFloat &inputVector);
102 
109  virtual bool clear();
110 
116  virtual bool print() const;
117 
125  virtual bool save( std::fstream &file ) const;
126 
134  virtual bool load( std::fstream &file );
135 
144 
150  const RegressionTreeNode* getTree() const;
151 
158  Float getMinRMSErrorPerNode() const;
159 
166  bool setMinRMSErrorPerNode(const Float minRMSErrorPerNode);
167 
168  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
169  using MLBase::save;
170  using MLBase::load;
171  using MLBase::train_;
172  using MLBase::predict_;
173 
174 protected:
176 
177  RegressionTreeNode* buildTree( const RegressionData &trainingData, RegressionTreeNode *parent, Vector< UINT > features, UINT nodeID );
178  bool computeBestSpilt( const RegressionData &trainingData, const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
179  bool computeBestSpiltBestIterativeSpilt( const RegressionData &trainingData, const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
180  //bool computeBestSpiltBestRandomSpilt( const RegressionData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &threshold, Float &minError );
181  bool computeNodeRegressionData( const RegressionData &trainingData, VectorFloat &regressionData );
182 
183  static RegisterRegressifierModule< RegressionTree > registerModule;
184 
185 };
186 
187 GRT_END_NAMESPACE
188 
189 #endif //GRT_REGRESSION_TREE_HEADER
190 
Float minRMSErrorPerNode
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:115
This file implements a RegressionTreeNode, which is a specific type of node used for a RegressionTree...
virtual bool save(const std::string filename) const
Definition: MLBase.cpp:143
virtual bool deepCopyFrom(const Regressifier *regressifier)
Definition: Regressifier.h:63
virtual bool load(const std::string filename)
Definition: MLBase.cpp:167
Definition: Tree.h:38
virtual bool print() const
Definition: MLBase.cpp:141
const Node * getTree() const
Definition: Tree.cpp:88
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:91
virtual Node * deepCopyTree() const
Definition: Tree.cpp:79
virtual bool clear()