GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClusterTree.h
Go to the documentation of this file.
1 
31 #ifndef GRT_CLUSTER_TREE_HEADER
32 #define GRT_CLUSTER_TREE_HEADER
33 
34 #include "../../CoreModules/Clusterer.h"
35 #include "../../CoreAlgorithms/Tree/Tree.h"
36 #include "ClusterTreeNode.h"
37 
38 GRT_BEGIN_NAMESPACE
39 
40 class GRT_API ClusterTree : public Clusterer
41 {
42 public:
54  ClusterTree(const UINT numSplittingSteps=100,const UINT minNumSamplesPerNode=5,const UINT maxDepth=10,const bool removeFeaturesAtEachSplit = false,const Tree::TrainingMode trainingMode = Tree::BEST_ITERATIVE_SPILT,const bool useScaling=false,const Float minRMSErrorPerNode = 0.01);
55 
61  ClusterTree(const ClusterTree &rhs);
62 
66  virtual ~ClusterTree(void);
67 
74  ClusterTree &operator=(const ClusterTree &rhs);
75 
83  virtual bool deepCopyFrom(const Clusterer *cluster) override;
84 
92  virtual bool train_(MatrixFloat &trainingData) override;
93 
101  virtual bool predict_(VectorFloat &inputVector) override;
102 
109  virtual bool clear() override;
110 
116  virtual bool print() const override;
117 
125  virtual bool saveModelToFile( std::fstream &file ) const override;
126 
134  virtual bool loadModelFromFile( std::fstream &file ) override;
135 
143  ClusterTreeNode* deepCopyTree() const;
144 
150  const ClusterTreeNode* getTree() const;
151 
158  UINT getPredictedClusterLabel() const;
159 
166  Float getMinRMSErrorPerNode() const;
167 
173  Tree::TrainingMode getTrainingMode() const;
174 
183  UINT getNumSplittingSteps() const;
184 
191  UINT getMinNumSamplesPerNode() const;
192 
198  UINT getMaxDepth() const;
199 
205  UINT getPredictedNodeID() const;
206 
212  bool getRemoveFeaturesAtEachSplit() const;
213 
220  bool setTrainingMode(const Tree::TrainingMode trainingMode);
221 
234  bool setNumSplittingSteps(const UINT numSplittingSteps);
235 
244  bool setMinNumSamplesPerNode(const UINT minNumSamplesPerNode);
245 
253  bool setMaxDepth(const UINT maxDepth);
254 
263  bool setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit);
264 
270  bool setMinRMSErrorPerNode(const Float minRMSErrorPerNode);
271 
272  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
273  using MLBase::saveModelToFile;
274  using MLBase::loadModelFromFile;
275  using MLBase::train;
276  using MLBase::predict;
277 
283  static std::string getId();
284 
285 protected:
286  Node *tree;
287  UINT minNumSamplesPerNode;
288  UINT maxDepth;
289  UINT numSplittingSteps;
290  bool removeFeaturesAtEachSplit;
291  Tree::TrainingMode trainingMode;
292  Float minRMSErrorPerNode;
293 
294  ClusterTreeNode* buildTree( const MatrixFloat &trainingData, ClusterTreeNode *parent, Vector< UINT > features, UINT &clusterLabel, UINT nodeID );
295  bool computeBestSplit( const MatrixFloat &trainingData, const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
296  bool computeBestSplitBestIterativeSplit( const MatrixFloat &trainingData, const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
297  bool computeBestSplitBestRandomSplit( const MatrixFloat &trainingData, const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
298 
299 private:
300  static RegisterClustererModule< ClusterTree > registerModule;
301  static const std::string id;
302 };
303 
304 GRT_END_NAMESPACE
305 
306 #endif //GRT_CLUSTER_TREE_HEADER
307 
std::string getId() const
Definition: GRTBase.cpp:85
virtual bool predict(VectorFloat inputVector)
Definition: MLBase.cpp:135
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:137
Definition: Node.h:37
virtual bool clear() override
Definition: Clusterer.cpp:144
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:107
virtual bool deepCopyFrom(const Clusterer *clusterer)
Definition: Clusterer.h:59
virtual bool train_(MatrixFloat &trainingData) override
Definition: Clusterer.cpp:116
This file implements a ClusterTreeNode, which is a specific type of node used for a ClusterTree...
virtual bool print() const
Definition: MLBase.cpp:165
UINT getPredictedClusterLabel() const
Definition: Clusterer.cpp:236