GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClusterTreeNode.h
Go to the documentation of this file.
1 
29 #ifndef GRT_CLUSTER_TREE_NODE_HEADER
30 #define GRT_CLUSTER_TREE_NODE_HEADER
31 
32 #include "../../CoreAlgorithms/Tree/Node.h"
33 
34 GRT_BEGIN_NAMESPACE
35 
36 class ClusterTreeNode : public Node{
37 public:
42  nodeType = "ClusterTreeNode";
43  parent = NULL;
44  leftChild = NULL;
45  rightChild = NULL;
46  clear();
47  }
48 
52  virtual ~ClusterTreeNode(){
53  clear();
54  }
55 
66  virtual bool predict_(VectorFloat &x) override{
67  if( x[ featureIndex ] >= threshold ) return true;
68  return false;
69  }
70 
84  virtual bool predict_(VectorFloat &x,VectorFloat &y) override{
85 
86  if( isLeafNode ){
87  if( y.size() != 1 ) y.resize( 1 );
88  y[0] = clusterLabel;
89  return true;
90  }
91 
92  if( leftChild == NULL && rightChild == NULL )
93  return false;
94 
95  if( predict_( x ) ){
96  if( rightChild )
97  return rightChild->predict_( x, y );
98  }else{
99  if( leftChild )
100  return leftChild->predict_( x, y );
101  }
102 
103  return false;
104  }
105 
112  virtual bool clear() override{
113 
114  //Call the base class clear function
115  Node::clear();
116 
117  nodeSize = 0;
118  featureIndex = 0;
119  threshold = 0;
120  clusterLabel = 0;
121 
122  return true;
123  }
124 
131  virtual bool print() const override{
132 
133  std::string tab = "";
134  for(UINT i=0; i<depth; i++) tab += "\t";
135 
136  std::cout << tab << "depth: " << depth << " nodeSize: " << nodeSize << " featureIndex: " << featureIndex << " threshold " << threshold << " isLeafNode: " << isLeafNode << std::endl;
137  std::cout << tab << "ClusterLabel: " << clusterLabel << std::endl;
138 
139  if( leftChild != NULL ){
140  std::cout << tab << "LeftChild: " << std::endl;
141  leftChild->print();
142  }
143 
144  if( rightChild != NULL ){
145  std::cout << tab << "RightChild: " << std::endl;
146  rightChild->print();
147  }
148 
149  return true;
150  }
151 
158  virtual Node* deepCopy() const override{
159 
160  ClusterTreeNode *node = new ClusterTreeNode;
161 
162  if( node == NULL ){
163  return NULL;
164  }
165 
166  //Copy this node into the node
167  node->depth = depth;
168  node->isLeafNode = isLeafNode;
169  node->nodeSize = nodeSize;
170  node->featureIndex = featureIndex;
171  node->threshold = threshold;
172  node->clusterLabel = clusterLabel;
173 
174  //Recursively deep copy the left child
175  if( leftChild ){
176  node->leftChild = leftChild->deepCopy();
177  node->leftChild->setParent( node );
178  }
179 
180  //Recursively deep copy the right child
181  if( rightChild ){
182  node->rightChild = rightChild->deepCopy();
183  node->rightChild->setParent( node );
184  }
185 
186  return dynamic_cast<Node*>(node);
187  }
188 
189  ClusterTreeNode* deepCopyTree() const{
190  ClusterTreeNode *node = dynamic_cast<ClusterTreeNode*>(deepCopy());
191  return node;
192  }
193 
199  UINT getNodeSize() const{
200  return nodeSize;
201  }
202 
208  UINT getFeatureIndex() const{
209  return featureIndex;
210  }
211 
217  Float getThreshold() const{
218  return threshold;
219  }
220 
226  UINT getClusterLabel() const{
227  return clusterLabel;
228  }
229 
239  bool set(const UINT nodeSize,const UINT featureIndex,const Float threshold,const UINT clusterLabel){
240  this->nodeSize = nodeSize;
241  this->featureIndex = featureIndex;
242  this->threshold = threshold;
243  this->clusterLabel = clusterLabel;
244  return true;
245  }
246 
247 protected:
255  virtual bool saveParametersToFile(std::fstream &file) const override{
256 
257  if(!file.is_open())
258  {
259  errorLog << "saveParametersToFile(fstream &file) - File is not open!" << std::endl;
260  return false;
261  }
262 
263  //Save the custom ClusterTreeNode parameters
264  file << "NodeSize: " << nodeSize << std::endl;
265  file << "FeatureIndex: " << featureIndex << std::endl;
266  file << "Threshold: " << threshold << std::endl;
267  file << "ClusterLabel: " << clusterLabel << std::endl;
268 
269  return true;
270  }
271 
278  virtual bool loadParametersFromFile(std::fstream &file) override{
279 
280  if(!file.is_open())
281  {
282  errorLog << "loadParametersFromFile(fstream &file) - File is not open!" << std::endl;
283  return false;
284  }
285 
286  std::string word;
287 
288  //Load the custom ClusterTreeNode Parameters
289  file >> word;
290  if( word != "NodeSize:" ){
291  errorLog << "loadParametersFromFile(fstream &file) - Failed to find NodeSize header!" << std::endl;
292  return false;
293  }
294  file >> nodeSize;
295 
296  file >> word;
297  if( word != "FeatureIndex:" ){
298  errorLog << "loadParametersFromFile(fstream &file) - Failed to find FeatureIndex header!" << std::endl;
299  return false;
300  }
301  file >> featureIndex;
302 
303  file >> word;
304  if( word != "Threshold:" ){
305  errorLog << "loadParametersFromFile(fstream &file) - Failed to find Threshold header!" << std::endl;
306  return false;
307  }
308  file >> threshold;
309 
310  file >> word;
311  if( word != "ClusterLabel:" ){
312  errorLog << "loadParametersFromFile(fstream &file) - Failed to find ClusterLabel header!" << std::endl;
313  return false;
314  }
315  file >> clusterLabel;
316 
317  return true;
318  }
319 
320  UINT clusterLabel;
321  UINT nodeSize;
322  UINT featureIndex;
323  Float threshold;
324 
325  static RegisterNode< ClusterTreeNode > registerModule;
326 };
327 
328 GRT_END_NAMESPACE
329 
330 #endif //GRT_CLUSTER_TREE_NODE_HEADER
331 
virtual bool clear() override
UINT getClusterLabel() const
UINT getFeatureIndex() const
Definition: Node.h:37
virtual bool predict_(VectorFloat &x) override
Definition: Node.cpp:56
virtual bool predict_(VectorFloat &x) override
Float getThreshold() const
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool saveParametersToFile(std::fstream &file) const override
virtual bool loadParametersFromFile(std::fstream &file) override
virtual ~ClusterTreeNode()
virtual bool predict_(VectorFloat &x, VectorFloat &y) override
virtual Node * deepCopy() const override
virtual bool print() const override
Definition: Node.cpp:105
UINT getNodeSize() const
virtual bool clear() override
Definition: Node.cpp:66
virtual Node * deepCopy() const
Definition: Node.cpp:272
virtual bool print() const override