GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTree.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #define GRT_DLL_EXPORTS
22 #include "DecisionTree.h"
23 
24 GRT_BEGIN_NAMESPACE
25 
26 //Define the string that will be used to indentify the object
27 std::string DecisionTree::id = "DecisionTree";
28 std::string DecisionTree::getId() { return DecisionTree::id; }
29 
30 //Register the DecisionTree module with the Classifier base class
31 RegisterClassifierModule< DecisionTree > DecisionTree::registerModule( DecisionTree::getId() );
32 
33 DecisionTree::DecisionTree(const DecisionTreeNode &decisionTreeNode,const UINT minNumSamplesPerNode,const UINT maxDepth,const bool removeFeaturesAtEachSpilt,const UINT trainingMode,const UINT numSplittingSteps,const bool useScaling){
34 
35  this->tree = NULL;
36  this->decisionTreeNode = NULL;
37  this->minNumSamplesPerNode = minNumSamplesPerNode;
38  this->maxDepth = maxDepth;
39  this->removeFeaturesAtEachSpilt = removeFeaturesAtEachSpilt;
40  this->trainingMode = trainingMode;
41  this->numSplittingSteps = numSplittingSteps;
42  this->useScaling = useScaling;
43  this->supportsNullRejection = true;
44  Classifier::classType = DecisionTree::getId();
45  classifierType = Classifier::classType;
46  classifierMode = STANDARD_CLASSIFIER_MODE;
47  Classifier::debugLog.setProceedingText("[DEBUG " + DecisionTree::getId() + "]");
48  Classifier::errorLog.setProceedingText("[ERROR " + DecisionTree::getId() + "]");
49  Classifier::trainingLog.setProceedingText("[TRAINING " + DecisionTree::getId() + "]");
50  Classifier::warningLog.setProceedingText("[WARNING " + DecisionTree::getId() + "]");
51 
52  this->decisionTreeNode = decisionTreeNode.deepCopy();
53 
54 }
55 
57  tree = NULL;
58  decisionTreeNode = NULL;
59  Classifier::classType = DecisionTree::getId();
60  classifierType = Classifier::classType;
61  classifierMode = STANDARD_CLASSIFIER_MODE;
62  Classifier:: debugLog.setProceedingText("[DEBUG " + DecisionTree::getId() + "]");
63  Classifier::errorLog.setProceedingText("[ERROR " + DecisionTree::getId() + "]");
64  Classifier::trainingLog.setProceedingText("[TRAINING " + DecisionTree::getId() + "]");
65  Classifier::warningLog.setProceedingText("[WARNING " + DecisionTree::getId() + "]");
66  *this = rhs;
67 }
68 
70 {
71  clear();
72 
73  if( decisionTreeNode != NULL ){
74  delete decisionTreeNode;
75  decisionTreeNode = NULL;
76  }
77 }
78 
80  if( this != &rhs ){
81  //Clear this tree
82  clear();
83 
84  if( rhs.getTrained() ){
85  //Deep copy the tree
86  this->tree = (DecisionTreeNode*)rhs.deepCopyTree();
87  }
88 
89  //Deep copy the main node
90  if( this->decisionTreeNode != NULL ){
91  delete decisionTreeNode;
92  decisionTreeNode = NULL;
93  }
94  this->decisionTreeNode = rhs.deepCopyDecisionTreeNode();
95 
96  this->minNumSamplesPerNode = rhs.minNumSamplesPerNode;
97  this->maxDepth = rhs.maxDepth;
98  this->removeFeaturesAtEachSpilt = rhs.removeFeaturesAtEachSpilt;
99  this->trainingMode = rhs.trainingMode;
100  this->numSplittingSteps = rhs.numSplittingSteps;
101  this->nodeClusters = rhs.nodeClusters;
102 
103  //Copy the base classifier variables
104  copyBaseVariables( (Classifier*)&rhs );
105  }
106  return *this;
107 }
108 
109 bool DecisionTree::deepCopyFrom(const Classifier *classifier){
110 
111  if( classifier == NULL ) return false;
112 
113  if( this->getClassifierType() == classifier->getClassifierType() ){
114 
115  DecisionTree *ptr = (DecisionTree*)classifier;
116 
117  //Clear this tree
118  this->clear();
119 
120  if( ptr->getTrained() ){
121  //Deep copy the tree
122  this->tree = ptr->deepCopyTree();
123  }
124 
125  //Deep copy the main node
126  if( this->decisionTreeNode != NULL ){
127  delete decisionTreeNode;
128  decisionTreeNode = NULL;
129  }
130  this->decisionTreeNode = ptr->deepCopyDecisionTreeNode();
131 
132  this->minNumSamplesPerNode = ptr->minNumSamplesPerNode;
133  this->maxDepth = ptr->maxDepth;
134  this->removeFeaturesAtEachSpilt = ptr->removeFeaturesAtEachSpilt;
135  this->trainingMode = ptr->trainingMode;
136  this->numSplittingSteps = ptr->numSplittingSteps;
137  this->nodeClusters = ptr->nodeClusters;
138 
139  //Copy the base classifier variables
140  return copyBaseVariables( classifier );
141  }
142  return false;
143 }
144 
146 
147  //Clear any previous model
148  clear();
149 
150  if( decisionTreeNode == NULL ){
151  Classifier::errorLog << "train_(ClassificationData &trainingData) - The decision tree node has not been set! You must set this first before training a model." << std::endl;
152  return false;
153  }
154 
155  const unsigned int M = trainingData.getNumSamples();
156  const unsigned int N = trainingData.getNumDimensions();
157  const unsigned int K = trainingData.getNumClasses();
158 
159  if( M == 0 ){
160  Classifier::errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl;
161  return false;
162  }
163 
164  numInputDimensions = N;
165  numClasses = K;
166  classLabels = trainingData.getClassLabels();
167  ranges = trainingData.getRanges();
168 
169  //Get the validation set if needed
170  ClassificationData validationData;
171  if( useValidationSet ){
172  validationData = trainingData.split( validationSetSize );
173  validationSetAccuracy = 0;
174  validationSetPrecision.resize( useNullRejection ? K+1 : K, 0 );
175  validationSetRecall.resize( useNullRejection ? K+1 : K, 0 );
176  }
177 
178  //Scale the training data if needed
179  if( useScaling ){
180  //Scale the training data between 0 and 1
181  trainingData.scale(0, 1);
182  }
183 
184  //If we are using null rejection, then we need a copy of the training dataset for later
185  ClassificationData trainingDataCopy;
186  if( useNullRejection ){
187  trainingDataCopy = trainingData;
188  }
189 
190  //Setup the valid features - at this point all features can be used
191  Vector< UINT > features(N);
192  for(UINT i=0; i<N; i++){
193  features[i] = i;
194  }
195 
196  //Build the tree
197  UINT nodeID = 0;
198  tree = buildTree( trainingData, NULL, features, classLabels, nodeID );
199 
200  if( tree == NULL ){
201  clear();
202  Classifier::errorLog << "train_(ClassificationData &trainingData) - Failed to build tree!" << std::endl;
203  return false;
204  }
205 
206  //Flag that the algorithm has been trained
207  trained = true;
208 
209  //Compute the null rejection thresholds if null rejection is enabled
210  if( useNullRejection ){
211  VectorFloat classLikelihoods( numClasses );
212  Vector< UINT > predictions(M);
213  VectorFloat distances(M);
214  VectorFloat classCounter( numClasses, 0 );
215 
216  //Run over the training dataset and compute the distance between each training sample and the predicted node cluster
217  for(UINT i=0; i<M; i++){
218  //Run the prediction for this sample
219  if( !tree->predict( trainingDataCopy[i].getSample(), classLikelihoods ) ){
220  Classifier::errorLog << "train_(ClassificationData &trainingData) - Failed to predict training sample while building null rejection model!" << std::endl;
221  return false;
222  }
223 
224  //Store the predicted class index and cluster distance
225  predictions[i] = Util::getMaxIndex( classLikelihoods );
226  distances[i] = getNodeDistance(trainingDataCopy[i].getSample(), tree->getPredictedNodeID() );
227 
228  classCounter[ predictions[i] ]++;
229  }
230 
231  //Compute the average distance for each class between the training data and the node clusters
232  classClusterMean.clear();
233  classClusterStdDev.clear();
234  classClusterMean.resize( numClasses, 0 );
235  classClusterStdDev.resize( numClasses, 0.01 ); //we start the std dev with a small value to ensure it is not zero
236 
237  for(UINT i=0; i<M; i++){
238  classClusterMean[ predictions[i] ] += distances[ i ];
239  }
240  for(UINT k=0; k<numClasses; k++){
241  classClusterMean[k] /= grt_max( classCounter[k], 1 );
242  }
243 
244  //Compute the std deviation
245  for(UINT i=0; i<M; i++){
246  classClusterStdDev[ predictions[i] ] += MLBase::SQR( distances[ i ] - classClusterMean[ predictions[i] ] );
247  }
248  for(UINT k=0; k<numClasses; k++){
249  classClusterStdDev[k] = sqrt( classClusterStdDev[k] / grt_max( classCounter[k], 1 ) );
250  }
251 
252  //Compute the null rejection thresholds using the class mean and std dev
254  }
255 
256  if( useValidationSet ){
257  const UINT numTestSamples = validationData.getNumSamples();
258  double numCorrect = 0;
259  UINT testLabel = 0;
260  VectorFloat testSample;
261  VectorFloat validationSetPrecisionCounter( validationSetPrecision.size(), 0.0 );
262  VectorFloat validationSetRecallCounter( validationSetRecall.size(), 0.0 );
263  Classifier::trainingLog << "Testing model with validation set..." << std::endl;
264  for(UINT i=0; i<numTestSamples; i++){
265  testLabel = validationData[i].getClassLabel();
266  testSample = validationData[i].getSample();
267  predict_( testSample );
268  if( predictedClassLabel == testLabel ){
269  numCorrect++;
270  validationSetPrecision[ getClassLabelIndexValue( testLabel ) ]++;
271  validationSetRecall[ getClassLabelIndexValue( testLabel ) ]++;
272  }
273  validationSetPrecisionCounter[ getClassLabelIndexValue( predictedClassLabel ) ]++;
274  validationSetRecallCounter[ getClassLabelIndexValue( testLabel ) ]++;
275  }
276 
277  validationSetAccuracy = (numCorrect / numTestSamples) * 100.0;
278  for(UINT i=0; i<validationSetPrecision.getSize(); i++){
279  validationSetPrecision[i] /= validationSetPrecisionCounter[i] > 0 ? validationSetPrecisionCounter[i] : 1;
280  }
281  for(UINT i=0; i<validationSetRecall.getSize(); i++){
282  validationSetRecall[i] /= validationSetRecallCounter[i] > 0 ? validationSetRecallCounter[i] : 1;
283  }
284 
285  Classifier::trainingLog << "Validation set accuracy: " << validationSetAccuracy << std::endl;
286 
287  Classifier::trainingLog << "Validation set precision: ";
288  for(UINT i=0; i<validationSetPrecision.getSize(); i++){
289  Classifier::trainingLog << validationSetPrecision[i] << " ";
290  }
291  Classifier::trainingLog << std::endl;
292 
293  Classifier::trainingLog << "Validation set recall: ";
294  for(UINT i=0; i<validationSetRecall.getSize(); i++){
295  Classifier::trainingLog << validationSetRecall[i] << " ";
296  }
297  Classifier::trainingLog << std::endl;
298  }
299 
300  return true;
301 }
302 
304 
305  predictedClassLabel = 0;
306  maxLikelihood = 0;
307 
308  //Validate the input is OK and the model is trained properly
309  if( !trained ){
310  Classifier::errorLog << "predict_(VectorFloat &inputVector) - Model Not Trained!" << std::endl;
311  return false;
312  }
313 
314  if( tree == NULL ){
315  Classifier::errorLog << "predict_(VectorFloat &inputVector) - DecisionTree pointer is null!" << std::endl;
316  return false;
317  }
318 
319  if( inputVector.getSize() != numInputDimensions ){
320  Classifier::errorLog << "predict_(VectorFloat &inputVector) - The size of the input Vector (" << inputVector.getSize() << ") does not match the num features in the model (" << numInputDimensions << std::endl;
321  return false;
322  }
323 
324  //Scale the input data if needed
325  if( useScaling ){
326  for(UINT n=0; n<numInputDimensions; n++){
327  inputVector[n] = grt_scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue, 0.0, 1.0);
328  }
329  }
330 
331  if( classLikelihoods.size() != numClasses ) classLikelihoods.resize(numClasses,0);
332  if( classDistances.size() != numClasses ) classDistances.resize(numClasses,0);
333 
334  //Run the decision tree prediction
335  if( !tree->predict( inputVector, classLikelihoods ) ){
336  Classifier::errorLog << "predict_(VectorFloat &inputVector) - Failed to predict!" << std::endl;
337  return false;
338  }
339 
340  //Find the maximum likelihood
341  //The tree automatically returns proper class likelihoods so we don't need to do anything else
342  UINT maxIndex = 0;
343  maxLikelihood = 0;
344  for(UINT k=0; k<numClasses; k++){
345  if( classLikelihoods[k] > maxLikelihood ){
346  maxLikelihood = classLikelihoods[k];
347  maxIndex = k;
348  }
349  }
350 
351  //Run the null rejection
352  if( useNullRejection ){
353 
354  //Get the distance between the input and the leaf mean
355  Float leafDistance = getNodeDistance( inputVector, tree->getPredictedNodeID() );
356 
357  if( grt_isnan(leafDistance) ){
358  Classifier::errorLog << "predict_(VectorFloat &inputVector) - Failed to match leaf node ID to compute node distance!" << std::endl;
359  return false;
360  }
361 
362  //Set the predicted class distance as the leaf distance, all other classes will have a distance of zero
363  classDistances.setAll(0.0);
364  classDistances[ maxIndex ] = leafDistance;
365 
366  //Use the distance to check if the class label should be rejected or not
367  if( leafDistance <= nullRejectionThresholds[ maxIndex ] ){
368  predictedClassLabel = classLabels[ maxIndex ];
369  }else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
370 
371  }else {
372  //Set the predicated class label
373  predictedClassLabel = classLabels[ maxIndex ];
374  }
375 
376  return true;
377 }
378 
380 
381  //Clear the Classifier variables
383 
384  //Clear the node clusters
385  nodeClusters.clear();
386 
387  //Delete the tree if it exists
388  if( tree != NULL ){
389  tree->clear();
390  delete tree;
391  tree = NULL;
392  }
393 
394  //NOTE: We do not want to clean up the decisionTreeNode here as we need to keep track of this, this is only delete in the destructor
395 
396  return true;
397 }
398 
400 
401  if( !trained ){
402  Classifier::warningLog << "recomputeNullRejectionThresholds() - Failed to recompute null rejection thresholds, the model has not been trained!" << std::endl;
403  return false;
404  }
405 
406  if( !useNullRejection ){
407  Classifier::warningLog << "recomputeNullRejectionThresholds() - Failed to recompute null rejection thresholds, null rejection is not enabled!" << std::endl;
408  return false;
409  }
410 
411  nullRejectionThresholds.resize( numClasses );
412 
413  //Compute the rejection threshold for each class using the mean and std dev
414  for(UINT k=0; k<numClasses; k++){
415  nullRejectionThresholds[k] = classClusterMean[k] + (classClusterStdDev[k]*nullRejectionCoeff);
416  }
417 
418  return true;
419 }
420 
421 bool DecisionTree::save( std::fstream &file ) const{
422 
423  if(!file.is_open())
424  {
425  Classifier::errorLog <<"save(fstream &file) - The file is not open!" << std::endl;
426  return false;
427  }
428 
429  //Write the header info
430  file << "GRT_DECISION_TREE_MODEL_FILE_V4.0\n";
431 
432  //Write the classifier settings to the file
434  Classifier::errorLog <<"save(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
435  return false;
436  }
437 
438  if( decisionTreeNode != NULL ){
439  file << "DecisionTreeNodeType: " << decisionTreeNode->getNodeType() << std::endl;
440  if( !decisionTreeNode->save( file ) ){
441  Classifier::errorLog <<"save(fstream &file) - Failed to save decisionTreeNode settings to file!" << std::endl;
442  return false;
443  }
444  }else{
445  file << "DecisionTreeNodeType: " << "NULL" << std::endl;
446  }
447 
448  file << "MinNumSamplesPerNode: " << minNumSamplesPerNode << std::endl;
449  file << "MaxDepth: " << maxDepth << std::endl;
450  file << "RemoveFeaturesAtEachSpilt: " << removeFeaturesAtEachSpilt << std::endl;
451  file << "TrainingMode: " << trainingMode << std::endl;
452  file << "NumSplittingSteps: " << numSplittingSteps << std::endl;
453  file << "TreeBuilt: " << (tree != NULL ? 1 : 0) << std::endl;
454 
455  if( tree != NULL ){
456  file << "Tree:\n";
457  if( !tree->save( file ) ){
458  Classifier::errorLog << "save(fstream &file) - Failed to save tree to file!" << std::endl;
459  return false;
460  }
461 
462  //Save the null rejection data if needed
463  if( useNullRejection ){
464 
465  file << "ClassClusterMean:";
466  for(UINT k=0; k<numClasses; k++){
467  file << " " << classClusterMean[k];
468  }
469  file << std::endl;
470 
471  file << "ClassClusterStdDev:";
472  for(UINT k=0; k<numClasses; k++){
473  file << " " << classClusterStdDev[k];
474  }
475  file << std::endl;
476 
477  file << "NumNodes: " << nodeClusters.size() << std::endl;
478  file << "NodeClusters:\n";
479 
480  std::map< UINT, VectorFloat >::const_iterator iter = nodeClusters.begin();
481 
482  while( iter != nodeClusters.end() ){
483 
484  //Write the nodeID
485  file << iter->first;
486 
487  //Write the node cluster
488  for(UINT j=0; j<numInputDimensions; j++){
489  file << " " << iter->second[j];
490  }
491  file << std::endl;
492 
493  iter++;
494  }
495  }
496 
497  }
498 
499  return true;
500 }
501 
502 bool DecisionTree::load( std::fstream &file ){
503 
504  clear();
505 
506  if( decisionTreeNode != NULL ){
507  delete decisionTreeNode;
508  decisionTreeNode = NULL;
509  }
510 
511  if( !file.is_open() )
512  {
513  Classifier::errorLog << "load(std::fstream &file) - Could not open file to load model" << std::endl;
514  return false;
515  }
516 
517  std::string word;
518  file >> word;
519 
520  //Check to see if we should load a legacy file
521  if( word == "GRT_DECISION_TREE_MODEL_FILE_V1.0" ){
522  return loadLegacyModelFromFile_v1( file );
523  }
524 
525  if( word == "GRT_DECISION_TREE_MODEL_FILE_V2.0" ){
526  return loadLegacyModelFromFile_v2( file );
527  }
528 
529  if( word == "GRT_DECISION_TREE_MODEL_FILE_V3.0" ){
530  return loadLegacyModelFromFile_v3( file );
531  }
532 
533  //Find the file type header
534  if( word != "GRT_DECISION_TREE_MODEL_FILE_V4.0" ){
535  Classifier::errorLog << "load(string filename) - Could not find Model File Header" << std::endl;
536  return false;
537  }
538 
539  //Load the base settings from the file
541  Classifier::errorLog << "load(string filename) - Failed to load base settings from file!" << std::endl;
542  return false;
543  }
544 
545  file >> word;
546  if(word != "DecisionTreeNodeType:"){
547  Classifier::errorLog << "load(string filename) - Could not find the DecisionTreeNodeType!" << std::endl;
548  return false;
549  }
550  file >> word;
551 
552  if( word != "NULL" ){
553 
554  decisionTreeNode = dynamic_cast< DecisionTreeNode* >( DecisionTreeNode::createInstanceFromString( word ) );
555 
556  if( decisionTreeNode == NULL ){
557  Classifier::errorLog << "load(string filename) - Could not create new DecisionTreeNode from type: " << word << std::endl;
558  return false;
559  }
560 
561  if( !decisionTreeNode->load( file ) ){
562  Classifier::errorLog <<"load(fstream &file) - Failed to load decisionTreeNode settings from file!" << std::endl;
563  return false;
564  }
565  }else{
566  Classifier::errorLog <<"load(fstream &file) - Failed to load decisionTreeNode! DecisionTreeNodeType is NULL!" << std::endl;
567  return false;
568  }
569 
570  file >> word;
571  if(word != "MinNumSamplesPerNode:"){
572  Classifier::errorLog << "load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
573  return false;
574  }
575  file >> minNumSamplesPerNode;
576 
577  file >> word;
578  if(word != "MaxDepth:"){
579  Classifier::errorLog << "load(string filename) - Could not find the MaxDepth!" << std::endl;
580  return false;
581  }
582  file >> maxDepth;
583 
584  file >> word;
585  if(word != "RemoveFeaturesAtEachSpilt:"){
586  Classifier::errorLog << "load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
587  return false;
588  }
589  file >> removeFeaturesAtEachSpilt;
590 
591  file >> word;
592  if(word != "TrainingMode:"){
593  Classifier::errorLog << "load(string filename) - Could not find the TrainingMode!" << std::endl;
594  return false;
595  }
596  file >> trainingMode;
597 
598  file >> word;
599  if(word != "NumSplittingSteps:"){
600  Classifier::errorLog << "load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
601  return false;
602  }
603  file >> numSplittingSteps;
604 
605  file >> word;
606  if(word != "TreeBuilt:"){
607  Classifier::errorLog << "load(string filename) - Could not find the TreeBuilt!" << std::endl;
608  return false;
609  }
610  file >> trained;
611 
612  if( trained ){
613  file >> word;
614  if(word != "Tree:"){
615  Classifier::errorLog << "load(string filename) - Could not find the Tree!" << std::endl;
616  return false;
617  }
618 
619  //Create a new DTree
620  tree = dynamic_cast< DecisionTreeNode* >( decisionTreeNode->createNewInstance() );
621 
622  if( tree == NULL ){
623  clear();
624  Classifier::errorLog << "load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
625  return false;
626  }
627 
628  tree->setParent( NULL );
629  if( !tree->load( file ) ){
630  clear();
631  Classifier::errorLog << "load(fstream &file) - Failed to load tree from file!" << std::endl;
632  return false;
633  }
634 
635  //Load the null rejection data if needed
636  if( useNullRejection ){
637 
638  UINT numNodes = 0;
639  classClusterMean.resize( numClasses );
640  classClusterStdDev.resize( numClasses );
641 
642  file >> word;
643  if(word != "ClassClusterMean:"){
644  Classifier::errorLog << "load(string filename) - Could not find the ClassClusterMean header!" << std::endl;
645  return false;
646  }
647  for(UINT k=0; k<numClasses; k++){
648  file >> classClusterMean[k];
649  }
650 
651  file >> word;
652  if(word != "ClassClusterStdDev:"){
653  Classifier::errorLog << "load(string filename) - Could not find the ClassClusterStdDev header!" << std::endl;
654  return false;
655  }
656  for(UINT k=0; k<numClasses; k++){
657  file >> classClusterStdDev[k];
658  }
659 
660  file >> word;
661  if(word != "NumNodes:"){
662  Classifier::errorLog << "load(string filename) - Could not find the NumNodes header!" << std::endl;
663  return false;
664  }
665  file >> numNodes;
666 
667  file >> word;
668  if(word != "NodeClusters:"){
669  Classifier::errorLog << "load(string filename) - Could not find the NodeClusters header!" << std::endl;
670  return false;
671  }
672 
673  UINT nodeID = 0;
674  VectorFloat cluster( numInputDimensions );
675  for(UINT i=0; i<numNodes; i++){
676 
677  //load the nodeID
678  file >> nodeID;
679 
680  for(UINT j=0; j<numInputDimensions; j++){
681  file >> cluster[j];
682  }
683 
684  //Add the cluster to the cluster nodes map
685  nodeClusters[ nodeID ] = cluster;
686  }
687 
688  //Recompute the null rejection thresholds
690  }
691 
692  //Resize the prediction results to make sure it is setup for realtime prediction
693  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
694  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
695  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
696  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
697  }
698 
699  return true;
700 }
701 
702 bool DecisionTree::getModel( std::ostream &stream ) const{
703 
704  if( tree != NULL )
705  return tree->getModel( stream );
706  return false;
707 
708 }
709 
711 
712  if( tree == NULL ){
713  return NULL;
714  }
715 
716  return dynamic_cast< DecisionTreeNode* >( tree->deepCopyNode() );
717 }
718 
720 
721  if( decisionTreeNode == NULL ){
722  return NULL;
723  }
724 
725  return decisionTreeNode->deepCopy();
726 }
727 
729  return dynamic_cast< DecisionTreeNode* >( tree );
730 }
731 
733 
734  if( decisionTreeNode != NULL ){
735  delete decisionTreeNode;
736  decisionTreeNode = NULL;
737  }
738  this->decisionTreeNode = node.deepCopy();
739 
740  return true;
741 }
742 
743 DecisionTreeNode* DecisionTree::buildTree(ClassificationData &trainingData,DecisionTreeNode *parent,Vector< UINT > features,const Vector< UINT > &classLabels, UINT nodeID){
744 
745  const UINT M = trainingData.getNumSamples();
746  const UINT N = trainingData.getNumDimensions();
747 
748  //Update the nodeID
749  nodeID++;
750 
751  //Get the depth
752  UINT depth = 0;
753 
754  if( parent != NULL )
755  depth = parent->getDepth() + 1;
756 
757  //If there are no training data then return NULL
758  if( trainingData.getNumSamples() == 0 )
759  return NULL;
760 
761  //Create the new node
762  DecisionTreeNode *node = dynamic_cast< DecisionTreeNode* >( decisionTreeNode->createNewInstance() );
763 
764  if( node == NULL )
765  return NULL;
766 
767  //Get the class probabilities
768  VectorFloat classProbs = trainingData.getClassProbabilities( classLabels );
769 
770  //Set the parent
771  node->initNode( parent, depth, nodeID );
772 
773  //If all the training data belongs to the same class or there are no features left then create a leaf node and return
774  if( trainingData.getNumClasses() == 1 || features.size() == 0 || M < minNumSamplesPerNode || depth >= maxDepth ){
775 
776  //Set the node
777  node->setLeafNode( trainingData.getNumSamples(), classProbs );
778 
779  //Build the null cluster if null rejection is enabled
780  if( useNullRejection ){
781  nodeClusters[ nodeID ] = trainingData.getMean();
782  }
783 
784  std::string info = "Reached leaf node.";
785  if( trainingData.getNumClasses() == 1 ) info = "Reached pure leaf node.";
786  else if( features.size() == 0 ) info = "Reached leaf node, no remaining features.";
787  else if( M < minNumSamplesPerNode ) info = "Reached leaf node, hit min-samples-per-node limit.";
788  else if( depth >= maxDepth ) info = "Reached leaf node, max depth reached.";
789 
790  Classifier::trainingLog << info << " Depth: " << depth << " NumSamples: " << trainingData.getNumSamples();
791 
792  Classifier::trainingLog << " Class Probabilities: ";
793  for(UINT k=0; k<classProbs.getSize(); k++){
794  Classifier::trainingLog << classProbs[k] << " ";
795  }
796  Classifier::trainingLog << std::endl;
797 
798  return node;
799  }
800 
801  //Compute the best spilt point
802  UINT featureIndex = 0;
803  Float minError = 0;
804 
805  if( !node->computeBestSpilt( trainingMode, numSplittingSteps, trainingData, features, classLabels, featureIndex, minError ) ){
806  delete node;
807  return NULL;
808  }
809 
810  Classifier::trainingLog << "Depth: " << depth << " FeatureIndex: " << featureIndex << " MinError: " << minError;
811  Classifier::trainingLog << " Class Probabilities: ";
812  for(size_t k=0; k<classProbs.size(); k++){
813  Classifier::trainingLog << classProbs[k] << " ";
814  }
815  Classifier::trainingLog << std::endl;
816 
817  //Remove the selected feature so we will not use it again
818  if( removeFeaturesAtEachSpilt ){
819  for(size_t i=0; i<features.size(); i++){
820  if( features[i] == featureIndex ){
821  features.erase( features.begin()+i );
822  break;
823  }
824  }
825  }
826 
827  //Split the data into a left and right dataset
828  ClassificationData lhs(N);
829  ClassificationData rhs(N);
830 
831  //Reserve the memory to speed up the allocation of large datasets
832  lhs.reserve( M );
833  rhs.reserve( M );
834 
835  for(UINT i=0; i<M; i++){
836  if( node->predict( trainingData[i].getSample() ) ){
837  rhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
838  }else lhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
839  }
840 
841  //Clear the parent dataset so we do not run out of memory with very large datasets (with very deep trees)
842  trainingData.clear();
843 
844  //Get the new node IDs for the children
845  UINT leftNodeID = ++nodeID;
846  UINT rightNodeID = ++nodeID;
847 
848  //Run the recursive tree building on the children
849  node->setLeftChild( buildTree( lhs, node, features, classLabels, leftNodeID ) );
850  node->setRightChild( buildTree( rhs, node, features, classLabels, rightNodeID ) );
851 
852  //Build the null clusters for the rhs and lhs nodes if null rejection is enabled
853  if( useNullRejection ){
854  nodeClusters[ leftNodeID ] = lhs.getMean();
855  nodeClusters[ rightNodeID ] = rhs.getMean();
856  }
857 
858  return node;
859 }
860 
861 Float DecisionTree::getNodeDistance( const VectorFloat &x, const UINT nodeID ){
862 
863  //Use the node ID to find the node cluster
864  std::map< UINT,VectorFloat >::iterator iter = nodeClusters.find( nodeID );
865 
866  //If we failed to find a match, return NAN
867  if( iter == nodeClusters.end() ) return NAN;
868 
869  //Compute the distance between the input and the node cluster
870  return getNodeDistance( x, iter->second );
871 }
872 
873 Float DecisionTree::getNodeDistance( const VectorFloat &x, const VectorFloat &y ){
874 
875  Float distance = 0;
876  const size_t N = x.size();
877 
878  for(size_t i=0; i<N; i++){
879  distance += MLBase::SQR( x[i] - y[i] );
880  }
881 
882  //Return the squared Euclidean distance instead of actual Euclidean distance as this is faster and just as useful
883  return distance;
884 }
885 
886 bool DecisionTree::loadLegacyModelFromFile_v1( std::fstream &file ){
887 
888  std::string word;
889 
890  file >> word;
891  if(word != "NumFeatures:"){
892  Classifier::errorLog << "load(string filename) - Could not find NumFeatures!" << std::endl;
893  return false;
894  }
895  file >> numInputDimensions;
896 
897  file >> word;
898  if(word != "NumClasses:"){
899  Classifier::errorLog << "load(string filename) - Could not find NumClasses!" << std::endl;
900  return false;
901  }
902  file >> numClasses;
903 
904  file >> word;
905  if(word != "UseScaling:"){
906  Classifier::errorLog << "load(string filename) - Could not find UseScaling!" << std::endl;
907  return false;
908  }
909  file >> useScaling;
910 
911  file >> word;
912  if(word != "UseNullRejection:"){
913  Classifier::errorLog << "load(string filename) - Could not find UseNullRejection!" << std::endl;
914  return false;
915  }
916  file >> useNullRejection;
917 
919  if( useScaling ){
920  //Resize the ranges buffer
921  ranges.resize( numInputDimensions );
922 
923  file >> word;
924  if(word != "Ranges:"){
925  Classifier::errorLog << "load(string filename) - Could not find the Ranges!" << std::endl;
926  return false;
927  }
928  for(UINT n=0; n<ranges.size(); n++){
929  file >> ranges[n].minValue;
930  file >> ranges[n].maxValue;
931  }
932  }
933 
934  file >> word;
935  if(word != "NumSplittingSteps:"){
936  Classifier::errorLog << "load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
937  return false;
938  }
939  file >> numSplittingSteps;
940 
941  file >> word;
942  if(word != "MinNumSamplesPerNode:"){
943  Classifier::errorLog << "load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
944  return false;
945  }
946  file >> minNumSamplesPerNode;
947 
948  file >> word;
949  if(word != "MaxDepth:"){
950  Classifier::errorLog << "load(string filename) - Could not find the MaxDepth!" << std::endl;
951  return false;
952  }
953  file >> maxDepth;
954 
955  file >> word;
956  if(word != "RemoveFeaturesAtEachSpilt:"){
957  Classifier::errorLog << "load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
958  return false;
959  }
960  file >> removeFeaturesAtEachSpilt;
961 
962  file >> word;
963  if(word != "TrainingMode:"){
964  Classifier::errorLog << "load(string filename) - Could not find the TrainingMode!" << std::endl;
965  return false;
966  }
967  file >> trainingMode;
968 
969  file >> word;
970  if(word != "TreeBuilt:"){
971  Classifier::errorLog << "load(string filename) - Could not find the TreeBuilt!" << std::endl;
972  return false;
973  }
974  file >> trained;
975 
976  if( trained ){
977  file >> word;
978  if(word != "Tree:"){
979  Classifier::errorLog << "load(string filename) - Could not find the Tree!" << std::endl;
980  return false;
981  }
982 
983  //Create a new DTree
984  tree = new DecisionTreeNode;
985 
986  if( tree == NULL ){
987  clear();
988  Classifier::errorLog << "load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
989  return false;
990  }
991 
992  tree->setParent( NULL );
993  if( !tree->load( file ) ){
994  clear();
995  Classifier::errorLog << "load(fstream &file) - Failed to load tree from file!" << std::endl;
996  return false;
997  }
998  }
999 
1000  return true;
1001 }
1002 
1003 bool DecisionTree::loadLegacyModelFromFile_v2( std::fstream &file ){
1004 
1005  std::string word;
1006 
1007  //Load the base settings from the file
1009  Classifier::errorLog << "load(string filename) - Failed to load base settings from file!" << std::endl;
1010  return false;
1011  }
1012 
1013  file >> word;
1014  if(word != "NumSplittingSteps:"){
1015  Classifier::errorLog << "load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1016  return false;
1017  }
1018  file >> numSplittingSteps;
1019 
1020  file >> word;
1021  if(word != "MinNumSamplesPerNode:"){
1022  Classifier::errorLog << "load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1023  return false;
1024  }
1025  file >> minNumSamplesPerNode;
1026 
1027  file >> word;
1028  if(word != "MaxDepth:"){
1029  Classifier::errorLog << "load(string filename) - Could not find the MaxDepth!" << std::endl;
1030  return false;
1031  }
1032  file >> maxDepth;
1033 
1034  file >> word;
1035  if(word != "RemoveFeaturesAtEachSpilt:"){
1036  Classifier::errorLog << "load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1037  return false;
1038  }
1039  file >> removeFeaturesAtEachSpilt;
1040 
1041  file >> word;
1042  if(word != "TrainingMode:"){
1043  Classifier::errorLog << "load(string filename) - Could not find the TrainingMode!" << std::endl;
1044  return false;
1045  }
1046  file >> trainingMode;
1047 
1048  file >> word;
1049  if(word != "TreeBuilt:"){
1050  Classifier::errorLog << "load(string filename) - Could not find the TreeBuilt!" << std::endl;
1051  return false;
1052  }
1053  file >> trained;
1054 
1055  if( trained ){
1056  file >> word;
1057  if(word != "Tree:"){
1058  Classifier::errorLog << "load(string filename) - Could not find the Tree!" << std::endl;
1059  return false;
1060  }
1061 
1062  //Create a new DTree
1063  tree = new DecisionTreeNode;
1064 
1065  if( tree == NULL ){
1066  clear();
1067  Classifier::errorLog << "load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1068  return false;
1069  }
1070 
1071  tree->setParent( NULL );
1072  if( !tree->load( file ) ){
1073  clear();
1074  Classifier::errorLog << "load(fstream &file) - Failed to load tree from file!" << std::endl;
1075  return false;
1076  }
1077 
1078  //Recompute the null rejection thresholds
1080 
1081  //Resize the prediction results to make sure it is setup for realtime prediction
1082  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
1083  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1084  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
1085  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1086  }
1087 
1088  return true;
1089 }
1090 
1091 bool DecisionTree::loadLegacyModelFromFile_v3( std::fstream &file ){
1092 
1093  std::string word;
1094 
1095  //Load the base settings from the file
1097  Classifier::errorLog << "load(string filename) - Failed to load base settings from file!" << std::endl;
1098  return false;
1099  }
1100 
1101  file >> word;
1102  if(word != "NumSplittingSteps:"){
1103  Classifier::errorLog << "load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1104  return false;
1105  }
1106  file >> numSplittingSteps;
1107 
1108  file >> word;
1109  if(word != "MinNumSamplesPerNode:"){
1110  Classifier::errorLog << "load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1111  return false;
1112  }
1113  file >> minNumSamplesPerNode;
1114 
1115  file >> word;
1116  if(word != "MaxDepth:"){
1117  Classifier::errorLog << "load(string filename) - Could not find the MaxDepth!" << std::endl;
1118  return false;
1119  }
1120  file >> maxDepth;
1121 
1122  file >> word;
1123  if(word != "RemoveFeaturesAtEachSpilt:"){
1124  Classifier::errorLog << "load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1125  return false;
1126  }
1127  file >> removeFeaturesAtEachSpilt;
1128 
1129  file >> word;
1130  if(word != "TrainingMode:"){
1131  Classifier::errorLog << "load(string filename) - Could not find the TrainingMode!" << std::endl;
1132  return false;
1133  }
1134  file >> trainingMode;
1135 
1136  file >> word;
1137  if(word != "TreeBuilt:"){
1138  Classifier::errorLog << "load(string filename) - Could not find the TreeBuilt!" << std::endl;
1139  return false;
1140  }
1141  file >> trained;
1142 
1143  if( trained ){
1144  file >> word;
1145  if(word != "Tree:"){
1146  Classifier::errorLog << "load(string filename) - Could not find the Tree!" << std::endl;
1147  return false;
1148  }
1149 
1150  //Create a new DTree
1151  tree = new DecisionTreeNode;
1152 
1153  if( tree == NULL ){
1154  clear();
1155  Classifier::errorLog << "load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1156  return false;
1157  }
1158 
1159  tree->setParent( NULL );
1160  if( !tree->load( file ) ){
1161  clear();
1162  Classifier::errorLog << "load(fstream &file) - Failed to load tree from file!" << std::endl;
1163  return false;
1164  }
1165 
1166  //Load the null rejection data if needed
1167  if( useNullRejection ){
1168 
1169  UINT numNodes = 0;
1170  classClusterMean.resize( numClasses );
1171  classClusterStdDev.resize( numClasses );
1172 
1173  file >> word;
1174  if(word != "ClassClusterMean:"){
1175  Classifier::errorLog << "load(string filename) - Could not find the ClassClusterMean header!" << std::endl;
1176  return false;
1177  }
1178  for(UINT k=0; k<numClasses; k++){
1179  file >> classClusterMean[k];
1180  }
1181 
1182  file >> word;
1183  if(word != "ClassClusterStdDev:"){
1184  Classifier::errorLog << "load(string filename) - Could not find the ClassClusterStdDev header!" << std::endl;
1185  return false;
1186  }
1187  for(UINT k=0; k<numClasses; k++){
1188  file >> classClusterStdDev[k];
1189  }
1190 
1191  file >> word;
1192  if(word != "NumNodes:"){
1193  Classifier::errorLog << "load(string filename) - Could not find the NumNodes header!" << std::endl;
1194  return false;
1195  }
1196  file >> numNodes;
1197 
1198  file >> word;
1199  if(word != "NodeClusters:"){
1200  Classifier::errorLog << "load(string filename) - Could not find the NodeClusters header!" << std::endl;
1201  return false;
1202  }
1203 
1204  UINT nodeID = 0;
1205  VectorFloat cluster( numInputDimensions );
1206  for(UINT i=0; i<numNodes; i++){
1207 
1208  //load the nodeID
1209  file >> nodeID;
1210 
1211  for(UINT j=0; j<numInputDimensions; j++){
1212  file >> cluster[j];
1213  }
1214 
1215  //Add the cluster to the cluster nodes map
1216  nodeClusters[ nodeID ] = cluster;
1217  }
1218 
1219  //Recompute the null rejection thresholds
1221  }
1222 
1223  //Resize the prediction results to make sure it is setup for realtime prediction
1224  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
1225  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1226  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
1227  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1228  }
1229 
1230  return true;
1231 }
1232 
1233 GRT_END_NAMESPACE
bool saveBaseSettingsToFile(std::fstream &file) const
Definition: Classifier.cpp:256
This class implements a basic Decision Tree classifier. Decision Trees are conceptually simple classi...
virtual bool load(std::fstream &file)
#define DEFAULT_NULL_LIKELIHOOD_VALUE
Definition: Classifier.h:38
virtual bool computeBestSpilt(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError)
bool setAll(const T &value)
Definition: Vector.h:175
std::string getClassifierType() const
Definition: Classifier.cpp:161
const DecisionTreeNode * getTree() const
DecisionTreeNode * deepCopyDecisionTreeNode() const
bool loadLegacyModelFromFile_v1(std::fstream &file)
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool train_(ClassificationData &trainingData)
std::string getNodeType() const
Definition: Node.cpp:304
bool getTrained() const
Definition: MLBase.cpp:259
virtual bool getModel(std::ostream &stream) const
Definition: Node.cpp:120
virtual bool predict_(VectorFloat &inputVector)
UINT getDepth() const
Definition: Node.cpp:308
UINT getSize() const
Definition: Vector.h:191
UINT getClassLabelIndexValue(UINT classLabel) const
Definition: Classifier.cpp:195
Vector< UINT > getClassLabels() const
DecisionTree(const DecisionTreeNode &decisionTreeNode=DecisionTreeClusterNode(), const UINT minNumSamplesPerNode=5, const UINT maxDepth=10, const bool removeFeaturesAtEachSpilt=false, const UINT trainingMode=BEST_ITERATIVE_SPILT, const UINT numSplittingSteps=100, const bool useScaling=false)
virtual bool deepCopyFrom(const Classifier *classifier)
static std::string getId()
DecisionTreeNode * deepCopyTree() const
virtual bool save(std::fstream &file) const
virtual bool recomputeNullRejectionThresholds()
virtual bool save(std::fstream &file) const
Definition: Node.cpp:140
UINT getNumSamples() const
virtual Node * deepCopyNode() const
Definition: Node.cpp:276
virtual bool clear()
bool copyBaseVariables(const Classifier *classifier)
Definition: Classifier.cpp:93
bool loadBaseSettingsFromFile(std::fstream &file)
Definition: Classifier.cpp:303
virtual bool clear()
Definition: Node.cpp:70
UINT getNumDimensions() const
UINT getNumClasses() const
Node * createNewInstance() const
Definition: Node.cpp:39
UINT getPredictedNodeID() const
Definition: Node.cpp:316
DecisionTree & operator=(const DecisionTree &rhs)
virtual bool predict(const VectorFloat &x, VectorFloat &classLikelihoods)
bool setLeafNode(const UINT nodeSize, const VectorFloat &classProbabilities)
DecisionTreeNode * deepCopy() const
Vector< MinMax > getRanges() const
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
static Node * createInstanceFromString(std::string const &nodeType)
Definition: Node.cpp:29
static unsigned int getMaxIndex(const VectorFloat &x)
Definition: Util.cpp:292
bool setDecisionTreeNode(const DecisionTreeNode &node)
virtual ~DecisionTree(void)
bool scale(const Float minTarget, const Float maxTarget)
virtual bool getModel(std::ostream &stream) const
virtual bool clear()
Definition: Classifier.cpp:142
virtual bool load(std::fstream &file)
Definition: Node.cpp:182
virtual bool predict(const VectorFloat &x)
Definition: Node.cpp:60
VectorFloat getMean() const