GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTreeNode.cpp
1 
2 #define GRT_DLL_EXPORTS
3 #include "DecisionTreeNode.h"
4 
5 using namespace GRT;
6 
7 //Register the DecisionTreeNode with the Node base class
8 RegisterNode< DecisionTreeNode > DecisionTreeNode::registerModule("DecisionTreeNode");
9 
11  nodeType = "DecisionTreeNode";
12  parent = NULL;
13  leftChild = NULL;
14  rightChild = NULL;
15  clear();
16 }
17 
19  clear();
20 }
21 
22 bool DecisionTreeNode::predict(const VectorFloat &x,VectorFloat &classLikelihoods){
23 
24  predictedNodeID = 0;
25 
26  if( isLeafNode ){
27  classLikelihoods = classProbabilities;
28  predictedNodeID = nodeID;
29  return true;
30  }
31 
32  if( leftChild == NULL && rightChild == NULL )
33  return false;
34 
35  if( predict( x ) ){
36  if( rightChild ){
37  if( rightChild->predict( x, classLikelihoods ) ){
38  predictedNodeID = rightChild->getPredictedNodeID();
39  return true;
40  }
41  warningLog << "predict(const VectorFloat &x,VectorFloat &classLikelihoods) - Right child failed prediction!" << std::endl;
42  return false;
43  }
44  }else{
45  if( leftChild ){
46  if( leftChild->predict( x, classLikelihoods ) ){
47  predictedNodeID = leftChild->getPredictedNodeID();
48  return true;
49  }
50  warningLog << "predict(const VectorFloat &x,VectorFloat &classLikelihoods) - Left child failed prediction!" << std::endl;
51  return false;
52  }
53  }
54 
55  return false;
56 }
57 
58 bool DecisionTreeNode::computeBestSpilt( const UINT &trainingMode, const UINT &numSplittingSteps,const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
59 
60  switch( trainingMode ){
61  case Tree::BEST_ITERATIVE_SPILT:
62  return computeBestSpiltBestIterativeSpilt( numSplittingSteps, trainingData, features, classLabels, featureIndex, minError );
63  break;
64  case Tree::BEST_RANDOM_SPLIT:
65  return computeBestSpiltBestRandomSpilt( numSplittingSteps, trainingData, features, classLabels, featureIndex, minError );
66  break;
67  default:
68  errorLog << "computeBestSpilt(...) - Uknown trainingMode!" << std::endl;
69  return false;
70  break;
71  }
72 
73  return false;
74 }
75 
77 
78  //Call the base class clear function
79  Node::clear();
80 
81  nodeSize = 0;
82  classProbabilities.clear();
83 
84  return true;
85 }
86 
87 bool DecisionTreeNode::getModel( std::ostream &stream ) const{
88 
89  std::string tab = "";
90  for(UINT i=0; i<depth; i++) tab += "\t";
91 
92  stream << tab << "depth: " << depth << " nodeSize: " << nodeSize << " isLeafNode: " << isLeafNode << std::endl;
93  stream << tab << "ClassProbabilities: ";
94  for(UINT i=0; i<classProbabilities.size(); i++){
95  stream << classProbabilities[i] << "\t";
96  }
97  stream << std::endl;
98 
99  if( leftChild != NULL ){
100  stream << tab << "LeftChild: " << std::endl;
101  leftChild->getModel( stream );
102  }
103 
104  if( rightChild != NULL ){
105  stream << tab << "RightChild: " << std::endl;
106  rightChild->getModel( stream );
107  }
108 
109  return true;
110 }
111 
113 
114  DecisionTreeNode *node = dynamic_cast< DecisionTreeNode* >( DecisionTreeNode::createInstanceFromString( nodeType ) );
115 
116  if( node == NULL ){
117  return NULL;
118  }
119 
120  //Copy this node into the node
121  node->depth = depth;
122  node->isLeafNode = isLeafNode;
123  node->nodeID = nodeID;
124  node->predictedNodeID = predictedNodeID;
125  node->nodeSize = nodeSize;
126  node->classProbabilities = classProbabilities;
127 
128  //Recursively deep copy the left child
129  if( leftChild ){
130  node->leftChild = leftChild->deepCopyNode();
131  node->leftChild->setParent( node );
132  }
133 
134  //Recursively deep copy the right child
135  if( rightChild ){
136  node->rightChild = rightChild->deepCopyNode();
137  node->rightChild->setParent( node );
138  }
139 
140  return node;
141 }
142 
144  return dynamic_cast< DecisionTreeNode* >( deepCopyNode() );
145 }
146 
148  return nodeSize;
149 }
150 
152  return (UINT)classProbabilities.size();
153 }
154 
156  return classProbabilities;
157 }
158 
159 bool DecisionTreeNode::setLeafNode( const UINT nodeSize, const VectorFloat &classProbabilities ){
160  this->nodeSize = nodeSize;
161  this->classProbabilities = classProbabilities;
162  this->isLeafNode = true;
163  return true;
164 }
165 
166 bool DecisionTreeNode::setNodeSize(const UINT nodeSize){
167  this->nodeSize = nodeSize;
168  return true;
169 }
170 
171 bool DecisionTreeNode::setClassProbabilities(const VectorFloat &classProbabilities){
172  this->classProbabilities = classProbabilities;
173  return true;
174 }
175 
176 UINT DecisionTreeNode::getClassLabelIndexValue(UINT classLabel,const Vector< UINT > &classLabels){
177  const UINT N = classLabels.getSize();
178  for(UINT i=0; i<N; i++){
179  if( classLabel == classLabels[i] )
180  return i;
181  }
182  return 0;
183 }
virtual bool computeBestSpilt(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError)
virtual bool clear()
Definition: Node.h:37
Definition: DebugLog.cpp:24
bool setClassProbabilities(const VectorFloat &classProbabilities)
UINT getSize() const
Definition: Vector.h:191
UINT getNumClasses() const
virtual Node * deepCopyNode() const
Definition: Node.cpp:276
This file implements a DecisionTreeNode, which is a specific base node used for a DecisionTree...
virtual bool clear()
Definition: Node.cpp:70
virtual bool predict(const VectorFloat &x, VectorFloat &classLikelihoods)
bool setLeafNode(const UINT nodeSize, const VectorFloat &classProbabilities)
DecisionTreeNode * deepCopy() const
static Node * createInstanceFromString(std::string const &nodeType)
Definition: Node.cpp:29
UINT getNodeSize() const
virtual Node * deepCopyNode() const
virtual ~DecisionTreeNode()
virtual bool getModel(std::ostream &stream) const
VectorFloat getClassProbabilities() const
bool setNodeSize(const UINT nodeSize)