GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTree.h
Go to the documentation of this file.
1 
26 #ifndef GRT_DECISION_TREE_HEADER
27 #define GRT_DECISION_TREE_HEADER
28 
29 #include "../../CoreModules/Classifier.h"
30 #include "../../CoreAlgorithms/Tree/Tree.h"
31 #include "DecisionTreeNode.h"
35 
36 GRT_BEGIN_NAMESPACE
37 
47 class GRT_API DecisionTree : public Classifier
48 {
49 public:
61  DecisionTree(const DecisionTreeNode &decisionTreeNode = DecisionTreeClusterNode(),const UINT minNumSamplesPerNode=5,const UINT maxDepth=10,const bool removeFeaturesAtEachSplit = false,const Tree::TrainingMode trainingMode = Tree::TrainingMode::BEST_ITERATIVE_SPILT,const UINT numSplittingSteps=100,const bool useScaling=false );
62 
68  DecisionTree(const DecisionTree &rhs);
69 
73  virtual ~DecisionTree(void);
74 
81  DecisionTree &operator=(const DecisionTree &rhs);
82 
90  virtual bool deepCopyFrom(const Classifier *classifier) override;
91 
99  virtual bool train_(ClassificationData &trainingData) override;
100 
108  virtual bool predict_(VectorFloat &inputVector) override;
109 
116  virtual bool clear() override;
117 
124  virtual bool recomputeNullRejectionThresholds() override;
125 
133  virtual bool save( std::fstream &file ) const override;
134 
142  virtual bool load( std::fstream &file ) override;
143 
151  virtual bool getModel( std::ostream &stream ) const override;
152 
160  DecisionTreeNode* deepCopyTree() const;
161 
167  DecisionTreeNode* deepCopyDecisionTreeNode() const;
168 
174  const DecisionTreeNode* getTree() const;
175 
181  Tree::TrainingMode getTrainingMode() const;
182 
191  UINT getNumSplittingSteps() const;
192 
199  UINT getMinNumSamplesPerNode() const;
200 
206  UINT getMaxDepth() const;
207 
213  UINT getPredictedNodeID() const;
214 
220  bool getRemoveFeaturesAtEachSplit() const;
221 
228  bool setTrainingMode(const Tree::TrainingMode trainingMode);
229 
242  bool setNumSplittingSteps(const UINT numSplittingSteps);
243 
252  bool setMinNumSamplesPerNode(const UINT minNumSamplesPerNode);
253 
261  bool setMaxDepth(const UINT maxDepth);
262 
271  bool setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit);
272 
276  GRT_DEPRECATED_MSG( "setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt) is deprecated, use setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit) instead", bool setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt) );
277 
283  bool setDecisionTreeNode( const DecisionTreeNode &node );
284 
290  static std::string getId();
291 
292  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
293  using MLBase::save;
294  using MLBase::load;
295  using MLBase::train_;
296  using MLBase::predict_;
297  using MLBase::print;
298 
299 protected:
300  bool loadLegacyModelFromFile_v1( std::fstream &file );
301  bool loadLegacyModelFromFile_v2( std::fstream &file );
302  bool loadLegacyModelFromFile_v3( std::fstream &file );
303 
304  bool trainTree( ClassificationData trainingData, const ClassificationData &trainingDataCopy, const ClassificationData &validationData, Vector< UINT > features );
305  DecisionTreeNode* buildTree(ClassificationData &trainingData, DecisionTreeNode *parent, Vector< UINT > features, const Vector< UINT > &classLabels, UINT nodeID );
306  Float getNodeDistance( const VectorFloat &x, const UINT nodeID );
307  Float getNodeDistance( const VectorFloat &x, const VectorFloat &y );
308 
309  DecisionTreeNode* decisionTreeNode;
310  std::map< UINT, VectorFloat > nodeClusters;
311  VectorFloat classClusterMean;
312  VectorFloat classClusterStdDev;
313 
314  DecisionTreeNode *tree;
315  UINT minNumSamplesPerNode;
316  UINT maxDepth;
317  UINT numSplittingSteps;
318  bool removeFeaturesAtEachSplit;
319  Tree::TrainingMode trainingMode;
320 
321 private:
322  static RegisterClassifierModule< DecisionTree > registerModule;
323  static const std::string id;
324 };
325 
326 GRT_END_NAMESPACE
327 
328 #endif //GRT_DECISION_TREE_HEADER
329 
std::string getId() const
Definition: GRTBase.cpp:85
virtual bool recomputeNullRejectionThresholds()
Definition: Classifier.h:255
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:137
virtual bool getModel(std::ostream &stream) const
Definition: MLBase.cpp:213
virtual bool save(const std::string &filename) const
Definition: MLBase.cpp:167
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: Classifier.h:64
virtual bool print() const
Definition: MLBase.cpp:165
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:109
virtual bool load(const std::string &filename)
Definition: MLBase.cpp:190
virtual bool clear()
Definition: Classifier.cpp:151
This is the main base class that all GRT Classification algorithms should inherit from...
Definition: Classifier.h:41