GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
RegressionTree.h
Go to the documentation of this file.
1 
26 #ifndef GRT_REGRESSION_TREE_HEADER
27 #define GRT_REGRESSION_TREE_HEADER
28 
29 #include "../../CoreModules/Regressifier.h"
30 #include "../../CoreAlgorithms/Tree/Tree.h"
31 #include "RegressionTreeNode.h"
32 
33 GRT_BEGIN_NAMESPACE
34 
39 class GRT_API RegressionTree : public Regressifier
40 {
41 public:
53  RegressionTree(const UINT numSplittingSteps=100,const UINT minNumSamplesPerNode=5,const UINT maxDepth=10,const bool removeFeaturesAtEachSpilt = false,const Tree::TrainingMode trainingMode = Tree::BEST_ITERATIVE_SPILT,const bool useScaling=false,const Float minRMSErrorPerNode = 0.01);
54 
60  RegressionTree(const RegressionTree &rhs);
61 
65  virtual ~RegressionTree(void);
66 
73  RegressionTree &operator=(const RegressionTree &rhs);
74 
82  virtual bool deepCopyFrom(const Regressifier *regressifier) override;
83 
91  virtual bool train_(RegressionData &trainingData) override;
92 
100  virtual bool predict_(VectorFloat &inputVector) override;
101 
108  virtual bool clear() override;
109 
115  virtual bool print() const override;
116 
124  virtual bool save( std::fstream &file ) const override;
125 
133  virtual bool load( std::fstream &file ) override;
134 
142  RegressionTreeNode* deepCopyTree() const;
143 
149  const RegressionTreeNode* getTree() const;
150 
157  Float getMinRMSErrorPerNode() const;
158 
164  Tree::TrainingMode getTrainingMode() const;
165 
174  UINT getNumSplittingSteps() const;
175 
182  UINT getMinNumSamplesPerNode() const;
183 
189  UINT getMaxDepth() const;
190 
196  UINT getPredictedNodeID() const;
197 
203  bool getRemoveFeaturesAtEachSpilt() const;
204 
211  bool setTrainingMode(const Tree::TrainingMode trainingMode);
212 
225  bool setNumSplittingSteps(const UINT numSplittingSteps);
226 
235  bool setMinNumSamplesPerNode(const UINT minNumSamplesPerNode);
236 
244  bool setMaxDepth(const UINT maxDepth);
245 
254  bool setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt);
255 
262  bool setMinRMSErrorPerNode(const Float minRMSErrorPerNode);
263 
269  static std::string getId();
270 
271  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
272  using MLBase::save;
273  using MLBase::load;
274  using MLBase::train_;
275  using MLBase::predict_;
276 
277 protected:
279  UINT minNumSamplesPerNode;
280  UINT maxDepth;
281  UINT numSplittingSteps;
282  bool removeFeaturesAtEachSpilt;
283  Tree::TrainingMode trainingMode;
284  Float minRMSErrorPerNode;
285 
286  RegressionTreeNode* buildTree( const RegressionData &trainingData, RegressionTreeNode *parent, Vector< UINT > features, UINT nodeID );
287  bool computeBestSpilt( const RegressionData &trainingData, const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
288  bool computeBestSpiltBestIterativeSpilt( const RegressionData &trainingData, const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
289  //bool computeBestSpiltBestRandomSpilt( const RegressionData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &threshold, Float &minError );
290  bool computeNodeRegressionData( const RegressionData &trainingData, VectorFloat &regressionData );
291 
292 private:
293  static RegisterRegressifierModule< RegressionTree > registerModule;
294  static const std::string id;
295 
296 };
297 
298 GRT_END_NAMESPACE
299 
300 #endif //GRT_REGRESSION_TREE_HEADER
301 
std::string getId() const
Definition: GRTBase.cpp:85
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:137
Definition: Node.h:37
virtual bool clear() override
This file implements a RegressionTreeNode, which is a specific type of node used for a RegressionTree...
virtual bool save(const std::string &filename) const
Definition: MLBase.cpp:167
This class implements a basic Regression Tree.
virtual bool deepCopyFrom(const Regressifier *regressifier)
Definition: Regressifier.h:64
virtual bool print() const
Definition: MLBase.cpp:165
Node * tree
<Tell the compiler we are using the base class predict method to stop hidden virtual function warning...
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:109
virtual bool load(const std::string &filename)
Definition: MLBase.cpp:190