GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTree.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #define GRT_DLL_EXPORTS
22 #include "DecisionTree.h"
23 
24 GRT_BEGIN_NAMESPACE
25 
26 //Define the string that will be used to identify the object
27 const std::string DecisionTree::id = "DecisionTree";
28 std::string DecisionTree::getId() { return DecisionTree::id; }
29 
30 //Register the DecisionTree module with the Classifier base class
31 RegisterClassifierModule< DecisionTree > DecisionTree::registerModule( DecisionTree::getId() );
32 
33 DecisionTree::DecisionTree(const DecisionTreeNode &decisionTreeNode,const UINT minNumSamplesPerNode,const UINT maxDepth,const bool removeFeaturesAtEachSplit,const Tree::TrainingMode trainingMode,const UINT numSplittingSteps,const bool useScaling) : Classifier( DecisionTree::getId() )
34 {
35  this->tree = NULL;
36  this->decisionTreeNode = dynamic_cast< DecisionTreeNode* >( decisionTreeNode.deepCopy() );
37  this->minNumSamplesPerNode = minNumSamplesPerNode;
38  this->maxDepth = maxDepth;
39  this->removeFeaturesAtEachSplit = removeFeaturesAtEachSplit;
40  this->trainingMode = trainingMode;
41  this->numSplittingSteps = numSplittingSteps;
42  this->useScaling = useScaling;
43  this->supportsNullRejection = true;
44  this->numTrainingIterationsToConverge = 20; //Retrain the model 20 times and pick the best one
45  classifierMode = STANDARD_CLASSIFIER_MODE;
46 }
47 
49 {
50  tree = NULL;
51  decisionTreeNode = NULL;
52  classifierMode = STANDARD_CLASSIFIER_MODE;
53  *this = rhs;
54 }
55 
57 {
58  clear();
59 
60  if( decisionTreeNode != NULL ){
61  delete decisionTreeNode;
62  decisionTreeNode = NULL;
63  }
64 }
65 
67  if( this != &rhs ){
68  //Clear this tree
69  clear();
70 
71  if( rhs.getTrained() ){
72  //Deep copy the tree
73  this->tree = (DecisionTreeNode*)rhs.deepCopyTree();
74  }
75 
76  //Deep copy the main node
77  if( this->decisionTreeNode != NULL ){
78  delete decisionTreeNode;
79  decisionTreeNode = NULL;
80  }
81  this->decisionTreeNode = rhs.deepCopyDecisionTreeNode();
82 
83  this->minNumSamplesPerNode = rhs.minNumSamplesPerNode;
84  this->maxDepth = rhs.maxDepth;
85  this->removeFeaturesAtEachSplit = rhs.removeFeaturesAtEachSplit;
86  this->trainingMode = rhs.trainingMode;
87  this->numSplittingSteps = rhs.numSplittingSteps;
88  this->nodeClusters = rhs.nodeClusters;
89 
90  //Copy the base classifier variables
91  copyBaseVariables( (Classifier*)&rhs );
92  }
93  return *this;
94 }
95 
96 bool DecisionTree::deepCopyFrom(const Classifier *classifier){
97 
98  if( classifier == NULL ) return false;
99 
100  if( this->getClassifierType() == classifier->getClassifierType() ){
101 
102  const DecisionTree *ptr = dynamic_cast<const DecisionTree*>( classifier );
103 
104  //Clear this tree
105  this->clear();
106 
107  if( ptr->getTrained() ){
108  //Deep copy the tree
109  this->tree = ptr->deepCopyTree();
110  }
111 
112  //Deep copy the main node
113  if( this->decisionTreeNode != NULL ){
114  delete decisionTreeNode;
115  decisionTreeNode = NULL;
116  }
117  this->decisionTreeNode = ptr->deepCopyDecisionTreeNode();
118 
119  this->minNumSamplesPerNode = ptr->minNumSamplesPerNode;
120  this->maxDepth = ptr->maxDepth;
121  this->removeFeaturesAtEachSplit = ptr->removeFeaturesAtEachSplit;
122  this->trainingMode = ptr->trainingMode;
123  this->numSplittingSteps = ptr->numSplittingSteps;
124  this->nodeClusters = ptr->nodeClusters;
125 
126  //Copy the base classifier variables
127  return copyBaseVariables( classifier );
128  }
129  return false;
130 }
131 
133 
134  //Clear any previous model
135  clear();
136 
137  if( decisionTreeNode == NULL ){
138  errorLog << __GRT_LOG__ << " The decision tree node has not been set! You must set this first before training a model." << std::endl;
139  return false;
140  }
141 
142  const unsigned int N = trainingData.getNumDimensions();
143  const unsigned int K = trainingData.getNumClasses();
144 
145  if( trainingData.getNumSamples() == 0 ){
146  errorLog << __GRT_LOG__ << " Training data has zero samples!" << std::endl;
147  return false;
148  }
149 
150  numInputDimensions = N;
151  numOutputDimensions = K;
152  numClasses = K;
153  classLabels = trainingData.getClassLabels();
154  ranges = trainingData.getRanges();
155 
156  //Get the validation set if needed
157  ClassificationData validationData;
158  if( useValidationSet ){
159  validationData = trainingData.split( validationSetSize );
160  validationSetAccuracy = 0;
161  validationSetPrecision.resize( useNullRejection ? K+1 : K, 0 );
162  validationSetRecall.resize( useNullRejection ? K+1 : K, 0 );
163  }
164 
165  const unsigned int M = trainingData.getNumSamples();
166 
167  //Scale the training data if needed
168  if( useScaling ){
169  //Scale the training data between 0 and 1
170  trainingData.scale(0, 1);
171  }
172 
173  //If we are using null rejection, then we need a copy of the training dataset for later
174  ClassificationData trainingDataCopy;
175  if( useNullRejection ){
176  trainingDataCopy = trainingData;
177  }
178 
179  //Setup the valid features - at this point all features can be used
180  Vector< UINT > features(N);
181  for(UINT i=0; i<N; i++){
182  features[i] = i;
183  }
184 
185  numTrainingIterationsToConverge = 1;
186  trainingLog << "numTrainingIterationsToConverge " << numTrainingIterationsToConverge << " useValidationSet: " << useValidationSet << std::endl;
187 
188  if( useValidationSet ){
189 
190  //If we get here, then we are going to train the tree several times and pick the best model
191  DecisionTreeNode *bestTree = NULL;
192  Float bestValidationSetAccuracy = 0;
193  UINT bestTreeIndex = 0;
194  for(UINT i=0; i<numTrainingIterationsToConverge; i++){
195 
196  trainingLog << "Training tree iteration: " << i+1 << "/" << numTrainingIterationsToConverge << std::endl;
197 
198  if( !trainTree( trainingData, trainingDataCopy, validationData, features ) ){
199  errorLog << __GRT_LOG__ << " Failed to build tree!" << std::endl;
200  //Delete the best tree if it exists
201  if( bestTree != NULL ){
202  delete bestTree;
203  bestTree = NULL;
204  }
205  return false;
206  }
207 
208  //Keep track of the best tree
209  if( bestTree == NULL ){
210  //Grab the tree pointer
211  bestTree = tree;
212  bestValidationSetAccuracy = validationSetAccuracy;
213  bestTreeIndex = i;
214  }else{
215  if( validationSetAccuracy > bestValidationSetAccuracy ){
216  //Grab the tree pointer
217  bestTree = tree;
218  bestValidationSetAccuracy = validationSetAccuracy;
219  bestTreeIndex = i;
220  }else{
221  //If we get here then the current tree is not the best so far, so free it
222  if( tree != NULL ){
223  delete tree;
224  tree = NULL;
225  }
226  }
227  }
228  }
229 
230  //Use the best tree
231  trainingLog << "Best tree index: " << bestTreeIndex+1 << " validation set accuracy: " << bestValidationSetAccuracy << std::endl;
232 
233  if( bestTree != tree ){
234  //Delete the tree if it exists
235  if( tree != NULL ){
236  delete tree;
237  tree = NULL;
238  }
239 
240  //Swap the pointers
241  tree = bestTree;
242  }
243 
244  }else{
245  //If we get here, then we are going to train the tree once
246  if( !trainTree( trainingData, trainingDataCopy, validationData, features ) ){
247  return false;
248  }
249  }
250 
251  //If we get this far, then a model has been trained
252  converged = true;
253 
254  //Setup the data for prediction
255  predictedClassLabel = 0;
256  maxLikelihood = 0;
257  classLikelihoods.resize(numClasses);
258  classDistances.resize(numClasses);
259 
260  //Compute the final training stats
261  trainingSetAccuracy = 0;
262  validationSetAccuracy = 0;
263 
264  //If scaling was on, then the data will already be scaled, so turn it off temporially
265  bool scalingState = useScaling;
266  useScaling = false;
267  for(UINT i=0; i<M; i++){
268  if( !predict_( trainingData[i].getSample() ) ){
269  trained = false;
270  errorLog << __GRT_LOG__ << " Failed to run prediction for training sample: " << i << "! Failed to fully train model!" << std::endl;
271  return false;
272  }
273 
274  if( predictedClassLabel == trainingData[i].getClassLabel() ){
275  trainingSetAccuracy++;
276  }
277  }
278 
279  if( useValidationSet ){
280  for(UINT i=0; i<validationData.getNumSamples(); i++){
281  if( !predict_( validationData[i].getSample() ) ){
282  trained = false;
283  errorLog << __GRT_LOG__ << " Failed to run prediction for validation sample: " << i << "! Failed to fully train model!" << std::endl;
284  return false;
285  }
286 
287  if( predictedClassLabel == validationData[i].getClassLabel() ){
288  validationSetAccuracy++;
289  }
290  }
291  }
292 
293  trainingSetAccuracy = trainingSetAccuracy / M * 100.0;
294 
295  trainingLog << "Training set accuracy: " << trainingSetAccuracy << std::endl;
296 
297  if( useValidationSet ){
298  validationSetAccuracy = validationSetAccuracy / validationData.getNumSamples() * 100.0;
299  trainingLog << "Validation set accuracy: " << validationSetAccuracy << std::endl;
300  }
301 
302  //Reset the scaling state for future prediction
303  useScaling = scalingState;
304 
305  return true;
306 }
307 
308 bool DecisionTree::trainTree( ClassificationData trainingData, const ClassificationData &trainingDataCopy, const ClassificationData &validationData, Vector< UINT > features ){
309 
310  //Note, this function is only called internally by the decision tree, users should call train_ instead.
311 
312  const unsigned int M = trainingData.getNumSamples();
313 
314  //Build the tree
315  UINT nodeID = 0;
316  tree = buildTree( trainingData, NULL, features, classLabels, nodeID );
317 
318  if( tree == NULL ){
319  clear();
320  errorLog << __GRT_LOG__ << " Failed to build tree!" << std::endl;
321  return false;
322  }
323 
324  //Flag that the algorithm has been trained
325  trained = true;
326 
327  //Compute the null rejection thresholds if null rejection is enabled
328  if( useNullRejection ){
329  VectorFloat classLikelihoods( numClasses );
330  Vector< UINT > predictions(M);
331  VectorFloat distances(M);
332  VectorFloat classCounter( numClasses, 0 );
333 
334  //Run over the training dataset and compute the distance between each training sample and the predicted node cluster
335  VectorFloat sample;
336  for(UINT i=0; i<M; i++){
337  //Run the prediction for this sample
338  sample = trainingDataCopy[i].getSample();
339  if( !tree->predict_( sample, classLikelihoods ) ){
340  errorLog << __GRT_LOG__ << " Failed to predict training sample while building null rejection model!" << std::endl;
341  return false;
342  }
343 
344  //Store the predicted class index and cluster distance
345  predictions[i] = Util::getMaxIndex( classLikelihoods );
346  distances[i] = getNodeDistance(sample, tree->getPredictedNodeID() );
347 
348  classCounter[ predictions[i] ]++;
349  }
350 
351  //Compute the average distance for each class between the training data and the node clusters
352  classClusterMean.clear();
353  classClusterStdDev.clear();
354  classClusterMean.resize( numClasses, 0 );
355  classClusterStdDev.resize( numClasses, 0.01 ); //we start the std dev with a small value to ensure it is not zero
356 
357  for(UINT i=0; i<M; i++){
358  classClusterMean[ predictions[i] ] += distances[ i ];
359  }
360  for(UINT k=0; k<numClasses; k++){
361  classClusterMean[k] /= grt_max( classCounter[k], 1 );
362  }
363 
364  //Compute the std deviation
365  for(UINT i=0; i<M; i++){
366  classClusterStdDev[ predictions[i] ] += MLBase::SQR( distances[ i ] - classClusterMean[ predictions[i] ] );
367  }
368  for(UINT k=0; k<numClasses; k++){
369  classClusterStdDev[k] = sqrt( classClusterStdDev[k] / grt_max( classCounter[k], 1 ) );
370  }
371 
372  //Compute the null rejection thresholds using the class mean and std dev
374  }
375 
376  if( useValidationSet ){
377  const UINT numTestSamples = validationData.getNumSamples();
378  double numCorrect = 0;
379  UINT testLabel = 0;
380  VectorFloat testSample;
381  VectorFloat validationSetPrecisionCounter( validationSetPrecision.size(), 0.0 );
382  VectorFloat validationSetRecallCounter( validationSetRecall.size(), 0.0 );
383  trainingLog << "Testing model with validation set..." << std::endl;
384  for(UINT i=0; i<numTestSamples; i++){
385  testLabel = validationData[i].getClassLabel();
386  testSample = validationData[i].getSample();
387  predict_( testSample );
388  if( predictedClassLabel == testLabel ){
389  numCorrect++;
390  validationSetPrecision[ getClassLabelIndexValue( testLabel ) ]++;
391  validationSetRecall[ getClassLabelIndexValue( testLabel ) ]++;
392  }
393  validationSetPrecisionCounter[ getClassLabelIndexValue( predictedClassLabel ) ]++;
394  validationSetRecallCounter[ getClassLabelIndexValue( testLabel ) ]++;
395  }
396 
397  validationSetAccuracy = (numCorrect / numTestSamples) * 100.0;
398  for(UINT i=0; i<validationSetPrecision.getSize(); i++){
399  validationSetPrecision[i] /= validationSetPrecisionCounter[i] > 0 ? validationSetPrecisionCounter[i] : 1;
400  }
401  for(UINT i=0; i<validationSetRecall.getSize(); i++){
402  validationSetRecall[i] /= validationSetRecallCounter[i] > 0 ? validationSetRecallCounter[i] : 1;
403  }
404 
405  Classifier::trainingLog << "Validation set accuracy: " << validationSetAccuracy << std::endl;
406 
407  Classifier::trainingLog << "Validation set precision: ";
408  for(UINT i=0; i<validationSetPrecision.getSize(); i++){
409  Classifier::trainingLog << validationSetPrecision[i] << " ";
410  }
411  Classifier::trainingLog << std::endl;
412 
413  Classifier::trainingLog << "Validation set recall: ";
414  for(UINT i=0; i<validationSetRecall.getSize(); i++){
415  Classifier::trainingLog << validationSetRecall[i] << " ";
416  }
417  Classifier::trainingLog << std::endl;
418  }
419 
420  return true;
421 }
422 
424 
425  predictedClassLabel = 0;
426  maxLikelihood = 0;
427 
428  //Validate the input is OK and the model is trained properly
429  if( !trained ){
430  errorLog << __GRT_LOG__ << " Model Not Trained!" << std::endl;
431  return false;
432  }
433 
434  if( tree == NULL ){
435  errorLog << __GRT_LOG__ << " DecisionTree pointer is null!" << std::endl;
436  return false;
437  }
438 
439  if( inputVector.getSize() != numInputDimensions ){
440  errorLog << __GRT_LOG__ << " The size of the input Vector (" << inputVector.getSize() << ") does not match the num features in the model (" << numInputDimensions << std::endl;
441  return false;
442  }
443 
444  //Scale the input data if needed
445  if( useScaling ){
446  for(UINT n=0; n<numInputDimensions; n++){
447  inputVector[n] = grt_scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue, 0.0, 1.0);
448  }
449  }
450 
451  if( classLikelihoods.size() != numClasses ) classLikelihoods.resize(numClasses,0);
452  if( classDistances.size() != numClasses ) classDistances.resize(numClasses,0);
453 
454  //Run the decision tree prediction
455  if( !tree->predict_( inputVector, classLikelihoods ) ){
456  errorLog << __GRT_LOG__ << " Failed to predict!" << std::endl;
457  return false;
458  }
459 
460  //Find the maximum likelihood
461  //The tree automatically returns proper class likelihoods so we don't need to do anything else
462  UINT maxIndex = 0;
463  maxLikelihood = 0;
464  for(UINT k=0; k<numClasses; k++){
465  if( classLikelihoods[k] > maxLikelihood ){
466  maxLikelihood = classLikelihoods[k];
467  maxIndex = k;
468  }
469  }
470 
471  //Run the null rejection
472  if( useNullRejection ){
473 
474  //Get the distance between the input and the leaf mean
475  Float leafDistance = getNodeDistance( inputVector, tree->getPredictedNodeID() );
476 
477  if( grt_isnan(leafDistance) ){
478  errorLog << __GRT_LOG__ << " Failed to match leaf node ID to compute node distance!" << std::endl;
479  return false;
480  }
481 
482  //Set the predicted class distance as the leaf distance, all other classes will have a distance of zero
483  classDistances.setAll(0.0);
484  classDistances[ maxIndex ] = leafDistance;
485 
486  //Use the distance to check if the class label should be rejected or not
487  if( leafDistance <= nullRejectionThresholds[ maxIndex ] ){
488  predictedClassLabel = classLabels[ maxIndex ];
489  }else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
490 
491  }else {
492  //Set the predicated class label
493  predictedClassLabel = classLabels[ maxIndex ];
494  }
495 
496  return true;
497 }
498 
500 
501  //Clear the Classifier variables
503 
504  //Clear the node clusters
505  nodeClusters.clear();
506 
507  //Delete the tree if it exists
508  if( tree != NULL ){
509  tree->clear();
510  delete tree;
511  tree = NULL;
512  }
513 
514  //NOTE: We do not want to clean up the decisionTreeNode here as we need to keep track of this, this is only delete in the destructor
515 
516  return true;
517 }
518 
520 
521  if( !trained ){
522  Classifier::warningLog << __GRT_LOG__ << " Failed to recompute null rejection thresholds, the model has not been trained!" << std::endl;
523  return false;
524  }
525 
526  if( !useNullRejection ){
527  Classifier::warningLog << __GRT_LOG__ << " Failed to recompute null rejection thresholds, null rejection is not enabled!" << std::endl;
528  return false;
529  }
530 
531  nullRejectionThresholds.resize( numClasses );
532 
533  //Compute the rejection threshold for each class using the mean and std dev
534  for(UINT k=0; k<numClasses; k++){
535  nullRejectionThresholds[k] = classClusterMean[k] + (classClusterStdDev[k]*nullRejectionCoeff);
536  }
537 
538  return true;
539 }
540 
541 bool DecisionTree::save( std::fstream &file ) const{
542 
543  if(!file.is_open())
544  {
545  errorLog << __GRT_LOG__ << " The file is not open!" << std::endl;
546  return false;
547  }
548 
549  //Write the header info
550  file << "GRT_DECISION_TREE_MODEL_FILE_V4.0\n";
551 
552  //Write the classifier settings to the file
554  errorLog << __GRT_LOG__ << " Failed to save classifier base settings to file!" << std::endl;
555  return false;
556  }
557 
558  if( decisionTreeNode != NULL ){
559  file << "DecisionTreeNodeType: " << decisionTreeNode->getNodeType() << std::endl;
560  if( !decisionTreeNode->save( file ) ){
561  errorLog << __GRT_LOG__ << " Failed to save decisionTreeNode settings to file!" << std::endl;
562  return false;
563  }
564  }else{
565  file << "DecisionTreeNodeType: " << "NULL" << std::endl;
566  }
567 
568  file << "MinNumSamplesPerNode: " << minNumSamplesPerNode << std::endl;
569  file << "MaxDepth: " << maxDepth << std::endl;
570  file << "RemoveFeaturesAtEachSpilt: " << removeFeaturesAtEachSplit << std::endl;
571  file << "TrainingMode: " << trainingMode << std::endl;
572  file << "NumSplittingSteps: " << numSplittingSteps << std::endl;
573  file << "TreeBuilt: " << (tree != NULL ? 1 : 0) << std::endl;
574 
575  if( tree != NULL ){
576  file << "Tree:\n";
577  if( !tree->save( file ) ){
578  errorLog << __GRT_LOG__ << " Failed to save tree to file!" << std::endl;
579  return false;
580  }
581 
582  //Save the null rejection data if needed
583  if( useNullRejection ){
584 
585  file << "ClassClusterMean:";
586  for(UINT k=0; k<numClasses; k++){
587  file << " " << classClusterMean[k];
588  }
589  file << std::endl;
590 
591  file << "ClassClusterStdDev:";
592  for(UINT k=0; k<numClasses; k++){
593  file << " " << classClusterStdDev[k];
594  }
595  file << std::endl;
596 
597  file << "NumNodes: " << nodeClusters.size() << std::endl;
598  file << "NodeClusters:\n";
599 
600  std::map< UINT, VectorFloat >::const_iterator iter = nodeClusters.begin();
601 
602  while( iter != nodeClusters.end() ){
603 
604  //Write the nodeID
605  file << iter->first;
606 
607  //Write the node cluster
608  for(UINT j=0; j<numInputDimensions; j++){
609  file << " " << iter->second[j];
610  }
611  file << std::endl;
612 
613  iter++;
614  }
615  }
616 
617  }
618 
619  return true;
620 }
621 
622 bool DecisionTree::load( std::fstream &file ){
623 
624  clear();
625 
626  UINT tempTrainingMode = 0;
627 
628  if( decisionTreeNode != NULL ){
629  delete decisionTreeNode;
630  decisionTreeNode = NULL;
631  }
632 
633  if( !file.is_open() )
634  {
635  errorLog << __GRT_LOG__ << " Could not open file to load model" << std::endl;
636  return false;
637  }
638 
639  std::string word;
640  file >> word;
641 
642  //Check to see if we should load a legacy file
643  if( word == "GRT_DECISION_TREE_MODEL_FILE_V1.0" ){
644  return loadLegacyModelFromFile_v1( file );
645  }
646 
647  if( word == "GRT_DECISION_TREE_MODEL_FILE_V2.0" ){
648  return loadLegacyModelFromFile_v2( file );
649  }
650 
651  if( word == "GRT_DECISION_TREE_MODEL_FILE_V3.0" ){
652  return loadLegacyModelFromFile_v3( file );
653  }
654 
655  //Find the file type header
656  if( word != "GRT_DECISION_TREE_MODEL_FILE_V4.0" ){
657  errorLog << __GRT_LOG__ << " Could not find Model File Header" << std::endl;
658  return false;
659  }
660 
661  //Load the base settings from the file
663  errorLog << __GRT_LOG__ << " Failed to load base settings from file!" << std::endl;
664  return false;
665  }
666 
667  file >> word;
668  if(word != "DecisionTreeNodeType:"){
669  errorLog << __GRT_LOG__ << " Could not find the DecisionTreeNodeType!" << std::endl;
670  return false;
671  }
672  file >> word;
673 
674  if( word != "NULL" ){
675 
676  decisionTreeNode = dynamic_cast< DecisionTreeNode* >( DecisionTreeNode::createInstanceFromString( word ) );
677 
678  if( decisionTreeNode == NULL ){
679  errorLog << __GRT_LOG__ << " Could not create new DecisionTreeNode from type: " << word << std::endl;
680  return false;
681  }
682 
683  if( !decisionTreeNode->load( file ) ){
684  errorLog << __GRT_LOG__ << " Failed to load decisionTreeNode settings from file!" << std::endl;
685  return false;
686  }
687  }else{
688  errorLog << __GRT_LOG__ << " Failed to load decisionTreeNode! DecisionTreeNodeType is NULL!" << std::endl;
689  return false;
690  }
691 
692  file >> word;
693  if(word != "MinNumSamplesPerNode:"){
694  errorLog << __GRT_LOG__ << " Could not find the MinNumSamplesPerNode!" << std::endl;
695  return false;
696  }
697  file >> minNumSamplesPerNode;
698 
699  file >> word;
700  if(word != "MaxDepth:"){
701  errorLog << __GRT_LOG__ << " Could not find the MaxDepth!" << std::endl;
702  return false;
703  }
704  file >> maxDepth;
705 
706  file >> word;
707  if(word != "RemoveFeaturesAtEachSpilt:"){
708  errorLog << __GRT_LOG__ << " Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
709  return false;
710  }
711  file >> removeFeaturesAtEachSplit;
712 
713  file >> word;
714  if(word != "TrainingMode:"){
715  errorLog << __GRT_LOG__ << " Could not find the TrainingMode!" << std::endl;
716  return false;
717  }
718  file >> tempTrainingMode;
719  trainingMode = static_cast<Tree::TrainingMode>(tempTrainingMode);
720 
721  file >> word;
722  if(word != "NumSplittingSteps:"){
723  errorLog << __GRT_LOG__ << " Could not find the NumSplittingSteps!" << std::endl;
724  return false;
725  }
726  file >> numSplittingSteps;
727 
728  file >> word;
729  if(word != "TreeBuilt:"){
730  errorLog << __GRT_LOG__ << " Could not find the TreeBuilt!" << std::endl;
731  return false;
732  }
733  file >> trained;
734 
735  if( trained ){
736  file >> word;
737  if(word != "Tree:"){
738  errorLog << __GRT_LOG__ << " Could not find the Tree!" << std::endl;
739  return false;
740  }
741 
742  //Create a new DTree
743  tree = dynamic_cast< DecisionTreeNode* >( decisionTreeNode->createNewInstance() );
744 
745  if( tree == NULL ){
746  clear();
747  errorLog << __GRT_LOG__ << " Failed to create new DecisionTreeNode!" << std::endl;
748  return false;
749  }
750 
751  tree->setParent( NULL );
752  if( !tree->load( file ) ){
753  clear();
754  errorLog << __GRT_LOG__ << " Failed to load tree from file!" << std::endl;
755  return false;
756  }
757 
758  //Load the null rejection data if needed
759  if( useNullRejection ){
760 
761  UINT numNodes = 0;
762  classClusterMean.resize( numClasses );
763  classClusterStdDev.resize( numClasses );
764 
765  file >> word;
766  if(word != "ClassClusterMean:"){
767  errorLog << __GRT_LOG__ << " Could not find the ClassClusterMean header!" << std::endl;
768  return false;
769  }
770  for(UINT k=0; k<numClasses; k++){
771  file >> classClusterMean[k];
772  }
773 
774  file >> word;
775  if(word != "ClassClusterStdDev:"){
776  errorLog << __GRT_LOG__ << " Could not find the ClassClusterStdDev header!" << std::endl;
777  return false;
778  }
779  for(UINT k=0; k<numClasses; k++){
780  file >> classClusterStdDev[k];
781  }
782 
783  file >> word;
784  if(word != "NumNodes:"){
785  errorLog << __GRT_LOG__ << " Could not find the NumNodes header!" << std::endl;
786  return false;
787  }
788  file >> numNodes;
789 
790  file >> word;
791  if(word != "NodeClusters:"){
792  errorLog << __GRT_LOG__ << " Could not find the NodeClusters header!" << std::endl;
793  return false;
794  }
795 
796  UINT nodeID = 0;
797  VectorFloat cluster( numInputDimensions );
798  for(UINT i=0; i<numNodes; i++){
799 
800  //load the nodeID
801  file >> nodeID;
802 
803  for(UINT j=0; j<numInputDimensions; j++){
804  file >> cluster[j];
805  }
806 
807  //Add the cluster to the cluster nodes map
808  nodeClusters[ nodeID ] = cluster;
809  }
810 
811  //Recompute the null rejection thresholds
813  }
814 
815  //Resize the prediction results to make sure it is setup for realtime prediction
816  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
817  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
818  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
819  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
820  }
821 
822  return true;
823 }
824 
825 bool DecisionTree::getModel( std::ostream &stream ) const{
826 
827  if( tree != NULL )
828  return tree->getModel( stream );
829  return false;
830 
831 }
832 
834 
835  if( tree == NULL ){
836  return NULL;
837  }
838 
839  return dynamic_cast< DecisionTreeNode* >( tree->deepCopy() );
840 }
841 
843 
844  if( decisionTreeNode == NULL ){
845  return NULL;
846  }
847 
848  return dynamic_cast< DecisionTreeNode* >(decisionTreeNode->deepCopy());
849 }
850 
852  return tree;
853 }
854 
856 
857  if( decisionTreeNode != NULL ){
858  decisionTreeNode->clear();
859  delete decisionTreeNode;
860  decisionTreeNode = NULL;
861  }
862  this->decisionTreeNode = dynamic_cast< DecisionTreeNode* >(node.deepCopy());
863 
864  return true;
865 }
866 
867 DecisionTreeNode* DecisionTree::buildTree(ClassificationData &trainingData,DecisionTreeNode *parent,Vector< UINT > features,const Vector< UINT > &classLabels, UINT nodeID){
868 
869  const UINT M = trainingData.getNumSamples();
870  const UINT N = trainingData.getNumDimensions();
871 
872  //Update the nodeID
873  nodeID++;
874 
875  //Get the depth
876  UINT depth = 0;
877 
878  if( parent != NULL )
879  depth = parent->getDepth() + 1;
880 
881  //If there are no training data then return NULL
882  if( trainingData.getNumSamples() == 0 )
883  return NULL;
884 
885  //Create the new node
886  DecisionTreeNode *node = dynamic_cast< DecisionTreeNode* >( decisionTreeNode->createNewInstance() );
887 
888  if( node == NULL )
889  return NULL;
890 
891  //Get the class probabilities
892  VectorFloat classProbs = trainingData.getClassProbabilities( classLabels );
893 
894  //Set the parent
895  node->initNode( parent, depth, nodeID );
896 
897  //If all the training data belongs to the same class or there are no features left then create a leaf node and return
898  if( trainingData.getNumClasses() == 1 || features.size() == 0 || M < minNumSamplesPerNode || depth >= maxDepth ){
899 
900  //Set the node
901  node->setLeafNode( trainingData.getNumSamples(), classProbs );
902 
903  //Build the null cluster if null rejection is enabled
904  if( useNullRejection ){
905  nodeClusters[ nodeID ] = trainingData.getMean();
906  }
907 
908  std::string info = "Reached leaf node.";
909  if( trainingData.getNumClasses() == 1 ) info = "Reached pure leaf node.";
910  else if( features.size() == 0 ) info = "Reached leaf node, no remaining features.";
911  else if( M < minNumSamplesPerNode ) info = "Reached leaf node, hit min-samples-per-node limit.";
912  else if( depth >= maxDepth ) info = "Reached leaf node, max depth reached.";
913 
914  trainingLog << info << " Depth: " << depth << " NumSamples: " << trainingData.getNumSamples();
915 
916  trainingLog << " Class Probabilities: ";
917  for(UINT k=0; k<classProbs.getSize(); k++){
918  trainingLog << classProbs[k] << " ";
919  }
920  trainingLog << std::endl;
921 
922  return node;
923  }
924 
925  //Compute the best spilt point
926  UINT featureIndex = 0;
927  Float minError = 0;
928 
929  if( !node->computeBestSplit( trainingMode, numSplittingSteps, trainingData, features, classLabels, featureIndex, minError ) ){
930  delete node;
931  return NULL;
932  }
933 
934  trainingLog << "Depth: " << depth << " FeatureIndex: " << featureIndex << " MinError: " << minError;
935  trainingLog << " Class Probabilities: ";
936  for(UINT k=0; k<classProbs.getSize(); k++){
937  trainingLog << classProbs[k] << " ";
938  }
939  trainingLog << std::endl;
940 
941  //Remove the selected feature so we will not use it again
942  if( removeFeaturesAtEachSplit ){
943  for(UINT i=0; i<features.getSize(); i++){
944  if( features[i] == featureIndex ){
945  features.erase( features.begin()+i );
946  break;
947  }
948  }
949  }
950 
951  //Split the data into a left and right dataset
952  ClassificationData lhs(N);
953  ClassificationData rhs(N);
954 
955  //Reserve the memory to speed up the allocation of large datasets
956  lhs.reserve( M );
957  rhs.reserve( M );
958 
959  for(UINT i=0; i<M; i++){
960  if( node->predict_( trainingData[i].getSample() ) ){
961  rhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
962  }else lhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
963  }
964 
965  //Clear the parent dataset so we do not run out of memory with very large datasets (with very deep trees)
966  trainingData.clear();
967 
968  //Get the new node IDs for the children
969  UINT leftNodeID = ++nodeID;
970  UINT rightNodeID = ++nodeID;
971 
972  //Run the recursive tree building on the children
973  node->setLeftChild( buildTree( lhs, node, features, classLabels, leftNodeID ) );
974  node->setRightChild( buildTree( rhs, node, features, classLabels, rightNodeID ) );
975 
976  //Build the null clusters for the rhs and lhs nodes if null rejection is enabled
977  if( useNullRejection ){
978  nodeClusters[ leftNodeID ] = lhs.getMean();
979  nodeClusters[ rightNodeID ] = rhs.getMean();
980  }
981 
982  return node;
983 }
984 
985 Float DecisionTree::getNodeDistance( const VectorFloat &x, const UINT nodeID ){
986 
987  //Use the node ID to find the node cluster
988  std::map< UINT,VectorFloat >::iterator iter = nodeClusters.find( nodeID );
989 
990  //If we failed to find a match, return NAN
991  if( iter == nodeClusters.end() ) return NAN;
992 
993  //Compute the distance between the input and the node cluster
994  return getNodeDistance( x, iter->second );
995 }
996 
997 Float DecisionTree::getNodeDistance( const VectorFloat &x, const VectorFloat &y ){
998 
999  Float distance = 0;
1000  const UINT N = x.getSize();
1001 
1002  for(UINT i=0; i<N; i++){
1003  distance += MLBase::SQR( x[i] - y[i] );
1004  }
1005 
1006  //Return the squared Euclidean distance instead of actual Euclidean distance as this is faster and just as useful
1007  return distance;
1008 }
1009 
1010 Tree::TrainingMode DecisionTree::getTrainingMode() const{
1011  return trainingMode;
1012 }
1013 
1015  return numSplittingSteps;
1016 }
1017 
1019  return minNumSamplesPerNode;
1020 }
1021 
1023  return maxDepth;
1024 }
1025 
1027 
1028  if( tree == NULL ){
1029  return 0;
1030  }
1031 
1032  return tree->getPredictedNodeID();
1033 }
1034 
1036  return removeFeaturesAtEachSplit;
1037 }
1038 
1039 bool DecisionTree::setTrainingMode(const Tree::TrainingMode trainingMode){
1040  if( trainingMode >= Tree::BEST_ITERATIVE_SPILT && trainingMode < Tree::NUM_TRAINING_MODES ){
1041  this->trainingMode = trainingMode;
1042  return true;
1043  }
1044  warningLog << __GRT_LOG__ << " Unknown trainingMode: " << trainingMode << std::endl;
1045  return false;
1046 }
1047 
1048 bool DecisionTree::setNumSplittingSteps(const UINT numSplittingSteps){
1049  if( numSplittingSteps > 0 ){
1050  this->numSplittingSteps = numSplittingSteps;
1051  return true;
1052  }
1053  warningLog << __GRT_LOG__ << " The number of splitting steps must be greater than zero!" << std::endl;
1054  return false;
1055 }
1056 
1057 bool DecisionTree::setMinNumSamplesPerNode(const UINT minNumSamplesPerNode){
1058  if( minNumSamplesPerNode > 0 ){
1059  this->minNumSamplesPerNode = minNumSamplesPerNode;
1060  return true;
1061  }
1062  warningLog << __GRT_LOG__ << " The minimum number of samples per node must be greater than zero!" << std::endl;
1063  return false;
1064 }
1065 
1066 bool DecisionTree::setMaxDepth(const UINT maxDepth){
1067  if( maxDepth > 0 ){
1068  this->maxDepth = maxDepth;
1069  return true;
1070  }
1071  warningLog << __GRT_LOG__ << " The maximum depth must be greater than zero!" << std::endl;
1072  return false;
1073 }
1074 
1075 bool DecisionTree::setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit){
1076  this->removeFeaturesAtEachSplit = removeFeaturesAtEachSplit;
1077  return true;
1078 }
1079 
1080 bool DecisionTree::setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt){
1081  return setRemoveFeaturesAtEachSplit(removeFeaturesAtEachSpilt);
1082 }
1083 
1084 bool DecisionTree::loadLegacyModelFromFile_v1( std::fstream &file ){
1085 
1086  std::string word;
1087  UINT tempTrainingMode = 0;
1088 
1089  file >> word;
1090  if(word != "NumFeatures:"){
1091  errorLog << __GRT_LOG__ << " Could not find NumFeatures!" << std::endl;
1092  return false;
1093  }
1094  file >> numInputDimensions;
1095 
1096  file >> word;
1097  if(word != "NumClasses:"){
1098  errorLog << __GRT_LOG__ << " Could not find NumClasses!" << std::endl;
1099  return false;
1100  }
1101  file >> numClasses;
1102 
1103  file >> word;
1104  if(word != "UseScaling:"){
1105  errorLog << __GRT_LOG__ << " Could not find UseScaling!" << std::endl;
1106  return false;
1107  }
1108  file >> useScaling;
1109 
1110  file >> word;
1111  if(word != "UseNullRejection:"){
1112  errorLog << __GRT_LOG__ << " Could not find UseNullRejection!" << std::endl;
1113  return false;
1114  }
1115  file >> useNullRejection;
1116 
1118  if( useScaling ){
1119  //Resize the ranges buffer
1120  ranges.resize( numInputDimensions );
1121 
1122  file >> word;
1123  if(word != "Ranges:"){
1124  errorLog << __GRT_LOG__ << " Could not find the Ranges!" << std::endl;
1125  return false;
1126  }
1127  for(UINT n=0; n<ranges.size(); n++){
1128  file >> ranges[n].minValue;
1129  file >> ranges[n].maxValue;
1130  }
1131  }
1132 
1133  file >> word;
1134  if(word != "NumSplittingSteps:"){
1135  errorLog << __GRT_LOG__ << " Could not find the NumSplittingSteps!" << std::endl;
1136  return false;
1137  }
1138  file >> numSplittingSteps;
1139 
1140  file >> word;
1141  if(word != "MinNumSamplesPerNode:"){
1142  errorLog << __GRT_LOG__ << " Could not find the MinNumSamplesPerNode!" << std::endl;
1143  return false;
1144  }
1145  file >> minNumSamplesPerNode;
1146 
1147  file >> word;
1148  if(word != "MaxDepth:"){
1149  errorLog << __GRT_LOG__ << " Could not find the MaxDepth!" << std::endl;
1150  return false;
1151  }
1152  file >> maxDepth;
1153 
1154  file >> word;
1155  if(word != "RemoveFeaturesAtEachSpilt:"){
1156  errorLog << __GRT_LOG__ << " Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1157  return false;
1158  }
1159  file >> removeFeaturesAtEachSplit;
1160 
1161  file >> word;
1162  if(word != "TrainingMode:"){
1163  errorLog << __GRT_LOG__ << " Could not find the TrainingMode!" << std::endl;
1164  return false;
1165  }
1166  file >> tempTrainingMode;
1167  trainingMode = static_cast<Tree::TrainingMode>(tempTrainingMode);
1168 
1169  file >> word;
1170  if(word != "TreeBuilt:"){
1171  errorLog << __GRT_LOG__ << " Could not find the TreeBuilt!" << std::endl;
1172  return false;
1173  }
1174  file >> trained;
1175 
1176  if( trained ){
1177  file >> word;
1178  if(word != "Tree:"){
1179  errorLog << __GRT_LOG__ << " Could not find the Tree!" << std::endl;
1180  return false;
1181  }
1182 
1183  //Create a new DTree
1184  tree = new DecisionTreeNode;
1185 
1186  if( tree == NULL ){
1187  clear();
1188  errorLog << __GRT_LOG__ << " Failed to create new DecisionTreeNode!" << std::endl;
1189  return false;
1190  }
1191 
1192  tree->setParent( NULL );
1193  if( !tree->load( file ) ){
1194  clear();
1195  errorLog << __GRT_LOG__ << " Failed to load tree from file!" << std::endl;
1196  return false;
1197  }
1198  }
1199 
1200  return true;
1201 }
1202 
1203 bool DecisionTree::loadLegacyModelFromFile_v2( std::fstream &file ){
1204 
1205  std::string word;
1206  UINT tempTrainingMode = 0;
1207 
1208  //Load the base settings from the file
1210  errorLog << "load(string filename) - Failed to load base settings from file!" << std::endl;
1211  return false;
1212  }
1213 
1214  file >> word;
1215  if(word != "NumSplittingSteps:"){
1216  errorLog << "load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1217  return false;
1218  }
1219  file >> numSplittingSteps;
1220 
1221  file >> word;
1222  if(word != "MinNumSamplesPerNode:"){
1223  errorLog << "load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1224  return false;
1225  }
1226  file >> minNumSamplesPerNode;
1227 
1228  file >> word;
1229  if(word != "MaxDepth:"){
1230  errorLog << "load(string filename) - Could not find the MaxDepth!" << std::endl;
1231  return false;
1232  }
1233  file >> maxDepth;
1234 
1235  file >> word;
1236  if(word != "RemoveFeaturesAtEachSpilt:"){
1237  errorLog << "load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1238  return false;
1239  }
1240  file >> removeFeaturesAtEachSplit;
1241 
1242  file >> word;
1243  if(word != "TrainingMode:"){
1244  errorLog << "load(string filename) - Could not find the TrainingMode!" << std::endl;
1245  return false;
1246  }
1247  file >> tempTrainingMode;
1248  trainingMode = static_cast<Tree::TrainingMode>(tempTrainingMode);
1249 
1250  file >> word;
1251  if(word != "TreeBuilt:"){
1252  errorLog << "load(string filename) - Could not find the TreeBuilt!" << std::endl;
1253  return false;
1254  }
1255  file >> trained;
1256 
1257  if( trained ){
1258  file >> word;
1259  if(word != "Tree:"){
1260  errorLog << "load(string filename) - Could not find the Tree!" << std::endl;
1261  return false;
1262  }
1263 
1264  //Create a new DTree
1265  tree = dynamic_cast<DecisionTreeNode*>( new DecisionTreeNode );
1266 
1267  if( tree == NULL ){
1268  clear();
1269  errorLog << "load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1270  return false;
1271  }
1272 
1273  tree->setParent( NULL );
1274  if( !tree->load( file ) ){
1275  clear();
1276  errorLog << "load(fstream &file) - Failed to load tree from file!" << std::endl;
1277  return false;
1278  }
1279 
1280  //Recompute the null rejection thresholds
1282 
1283  //Resize the prediction results to make sure it is setup for realtime prediction
1284  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
1285  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1286  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
1287  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1288  }
1289 
1290  return true;
1291 }
1292 
1293 bool DecisionTree::loadLegacyModelFromFile_v3( std::fstream &file ){
1294 
1295  std::string word;
1296  UINT tempTrainingMode = 0;
1297 
1298  //Load the base settings from the file
1300  errorLog << "load(string filename) - Failed to load base settings from file!" << std::endl;
1301  return false;
1302  }
1303 
1304  file >> word;
1305  if(word != "NumSplittingSteps:"){
1306  errorLog << "load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1307  return false;
1308  }
1309  file >> numSplittingSteps;
1310 
1311  file >> word;
1312  if(word != "MinNumSamplesPerNode:"){
1313  errorLog << "load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1314  return false;
1315  }
1316  file >> minNumSamplesPerNode;
1317 
1318  file >> word;
1319  if(word != "MaxDepth:"){
1320  errorLog << "load(string filename) - Could not find the MaxDepth!" << std::endl;
1321  return false;
1322  }
1323  file >> maxDepth;
1324 
1325  file >> word;
1326  if(word != "RemoveFeaturesAtEachSpilt:"){
1327  errorLog << "load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1328  return false;
1329  }
1330  file >> removeFeaturesAtEachSplit;
1331 
1332  file >> word;
1333  if(word != "TrainingMode:"){
1334  errorLog << "load(string filename) - Could not find the TrainingMode!" << std::endl;
1335  return false;
1336  }
1337  file >> tempTrainingMode;
1338  trainingMode = static_cast<Tree::TrainingMode>(tempTrainingMode);
1339 
1340  file >> word;
1341  if(word != "TreeBuilt:"){
1342  errorLog << "load(string filename) - Could not find the TreeBuilt!" << std::endl;
1343  return false;
1344  }
1345  file >> trained;
1346 
1347  if( trained ){
1348  file >> word;
1349  if(word != "Tree:"){
1350  errorLog << "load(string filename) - Could not find the Tree!" << std::endl;
1351  return false;
1352  }
1353 
1354  //Create a new DTree
1355  tree = new DecisionTreeNode;
1356 
1357  if( tree == NULL ){
1358  clear();
1359  errorLog << "load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1360  return false;
1361  }
1362 
1363  tree->setParent( NULL );
1364  if( !tree->load( file ) ){
1365  clear();
1366  errorLog << "load(fstream &file) - Failed to load tree from file!" << std::endl;
1367  return false;
1368  }
1369 
1370  //Load the null rejection data if needed
1371  if( useNullRejection ){
1372 
1373  UINT numNodes = 0;
1374  classClusterMean.resize( numClasses );
1375  classClusterStdDev.resize( numClasses );
1376 
1377  file >> word;
1378  if(word != "ClassClusterMean:"){
1379  errorLog << "load(string filename) - Could not find the ClassClusterMean header!" << std::endl;
1380  return false;
1381  }
1382  for(UINT k=0; k<numClasses; k++){
1383  file >> classClusterMean[k];
1384  }
1385 
1386  file >> word;
1387  if(word != "ClassClusterStdDev:"){
1388  errorLog << "load(string filename) - Could not find the ClassClusterStdDev header!" << std::endl;
1389  return false;
1390  }
1391  for(UINT k=0; k<numClasses; k++){
1392  file >> classClusterStdDev[k];
1393  }
1394 
1395  file >> word;
1396  if(word != "NumNodes:"){
1397  errorLog << "load(string filename) - Could not find the NumNodes header!" << std::endl;
1398  return false;
1399  }
1400  file >> numNodes;
1401 
1402  file >> word;
1403  if(word != "NodeClusters:"){
1404  errorLog << "load(string filename) - Could not find the NodeClusters header!" << std::endl;
1405  return false;
1406  }
1407 
1408  UINT nodeID = 0;
1409  VectorFloat cluster( numInputDimensions );
1410  for(UINT i=0; i<numNodes; i++){
1411 
1412  //load the nodeID
1413  file >> nodeID;
1414 
1415  for(UINT j=0; j<numInputDimensions; j++){
1416  file >> cluster[j];
1417  }
1418 
1419  //Add the cluster to the cluster nodes map
1420  nodeClusters[ nodeID ] = cluster;
1421  }
1422 
1423  //Recompute the null rejection thresholds
1425  }
1426 
1427  //Resize the prediction results to make sure it is setup for realtime prediction
1428  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
1429  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1430  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
1431  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1432  }
1433 
1434  return true;
1435 }
1436 
1437 GRT_END_NAMESPACE
bool saveBaseSettingsToFile(std::fstream &file) const
Definition: Classifier.cpp:274
virtual bool computeBestSplit(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError)
bool setMinNumSamplesPerNode(const UINT minNumSamplesPerNode)
#define DEFAULT_NULL_LIKELIHOOD_VALUE
Definition: Classifier.h:33
virtual bool clear() override
bool setAll(const T &value)
Definition: Vector.h:192
UINT getMaxDepth() const
bool addSample(const UINT classLabel, const VectorFloat &sample)
virtual bool save(std::fstream &file) const override
std::string getClassifierType() const
Definition: Classifier.cpp:175
const DecisionTreeNode * getTree() const
DecisionTreeNode * deepCopyDecisionTreeNode() const
bool loadLegacyModelFromFile_v1(std::fstream &file)
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
UINT getNumSplittingSteps() const
std::string getNodeType() const
Definition: Node.cpp:300
bool setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit)
bool setTrainingMode(const Tree::TrainingMode trainingMode)
virtual bool predict_(VectorFloat &x, VectorFloat &classLikelihoods) override
bool getTrained() const
Definition: MLBase.cpp:294
UINT getDepth() const
Definition: Node.cpp:304
UINT getSize() const
Definition: Vector.h:201
Vector< UINT > getClassLabels() const
UINT getPredictedNodeID() const
virtual bool recomputeNullRejectionThresholds() override
virtual Node * deepCopy() const override
static std::string getId()
virtual bool getModel(std::ostream &stream) const override
DecisionTreeNode * deepCopyTree() const
UINT getNumSamples() const
bool getRemoveFeaturesAtEachSplit() const
bool setMaxDepth(const UINT maxDepth)
bool setNumSplittingSteps(const UINT numSplittingSteps)
bool copyBaseVariables(const Classifier *classifier)
Definition: Classifier.cpp:101
bool loadBaseSettingsFromFile(std::fstream &file)
Definition: Classifier.cpp:321
virtual bool save(std::fstream &file) const override
Definition: Node.cpp:136
virtual bool getModel(std::ostream &stream) const override
UINT getNumDimensions() const
UINT getNumClasses() const
DecisionTree(const DecisionTreeNode &decisionTreeNode=DecisionTreeClusterNode(), const UINT minNumSamplesPerNode=5, const UINT maxDepth=10, const bool removeFeaturesAtEachSplit=false, const Tree::TrainingMode trainingMode=Tree::TrainingMode::BEST_ITERATIVE_SPILT, const UINT numSplittingSteps=100, const bool useScaling=false)
Node * createNewInstance() const
Definition: Node.cpp:39
UINT getPredictedNodeID() const
Definition: Node.cpp:312
UINT getClassLabelIndexValue(const UINT classLabel) const
Definition: Classifier.cpp:213
DecisionTree & operator=(const DecisionTree &rhs)
bool setLeafNode(const UINT nodeSize, const VectorFloat &classProbabilities)
virtual bool load(std::fstream &file) override
virtual bool deepCopyFrom(const Classifier *classifier) override
virtual bool clear() override
Vector< MinMax > getRanges() const
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
Tree::TrainingMode getTrainingMode() const
bool reserve(const UINT M)
virtual bool train_(ClassificationData &trainingData) override
static Node * createInstanceFromString(std::string const &nodeType)
Definition: Node.cpp:29
static unsigned int getMaxIndex(const VectorFloat &x)
Definition: Util.cpp:309
virtual bool load(std::fstream &file) override
Definition: Node.cpp:178
bool setDecisionTreeNode(const DecisionTreeNode &node)
virtual ~DecisionTree(void)
bool scale(const Float minTarget, const Float maxTarget)
virtual bool clear()
Definition: Classifier.cpp:151
virtual bool predict_(VectorFloat &inputVector) override
UINT getMinNumSamplesPerNode() const
This is the main base class that all GRT Classification algorithms should inherit from...
Definition: Classifier.h:41
VectorFloat getMean() const