GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
RandomForests.h
Go to the documentation of this file.
1 
26 #ifndef GRT_RANDOM_FORESTS_HEADER
27 #define GRT_RANDOM_FORESTS_HEADER
28 
29 #include "../DecisionTree/DecisionTree.h"
30 
31 GRT_BEGIN_NAMESPACE
32 
43 class GRT_API RandomForests : public Classifier
44 {
45 public:
58  RandomForests(const DecisionTreeNode &decisionTreeNode = DecisionTreeClusterNode(),
59  const UINT forestSize=10,
60  const UINT numRandomSplits=100,
61  const UINT minNumSamplesPerNode=5,
62  const UINT maxDepth=10,
63  const Tree::TrainingMode trainingMode = Tree::BEST_RANDOM_SPLIT,
64  const bool removeFeaturesAtEachSplit = true,
65  const bool useScaling=false,
66  const Float bootstrappedDatasetWeight = 0.8);
67 
73  RandomForests(const RandomForests &rhs);
74 
78  virtual ~RandomForests(void);
79 
86  RandomForests &operator=(const RandomForests &rhs);
87 
95  virtual bool deepCopyFrom(const Classifier *classifier);
96 
104  virtual bool train_(ClassificationData &trainingData);
105 
113  virtual bool predict_(VectorDouble &inputVector);
114 
120  virtual bool clear();
121 
127  virtual bool print() const;
128 
136  virtual bool save( std::fstream &file ) const;
137 
145  virtual bool load( std::fstream &file );
146 
155  bool combineModels( const RandomForests &forest );
156 
162  UINT getForestSize() const;
163 
169  UINT getNumRandomSplits() const;
170 
177  UINT getMinNumSamplesPerNode() const;
178 
184  UINT getMaxDepth() const;
185 
191  UINT getTrainingMode() const;
192 
198  const Vector< DecisionTreeNode* > &getForest() const;
199 
207  bool getRemoveFeaturesAtEachSplit() const;
208 
215  Float getBootstrappedDatasetWeight() const;
216 
223  DecisionTreeNode* getTree( const UINT index ) const;
224 
230  DecisionTreeNode* deepCopyDecisionTreeNode() const;
231 
244  VectorDouble getFeatureWeights( const bool normWeights = true ) const;
245 
258  MatrixDouble getLeafNodeFeatureWeights( const bool normWeights = true ) const;
259 
266  bool setForestSize(const UINT forestSize);
267 
277  bool setNumRandomSplits(const UINT numSplittingSteps);
278 
287  bool setMinNumSamplesPerNode(const UINT minNumSamplesPerNode);
288 
296  bool setMaxDepth(const UINT maxDepth);
297 
306  bool setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit);
307 
311  GRT_DEPRECATED_MSG( "setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt) is deprecated, use setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit) instead", bool setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt) );
312 
319  bool setTrainingMode(const Tree::TrainingMode trainingMode);
320 
326  bool setDecisionTreeNode( const DecisionTreeNode &node );
327 
335  bool setBootstrappedDatasetWeight( const Float bootstrappedDatasetWeight );
336 
342  static std::string getId();
343 
344  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
345  using MLBase::save;
346  using MLBase::load;
347 
348 protected:
349 
350  UINT forestSize;
351  UINT numRandomSplits;
352  UINT minNumSamplesPerNode;
353  UINT maxDepth;
354  Tree::TrainingMode trainingMode;
355  bool removeFeaturesAtEachSplit;
356  Float bootstrappedDatasetWeight;
357  DecisionTreeNode* decisionTreeNode;
359 
360 private:
361  static RegisterClassifierModule< RandomForests > registerModule;
362  static const std::string id;
363 
364 };
365 
366 GRT_END_NAMESPACE
367 
368 #endif //GRT_RANDOM_FORESTS_HEADER
369 
std::string getId() const
Definition: GRTBase.cpp:85
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:137
virtual bool save(const std::string &filename) const
Definition: MLBase.cpp:167
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: Classifier.h:64
virtual bool print() const
Definition: MLBase.cpp:165
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:109
virtual bool load(const std::string &filename)
Definition: MLBase.cpp:190
virtual bool clear()
Definition: Classifier.cpp:151
This is the main base class that all GRT Classification algorithms should inherit from...
Definition: Classifier.h:41