GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
RandomForests.h
1 
35 #ifndef GRT_RANDOM_FORESTS_HEADER
36 #define GRT_RANDOM_FORESTS_HEADER
37 
38 #include "../DecisionTree/DecisionTree.h"
39 
40 GRT_BEGIN_NAMESPACE
41 
42 class GRT_API RandomForests : public Classifier
43 {
44  public:
57  RandomForests(const DecisionTreeNode &decisionTreeNode = DecisionTreeClusterNode(),
58  const UINT forestSize=10,
59  const UINT numRandomSplits=100,
60  const UINT minNumSamplesPerNode=5,
61  const UINT maxDepth=10,
62  const UINT trainingMode = DecisionTree::BEST_RANDOM_SPLIT,
63  const bool removeFeaturesAtEachSpilt = true,
64  const bool useScaling=false,
65  const Float bootstrappedDatasetWeight = 0.8);
66 
72  RandomForests(const RandomForests &rhs);
73 
77  virtual ~RandomForests(void);
78 
85  RandomForests &operator=(const RandomForests &rhs);
86 
94  virtual bool deepCopyFrom(const Classifier *classifier);
95 
103  virtual bool train_(ClassificationData &trainingData);
104 
112  virtual bool predict_(VectorDouble &inputVector);
113 
119  virtual bool clear();
120 
126  virtual bool print() const;
127 
135  virtual bool save( std::fstream &file ) const;
136 
144  virtual bool load( std::fstream &file );
145 
154  bool combineModels( const RandomForests &forest );
155 
161  UINT getForestSize() const;
162 
168  UINT getNumRandomSplits() const;
169 
176  UINT getMinNumSamplesPerNode() const;
177 
183  UINT getMaxDepth() const;
184 
190  UINT getTrainingMode() const;
191 
197  const Vector< DecisionTreeNode* > &getForest() const;
198 
206  bool getRemoveFeaturesAtEachSpilt() const;
207 
214  Float getBootstrappedDatasetWeight() const;
215 
222  DecisionTreeNode* getTree( const UINT index ) const;
223 
229  DecisionTreeNode* deepCopyDecisionTreeNode() const;
230 
243  VectorDouble getFeatureWeights( const bool normWeights = true ) const;
244 
257  MatrixDouble getLeafNodeFeatureWeights( const bool normWeights = true ) const;
258 
265  bool setForestSize(const UINT forestSize);
266 
276  bool setNumRandomSplits(const UINT numSplittingSteps);
277 
286  bool setMinNumSamplesPerNode(const UINT minNumSamplesPerNode);
287 
295  bool setMaxDepth(const UINT maxDepth);
296 
305  bool setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt);
306 
313  bool setTrainingMode(const UINT trainingMode);
314 
320  bool setDecisionTreeNode( const DecisionTreeNode &node );
321 
329  bool setBootstrappedDatasetWeight( const Float bootstrappedDatasetWeight );
330 
336  static std::string getId();
337 
338  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
339  using MLBase::save;
340  using MLBase::load;
341 
342 protected:
343 
344  UINT forestSize;
345  UINT numRandomSplits;
346  UINT minNumSamplesPerNode;
347  UINT maxDepth;
348  UINT trainingMode;
349  bool removeFeaturesAtEachSpilt;
350  Float bootstrappedDatasetWeight;
351  DecisionTreeNode* decisionTreeNode;
353 
354 private:
355  static RegisterClassifierModule< RandomForests > registerModule;
356  static std::string id;
357 
358 };
359 
360 GRT_END_NAMESPACE
361 
362 #endif //GRT_RANDOM_FORESTS_HEADER
363 
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:115
virtual bool save(const std::string filename) const
Definition: MLBase.cpp:143
virtual bool load(const std::string filename)
Definition: MLBase.cpp:167
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: Classifier.h:63
virtual bool print() const
Definition: MLBase.cpp:141
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:91
virtual bool clear()
Definition: Classifier.cpp:142