GestureRecognitionToolkit  Version: 0.1.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTree.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #include "DecisionTree.h"
22 
23 GRT_BEGIN_NAMESPACE
24 
25 //Register the DecisionTree module with the Classifier base class
26 RegisterClassifierModule< DecisionTree > DecisionTree::registerModule("DecisionTree");
27 
28 DecisionTree::DecisionTree(const DecisionTreeNode &decisionTreeNode,const UINT minNumSamplesPerNode,const UINT maxDepth,const bool removeFeaturesAtEachSpilt,const UINT trainingMode,const UINT numSplittingSteps,const bool useScaling)
29 {
30  this->tree = NULL;
31  this->decisionTreeNode = NULL;
32  this->minNumSamplesPerNode = minNumSamplesPerNode;
33  this->maxDepth = maxDepth;
34  this->removeFeaturesAtEachSpilt = removeFeaturesAtEachSpilt;
35  this->trainingMode = trainingMode;
36  this->numSplittingSteps = numSplittingSteps;
37  this->useScaling = useScaling;
38  this->supportsNullRejection = true;
39  Classifier::classType = "DecisionTree";
40  classifierType = Classifier::classType;
41  classifierMode = STANDARD_CLASSIFIER_MODE;
42  Classifier::debugLog.setProceedingText("[DEBUG DecisionTree]");
43  Classifier::errorLog.setProceedingText("[ERROR DecisionTree]");
44  Classifier::trainingLog.setProceedingText("[TRAINING DecisionTree]");
45  Classifier::warningLog.setProceedingText("[WARNING DecisionTree]");
46 
47  this->decisionTreeNode = decisionTreeNode.deepCopy();
48 
49 }
50 
52  tree = NULL;
53  decisionTreeNode = NULL;
54  Classifier::classType = "DecisionTree";
55  classifierType = Classifier::classType;
56  classifierMode = STANDARD_CLASSIFIER_MODE;
57  Classifier:: debugLog.setProceedingText("[DEBUG DecisionTree]");
58  Classifier::errorLog.setProceedingText("[ERROR DecisionTree]");
59  Classifier::trainingLog.setProceedingText("[TRAINING DecisionTree]");
60  Classifier::warningLog.setProceedingText("[WARNING DecisionTree]");
61  *this = rhs;
62 }
63 
65 {
66  clear();
67 
68  if( decisionTreeNode != NULL ){
69  delete decisionTreeNode;
70  decisionTreeNode = NULL;
71  }
72 }
73 
75  if( this != &rhs ){
76  //Clear this tree
77  clear();
78 
79  if( rhs.getTrained() ){
80  //Deep copy the tree
81  this->tree = (DecisionTreeNode*)rhs.deepCopyTree();
82  }
83 
84  //Deep copy the main node
85  if( this->decisionTreeNode != NULL ){
86  delete decisionTreeNode;
87  decisionTreeNode = NULL;
88  }
89  this->decisionTreeNode = rhs.deepCopyDecisionTreeNode();
90 
91  this->minNumSamplesPerNode = rhs.minNumSamplesPerNode;
92  this->maxDepth = rhs.maxDepth;
93  this->removeFeaturesAtEachSpilt = rhs.removeFeaturesAtEachSpilt;
94  this->trainingMode = rhs.trainingMode;
95  this->numSplittingSteps = rhs.numSplittingSteps;
96  this->nodeClusters = rhs.nodeClusters;
97 
98  //Copy the base classifier variables
99  copyBaseVariables( (Classifier*)&rhs );
100  }
101  return *this;
102 }
103 
104 bool DecisionTree::deepCopyFrom(const Classifier *classifier){
105 
106  if( classifier == NULL ) return false;
107 
108  if( this->getClassifierType() == classifier->getClassifierType() ){
109 
110  DecisionTree *ptr = (DecisionTree*)classifier;
111 
112  //Clear this tree
113  this->clear();
114 
115  if( ptr->getTrained() ){
116  //Deep copy the tree
117  this->tree = ptr->deepCopyTree();
118  }
119 
120  //Deep copy the main node
121  if( this->decisionTreeNode != NULL ){
122  delete decisionTreeNode;
123  decisionTreeNode = NULL;
124  }
125  this->decisionTreeNode = ptr->deepCopyDecisionTreeNode();
126 
127  this->minNumSamplesPerNode = ptr->minNumSamplesPerNode;
128  this->maxDepth = ptr->maxDepth;
129  this->removeFeaturesAtEachSpilt = ptr->removeFeaturesAtEachSpilt;
130  this->trainingMode = ptr->trainingMode;
131  this->numSplittingSteps = ptr->numSplittingSteps;
132  this->nodeClusters = ptr->nodeClusters;
133 
134  //Copy the base classifier variables
135  return copyBaseVariables( classifier );
136  }
137  return false;
138 }
139 
141 
142  //Clear any previous model
143  clear();
144 
145  if( decisionTreeNode == NULL ){
146  Classifier::errorLog << "train_(ClassificationData &trainingData) - The decision tree node has not been set! You must set this first before training a model." << std::endl;
147  return false;
148  }
149 
150  const unsigned int M = trainingData.getNumSamples();
151  const unsigned int N = trainingData.getNumDimensions();
152  const unsigned int K = trainingData.getNumClasses();
153 
154  if( M == 0 ){
155  Classifier::errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl;
156  return false;
157  }
158 
159  numInputDimensions = N;
160  numClasses = K;
161  classLabels = trainingData.getClassLabels();
162  ranges = trainingData.getRanges();
163 
164  //Get the validation set if needed
165  ClassificationData validationData;
166  if( useValidationSet ){
167  validationData = trainingData.partition( validationSetSize );
168  validationSetAccuracy = 0;
169  validationSetPrecision.resize( useNullRejection ? K+1 : K, 0 );
170  validationSetRecall.resize( useNullRejection ? K+1 : K, 0 );
171  }
172 
173  //Scale the training data if needed
174  if( useScaling ){
175  //Scale the training data between 0 and 1
176  trainingData.scale(0, 1);
177  }
178 
179  //If we are using null rejection, then we need a copy of the training dataset for later
180  ClassificationData trainingDataCopy;
181  if( useNullRejection ){
182  trainingDataCopy = trainingData;
183  }
184 
185  //Setup the valid features - at this point all features can be used
186  Vector< UINT > features(N);
187  for(UINT i=0; i<N; i++){
188  features[i] = i;
189  }
190 
191  //Build the tree
192  UINT nodeID = 0;
193  tree = buildTree( trainingData, NULL, features, classLabels, nodeID );
194 
195  if( tree == NULL ){
196  clear();
197  Classifier::errorLog << "train_(ClassificationData &trainingData) - Failed to build tree!" << std::endl;
198  return false;
199  }
200 
201  //Flag that the algorithm has been trained
202  trained = true;
203 
204  //Compute the null rejection thresholds if null rejection is enabled
205  if( useNullRejection ){
206  VectorFloat classLikelihoods( numClasses );
207  Vector< UINT > predictions(M);
208  VectorFloat distances(M);
209  VectorFloat classCounter( numClasses, 0 );
210 
211  //Run over the training dataset and compute the distance between each training sample and the predicted node cluster
212  for(UINT i=0; i<M; i++){
213  //Run the prediction for this sample
214  if( !tree->predict( trainingDataCopy[i].getSample(), classLikelihoods ) ){
215  Classifier::errorLog << "predict_(VectorFloat &inputVector) - Failed to predict!" << std::endl;
216  return false;
217  }
218 
219  //Store the predicted class index and cluster distance
220  predictions[i] = Util::getMaxIndex( classLikelihoods );
221  distances[i] = getNodeDistance(trainingDataCopy[i].getSample(), tree->getPredictedNodeID() );
222 
223  classCounter[ predictions[i] ]++;
224  }
225 
226  //Compute the average distance for each class between the training data and the node clusters
227  classClusterMean.clear();
228  classClusterStdDev.clear();
229  classClusterMean.resize( numClasses, 0 );
230  classClusterStdDev.resize( numClasses, 0.01 ); //we start the std dev with a small value to ensure it is not zero
231 
232  for(UINT i=0; i<M; i++){
233  classClusterMean[ predictions[i] ] += distances[ i ];
234  }
235  for(UINT k=0; k<numClasses; k++){
236  classClusterMean[k] /= MAX( classCounter[k], 1 );
237  }
238 
239  //Compute the std deviation
240  for(UINT i=0; i<M; i++){
241  classClusterStdDev[ predictions[i] ] += MLBase::SQR( distances[ i ] - classClusterMean[ predictions[i] ] );
242  }
243  for(UINT k=0; k<numClasses; k++){
244  classClusterStdDev[k] = sqrt( classClusterStdDev[k] / MAX( classCounter[k], 1 ) );
245  }
246 
247  //Compute the null rejection thresholds using the class mean and std dev
249  }
250 
251  if( useValidationSet ){
252  const UINT numTestSamples = validationData.getNumSamples();
253  double numCorrect = 0;
254  UINT testLabel = 0;
255  VectorDouble testSample;
256  VectorDouble validationSetPrecisionCounter( validationSetPrecision.size(), 0.0 );
257  VectorDouble validationSetRecallCounter( validationSetRecall.size(), 0.0 );
258  Classifier::trainingLog << "Testing model with validation set..." << std::endl;
259  for(UINT i=0; i<numTestSamples; i++){
260  testLabel = validationData[i].getClassLabel();
261  testSample = validationData[i].getSample();
262  predict_( testSample );
263  if( predictedClassLabel == testLabel ){
264  numCorrect++;
265  validationSetPrecision[ getClassLabelIndexValue( testLabel ) ]++;
266  validationSetRecall[ getClassLabelIndexValue( testLabel ) ]++;
267  }
268  validationSetPrecisionCounter[ getClassLabelIndexValue( predictedClassLabel ) ]++;
269  validationSetRecallCounter[ getClassLabelIndexValue( testLabel ) ]++;
270  }
271 
272  validationSetAccuracy = (numCorrect / numTestSamples) * 100.0;
273  for(size_t i=0; i<validationSetPrecision.size(); i++){
274  validationSetPrecision[i] /= validationSetPrecisionCounter[i] > 0 ? validationSetPrecisionCounter[i] : 1;
275  }
276  for(size_t i=0; i<validationSetRecall.size(); i++){
277  validationSetRecall[i] /= validationSetRecallCounter[i] > 0 ? validationSetRecallCounter[i] : 1;
278  }
279 
280  Classifier::trainingLog << "Validation set accuracy: " << validationSetAccuracy << std::endl;
281 
282  Classifier::trainingLog << "Validation set precision: ";
283  for(size_t i=0; i<validationSetPrecision.size(); i++){
284  Classifier::trainingLog << validationSetPrecision[i] << " ";
285  }
286  Classifier::trainingLog << std::endl;
287 
288  Classifier::trainingLog << "Validation set recall: ";
289  for(size_t i=0; i<validationSetRecall.size(); i++){
290  Classifier::trainingLog << validationSetRecall[i] << " ";
291  }
292  Classifier::trainingLog << std::endl;
293  }
294 
295  return true;
296 }
297 
299 
300  predictedClassLabel = 0;
301  maxLikelihood = 0;
302 
303  //Validate the input is OK and the model is trained properly
304  if( !trained ){
305  Classifier::errorLog << "predict_(VectorFloat &inputVector) - Model Not Trained!" << std::endl;
306  return false;
307  }
308 
309  if( tree == NULL ){
310  Classifier::errorLog << "predict_(VectorFloat &inputVector) - DecisionTree pointer is null!" << std::endl;
311  return false;
312  }
313 
314  if( inputVector.getSize() != numInputDimensions ){
315  Classifier::errorLog << "predict_(VectorFloat &inputVector) - The size of the input Vector (" << inputVector.getSize() << ") does not match the num features in the model (" << numInputDimensions << std::endl;
316  return false;
317  }
318 
319  //Scale the input data if needed
320  if( useScaling ){
321  for(UINT n=0; n<numInputDimensions; n++){
322  inputVector[n] = grt_scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue, 0.0, 1.0);
323  }
324  }
325 
326  if( classLikelihoods.size() != numClasses ) classLikelihoods.resize(numClasses,0);
327  if( classDistances.size() != numClasses ) classDistances.resize(numClasses,0);
328 
329  //Run the decision tree prediction
330  if( !tree->predict( inputVector, classLikelihoods ) ){
331  Classifier::errorLog << "predict_(VectorFloat &inputVector) - Failed to predict!" << std::endl;
332  return false;
333  }
334 
335  //Find the maximum likelihood
336  //The tree automatically returns proper class likelihoods so we don't need to do anything else
337  UINT maxIndex = 0;
338  maxLikelihood = 0;
339  for(UINT k=0; k<numClasses; k++){
340  if( classLikelihoods[k] > maxLikelihood ){
341  maxLikelihood = classLikelihoods[k];
342  maxIndex = k;
343  }
344  }
345 
346  //Run the null rejection
347  if( useNullRejection ){
348 
349  //Get the distance between the input and the leaf mean
350  Float leafDistance = getNodeDistance( inputVector, tree->getPredictedNodeID() );
351 
352  if( grt_isnan(leafDistance) ){
353  Classifier::errorLog << "predict_(VectorFloat &inputVector) - Failed to match leaf node ID to compute node distance!" << std::endl;
354  return false;
355  }
356 
357  //Set the predicted class distance as the leaf distance, all other classes will have a distance of zero
358  std::fill(classDistances.begin(),classDistances.end(),0);
359  classDistances[ maxIndex ] = leafDistance;
360 
361  //Use the distance to check if the class label should be rejected or not
362  if( leafDistance <= nullRejectionThresholds[ maxIndex ] ){
363  predictedClassLabel = classLabels[ maxIndex ];
364  }else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
365 
366  }else {
367  //Set the predicated class label
368  predictedClassLabel = classLabels[ maxIndex ];
369  }
370 
371  return true;
372 }
373 
375 
376  //Clear the Classifier variables
378 
379  //Clear the node clusters
380  nodeClusters.clear();
381 
382  //Delete the tree if it exists
383  if( tree != NULL ){
384  tree->clear();
385  delete tree;
386  tree = NULL;
387  }
388 
389  //NOTE: We do not want to clean up the decisionTreeNode here as we need to keep track of this, this is only delete in the destructor
390 
391  return true;
392 }
393 
395 
396  if( !trained ){
397  Classifier::warningLog << "recomputeNullRejectionThresholds() - Failed to recompute null rejection thresholds, the model has not been trained!" << std::endl;
398  return false;
399  }
400 
401  if( !useNullRejection ){
402  Classifier::warningLog << "recomputeNullRejectionThresholds() - Failed to recompute null rejection thresholds, null rejection is not enabled!" << std::endl;
403  return false;
404  }
405 
406  nullRejectionThresholds.resize( numClasses );
407 
408  //Compute the rejection threshold for each class using the mean and std dev
409  for(UINT k=0; k<numClasses; k++){
410  nullRejectionThresholds[k] = classClusterMean[k] + (classClusterStdDev[k]*nullRejectionCoeff);
411  }
412 
413  return true;
414 }
415 
416 bool DecisionTree::saveModelToFile( std::fstream &file ) const{
417 
418  if(!file.is_open())
419  {
420  Classifier::errorLog <<"saveModelToFile(fstream &file) - The file is not open!" << std::endl;
421  return false;
422  }
423 
424  //Write the header info
425  file << "GRT_DECISION_TREE_MODEL_FILE_V4.0\n";
426 
427  //Write the classifier settings to the file
429  Classifier::errorLog <<"saveModelToFile(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
430  return false;
431  }
432 
433  if( decisionTreeNode != NULL ){
434  file << "DecisionTreeNodeType: " << decisionTreeNode->getNodeType() << std::endl;
435  if( !decisionTreeNode->saveToFile( file ) ){
436  Classifier::errorLog <<"saveModelToFile(fstream &file) - Failed to save decisionTreeNode settings to file!" << std::endl;
437  return false;
438  }
439  }else{
440  file << "DecisionTreeNodeType: " << "NULL" << std::endl;
441  }
442 
443  file << "MinNumSamplesPerNode: " << minNumSamplesPerNode << std::endl;
444  file << "MaxDepth: " << maxDepth << std::endl;
445  file << "RemoveFeaturesAtEachSpilt: " << removeFeaturesAtEachSpilt << std::endl;
446  file << "TrainingMode: " << trainingMode << std::endl;
447  file << "NumSplittingSteps: " << numSplittingSteps << std::endl;
448  file << "TreeBuilt: " << (tree != NULL ? 1 : 0) << std::endl;
449 
450  if( tree != NULL ){
451  file << "Tree:\n";
452  if( !tree->saveToFile( file ) ){
453  Classifier::errorLog << "saveModelToFile(fstream &file) - Failed to save tree to file!" << std::endl;
454  return false;
455  }
456 
457  //Save the null rejection data if needed
458  if( useNullRejection ){
459 
460  file << "ClassClusterMean:";
461  for(UINT k=0; k<numClasses; k++){
462  file << " " << classClusterMean[k];
463  }
464  file << std::endl;
465 
466  file << "ClassClusterStdDev:";
467  for(UINT k=0; k<numClasses; k++){
468  file << " " << classClusterStdDev[k];
469  }
470  file << std::endl;
471 
472  file << "NumNodes: " << nodeClusters.size() << std::endl;
473  file << "NodeClusters:\n";
474 
475  std::map< UINT, VectorFloat >::const_iterator iter = nodeClusters.begin();
476 
477  while( iter != nodeClusters.end() ){
478 
479  //Write the nodeID
480  file << iter->first;
481 
482  //Write the node cluster
483  for(UINT j=0; j<numInputDimensions; j++){
484  file << " " << iter->second[j];
485  }
486  file << std::endl;
487 
488  iter++;
489  }
490  }
491 
492  }
493 
494  return true;
495 }
496 
497 bool DecisionTree::loadModelFromFile( std::fstream &file ){
498 
499  clear();
500 
501  if( decisionTreeNode != NULL ){
502  delete decisionTreeNode;
503  decisionTreeNode = NULL;
504  }
505 
506  if( !file.is_open() )
507  {
508  Classifier::errorLog << "loadModelFromFile(string filename) - Could not open file to load model" << std::endl;
509  return false;
510  }
511 
512  std::string word;
513  file >> word;
514 
515  //Check to see if we should load a legacy file
516  if( word == "GRT_DECISION_TREE_MODEL_FILE_V1.0" ){
517  return loadLegacyModelFromFile_v1( file );
518  }
519 
520  if( word == "GRT_DECISION_TREE_MODEL_FILE_V2.0" ){
521  return loadLegacyModelFromFile_v2( file );
522  }
523 
524  if( word == "GRT_DECISION_TREE_MODEL_FILE_V3.0" ){
525  return loadLegacyModelFromFile_v3( file );
526  }
527 
528  //Find the file type header
529  if( word != "GRT_DECISION_TREE_MODEL_FILE_V4.0" ){
530  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find Model File Header" << std::endl;
531  return false;
532  }
533 
534  //Load the base settings from the file
536  Classifier::errorLog << "loadModelFromFile(string filename) - Failed to load base settings from file!" << std::endl;
537  return false;
538  }
539 
540  file >> word;
541  if(word != "DecisionTreeNodeType:"){
542  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the DecisionTreeNodeType!" << std::endl;
543  return false;
544  }
545  file >> word;
546 
547  if( word != "NULL" ){
548 
549  decisionTreeNode = dynamic_cast< DecisionTreeNode* >( DecisionTreeNode::createInstanceFromString( word ) );
550 
551  if( decisionTreeNode == NULL ){
552  Classifier::errorLog << "loadModelFromFile(string filename) - Could not create new DecisionTreeNode from type: " << word << std::endl;
553  return false;
554  }
555 
556  if( !decisionTreeNode->loadFromFile( file ) ){
557  Classifier::errorLog <<"loadModelFromFile(fstream &file) - Failed to load decisionTreeNode settings from file!" << std::endl;
558  return false;
559  }
560  }else{
561  Classifier::errorLog <<"loadModelFromFile(fstream &file) - Failed to load decisionTreeNode! DecisionTreeNodeType is NULL!" << std::endl;
562  return false;
563  }
564 
565  file >> word;
566  if(word != "MinNumSamplesPerNode:"){
567  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
568  return false;
569  }
570  file >> minNumSamplesPerNode;
571 
572  file >> word;
573  if(word != "MaxDepth:"){
574  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MaxDepth!" << std::endl;
575  return false;
576  }
577  file >> maxDepth;
578 
579  file >> word;
580  if(word != "RemoveFeaturesAtEachSpilt:"){
581  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
582  return false;
583  }
584  file >> removeFeaturesAtEachSpilt;
585 
586  file >> word;
587  if(word != "TrainingMode:"){
588  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TrainingMode!" << std::endl;
589  return false;
590  }
591  file >> trainingMode;
592 
593  file >> word;
594  if(word != "NumSplittingSteps:"){
595  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumSplittingSteps!" << std::endl;
596  return false;
597  }
598  file >> numSplittingSteps;
599 
600  file >> word;
601  if(word != "TreeBuilt:"){
602  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TreeBuilt!" << std::endl;
603  return false;
604  }
605  file >> trained;
606 
607  if( trained ){
608  file >> word;
609  if(word != "Tree:"){
610  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Tree!" << std::endl;
611  return false;
612  }
613 
614  //Create a new DTree
615  tree = dynamic_cast< DecisionTreeNode* >( decisionTreeNode->createNewInstance() );
616 
617  if( tree == NULL ){
618  clear();
619  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
620  return false;
621  }
622 
623  tree->setParent( NULL );
624  if( !tree->loadFromFile( file ) ){
625  clear();
626  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to load tree from file!" << std::endl;
627  return false;
628  }
629 
630  //Load the null rejection data if needed
631  if( useNullRejection ){
632 
633  UINT numNodes = 0;
634  classClusterMean.resize( numClasses );
635  classClusterStdDev.resize( numClasses );
636 
637  file >> word;
638  if(word != "ClassClusterMean:"){
639  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the ClassClusterMean header!" << std::endl;
640  return false;
641  }
642  for(UINT k=0; k<numClasses; k++){
643  file >> classClusterMean[k];
644  }
645 
646  file >> word;
647  if(word != "ClassClusterStdDev:"){
648  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the ClassClusterStdDev header!" << std::endl;
649  return false;
650  }
651  for(UINT k=0; k<numClasses; k++){
652  file >> classClusterStdDev[k];
653  }
654 
655  file >> word;
656  if(word != "NumNodes:"){
657  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumNodes header!" << std::endl;
658  return false;
659  }
660  file >> numNodes;
661 
662  file >> word;
663  if(word != "NodeClusters:"){
664  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NodeClusters header!" << std::endl;
665  return false;
666  }
667 
668  UINT nodeID = 0;
669  VectorFloat cluster( numInputDimensions );
670  for(UINT i=0; i<numNodes; i++){
671 
672  //load the nodeID
673  file >> nodeID;
674 
675  for(UINT j=0; j<numInputDimensions; j++){
676  file >> cluster[j];
677  }
678 
679  //Add the cluster to the cluster nodes map
680  nodeClusters[ nodeID ] = cluster;
681  }
682 
683  //Recompute the null rejection thresholds
685  }
686 
687  //Resize the prediction results to make sure it is setup for realtime prediction
688  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
689  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
690  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
691  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
692  }
693 
694  return true;
695 }
696 
697 bool DecisionTree::getModel( std::ostream &stream ) const{
698 
699  if( tree != NULL )
700  return tree->getModel( stream );
701  return false;
702 
703 }
704 
706 
707  if( tree == NULL ){
708  return NULL;
709  }
710 
711  return dynamic_cast< DecisionTreeNode* >( tree->deepCopyNode() );
712 }
713 
715 
716  if( decisionTreeNode == NULL ){
717  return NULL;
718  }
719 
720  return decisionTreeNode->deepCopy();
721 }
722 
724  return dynamic_cast< DecisionTreeNode* >( tree );
725 }
726 
728 
729  if( decisionTreeNode != NULL ){
730  delete decisionTreeNode;
731  decisionTreeNode = NULL;
732  }
733  this->decisionTreeNode = node.deepCopy();
734 
735  return true;
736 }
737 
738 DecisionTreeNode* DecisionTree::buildTree(ClassificationData &trainingData,DecisionTreeNode *parent,Vector< UINT > features,const Vector< UINT > &classLabels, UINT nodeID){
739 
740  const UINT M = trainingData.getNumSamples();
741  const UINT N = trainingData.getNumDimensions();
742 
743  //Update the nodeID
744  nodeID++;
745 
746  //Get the depth
747  UINT depth = 0;
748 
749  if( parent != NULL )
750  depth = parent->getDepth() + 1;
751 
752  //If there are no training data then return NULL
753  if( trainingData.getNumSamples() == 0 )
754  return NULL;
755 
756  //Create the new node
757  DecisionTreeNode *node = dynamic_cast< DecisionTreeNode* >( decisionTreeNode->createNewInstance() );
758 
759  if( node == NULL )
760  return NULL;
761 
762  //Get the class probabilities
763  VectorFloat classProbs = trainingData.getClassProbabilities( classLabels );
764 
765  //Set the parent
766  node->initNode( parent, depth, nodeID );
767 
768  //If all the training data belongs to the same class or there are no features left then create a leaf node and return
769  if( trainingData.getNumClasses() == 1 || features.size() == 0 || M < minNumSamplesPerNode || depth >= maxDepth ){
770 
771  //Set the node
772  node->setLeafNode( trainingData.getNumSamples(), classProbs );
773 
774  //Build the null cluster if null rejection is enabled
775  if( useNullRejection ){
776  nodeClusters[ nodeID ] = trainingData.getMean();
777  }
778 
779  std::string info = "Reached leaf node.";
780  if( trainingData.getNumClasses() == 1 ) info = "Reached pure leaf node.";
781  else if( features.size() == 0 ) info = "Reached leaf node, no remaining features.";
782  else if( M < minNumSamplesPerNode ) info = "Reached leaf node, hit min-samples-per-node limit.";
783  else if( depth >= maxDepth ) info = "Reached leaf node, max depth reached.";
784 
785  Classifier::trainingLog << info << " Depth: " << depth << " NumSamples: " << trainingData.getNumSamples();
786 
787  Classifier::trainingLog << " Class Probabilities: ";
788  for(UINT k=0; k<classProbs.getSize(); k++){
789  Classifier::trainingLog << classProbs[k] << " ";
790  }
791  Classifier::trainingLog << std::endl;
792 
793  return node;
794  }
795 
796  //Compute the best spilt point
797  UINT featureIndex = 0;
798  Float minError = 0;
799 
800  if( !node->computeBestSpilt( trainingMode, numSplittingSteps, trainingData, features, classLabels, featureIndex, minError ) ){
801  delete node;
802  return NULL;
803  }
804 
805  Classifier::trainingLog << "Depth: " << depth << " FeatureIndex: " << featureIndex << " MinError: " << minError;
806  Classifier::trainingLog << " Class Probabilities: ";
807  for(size_t k=0; k<classProbs.size(); k++){
808  Classifier::trainingLog << classProbs[k] << " ";
809  }
810  Classifier::trainingLog << std::endl;
811 
812  //Remove the selected feature so we will not use it again
813  if( removeFeaturesAtEachSpilt ){
814  for(size_t i=0; i<features.size(); i++){
815  if( features[i] == featureIndex ){
816  features.erase( features.begin()+i );
817  break;
818  }
819  }
820  }
821 
822  //Split the data into a left and right dataset
823  ClassificationData lhs(N);
824  ClassificationData rhs(N);
825 
826  //Reserve the memory to speed up the allocation of large datasets
827  lhs.reserve( M );
828  rhs.reserve( M );
829 
830  for(UINT i=0; i<M; i++){
831  if( node->predict( trainingData[i].getSample() ) ){
832  rhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
833  }else lhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
834  }
835 
836  //Clear the parent dataset so we do not run out of memory with very large datasets (with very deep trees)
837  trainingData.clear();
838 
839  //Get the new node IDs for the children
840  UINT leftNodeID = ++nodeID;
841  UINT rightNodeID = ++nodeID;
842 
843  //Run the recursive tree building on the children
844  node->setLeftChild( buildTree( lhs, node, features, classLabels, leftNodeID ) );
845  node->setRightChild( buildTree( rhs, node, features, classLabels, rightNodeID ) );
846 
847  //Build the null clusters for the rhs and lhs nodes if null rejection is enabled
848  if( useNullRejection ){
849  nodeClusters[ leftNodeID ] = lhs.getMean();
850  nodeClusters[ rightNodeID ] = rhs.getMean();
851  }
852 
853  return node;
854 }
855 
856 Float DecisionTree::getNodeDistance( const VectorFloat &x, const UINT nodeID ){
857 
858  //Use the node ID to find the node cluster
859  std::map< UINT,VectorFloat >::iterator iter = nodeClusters.find( nodeID );
860 
861  //If we failed to find a match, return NAN
862  if( iter == nodeClusters.end() ) return NAN;
863 
864  //Compute the distance between the input and the node cluster
865  return getNodeDistance( x, iter->second );
866 }
867 
868 Float DecisionTree::getNodeDistance( const VectorFloat &x, const VectorFloat &y ){
869 
870  Float distance = 0;
871  const size_t N = x.size();
872 
873  for(size_t i=0; i<N; i++){
874  distance += MLBase::SQR( x[i] - y[i] );
875  }
876 
877  //Return the squared Euclidean distance instead of actual Euclidean distance as this is faster and just as useful
878  return distance;
879 }
880 
881 bool DecisionTree::loadLegacyModelFromFile_v1( std::fstream &file ){
882 
883  std::string word;
884 
885  file >> word;
886  if(word != "NumFeatures:"){
887  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find NumFeatures!" << std::endl;
888  return false;
889  }
890  file >> numInputDimensions;
891 
892  file >> word;
893  if(word != "NumClasses:"){
894  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find NumClasses!" << std::endl;
895  return false;
896  }
897  file >> numClasses;
898 
899  file >> word;
900  if(word != "UseScaling:"){
901  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find UseScaling!" << std::endl;
902  return false;
903  }
904  file >> useScaling;
905 
906  file >> word;
907  if(word != "UseNullRejection:"){
908  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find UseNullRejection!" << std::endl;
909  return false;
910  }
911  file >> useNullRejection;
912 
914  if( useScaling ){
915  //Resize the ranges buffer
916  ranges.resize( numInputDimensions );
917 
918  file >> word;
919  if(word != "Ranges:"){
920  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Ranges!" << std::endl;
921  return false;
922  }
923  for(UINT n=0; n<ranges.size(); n++){
924  file >> ranges[n].minValue;
925  file >> ranges[n].maxValue;
926  }
927  }
928 
929  file >> word;
930  if(word != "NumSplittingSteps:"){
931  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumSplittingSteps!" << std::endl;
932  return false;
933  }
934  file >> numSplittingSteps;
935 
936  file >> word;
937  if(word != "MinNumSamplesPerNode:"){
938  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
939  return false;
940  }
941  file >> minNumSamplesPerNode;
942 
943  file >> word;
944  if(word != "MaxDepth:"){
945  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MaxDepth!" << std::endl;
946  return false;
947  }
948  file >> maxDepth;
949 
950  file >> word;
951  if(word != "RemoveFeaturesAtEachSpilt:"){
952  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
953  return false;
954  }
955  file >> removeFeaturesAtEachSpilt;
956 
957  file >> word;
958  if(word != "TrainingMode:"){
959  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TrainingMode!" << std::endl;
960  return false;
961  }
962  file >> trainingMode;
963 
964  file >> word;
965  if(word != "TreeBuilt:"){
966  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TreeBuilt!" << std::endl;
967  return false;
968  }
969  file >> trained;
970 
971  if( trained ){
972  file >> word;
973  if(word != "Tree:"){
974  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Tree!" << std::endl;
975  return false;
976  }
977 
978  //Create a new DTree
979  tree = new DecisionTreeNode;
980 
981  if( tree == NULL ){
982  clear();
983  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
984  return false;
985  }
986 
987  tree->setParent( NULL );
988  if( !tree->loadFromFile( file ) ){
989  clear();
990  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to load tree from file!" << std::endl;
991  return false;
992  }
993  }
994 
995  return true;
996 }
997 
998 bool DecisionTree::loadLegacyModelFromFile_v2( std::fstream &file ){
999 
1000  std::string word;
1001 
1002  //Load the base settings from the file
1004  Classifier::errorLog << "loadModelFromFile(string filename) - Failed to load base settings from file!" << std::endl;
1005  return false;
1006  }
1007 
1008  file >> word;
1009  if(word != "NumSplittingSteps:"){
1010  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1011  return false;
1012  }
1013  file >> numSplittingSteps;
1014 
1015  file >> word;
1016  if(word != "MinNumSamplesPerNode:"){
1017  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1018  return false;
1019  }
1020  file >> minNumSamplesPerNode;
1021 
1022  file >> word;
1023  if(word != "MaxDepth:"){
1024  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MaxDepth!" << std::endl;
1025  return false;
1026  }
1027  file >> maxDepth;
1028 
1029  file >> word;
1030  if(word != "RemoveFeaturesAtEachSpilt:"){
1031  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1032  return false;
1033  }
1034  file >> removeFeaturesAtEachSpilt;
1035 
1036  file >> word;
1037  if(word != "TrainingMode:"){
1038  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TrainingMode!" << std::endl;
1039  return false;
1040  }
1041  file >> trainingMode;
1042 
1043  file >> word;
1044  if(word != "TreeBuilt:"){
1045  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TreeBuilt!" << std::endl;
1046  return false;
1047  }
1048  file >> trained;
1049 
1050  if( trained ){
1051  file >> word;
1052  if(word != "Tree:"){
1053  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Tree!" << std::endl;
1054  return false;
1055  }
1056 
1057  //Create a new DTree
1058  tree = new DecisionTreeNode;
1059 
1060  if( tree == NULL ){
1061  clear();
1062  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1063  return false;
1064  }
1065 
1066  tree->setParent( NULL );
1067  if( !tree->loadFromFile( file ) ){
1068  clear();
1069  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to load tree from file!" << std::endl;
1070  return false;
1071  }
1072 
1073  //Recompute the null rejection thresholds
1075 
1076  //Resize the prediction results to make sure it is setup for realtime prediction
1077  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
1078  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1079  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
1080  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1081  }
1082 
1083  return true;
1084 }
1085 
1086 bool DecisionTree::loadLegacyModelFromFile_v3( std::fstream &file ){
1087 
1088  std::string word;
1089 
1090  //Load the base settings from the file
1092  Classifier::errorLog << "loadModelFromFile(string filename) - Failed to load base settings from file!" << std::endl;
1093  return false;
1094  }
1095 
1096  file >> word;
1097  if(word != "NumSplittingSteps:"){
1098  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1099  return false;
1100  }
1101  file >> numSplittingSteps;
1102 
1103  file >> word;
1104  if(word != "MinNumSamplesPerNode:"){
1105  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1106  return false;
1107  }
1108  file >> minNumSamplesPerNode;
1109 
1110  file >> word;
1111  if(word != "MaxDepth:"){
1112  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MaxDepth!" << std::endl;
1113  return false;
1114  }
1115  file >> maxDepth;
1116 
1117  file >> word;
1118  if(word != "RemoveFeaturesAtEachSpilt:"){
1119  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1120  return false;
1121  }
1122  file >> removeFeaturesAtEachSpilt;
1123 
1124  file >> word;
1125  if(word != "TrainingMode:"){
1126  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TrainingMode!" << std::endl;
1127  return false;
1128  }
1129  file >> trainingMode;
1130 
1131  file >> word;
1132  if(word != "TreeBuilt:"){
1133  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TreeBuilt!" << std::endl;
1134  return false;
1135  }
1136  file >> trained;
1137 
1138  if( trained ){
1139  file >> word;
1140  if(word != "Tree:"){
1141  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Tree!" << std::endl;
1142  return false;
1143  }
1144 
1145  //Create a new DTree
1146  tree = new DecisionTreeNode;
1147 
1148  if( tree == NULL ){
1149  clear();
1150  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1151  return false;
1152  }
1153 
1154  tree->setParent( NULL );
1155  if( !tree->loadFromFile( file ) ){
1156  clear();
1157  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to load tree from file!" << std::endl;
1158  return false;
1159  }
1160 
1161  //Load the null rejection data if needed
1162  if( useNullRejection ){
1163 
1164  UINT numNodes = 0;
1165  classClusterMean.resize( numClasses );
1166  classClusterStdDev.resize( numClasses );
1167 
1168  file >> word;
1169  if(word != "ClassClusterMean:"){
1170  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the ClassClusterMean header!" << std::endl;
1171  return false;
1172  }
1173  for(UINT k=0; k<numClasses; k++){
1174  file >> classClusterMean[k];
1175  }
1176 
1177  file >> word;
1178  if(word != "ClassClusterStdDev:"){
1179  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the ClassClusterStdDev header!" << std::endl;
1180  return false;
1181  }
1182  for(UINT k=0; k<numClasses; k++){
1183  file >> classClusterStdDev[k];
1184  }
1185 
1186  file >> word;
1187  if(word != "NumNodes:"){
1188  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumNodes header!" << std::endl;
1189  return false;
1190  }
1191  file >> numNodes;
1192 
1193  file >> word;
1194  if(word != "NodeClusters:"){
1195  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NodeClusters header!" << std::endl;
1196  return false;
1197  }
1198 
1199  UINT nodeID = 0;
1200  VectorFloat cluster( numInputDimensions );
1201  for(UINT i=0; i<numNodes; i++){
1202 
1203  //load the nodeID
1204  file >> nodeID;
1205 
1206  for(UINT j=0; j<numInputDimensions; j++){
1207  file >> cluster[j];
1208  }
1209 
1210  //Add the cluster to the cluster nodes map
1211  nodeClusters[ nodeID ] = cluster;
1212  }
1213 
1214  //Recompute the null rejection thresholds
1216  }
1217 
1218  //Resize the prediction results to make sure it is setup for realtime prediction
1219  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
1220  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1221  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
1222  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1223  }
1224 
1225  return true;
1226 }
1227 
1228 GRT_END_NAMESPACE
1229 
virtual bool saveModelToFile(std::fstream &file) const
bool saveBaseSettingsToFile(std::fstream &file) const
Definition: Classifier.cpp:255
This class implements a basic Decision Tree classifier. Decision Trees are conceptually simple classi...
#define DEFAULT_NULL_LIKELIHOOD_VALUE
Definition: Classifier.h:38
virtual bool computeBestSpilt(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError)
std::string getClassifierType() const
Definition: Classifier.cpp:160
const DecisionTreeNode * getTree() const
DecisionTreeNode * deepCopyDecisionTreeNode() const
bool loadLegacyModelFromFile_v1(std::fstream &file)
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool train_(ClassificationData &trainingData)
std::string getNodeType() const
Definition: Node.cpp:303
bool getTrained() const
Definition: MLBase.cpp:254
virtual bool getModel(std::ostream &stream) const
Definition: Node.cpp:119
virtual bool predict_(VectorFloat &inputVector)
UINT getDepth() const
Definition: Node.cpp:307
UINT getClassLabelIndexValue(UINT classLabel) const
Definition: Classifier.cpp:194
Vector< UINT > getClassLabels() const
DecisionTree(const DecisionTreeNode &decisionTreeNode=DecisionTreeClusterNode(), const UINT minNumSamplesPerNode=5, const UINT maxDepth=10, const bool removeFeaturesAtEachSpilt=false, const UINT trainingMode=BEST_ITERATIVE_SPILT, const UINT numSplittingSteps=100, const bool useScaling=false)
virtual bool deepCopyFrom(const Classifier *classifier)
unsigned int getSize() const
Definition: Vector.h:193
DecisionTreeNode * deepCopyTree() const
virtual bool recomputeNullRejectionThresholds()
UINT getNumSamples() const
virtual Node * deepCopyNode() const
Definition: Node.cpp:275
virtual bool saveToFile(std::fstream &file) const
Definition: Node.cpp:139
virtual bool clear()
virtual bool loadModelFromFile(std::fstream &file)
virtual bool loadFromFile(std::fstream &file)
Definition: Node.cpp:181
bool copyBaseVariables(const Classifier *classifier)
Definition: Classifier.cpp:92
bool loadBaseSettingsFromFile(std::fstream &file)
Definition: Classifier.cpp:302
ClassificationData partition(const UINT partitionPercentage, const bool useStratifiedSampling=false)
virtual bool clear()
Definition: Node.cpp:69
UINT getNumDimensions() const
UINT getNumClasses() const
Node * createNewInstance() const
Definition: Node.cpp:38
UINT getPredictedNodeID() const
Definition: Node.cpp:315
DecisionTree & operator=(const DecisionTree &rhs)
virtual bool predict(const VectorFloat &x, VectorFloat &classLikelihoods)
bool setLeafNode(const UINT nodeSize, const VectorFloat &classProbabilities)
DecisionTreeNode * deepCopy() const
Vector< MinMax > getRanges() const
static Node * createInstanceFromString(std::string const &nodeType)
Definition: Node.cpp:28
static unsigned int getMaxIndex(const VectorFloat &x)
Definition: Util.cpp:291
bool setDecisionTreeNode(const DecisionTreeNode &node)
virtual ~DecisionTree(void)
bool scale(const Float minTarget, const Float maxTarget)
virtual bool getModel(std::ostream &stream) const
virtual bool clear()
Definition: Classifier.cpp:141
virtual bool predict(const VectorFloat &x)
Definition: Node.cpp:59
VectorFloat getMean() const