GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
RegressionTreeNode.h
Go to the documentation of this file.
1 
30 #ifndef GRT_REGRESSION_TREE_NODE_HEADER
31 #define GRT_REGRESSION_TREE_NODE_HEADER
32 
33 #include "../../CoreAlgorithms/Tree/Node.h"
34 
35 GRT_BEGIN_NAMESPACE
36 
37 class RegressionTreeNode : public Node{
38 public:
43  nodeType = "RegressionTreeNode";
44  parent = NULL;
45  leftChild = NULL;
46  rightChild = NULL;
47  clear();
48  }
49 
54  clear();
55  }
56 
67  virtual bool predict_(VectorFloat &x) override{
68  if( x[ featureIndex ] >= threshold ) return true;
69  return false;
70  }
71 
85  virtual bool predict_(VectorFloat &x,VectorFloat &y) override{
86 
87  if( isLeafNode ){
88  y = this->regressionData;
89  return true;
90  }
91 
92  if( leftChild == NULL && rightChild == NULL )
93  return false;
94 
95  if( predict_( x ) ){
96  if( rightChild )
97  return rightChild->predict_( x, y );
98  }else{
99  if( leftChild )
100  return leftChild->predict_( x, y );
101  }
102 
103  return false;
104  }
105 
112  virtual bool clear() override{
113 
114  //Call the base class clear function
115  Node::clear();
116 
117  nodeSize = 0;
118  featureIndex = 0;
119  threshold = 0;
120  regressionData.clear();
121 
122  return true;
123  }
124 
131  virtual bool print() const override{
132 
133  std::string tab = "";
134  for(UINT i=0; i<depth; i++) tab += "\t";
135 
136  std::cout << tab << "depth: " << depth << " nodeSize: " << nodeSize << " featureIndex: " << featureIndex << " threshold " << threshold << " isLeafNode: " << isLeafNode << std::endl;
137  std::cout << tab << "RegressionData: ";
138  for(UINT i=0; i<regressionData.size(); i++){
139  std::cout << regressionData[i] << "\t";
140  }
141  std::cout << std::endl;
142 
143  if( leftChild != NULL ){
144  std::cout << tab << "LeftChild: " << std::endl;
145  leftChild->print();
146  }
147 
148  if( rightChild != NULL ){
149  std::cout << tab << "RightChild: " << std::endl;
150  rightChild->print();
151  }
152 
153  return true;
154  }
155 
162  virtual Node* deepCopy() const override{
163 
165 
166  if( node == NULL ){
167  return NULL;
168  }
169 
170  //Copy this node into the node
171  node->depth = this->depth;
172  node->isLeafNode = this->isLeafNode;
173  node->nodeSize = this->nodeSize;
174  node->featureIndex = this->featureIndex;
175  node->threshold = this->threshold;
176  node->regressionData = this->regressionData;
177 
178  //Recursively deep copy the left child
179  if( this->leftChild ){
180  node->leftChild = this->leftChild->deepCopy();
181  node->leftChild->setParent( node );
182  }
183 
184  //Recursively deep copy the right child
185  if( this->rightChild ){
186  node->rightChild = this->rightChild->deepCopy();
187  node->rightChild->setParent( node );
188  }
189 
190  return dynamic_cast< Node* >( node );
191  }
192 
193  RegressionTreeNode* deepCopyTree() const{
194  RegressionTreeNode *node = dynamic_cast< RegressionTreeNode* >( deepCopy() );
195  return node;
196  }
197 
207  bool set(const UINT nodeSize,const UINT featureIndex,const Float threshold,const VectorFloat &regressionData){
208  this->nodeSize = nodeSize;
209  this->featureIndex = featureIndex;
210  this->threshold = threshold;
211  this->regressionData = regressionData;
212  return true;
213  }
214 
215 protected:
223  virtual bool saveParametersToFile( std::fstream &file ) const override{
224 
225  if(!file.is_open())
226  {
227  errorLog << "saveParametersToFile(fstream &file) - File is not open!" << std::endl;
228  return false;
229  }
230 
231  //Save the custom ClusterTreeNode parameters
232  file << "NodeSize: " << nodeSize << std::endl;
233  file << "FeatureIndex: " << featureIndex << std::endl;
234  file << "Threshold: " << threshold << std::endl;
235  file << "RegressionDataSize: " << regressionData.getSize() << std::endl;
236  file << "RegressionData: ";
237  for(unsigned int i=0; i<regressionData.getSize(); i++){
238  file << regressionData[i] << " ";
239  }
240  file << std::endl;
241 
242  return true;
243  }
244 
251  virtual bool loadParametersFromFile( std::fstream &file ) override{
252 
253  if(!file.is_open())
254  {
255  errorLog << "load(fstream &file) - File is not open!" << std::endl;
256  return false;
257  }
258 
259  std::string word;
260  UINT regressionDataSize = 0;
261 
262  //Load the custom ClusterTreeNode Parameters
263  file >> word;
264  if( word != "NodeSize:" ){
265  errorLog << "loadParametersFromFile(fstream &file) - Failed to find NodeSize header!" << std::endl;
266  return false;
267  }
268  file >> nodeSize;
269 
270  file >> word;
271  if( word != "FeatureIndex:" ){
272  errorLog << "loadParametersFromFile(fstream &file) - Failed to find FeatureIndex header!" << std::endl;
273  return false;
274  }
275  file >> featureIndex;
276 
277  file >> word;
278  if( word != "Threshold:" ){
279  errorLog << "loadParametersFromFile(fstream &file) - Failed to find Threshold header!" << std::endl;
280  return false;
281  }
282  file >> threshold;
283 
284  file >> word;
285  if( word != "RegressionDataSize:" ){
286  errorLog << "loadParametersFromFile(fstream &file) - Failed to find RegressionDataSize header!" << std::endl;
287  return false;
288  }
289  file >> regressionDataSize;
290  regressionData.resize(regressionDataSize);
291 
292  file >> word;
293  if( word != "RegressionData:" ){
294  errorLog << "loadParametersFromFile(fstream &file) - Failed to find RegressionData header!" << std::endl;
295  return false;
296  }
297  for(unsigned int i=0; i<regressionData.getSize(); i++){
298  file >> regressionData[i];
299  }
300 
301  return true;
302  }
303 
304  UINT nodeSize;
305  UINT featureIndex;
306  Float threshold;
307  VectorFloat regressionData;
308 
309 private:
310  static RegisterNode< RegressionTreeNode > registerModule;
311 };
312 
313 GRT_END_NAMESPACE
314 
315 #endif //GRT_REGRESSION_TREE_NODE_HEADER
316 
virtual Node * deepCopy() const override
virtual bool predict_(VectorFloat &x) override
Definition: Node.h:37
virtual bool loadParametersFromFile(std::fstream &file) override
virtual bool predict_(VectorFloat &x) override
Definition: Node.cpp:56
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool print() const override
UINT getSize() const
Definition: Vector.h:201
virtual bool saveParametersToFile(std::fstream &file) const override
virtual bool print() const override
Definition: Node.cpp:105
virtual bool predict_(VectorFloat &x, VectorFloat &y) override
virtual bool clear() override
virtual bool clear() override
Definition: Node.cpp:66
virtual Node * deepCopy() const
Definition: Node.cpp:272