GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTreeNode.h
Go to the documentation of this file.
1 
31 #ifndef GRT_DECISION_TREE_NODE_HEADER
32 #define GRT_DECISION_TREE_NODE_HEADER
33 
34 #include "../../CoreAlgorithms/Tree/Node.h"
35 #include "../../CoreAlgorithms/Tree/Tree.h"
36 #include "../../DataStructures/ClassificationData.h"
37 
38 GRT_BEGIN_NAMESPACE
39 
40 class GRT_API DecisionTreeNode : public Node{
41 public:
46 
50  virtual ~DecisionTreeNode();
51 
65  virtual bool predict(const VectorFloat &x,VectorFloat &classLikelihoods);
66 
81  virtual bool computeBestSpilt( const UINT &trainingMode, const UINT &numSplittingSteps,const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError );
82 
89  virtual bool clear();
90 
98  virtual bool getModel( std::ostream &stream ) const;
99 
106  virtual Node* deepCopyNode() const;
107 
114  DecisionTreeNode* deepCopy() const;
115 
121  UINT getNodeSize() const;
122 
128  UINT getNumClasses() const;
129 
135  VectorFloat getClassProbabilities() const;
136 
144  bool setLeafNode( const UINT nodeSize, const VectorFloat &classProbabilities );
145 
152  bool setNodeSize(const UINT nodeSize);
153 
160  bool setClassProbabilities(const VectorFloat &classProbabilities);
161 
162  static UINT getClassLabelIndexValue(UINT classLabel,const Vector< UINT > &classLabels);
163 
164  using Node::predict;
165 
166 protected:
167  virtual bool computeBestSpiltBestIterativeSpilt( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
168 
169  errorLog << "computeBestSpiltBestIterativeSpilt(...) - Base class not overwritten!" << std::endl;
170 
171  return false;
172  }
173 
174  virtual bool computeBestSpiltBestRandomSpilt( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
175 
176  errorLog << "computeBestSpiltBestRandomSpilt(...) - Base class not overwritten!" << std::endl;
177 
178  return false;
179  }
180 
188  virtual bool saveParametersToFile( std::fstream &file ) const{
189 
190  if( !file.is_open() )
191  {
192  errorLog << "saveParametersToFile(fstream &file) - File is not open!" << std::endl;
193  return false;
194  }
195 
196  //Save the custom DecisionTreeNode parameters
197  file << "NodeSize: " << nodeSize << std::endl;
198  file << "NumClasses: " << classProbabilities.size() << std::endl;
199  file << "ClassProbabilities: ";
200  if( classProbabilities.size() > 0 ){
201  for(UINT i=0; i<classProbabilities.size(); i++){
202  file << classProbabilities[i];
203  if( i < classProbabilities.size()-1 ) file << "\t";
204  else file << std::endl;
205  }
206  }
207 
208  return true;
209  }
210 
217  virtual bool loadParametersFromFile( std::fstream &file ){
218 
219  if( !file.is_open() )
220  {
221  errorLog << "loadParametersFromFile(fstream &file) - File is not open!" << std::endl;
222  return false;
223  }
224 
225  classProbabilities.clear();
226 
227  std::string word;
228  UINT numClasses;
229 
230  //Load the custom DecisionTreeNode Parameters
231  file >> word;
232  if( word != "NodeSize:" ){
233  errorLog << "loadParametersFromFile(fstream &file) - Failed to find NodeSize header!" << std::endl;
234  return false;
235  }
236  file >> nodeSize;
237 
238  file >> word;
239  if( word != "NumClasses:" ){
240  errorLog << "loadParametersFromFile(fstream &file) - Failed to find NumClasses header!" << std::endl;
241  return false;
242  }
243  file >> numClasses;
244  if( numClasses > 0 )
245  classProbabilities.resize( numClasses );
246 
247  file >> word;
248  if( word != "ClassProbabilities:" ){
249  errorLog << "loadParametersFromFile(fstream &file) - Failed to find ClassProbabilities header!" << std::endl;
250  return false;
251  }
252  if( numClasses > 0 ){
253  for(UINT i=0; i<numClasses; i++){
254  file >> classProbabilities[i];
255  }
256  }
257 
258  return true;
259  }
260 
261  UINT nodeSize;
262  VectorFloat classProbabilities;
263 
264  static RegisterNode< DecisionTreeNode > registerModule;
265 };
266 
267 GRT_END_NAMESPACE
268 
269 #endif //GRT_DECISION_TREE_NODE_HEADER
270 
Definition: Node.h:37
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool getModel(std::ostream &stream) const
Definition: Node.cpp:120
virtual Node * deepCopyNode() const
Definition: Node.cpp:276
virtual bool clear()
Definition: Node.cpp:70
virtual bool saveParametersToFile(std::fstream &file) const
virtual bool loadParametersFromFile(std::fstream &file)
virtual bool predict(const VectorFloat &x)
Definition: Node.cpp:60