GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTreeThresholdNode.cpp
1 
2 #define GRT_DLL_EXPORTS
4 
5 GRT_BEGIN_NAMESPACE
6 
7 //Register the DecisionTreeThresholdNode module with the Node base class
8 RegisterNode< DecisionTreeThresholdNode > DecisionTreeThresholdNode::registerModule("DecisionTreeThresholdNode");
9 
11  clear();
12 }
13 
15  clear();
16 }
17 
19 
20  if( x[ featureIndex ] >= threshold ) return true;
21 
22  return false;
23 }
24 
26 
27  //Call the base class clear function
29 
30  featureIndex = 0;
31  threshold = 0;
32 
33  return true;
34 }
35 
37 
38  std::ostringstream stream;
39 
40  if( getModel( stream ) ){
41  std::cout << stream.str();
42  return true;
43  }
44 
45  return false;
46 }
47 
48 bool DecisionTreeThresholdNode::getModel( std::ostream &stream ) const{
49 
50  std::string tab = "";
51  for(UINT i=0; i<depth; i++) tab += "\t";
52 
53  stream << tab << "depth: " << depth << " nodeSize: " << nodeSize << " featureIndex: " << featureIndex << " threshold " << threshold << " isLeafNode: " << isLeafNode << std::endl;
54  stream << tab << "ClassProbabilities: ";
55  for(UINT i=0; i<classProbabilities.size(); i++){
56  stream << classProbabilities[i] << "\t";
57  }
58  stream << std::endl;
59 
60  if( leftChild != NULL ){
61  stream << tab << "LeftChild: " << std::endl;
62  leftChild->getModel( stream );
63  }
64 
65  if( rightChild != NULL ){
66  stream << tab << "RightChild: " << std::endl;
67  rightChild->getModel( stream );
68  }
69 
70  return true;
71 }
72 
74 
76 
77  if( node == NULL ){
78  return NULL;
79  }
80 
81  //Copy this node into the node
82  node->depth = depth;
83  node->isLeafNode = isLeafNode;
84  node->nodeID = nodeID;
85  node->predictedNodeID = predictedNodeID;
86  node->nodeSize = nodeSize;
87  node->featureIndex = featureIndex;
88  node->threshold = threshold;
89  node->classProbabilities = classProbabilities;
90 
91  //Recursively deep copy the left child
92  if( leftChild ){
93  node->leftChild = leftChild->deepCopy();
94  node->leftChild->setParent( node );
95  }
96 
97  //Recursively deep copy the right child
98  if( rightChild ){
99  node->rightChild = rightChild->deepCopy();
100  node->rightChild->setParent( node );
101  }
102 
103  return dynamic_cast< Node* >( node );
104 }
105 
107  return featureIndex;
108 }
109 
111  return threshold;
112 }
113 
114 bool DecisionTreeThresholdNode::set(const UINT nodeSize,const UINT featureIndex,const Float threshold,const VectorFloat &classProbabilities){
115  this->nodeSize = nodeSize;
116  this->featureIndex = featureIndex;
117  this->threshold = threshold;
118  this->classProbabilities = classProbabilities;
119  return true;
120 }
121 
122 bool DecisionTreeThresholdNode::computeBestSplitBestIterativeSplit( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
123 
124  const UINT M = trainingData.getNumSamples();
125  const UINT N = features.getSize();
126  const UINT K = classLabels.getSize();
127 
128  if( N == 0 ) return false;
129 
131  UINT bestFeatureIndex = 0;
132  Float bestThreshold = 0;
133  Float error = 0;
134  Float minRange = 0;
135  Float maxRange = 0;
136  Float step = 0;
137  Float giniIndexL = 0;
138  Float giniIndexR = 0;
139  Float weightL = 0;
140  Float weightR = 0;
141  Vector< UINT > groupIndex(M);
142  VectorFloat groupCounter(2,0);
143  Vector< MinMax > ranges = trainingData.getRanges();
144 
145  MatrixFloat classProbabilities(K,2);
146 
147  //Loop over each feature and try and find the best split point
148  for(UINT n=0; n<N; n++){
149  minRange = ranges[n].minValue;
150  maxRange = ranges[n].maxValue;
151  step = (maxRange-minRange)/Float(numSplittingSteps);
152  threshold = minRange;
153  featureIndex = features[n];
154  while( threshold <= maxRange ){
155 
156  //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group
157  groupCounter[0] = groupCounter[1] = 0;
158  classProbabilities.setAllValues(0);
159  for(UINT i=0; i<M; i++){
160  groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
161  groupCounter[ groupIndex[i] ]++;
162  classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
163  }
164 
165  //Compute the class probabilities for the lhs group and rhs group
166  for(UINT k=0; k<K; k++){
167  classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
168  classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
169  }
170 
171  //Compute the Gini index for the lhs and rhs groups
172  giniIndexL = giniIndexR = 0;
173  for(UINT k=0; k<K; k++){
174  giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
175  giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
176  }
177  weightL = groupCounter[0]/M;
178  weightR = groupCounter[1]/M;
179  error = (giniIndexL*weightL) + (giniIndexR*weightR);
180 
181  //Store the best threshold and feature index
182  if( error < minError ){
183  minError = error;
184  bestThreshold = threshold;
185  bestFeatureIndex = featureIndex;
186  }
187 
188  //Update the threshold
189  threshold += step;
190  }
191  }
192 
193  //Set the best feature index that will be returned to the DecisionTree that called this function
194  featureIndex = bestFeatureIndex;
195 
196  //Store the node size, feature index, best threshold and class probabilities for this node
197  set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
198 
199  return true;
200 }
201 
202 bool DecisionTreeThresholdNode::computeBestSplitBestRandomSplit( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
203 
204  const UINT M = trainingData.getNumSamples();
205  const UINT N = (UINT)features.size();
206  const UINT K = (UINT)classLabels.size();
207 
208  if( N == 0 ) return false;
209 
211  UINT bestFeatureIndex = 0;
212  Float bestThreshold = 0;
213  Float error = 0;
214  Float giniIndexL = 0;
215  Float giniIndexR = 0;
216  Float weightL = 0;
217  Float weightR = 0;
218  Random random;
219  Vector< UINT > groupIndex(M);
220  VectorFloat groupCounter(2,0);
221 
222  MatrixFloat classProbabilities(K,2);
223 
224  //Loop over each feature and try and find the best split point
225  UINT m,n;
226  const UINT numFeatures = features.getSize();
227  for(m=0; m<numSplittingSteps; m++){
228  //Chose a random feature
229  n = random.getRandomNumberInt(0,numFeatures);
230  featureIndex = features[n];
231 
232  //Randomly choose the threshold, the threshold is based on a randomly selected sample with some random scaling
233  threshold = trainingData[ random.getRandomNumberInt(0,M) ][ featureIndex ] * random.getRandomNumberUniform(0.8,1.2);
234 
235  //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group
236  groupCounter[0] = groupCounter[1] = 0;
237  classProbabilities.setAllValues(0);
238  for(UINT i=0; i<M; i++){
239  groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
240  groupCounter[ groupIndex[i] ]++;
241  classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
242  }
243 
244  //Compute the class probabilities for the lhs group and rhs group
245  for(UINT k=0; k<K; k++){
246  classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
247  classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
248  }
249 
250  //Compute the Gini index for the lhs and rhs groups
251  giniIndexL = giniIndexR = 0;
252  for(UINT k=0; k<K; k++){
253  giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
254  giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
255  }
256  weightL = groupCounter[0]/M;
257  weightR = groupCounter[1]/M;
258  error = (giniIndexL*weightL) + (giniIndexR*weightR);
259 
260  //Store the best threshold and feature index
261  if( error < minError ){
262  minError = error;
263  bestThreshold = threshold;
264  bestFeatureIndex = featureIndex;
265  }
266  }
267 
268  //Set the best feature index that will be returned to the DecisionTree that called this function
269  featureIndex = bestFeatureIndex;
270 
271  //Store the node size, feature index, best threshold and class probabilities for this node
272  set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
273 
274  return true;
275 }
276 
277 bool DecisionTreeThresholdNode::saveParametersToFile( std::fstream &file ) const{
278 
279  if(!file.is_open())
280  {
281  errorLog << "saveParametersToFile(fstream &file) - File is not open!" << std::endl;
282  return false;
283  }
284 
285  //Save the DecisionTreeNode parameters
287  errorLog << "saveParametersToFile(fstream &file) - Failed to save DecisionTreeNode parameters to file!" << std::endl;
288  return false;
289  }
290 
291  //Save the custom DecisionTreeThresholdNode parameters
292  file << "FeatureIndex: " << featureIndex << std::endl;
293  file << "Threshold: " << threshold << std::endl;
294 
295  return true;
296 }
297 
299 
300  if(!file.is_open())
301  {
302  errorLog << "loadParametersFromFile(fstream &file) - File is not open!" << std::endl;
303  return false;
304  }
305 
306  //Load the DecisionTreeNode parameters
308  errorLog << "loadParametersFromFile(fstream &file) - Failed to load DecisionTreeNode parameters from file!" << std::endl;
309  return false;
310  }
311 
312  std::string word;
313 
314  //Load the custom DecisionTreeThresholdNode Parameters
315  file >> word;
316  if( word != "FeatureIndex:" ){
317  errorLog << "loadParametersFromFile(fstream &file) - Failed to find FeatureIndex header!" << std::endl;
318  return false;
319  }
320  file >> featureIndex;
321 
322  file >> word;
323  if( word != "Threshold:" ){
324  errorLog << "loadParametersFromFile(fstream &file) - Failed to find Threshold header!" << std::endl;
325  return false;
326  }
327  file >> threshold;
328 
329  return true;
330 }
331 
332 GRT_END_NAMESPACE
333 
virtual bool clear() override
virtual bool getModel(std::ostream &stream) const override
Definition: Node.cpp:116
Definition: Node.h:37
This file contains the Random class, a useful wrapper for generating cross platform random functions...
Definition: Random.h:46
virtual bool getModel(std::ostream &stream) const override
UINT getSize() const
Definition: Vector.h:201
virtual bool print() const override
virtual bool predict_(VectorFloat &x) override
virtual bool loadParametersFromFile(std::fstream &file) override
virtual bool saveParametersToFile(std::fstream &file) const override
bool setAllValues(const T &value)
Definition: Matrix.h:366
UINT getNumSamples() const
virtual Node * deepCopy() const override
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
Definition: Random.cpp:129
bool set(const UINT nodeSize, const UINT featureIndex, const Float threshold, const VectorFloat &classProbabilities)
int getRandomNumberInt(int minRange, int maxRange)
Definition: Random.cpp:59
virtual bool loadParametersFromFile(std::fstream &file) override
virtual bool saveParametersToFile(std::fstream &file) const override
virtual Node * deepCopy() const
Definition: Node.cpp:272