GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTreeNode.h
Go to the documentation of this file.
1 
27 #ifndef GRT_DECISION_TREE_NODE_HEADER
28 #define GRT_DECISION_TREE_NODE_HEADER
29 
30 #include "../../CoreAlgorithms/Tree/Node.h"
31 #include "../../CoreAlgorithms/Tree/Tree.h"
32 #include "../../DataStructures/ClassificationData.h"
33 
34 GRT_BEGIN_NAMESPACE
35 
41 class GRT_API DecisionTreeNode : public Node{
42 public:
46  DecisionTreeNode( const std::string id = "DecisionTreeNode" );
47 
51  DecisionTreeNode(const DecisionTreeNode &rhs) = delete;
52 
56  virtual ~DecisionTreeNode();
57 
61  DecisionTreeNode& operator=(const DecisionTreeNode &rhs) = delete;
62 
72  virtual bool predict_(VectorFloat &x,VectorFloat &classLikelihoods) override;
73 
88  virtual bool computeBestSplit( const UINT &trainingMode, const UINT &numSplittingSteps,const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError );
89 
96  virtual bool clear() override;
97 
105  virtual bool getModel( std::ostream &stream ) const override;
106 
113  virtual Node* deepCopy() const override;
114 
120  UINT getNodeSize() const;
121 
127  UINT getNumClasses() const;
128 
134  VectorFloat getClassProbabilities() const;
135 
143  bool setLeafNode( const UINT nodeSize, const VectorFloat &classProbabilities );
144 
151  bool setNodeSize(const UINT nodeSize);
152 
159  bool setClassProbabilities(const VectorFloat &classProbabilities);
160 
161  static UINT getClassLabelIndexValue(UINT classLabel,const Vector< UINT > &classLabels);
162 
163  using Node::predict;
164  using Node::predict_;
165 
166 protected:
167  virtual bool computeBestSplitBestIterativeSplit( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
168 
169  errorLog << __GRT_LOG__ << " Base class not overwritten!" << std::endl;
170 
171  return false;
172  }
173 
174  virtual bool computeBestSplitBestRandomSplit( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
175 
176  errorLog << __GRT_LOG__ << " Base class not overwritten!" << std::endl;
177 
178  return false;
179  }
180 
188  virtual bool saveParametersToFile( std::fstream &file ) const override{
189 
190  if( !file.is_open() )
191  {
192  errorLog << __GRT_LOG__ << " File is not open!" << std::endl;
193  return false;
194  }
195 
196  //Save the custom DecisionTreeNode parameters
197  file << "NodeSize: " << nodeSize << std::endl;
198  file << "NumClasses: " << classProbabilities.size() << std::endl;
199  file << "ClassProbabilities: ";
200  if( classProbabilities.size() > 0 ){
201  for(UINT i=0; i<classProbabilities.size(); i++){
202  file << classProbabilities[i];
203  if( i < classProbabilities.size()-1 ) file << "\t";
204  else file << std::endl;
205  }
206  }
207 
208  return true;
209  }
210 
217  virtual bool loadParametersFromFile( std::fstream &file ) override{
218 
219  if( !file.is_open() )
220  {
221  errorLog << __GRT_LOG__ << " File is not open!" << std::endl;
222  return false;
223  }
224 
225  classProbabilities.clear();
226 
227  std::string word;
228  UINT numClasses;
229 
230  //Load the custom DecisionTreeNode Parameters
231  file >> word;
232  if( word != "NodeSize:" ){
233  errorLog << __GRT_LOG__ << " Failed to find NodeSize header!" << std::endl;
234  return false;
235  }
236  file >> nodeSize;
237 
238  file >> word;
239  if( word != "NumClasses:" ){
240  errorLog << __GRT_LOG__ << " Failed to find NumClasses header!" << std::endl;
241  return false;
242  }
243  file >> numClasses;
244  if( numClasses > 0 )
245  classProbabilities.resize( numClasses );
246 
247  file >> word;
248  if( word != "ClassProbabilities:" ){
249  errorLog << __GRT_LOG__ << " Failed to find ClassProbabilities header!" << std::endl;
250  return false;
251  }
252  if( numClasses > 0 ){
253  for(UINT i=0; i<numClasses; i++){
254  file >> classProbabilities[i];
255  }
256  }
257 
258  return true;
259  }
260 
261  UINT nodeSize;
262  VectorFloat classProbabilities;
263 
264  static RegisterNode< DecisionTreeNode > registerModule;
265 };
266 
267 GRT_END_NAMESPACE
268 
269 #endif //GRT_DECISION_TREE_NODE_HEADER
270 
virtual bool predict(VectorFloat inputVector)
Definition: MLBase.cpp:135
virtual bool getModel(std::ostream &stream) const override
Definition: Node.cpp:116
Definition: Node.h:37
virtual bool predict_(VectorFloat &x) override
Definition: Node.cpp:56
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool loadParametersFromFile(std::fstream &file) override
virtual bool saveParametersToFile(std::fstream &file) const override
virtual bool clear() override
Definition: Node.cpp:66
virtual Node * deepCopy() const
Definition: Node.cpp:272