GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTreeThresholdNode.cpp
1 
2 #define GRT_DLL_EXPORTS
4 
5 GRT_BEGIN_NAMESPACE
6 
7 //Register the DecisionTreeThresholdNode module with the Node base class
8 RegisterNode< DecisionTreeThresholdNode > DecisionTreeThresholdNode::registerModule("DecisionTreeThresholdNode");
9 
11  nodeType = "DecisionTreeThresholdNode";
12  parent = NULL;
13  leftChild = NULL;
14  rightChild = NULL;
15  clear();
16 }
17 
19  clear();
20 }
21 
23 
24  if( x[ featureIndex ] >= threshold ) return true;
25 
26  return false;
27 }
28 
30 
31  //Call the base class clear function
33 
34  featureIndex = 0;
35  threshold = 0;
36 
37  return true;
38 }
39 
41 
42  std::ostringstream stream;
43 
44  if( getModel( stream ) ){
45  std::cout << stream.str();
46  return true;
47  }
48 
49  return false;
50 }
51 
52 bool DecisionTreeThresholdNode::getModel( std::ostream &stream ) const{
53 
54  std::string tab = "";
55  for(UINT i=0; i<depth; i++) tab += "\t";
56 
57  stream << tab << "depth: " << depth << " nodeSize: " << nodeSize << " featureIndex: " << featureIndex << " threshold " << threshold << " isLeafNode: " << isLeafNode << std::endl;
58  stream << tab << "ClassProbabilities: ";
59  for(UINT i=0; i<classProbabilities.size(); i++){
60  stream << classProbabilities[i] << "\t";
61  }
62  stream << std::endl;
63 
64  if( leftChild != NULL ){
65  stream << tab << "LeftChild: " << std::endl;
66  leftChild->getModel( stream );
67  }
68 
69  if( rightChild != NULL ){
70  stream << tab << "RightChild: " << std::endl;
71  rightChild->getModel( stream );
72  }
73 
74  return true;
75 }
76 
78 
80 
81  if( node == NULL ){
82  return NULL;
83  }
84 
85  //Copy this node into the node
86  node->depth = depth;
87  node->isLeafNode = isLeafNode;
88  node->nodeID = nodeID;
89  node->predictedNodeID = predictedNodeID;
90  node->nodeSize = nodeSize;
91  node->featureIndex = featureIndex;
92  node->threshold = threshold;
93  node->classProbabilities = classProbabilities;
94 
95  //Recursively deep copy the left child
96  if( leftChild ){
97  node->leftChild = leftChild->deepCopyNode();
98  node->leftChild->setParent( node );
99  }
100 
101  //Recursively deep copy the right child
102  if( rightChild ){
103  node->rightChild = rightChild->deepCopyNode();
104  node->rightChild->setParent( node );
105  }
106 
107  return dynamic_cast< Node* >( node );
108 }
109 
111  return dynamic_cast< DecisionTreeThresholdNode* >( deepCopyNode() );
112 }
113 
115  return featureIndex;
116 }
117 
119  return threshold;
120 }
121 
122 bool DecisionTreeThresholdNode::set(const UINT nodeSize,const UINT featureIndex,const Float threshold,const VectorFloat &classProbabilities){
123  this->nodeSize = nodeSize;
124  this->featureIndex = featureIndex;
125  this->threshold = threshold;
126  this->classProbabilities = classProbabilities;
127  return true;
128 }
129 
130 bool DecisionTreeThresholdNode::computeBestSpiltBestIterativeSpilt( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
131 
132  const UINT M = trainingData.getNumSamples();
133  const UINT N = features.getSize();
134  const UINT K = classLabels.getSize();
135 
136  if( N == 0 ) return false;
137 
139  UINT bestFeatureIndex = 0;
140  Float bestThreshold = 0;
141  Float error = 0;
142  Float minRange = 0;
143  Float maxRange = 0;
144  Float step = 0;
145  Float giniIndexL = 0;
146  Float giniIndexR = 0;
147  Float weightL = 0;
148  Float weightR = 0;
149  Vector< UINT > groupIndex(M);
150  VectorFloat groupCounter(2,0);
151  Vector< MinMax > ranges = trainingData.getRanges();
152 
153  MatrixFloat classProbabilities(K,2);
154 
155  //Loop over each feature and try and find the best split point
156  for(UINT n=0; n<N; n++){
157  minRange = ranges[n].minValue;
158  maxRange = ranges[n].maxValue;
159  step = (maxRange-minRange)/Float(numSplittingSteps);
160  threshold = minRange;
161  featureIndex = features[n];
162  while( threshold <= maxRange ){
163 
164  //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group
165  groupCounter[0] = groupCounter[1] = 0;
166  classProbabilities.setAllValues(0);
167  for(UINT i=0; i<M; i++){
168  groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
169  groupCounter[ groupIndex[i] ]++;
170  classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
171  }
172 
173  //Compute the class probabilities for the lhs group and rhs group
174  for(UINT k=0; k<K; k++){
175  classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
176  classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
177  }
178 
179  //Compute the Gini index for the lhs and rhs groups
180  giniIndexL = giniIndexR = 0;
181  for(UINT k=0; k<K; k++){
182  giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
183  giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
184  }
185  weightL = groupCounter[0]/M;
186  weightR = groupCounter[1]/M;
187  error = (giniIndexL*weightL) + (giniIndexR*weightR);
188 
189  //Store the best threshold and feature index
190  if( error < minError ){
191  minError = error;
192  bestThreshold = threshold;
193  bestFeatureIndex = featureIndex;
194  }
195 
196  //Update the threshold
197  threshold += step;
198  }
199  }
200 
201  //Set the best feature index that will be returned to the DecisionTree that called this function
202  featureIndex = bestFeatureIndex;
203 
204  //Store the node size, feature index, best threshold and class probabilities for this node
205  set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
206 
207  return true;
208 }
209 
210 bool DecisionTreeThresholdNode::computeBestSpiltBestRandomSpilt( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
211 
212  const UINT M = trainingData.getNumSamples();
213  const UINT N = (UINT)features.size();
214  const UINT K = (UINT)classLabels.size();
215 
216  if( N == 0 ) return false;
217 
219  UINT bestFeatureIndex = 0;
220  Float bestThreshold = 0;
221  Float error = 0;
222  Float giniIndexL = 0;
223  Float giniIndexR = 0;
224  Float weightL = 0;
225  Float weightR = 0;
226  Random random;
227  Vector< UINT > groupIndex(M);
228  VectorFloat groupCounter(2,0);
229  Vector< MinMax > ranges = trainingData.getRanges();
230 
231  MatrixFloat classProbabilities(K,2);
232 
233  //Loop over each feature and try and find the best split point
234  for(UINT n=0; n<N; n++){
235  featureIndex = features[n];
236  for(UINT m=0; m<numSplittingSteps; m++){
237  //Randomly choose the threshold
238  threshold = random.getRandomNumberUniform(ranges[n].minValue,ranges[n].maxValue);
239 
240  //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group
241  groupCounter[0] = groupCounter[1] = 0;
242  classProbabilities.setAllValues(0);
243  for(UINT i=0; i<M; i++){
244  groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
245  groupCounter[ groupIndex[i] ]++;
246  classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
247  }
248 
249  //Compute the class probabilities for the lhs group and rhs group
250  for(UINT k=0; k<K; k++){
251  classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
252  classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
253  }
254 
255  //Compute the Gini index for the lhs and rhs groups
256  giniIndexL = giniIndexR = 0;
257  for(UINT k=0; k<K; k++){
258  giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
259  giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
260  }
261  weightL = groupCounter[0]/M;
262  weightR = groupCounter[1]/M;
263  error = (giniIndexL*weightL) + (giniIndexR*weightR);
264 
265  //Store the best threshold and feature index
266  if( error < minError ){
267  minError = error;
268  bestThreshold = threshold;
269  bestFeatureIndex = featureIndex;
270  }
271  }
272  }
273 
274  //Set the best feature index that will be returned to the DecisionTree that called this function
275  featureIndex = bestFeatureIndex;
276 
277  //Store the node size, feature index, best threshold and class probabilities for this node
278  set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
279 
280  return true;
281 }
282 
283 bool DecisionTreeThresholdNode::saveParametersToFile( std::fstream &file ) const{
284 
285  if(!file.is_open())
286  {
287  errorLog << "saveParametersToFile(fstream &file) - File is not open!" << std::endl;
288  return false;
289  }
290 
291  //Save the DecisionTreeNode parameters
293  errorLog << "saveParametersToFile(fstream &file) - Failed to save DecisionTreeNode parameters to file!" << std::endl;
294  return false;
295  }
296 
297  //Save the custom DecisionTreeThresholdNode parameters
298  file << "FeatureIndex: " << featureIndex << std::endl;
299  file << "Threshold: " << threshold << std::endl;
300 
301  return true;
302 }
303 
305 
306  if(!file.is_open())
307  {
308  errorLog << "loadParametersFromFile(fstream &file) - File is not open!" << std::endl;
309  return false;
310  }
311 
312  //Load the DecisionTreeNode parameters
314  errorLog << "loadParametersFromFile(fstream &file) - Failed to load DecisionTreeNode parameters from file!" << std::endl;
315  return false;
316  }
317 
318  std::string word;
319 
320  //Load the custom DecisionTreeThresholdNode Parameters
321  file >> word;
322  if( word != "FeatureIndex:" ){
323  errorLog << "loadParametersFromFile(fstream &file) - Failed to find FeatureIndex header!" << std::endl;
324  return false;
325  }
326  file >> featureIndex;
327 
328  file >> word;
329  if( word != "Threshold:" ){
330  errorLog << "loadParametersFromFile(fstream &file) - Failed to find Threshold header!" << std::endl;
331  return false;
332  }
333  file >> threshold;
334 
335  return true;
336 }
337 
338 GRT_END_NAMESPACE
339 
virtual bool clear()
Definition: Node.h:37
DecisionTreeThresholdNode * deepCopy() const
Definition: Random.h:40
virtual bool getModel(std::ostream &stream) const
Definition: Node.cpp:120
UINT getSize() const
Definition: Vector.h:191
virtual Node * deepCopyNode() const
virtual bool saveParametersToFile(std::fstream &file) const
UINT getNumSamples() const
virtual Node * deepCopyNode() const
Definition: Node.cpp:276
virtual bool loadParametersFromFile(std::fstream &file)
virtual bool getModel(std::ostream &stream) const
virtual bool saveParametersToFile(std::fstream &file) const
virtual bool predict(const VectorFloat &x)
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
Definition: Random.h:198
bool set(const UINT nodeSize, const UINT featureIndex, const Float threshold, const VectorFloat &classProbabilities)
virtual bool loadParametersFromFile(std::fstream &file)
This file implements a DecisionTreeThresholdNode, which is a specific type of node used for a Decisio...