GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTreeNode.cpp
1 
2 #define GRT_DLL_EXPORTS
3 #include "DecisionTreeNode.h"
4 
5 using namespace GRT;
6 
7 //Register the DecisionTreeNode with the Node base class
8 RegisterNode< DecisionTreeNode > DecisionTreeNode::registerModule("DecisionTreeNode");
9 
10 DecisionTreeNode::DecisionTreeNode( const std::string id ) : Node(id){
11  clear();
12 }
13 
15  clear();
16 }
17 
19 
20  predictedNodeID = 0;
21 
22  if( isLeafNode ){
23  classLikelihoods = classProbabilities;
24  predictedNodeID = nodeID;
25  return true;
26  }
27 
28  if( leftChild == NULL && rightChild == NULL ){
29  classLikelihoods = classProbabilities;
30  predictedNodeID = nodeID;
31  warningLog << __GRT_LOG__ << " Left and right children are NULL but node not marked as leaf node!" << std::endl;
32  return false;
33  }
34 
35  if( predict_( x ) ){
36  if( rightChild ){
37  if( rightChild->predict_( x, classLikelihoods ) ){
38  predictedNodeID = rightChild->getPredictedNodeID();
39  return true;
40  }
41  warningLog << __GRT_LOG__ << " Right child failed prediction!" << std::endl;
42  return false;
43  }
44  classLikelihoods = classProbabilities;
45  predictedNodeID = nodeID;
46  return true;
47  }else{
48  if( leftChild ){
49  if( leftChild->predict_( x, classLikelihoods ) ){
50  predictedNodeID = leftChild->getPredictedNodeID();
51  return true;
52  }
53  warningLog << __GRT_LOG__ << " Left child failed prediction!" << std::endl;
54  return false;
55  }
56  classLikelihoods = classProbabilities;
57  predictedNodeID = nodeID;
58  return true;
59  }
60 
61  return false;
62 }
63 
64 bool DecisionTreeNode::computeBestSplit( const UINT &trainingMode, const UINT &numSplittingSteps,const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
65 
66  switch( trainingMode ){
67  case Tree::BEST_ITERATIVE_SPILT:
68  return computeBestSplitBestIterativeSplit( numSplittingSteps, trainingData, features, classLabels, featureIndex, minError );
69  break;
70  case Tree::BEST_RANDOM_SPLIT:
71  return computeBestSplitBestRandomSplit( numSplittingSteps, trainingData, features, classLabels, featureIndex, minError );
72  break;
73  default:
74  errorLog << __GRT_LOG__ << " Uknown trainingMode!" << std::endl;
75  return false;
76  break;
77  }
78 
79  return false;
80 }
81 
83 
84  //Call the base class clear function
85  Node::clear();
86 
87  nodeSize = 0;
88  classProbabilities.clear();
89 
90  return true;
91 }
92 
93 bool DecisionTreeNode::getModel( std::ostream &stream ) const{
94 
95  std::string tab = "";
96  for(UINT i=0; i<depth; i++) tab += "\t";
97 
98  stream << tab << "depth: " << depth << " nodeSize: " << nodeSize << " isLeafNode: " << isLeafNode << std::endl;
99  stream << tab << "ClassProbabilities: ";
100  for(UINT i=0; i<classProbabilities.size(); i++){
101  stream << classProbabilities[i] << "\t";
102  }
103  stream << std::endl;
104 
105  if( leftChild != NULL ){
106  stream << tab << "LeftChild: " << std::endl;
107  leftChild->getModel( stream );
108  }
109 
110  if( rightChild != NULL ){
111  stream << tab << "RightChild: " << std::endl;
112  rightChild->getModel( stream );
113  }
114 
115  return true;
116 }
117 
119 
120  DecisionTreeNode *node = dynamic_cast< DecisionTreeNode* >( DecisionTreeNode::createInstanceFromString( nodeType ) );
121 
122  if( node == NULL ){
123  return NULL;
124  }
125 
126  //Copy this node into the node
127  node->depth = depth;
128  node->isLeafNode = isLeafNode;
129  node->nodeID = nodeID;
130  node->predictedNodeID = predictedNodeID;
131  node->nodeSize = nodeSize;
132  node->classProbabilities = classProbabilities;
133 
134  //Recursively deep copy the left child
135  if( leftChild ){
136  node->leftChild = leftChild->deepCopy();
137  node->leftChild->setParent( node );
138  }
139 
140  //Recursively deep copy the right child
141  if( rightChild ){
142  node->rightChild = rightChild->deepCopy();
143  node->rightChild->setParent( node );
144  }
145 
146  return dynamic_cast< Node* >(node);
147 }
148 
150  return nodeSize;
151 }
152 
154  return classProbabilities.getSize();
155 }
156 
158  return classProbabilities;
159 }
160 
161 bool DecisionTreeNode::setLeafNode( const UINT nodeSize, const VectorFloat &classProbabilities ){
162  this->nodeSize = nodeSize;
163  this->classProbabilities = classProbabilities;
164  this->isLeafNode = true;
165  return true;
166 }
167 
168 bool DecisionTreeNode::setNodeSize(const UINT nodeSize){
169  this->nodeSize = nodeSize;
170  return true;
171 }
172 
173 bool DecisionTreeNode::setClassProbabilities(const VectorFloat &classProbabilities){
174  this->classProbabilities = classProbabilities;
175  return true;
176 }
177 
178 UINT DecisionTreeNode::getClassLabelIndexValue(UINT classLabel,const Vector< UINT > &classLabels){
179  const UINT N = classLabels.getSize();
180  for(UINT i=0; i<N; i++){
181  if( classLabel == classLabels[i] )
182  return i;
183  }
184  return 0;
185 }
virtual bool computeBestSplit(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError)
virtual bool clear() override
virtual bool getModel(std::ostream &stream) const override
Definition: Node.cpp:116
Definition: Node.h:37
DecisionTreeNode(const std::string id="DecisionTreeNode")
virtual bool predict_(VectorFloat &x) override
Definition: Node.cpp:56
bool setClassProbabilities(const VectorFloat &classProbabilities)
virtual bool predict_(VectorFloat &x, VectorFloat &classLikelihoods) override
UINT getSize() const
Definition: Vector.h:201
virtual Node * deepCopy() const override
UINT getNumClasses() const
virtual bool getModel(std::ostream &stream) const override
UINT getPredictedNodeID() const
Definition: Node.cpp:312
bool setLeafNode(const UINT nodeSize, const VectorFloat &classProbabilities)
static Node * createInstanceFromString(std::string const &nodeType)
Definition: Node.cpp:29
UINT getNodeSize() const
virtual ~DecisionTreeNode()
virtual bool clear() override
Definition: Node.cpp:66
VectorFloat getClassProbabilities() const
bool setNodeSize(const UINT nodeSize)
virtual Node * deepCopy() const
Definition: Node.cpp:272