GestureRecognitionToolkit  Version: 0.1.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTreeThresholdNode.cpp
1 
3 
4 GRT_BEGIN_NAMESPACE
5 
6 //Register the DecisionTreeThresholdNode module with the Node base class
7 RegisterNode< DecisionTreeThresholdNode > DecisionTreeThresholdNode::registerModule("DecisionTreeThresholdNode");
8 
10  nodeType = "DecisionTreeThresholdNode";
11  parent = NULL;
12  leftChild = NULL;
13  rightChild = NULL;
14  clear();
15 }
16 
18  clear();
19 }
20 
22 
23  if( x[ featureIndex ] >= threshold ) return true;
24 
25  return false;
26 }
27 
29 
30  //Call the base class clear function
32 
33  featureIndex = 0;
34  threshold = 0;
35 
36  return true;
37 }
38 
40 
41  std::ostringstream stream;
42 
43  if( getModel( stream ) ){
44  std::cout << stream.str();
45  return true;
46  }
47 
48  return false;
49 }
50 
51 bool DecisionTreeThresholdNode::getModel( std::ostream &stream ) const{
52 
53  std::string tab = "";
54  for(UINT i=0; i<depth; i++) tab += "\t";
55 
56  stream << tab << "depth: " << depth << " nodeSize: " << nodeSize << " featureIndex: " << featureIndex << " threshold " << threshold << " isLeafNode: " << isLeafNode << std::endl;
57  stream << tab << "ClassProbabilities: ";
58  for(UINT i=0; i<classProbabilities.size(); i++){
59  stream << classProbabilities[i] << "\t";
60  }
61  stream << std::endl;
62 
63  if( leftChild != NULL ){
64  stream << tab << "LeftChild: " << std::endl;
65  leftChild->getModel( stream );
66  }
67 
68  if( rightChild != NULL ){
69  stream << tab << "RightChild: " << std::endl;
70  rightChild->getModel( stream );
71  }
72 
73  return true;
74 }
75 
77 
79 
80  if( node == NULL ){
81  return NULL;
82  }
83 
84  //Copy this node into the node
85  node->depth = depth;
86  node->isLeafNode = isLeafNode;
87  node->nodeID = nodeID;
88  node->predictedNodeID = predictedNodeID;
89  node->nodeSize = nodeSize;
90  node->featureIndex = featureIndex;
91  node->threshold = threshold;
92  node->classProbabilities = classProbabilities;
93 
94  //Recursively deep copy the left child
95  if( leftChild ){
96  node->leftChild = leftChild->deepCopyNode();
97  node->leftChild->setParent( node );
98  }
99 
100  //Recursively deep copy the right child
101  if( rightChild ){
102  node->rightChild = rightChild->deepCopyNode();
103  node->rightChild->setParent( node );
104  }
105 
106  return dynamic_cast< Node* >( node );
107 }
108 
110  return dynamic_cast< DecisionTreeThresholdNode* >( deepCopyNode() );
111 }
112 
114  return featureIndex;
115 }
116 
118  return threshold;
119 }
120 
121 bool DecisionTreeThresholdNode::set(const UINT nodeSize,const UINT featureIndex,const Float threshold,const VectorFloat &classProbabilities){
122  this->nodeSize = nodeSize;
123  this->featureIndex = featureIndex;
124  this->threshold = threshold;
125  this->classProbabilities = classProbabilities;
126  return true;
127 }
128 
129 bool DecisionTreeThresholdNode::computeBestSpiltBestIterativeSpilt( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
130 
131  const UINT M = trainingData.getNumSamples();
132  const UINT N = features.getSize();
133  const UINT K = classLabels.getSize();
134 
135  if( N == 0 ) return false;
136 
138  UINT bestFeatureIndex = 0;
139  Float bestThreshold = 0;
140  Float error = 0;
141  Float minRange = 0;
142  Float maxRange = 0;
143  Float step = 0;
144  Float giniIndexL = 0;
145  Float giniIndexR = 0;
146  Float weightL = 0;
147  Float weightR = 0;
148  Vector< UINT > groupIndex(M);
149  VectorFloat groupCounter(2,0);
150  Vector< MinMax > ranges = trainingData.getRanges();
151 
152  MatrixFloat classProbabilities(K,2);
153 
154  //Loop over each feature and try and find the best split point
155  for(UINT n=0; n<N; n++){
156  minRange = ranges[n].minValue;
157  maxRange = ranges[n].maxValue;
158  step = (maxRange-minRange)/Float(numSplittingSteps);
159  threshold = minRange;
160  featureIndex = features[n];
161  while( threshold <= maxRange ){
162 
163  //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group
164  groupCounter[0] = groupCounter[1] = 0;
165  classProbabilities.setAllValues(0);
166  for(UINT i=0; i<M; i++){
167  groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
168  groupCounter[ groupIndex[i] ]++;
169  classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
170  }
171 
172  //Compute the class probabilities for the lhs group and rhs group
173  for(UINT k=0; k<K; k++){
174  classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
175  classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
176  }
177 
178  //Compute the Gini index for the lhs and rhs groups
179  giniIndexL = giniIndexR = 0;
180  for(UINT k=0; k<K; k++){
181  giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
182  giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
183  }
184  weightL = groupCounter[0]/M;
185  weightR = groupCounter[1]/M;
186  error = (giniIndexL*weightL) + (giniIndexR*weightR);
187 
188  //Store the best threshold and feature index
189  if( error < minError ){
190  minError = error;
191  bestThreshold = threshold;
192  bestFeatureIndex = featureIndex;
193  }
194 
195  //Update the threshold
196  threshold += step;
197  }
198  }
199 
200  //Set the best feature index that will be returned to the DecisionTree that called this function
201  featureIndex = bestFeatureIndex;
202 
203  //Store the node size, feature index, best threshold and class probabilities for this node
204  set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
205 
206  return true;
207 }
208 
209 bool DecisionTreeThresholdNode::computeBestSpiltBestRandomSpilt( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
210 
211  const UINT M = trainingData.getNumSamples();
212  const UINT N = (UINT)features.size();
213  const UINT K = (UINT)classLabels.size();
214 
215  if( N == 0 ) return false;
216 
218  UINT bestFeatureIndex = 0;
219  Float bestThreshold = 0;
220  Float error = 0;
221  Float giniIndexL = 0;
222  Float giniIndexR = 0;
223  Float weightL = 0;
224  Float weightR = 0;
225  Random random;
226  Vector< UINT > groupIndex(M);
227  VectorFloat groupCounter(2,0);
228  Vector< MinMax > ranges = trainingData.getRanges();
229 
230  MatrixFloat classProbabilities(K,2);
231 
232  //Loop over each feature and try and find the best split point
233  for(UINT n=0; n<N; n++){
234  featureIndex = features[n];
235  for(UINT m=0; m<numSplittingSteps; m++){
236  //Randomly choose the threshold
237  threshold = random.getRandomNumberUniform(ranges[n].minValue,ranges[n].maxValue);
238 
239  //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group
240  groupCounter[0] = groupCounter[1] = 0;
241  classProbabilities.setAllValues(0);
242  for(UINT i=0; i<M; i++){
243  groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
244  groupCounter[ groupIndex[i] ]++;
245  classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
246  }
247 
248  //Compute the class probabilities for the lhs group and rhs group
249  for(UINT k=0; k<K; k++){
250  classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
251  classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
252  }
253 
254  //Compute the Gini index for the lhs and rhs groups
255  giniIndexL = giniIndexR = 0;
256  for(UINT k=0; k<K; k++){
257  giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
258  giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
259  }
260  weightL = groupCounter[0]/M;
261  weightR = groupCounter[1]/M;
262  error = (giniIndexL*weightL) + (giniIndexR*weightR);
263 
264  //Store the best threshold and feature index
265  if( error < minError ){
266  minError = error;
267  bestThreshold = threshold;
268  bestFeatureIndex = featureIndex;
269  }
270  }
271  }
272 
273  //Set the best feature index that will be returned to the DecisionTree that called this function
274  featureIndex = bestFeatureIndex;
275 
276  //Store the node size, feature index, best threshold and class probabilities for this node
277  set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
278 
279  return true;
280 }
281 
282 bool DecisionTreeThresholdNode::saveParametersToFile( std::fstream &file ) const{
283 
284  if(!file.is_open())
285  {
286  errorLog << "saveParametersToFile(fstream &file) - File is not open!" << std::endl;
287  return false;
288  }
289 
290  //Save the DecisionTreeNode parameters
292  errorLog << "saveParametersToFile(fstream &file) - Failed to save DecisionTreeNode parameters to file!" << std::endl;
293  return false;
294  }
295 
296  //Save the custom DecisionTreeThresholdNode parameters
297  file << "FeatureIndex: " << featureIndex << std::endl;
298  file << "Threshold: " << threshold << std::endl;
299 
300  return true;
301 }
302 
304 
305  if(!file.is_open())
306  {
307  errorLog << "loadParametersFromFile(fstream &file) - File is not open!" << std::endl;
308  return false;
309  }
310 
311  //Load the DecisionTreeNode parameters
313  errorLog << "loadParametersFromFile(fstream &file) - Failed to load DecisionTreeNode parameters from file!" << std::endl;
314  return false;
315  }
316 
317  std::string word;
318 
319  //Load the custom DecisionTreeThresholdNode Parameters
320  file >> word;
321  if( word != "FeatureIndex:" ){
322  errorLog << "loadParametersFromFile(fstream &file) - Failed to find FeatureIndex header!" << std::endl;
323  return false;
324  }
325  file >> featureIndex;
326 
327  file >> word;
328  if( word != "Threshold:" ){
329  errorLog << "loadParametersFromFile(fstream &file) - Failed to find Threshold header!" << std::endl;
330  return false;
331  }
332  file >> threshold;
333 
334  return true;
335 }
336 
337 GRT_END_NAMESPACE
338 
virtual bool clear()
Definition: Node.h:37
DecisionTreeThresholdNode * deepCopy() const
Definition: Random.h:40
virtual bool getModel(std::ostream &stream) const
Definition: Node.cpp:119
unsigned int getSize() const
Definition: Vector.h:193
virtual Node * deepCopyNode() const
virtual bool saveParametersToFile(std::fstream &file) const
UINT getNumSamples() const
virtual Node * deepCopyNode() const
Definition: Node.cpp:275
virtual bool loadParametersFromFile(std::fstream &file)
virtual bool getModel(std::ostream &stream) const
virtual bool saveParametersToFile(std::fstream &file) const
virtual bool predict(const VectorFloat &x)
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
Definition: Random.h:198
bool set(const UINT nodeSize, const UINT featureIndex, const Float threshold, const VectorFloat &classProbabilities)
virtual bool loadParametersFromFile(std::fstream &file)
This file implements a DecisionTreeThresholdNode, which is a specific type of node used for a Decisio...