21 #define GRT_DLL_EXPORTS 27 const std::string DecisionTree::id =
"DecisionTree";
37 this->minNumSamplesPerNode = minNumSamplesPerNode;
38 this->maxDepth = maxDepth;
39 this->removeFeaturesAtEachSplit = removeFeaturesAtEachSplit;
40 this->trainingMode = trainingMode;
41 this->numSplittingSteps = numSplittingSteps;
42 this->useScaling = useScaling;
43 this->supportsNullRejection =
true;
44 this->numTrainingIterationsToConverge = 20;
45 classifierMode = STANDARD_CLASSIFIER_MODE;
51 decisionTreeNode = NULL;
52 classifierMode = STANDARD_CLASSIFIER_MODE;
60 if( decisionTreeNode != NULL ){
61 delete decisionTreeNode;
62 decisionTreeNode = NULL;
77 if( this->decisionTreeNode != NULL ){
78 delete decisionTreeNode;
79 decisionTreeNode = NULL;
83 this->minNumSamplesPerNode = rhs.minNumSamplesPerNode;
84 this->maxDepth = rhs.maxDepth;
85 this->removeFeaturesAtEachSplit = rhs.removeFeaturesAtEachSplit;
86 this->trainingMode = rhs.trainingMode;
87 this->numSplittingSteps = rhs.numSplittingSteps;
88 this->nodeClusters = rhs.nodeClusters;
98 if( classifier == NULL )
return false;
113 if( this->decisionTreeNode != NULL ){
114 delete decisionTreeNode;
115 decisionTreeNode = NULL;
119 this->minNumSamplesPerNode = ptr->minNumSamplesPerNode;
120 this->maxDepth = ptr->maxDepth;
121 this->removeFeaturesAtEachSplit = ptr->removeFeaturesAtEachSplit;
122 this->trainingMode = ptr->trainingMode;
123 this->numSplittingSteps = ptr->numSplittingSteps;
124 this->nodeClusters = ptr->nodeClusters;
137 if( decisionTreeNode == NULL ){
138 errorLog << __GRT_LOG__ <<
" The decision tree node has not been set! You must set this first before training a model." << std::endl;
146 errorLog << __GRT_LOG__ <<
" Training data has zero samples!" << std::endl;
150 numInputDimensions = N;
151 numOutputDimensions = K;
158 if( useValidationSet ){
159 validationData = trainingData.
split( validationSetSize );
160 validationSetAccuracy = 0;
161 validationSetPrecision.
resize( useNullRejection ? K+1 : K, 0 );
162 validationSetRecall.
resize( useNullRejection ? K+1 : K, 0 );
170 trainingData.
scale(0, 1);
175 if( useNullRejection ){
176 trainingDataCopy = trainingData;
181 for(UINT i=0; i<N; i++){
185 numTrainingIterationsToConverge = 1;
186 trainingLog <<
"numTrainingIterationsToConverge " << numTrainingIterationsToConverge <<
" useValidationSet: " << useValidationSet << std::endl;
188 if( useValidationSet ){
192 Float bestValidationSetAccuracy = 0;
193 UINT bestTreeIndex = 0;
194 for(UINT i=0; i<numTrainingIterationsToConverge; i++){
196 trainingLog <<
"Training tree iteration: " << i+1 <<
"/" << numTrainingIterationsToConverge << std::endl;
198 if( !trainTree( trainingData, trainingDataCopy, validationData, features ) ){
199 errorLog << __GRT_LOG__ <<
" Failed to build tree!" << std::endl;
201 if( bestTree != NULL ){
209 if( bestTree == NULL ){
212 bestValidationSetAccuracy = validationSetAccuracy;
215 if( validationSetAccuracy > bestValidationSetAccuracy ){
218 bestValidationSetAccuracy = validationSetAccuracy;
231 trainingLog <<
"Best tree index: " << bestTreeIndex+1 <<
" validation set accuracy: " << bestValidationSetAccuracy << std::endl;
233 if( bestTree != tree ){
246 if( !trainTree( trainingData, trainingDataCopy, validationData, features ) ){
255 predictedClassLabel = 0;
257 classLikelihoods.
resize(numClasses);
258 classDistances.
resize(numClasses);
261 trainingSetAccuracy = 0;
262 validationSetAccuracy = 0;
265 bool scalingState = useScaling;
267 for(UINT i=0; i<M; i++){
268 if( !
predict_( trainingData[i].getSample() ) ){
270 errorLog << __GRT_LOG__ <<
" Failed to run prediction for training sample: " << i <<
"! Failed to fully train model!" << std::endl;
274 if( predictedClassLabel == trainingData[i].getClassLabel() ){
275 trainingSetAccuracy++;
279 if( useValidationSet ){
281 if( !
predict_( validationData[i].getSample() ) ){
283 errorLog << __GRT_LOG__ <<
" Failed to run prediction for validation sample: " << i <<
"! Failed to fully train model!" << std::endl;
287 if( predictedClassLabel == validationData[i].getClassLabel() ){
288 validationSetAccuracy++;
293 trainingSetAccuracy = trainingSetAccuracy / M * 100.0;
295 trainingLog <<
"Training set accuracy: " << trainingSetAccuracy << std::endl;
297 if( useValidationSet ){
298 validationSetAccuracy = validationSetAccuracy / validationData.
getNumSamples() * 100.0;
299 trainingLog <<
"Validation set accuracy: " << validationSetAccuracy << std::endl;
303 useScaling = scalingState;
316 tree = buildTree( trainingData, NULL, features, classLabels, nodeID );
320 errorLog << __GRT_LOG__ <<
" Failed to build tree!" << std::endl;
328 if( useNullRejection ){
336 for(UINT i=0; i<M; i++){
338 sample = trainingDataCopy[i].getSample();
339 if( !tree->
predict_( sample, classLikelihoods ) ){
340 errorLog << __GRT_LOG__ <<
" Failed to predict training sample while building null rejection model!" << std::endl;
348 classCounter[ predictions[i] ]++;
352 classClusterMean.clear();
353 classClusterStdDev.clear();
354 classClusterMean.
resize( numClasses, 0 );
355 classClusterStdDev.
resize( numClasses, 0.01 );
357 for(UINT i=0; i<M; i++){
358 classClusterMean[ predictions[i] ] += distances[ i ];
360 for(UINT k=0; k<numClasses; k++){
361 classClusterMean[k] /= grt_max( classCounter[k], 1 );
365 for(UINT i=0; i<M; i++){
366 classClusterStdDev[ predictions[i] ] += MLBase::SQR( distances[ i ] - classClusterMean[ predictions[i] ] );
368 for(UINT k=0; k<numClasses; k++){
369 classClusterStdDev[k] = sqrt( classClusterStdDev[k] / grt_max( classCounter[k], 1 ) );
376 if( useValidationSet ){
378 double numCorrect = 0;
381 VectorFloat validationSetPrecisionCounter( validationSetPrecision.size(), 0.0 );
382 VectorFloat validationSetRecallCounter( validationSetRecall.size(), 0.0 );
383 trainingLog <<
"Testing model with validation set..." << std::endl;
384 for(UINT i=0; i<numTestSamples; i++){
385 testLabel = validationData[i].getClassLabel();
386 testSample = validationData[i].getSample();
388 if( predictedClassLabel == testLabel ){
397 validationSetAccuracy = (numCorrect / numTestSamples) * 100.0;
398 for(UINT i=0; i<validationSetPrecision.
getSize(); i++){
399 validationSetPrecision[i] /= validationSetPrecisionCounter[i] > 0 ? validationSetPrecisionCounter[i] : 1;
401 for(UINT i=0; i<validationSetRecall.
getSize(); i++){
402 validationSetRecall[i] /= validationSetRecallCounter[i] > 0 ? validationSetRecallCounter[i] : 1;
405 Classifier::trainingLog <<
"Validation set accuracy: " << validationSetAccuracy << std::endl;
407 Classifier::trainingLog <<
"Validation set precision: ";
408 for(UINT i=0; i<validationSetPrecision.
getSize(); i++){
409 Classifier::trainingLog << validationSetPrecision[i] <<
" ";
411 Classifier::trainingLog << std::endl;
413 Classifier::trainingLog <<
"Validation set recall: ";
414 for(UINT i=0; i<validationSetRecall.
getSize(); i++){
415 Classifier::trainingLog << validationSetRecall[i] <<
" ";
417 Classifier::trainingLog << std::endl;
425 predictedClassLabel = 0;
430 errorLog << __GRT_LOG__ <<
" Model Not Trained!" << std::endl;
435 errorLog << __GRT_LOG__ <<
" DecisionTree pointer is null!" << std::endl;
439 if( inputVector.
getSize() != numInputDimensions ){
440 errorLog << __GRT_LOG__ <<
" The size of the input Vector (" << inputVector.
getSize() <<
") does not match the num features in the model (" << numInputDimensions << std::endl;
446 for(UINT n=0; n<numInputDimensions; n++){
447 inputVector[n] = grt_scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue, 0.0, 1.0);
451 if( classLikelihoods.size() != numClasses ) classLikelihoods.
resize(numClasses,0);
452 if( classDistances.size() != numClasses ) classDistances.
resize(numClasses,0);
455 if( !tree->
predict_( inputVector, classLikelihoods ) ){
456 errorLog << __GRT_LOG__ <<
" Failed to predict!" << std::endl;
464 for(UINT k=0; k<numClasses; k++){
465 if( classLikelihoods[k] > maxLikelihood ){
466 maxLikelihood = classLikelihoods[k];
472 if( useNullRejection ){
477 if( grt_isnan(leafDistance) ){
478 errorLog << __GRT_LOG__ <<
" Failed to match leaf node ID to compute node distance!" << std::endl;
483 classDistances.
setAll(0.0);
484 classDistances[ maxIndex ] = leafDistance;
487 if( leafDistance <= nullRejectionThresholds[ maxIndex ] ){
488 predictedClassLabel = classLabels[ maxIndex ];
489 }
else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
493 predictedClassLabel = classLabels[ maxIndex ];
505 nodeClusters.clear();
522 Classifier::warningLog << __GRT_LOG__ <<
" Failed to recompute null rejection thresholds, the model has not been trained!" << std::endl;
526 if( !useNullRejection ){
527 Classifier::warningLog << __GRT_LOG__ <<
" Failed to recompute null rejection thresholds, null rejection is not enabled!" << std::endl;
531 nullRejectionThresholds.
resize( numClasses );
534 for(UINT k=0; k<numClasses; k++){
535 nullRejectionThresholds[k] = classClusterMean[k] + (classClusterStdDev[k]*nullRejectionCoeff);
545 errorLog << __GRT_LOG__ <<
" The file is not open!" << std::endl;
550 file <<
"GRT_DECISION_TREE_MODEL_FILE_V4.0\n";
554 errorLog << __GRT_LOG__ <<
" Failed to save classifier base settings to file!" << std::endl;
558 if( decisionTreeNode != NULL ){
559 file <<
"DecisionTreeNodeType: " << decisionTreeNode->
getNodeType() << std::endl;
560 if( !decisionTreeNode->
save( file ) ){
561 errorLog << __GRT_LOG__ <<
" Failed to save decisionTreeNode settings to file!" << std::endl;
565 file <<
"DecisionTreeNodeType: " <<
"NULL" << std::endl;
568 file <<
"MinNumSamplesPerNode: " << minNumSamplesPerNode << std::endl;
569 file <<
"MaxDepth: " << maxDepth << std::endl;
570 file <<
"RemoveFeaturesAtEachSpilt: " << removeFeaturesAtEachSplit << std::endl;
571 file <<
"TrainingMode: " << trainingMode << std::endl;
572 file <<
"NumSplittingSteps: " << numSplittingSteps << std::endl;
573 file <<
"TreeBuilt: " << (tree != NULL ? 1 : 0) << std::endl;
577 if( !tree->
save( file ) ){
578 errorLog << __GRT_LOG__ <<
" Failed to save tree to file!" << std::endl;
583 if( useNullRejection ){
585 file <<
"ClassClusterMean:";
586 for(UINT k=0; k<numClasses; k++){
587 file <<
" " << classClusterMean[k];
591 file <<
"ClassClusterStdDev:";
592 for(UINT k=0; k<numClasses; k++){
593 file <<
" " << classClusterStdDev[k];
597 file <<
"NumNodes: " << nodeClusters.size() << std::endl;
598 file <<
"NodeClusters:\n";
600 std::map< UINT, VectorFloat >::const_iterator iter = nodeClusters.begin();
602 while( iter != nodeClusters.end() ){
608 for(UINT j=0; j<numInputDimensions; j++){
609 file <<
" " << iter->second[j];
626 UINT tempTrainingMode = 0;
628 if( decisionTreeNode != NULL ){
629 delete decisionTreeNode;
630 decisionTreeNode = NULL;
633 if( !file.is_open() )
635 errorLog << __GRT_LOG__ <<
" Could not open file to load model" << std::endl;
643 if( word ==
"GRT_DECISION_TREE_MODEL_FILE_V1.0" ){
647 if( word ==
"GRT_DECISION_TREE_MODEL_FILE_V2.0" ){
648 return loadLegacyModelFromFile_v2( file );
651 if( word ==
"GRT_DECISION_TREE_MODEL_FILE_V3.0" ){
652 return loadLegacyModelFromFile_v3( file );
656 if( word !=
"GRT_DECISION_TREE_MODEL_FILE_V4.0" ){
657 errorLog << __GRT_LOG__ <<
" Could not find Model File Header" << std::endl;
663 errorLog << __GRT_LOG__ <<
" Failed to load base settings from file!" << std::endl;
668 if(word !=
"DecisionTreeNodeType:"){
669 errorLog << __GRT_LOG__ <<
" Could not find the DecisionTreeNodeType!" << std::endl;
674 if( word !=
"NULL" ){
678 if( decisionTreeNode == NULL ){
679 errorLog << __GRT_LOG__ <<
" Could not create new DecisionTreeNode from type: " << word << std::endl;
683 if( !decisionTreeNode->
load( file ) ){
684 errorLog << __GRT_LOG__ <<
" Failed to load decisionTreeNode settings from file!" << std::endl;
688 errorLog << __GRT_LOG__ <<
" Failed to load decisionTreeNode! DecisionTreeNodeType is NULL!" << std::endl;
693 if(word !=
"MinNumSamplesPerNode:"){
694 errorLog << __GRT_LOG__ <<
" Could not find the MinNumSamplesPerNode!" << std::endl;
697 file >> minNumSamplesPerNode;
700 if(word !=
"MaxDepth:"){
701 errorLog << __GRT_LOG__ <<
" Could not find the MaxDepth!" << std::endl;
707 if(word !=
"RemoveFeaturesAtEachSpilt:"){
708 errorLog << __GRT_LOG__ <<
" Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
711 file >> removeFeaturesAtEachSplit;
714 if(word !=
"TrainingMode:"){
715 errorLog << __GRT_LOG__ <<
" Could not find the TrainingMode!" << std::endl;
718 file >> tempTrainingMode;
719 trainingMode =
static_cast<Tree::TrainingMode
>(tempTrainingMode);
722 if(word !=
"NumSplittingSteps:"){
723 errorLog << __GRT_LOG__ <<
" Could not find the NumSplittingSteps!" << std::endl;
726 file >> numSplittingSteps;
729 if(word !=
"TreeBuilt:"){
730 errorLog << __GRT_LOG__ <<
" Could not find the TreeBuilt!" << std::endl;
738 errorLog << __GRT_LOG__ <<
" Could not find the Tree!" << std::endl;
747 errorLog << __GRT_LOG__ <<
" Failed to create new DecisionTreeNode!" << std::endl;
751 tree->setParent( NULL );
752 if( !tree->
load( file ) ){
754 errorLog << __GRT_LOG__ <<
" Failed to load tree from file!" << std::endl;
759 if( useNullRejection ){
762 classClusterMean.
resize( numClasses );
763 classClusterStdDev.
resize( numClasses );
766 if(word !=
"ClassClusterMean:"){
767 errorLog << __GRT_LOG__ <<
" Could not find the ClassClusterMean header!" << std::endl;
770 for(UINT k=0; k<numClasses; k++){
771 file >> classClusterMean[k];
775 if(word !=
"ClassClusterStdDev:"){
776 errorLog << __GRT_LOG__ <<
" Could not find the ClassClusterStdDev header!" << std::endl;
779 for(UINT k=0; k<numClasses; k++){
780 file >> classClusterStdDev[k];
784 if(word !=
"NumNodes:"){
785 errorLog << __GRT_LOG__ <<
" Could not find the NumNodes header!" << std::endl;
791 if(word !=
"NodeClusters:"){
792 errorLog << __GRT_LOG__ <<
" Could not find the NodeClusters header!" << std::endl;
798 for(UINT i=0; i<numNodes; i++){
803 for(UINT j=0; j<numInputDimensions; j++){
808 nodeClusters[ nodeID ] = cluster;
817 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
819 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
844 if( decisionTreeNode == NULL ){
857 if( decisionTreeNode != NULL ){
858 decisionTreeNode->
clear();
859 delete decisionTreeNode;
860 decisionTreeNode = NULL;
892 VectorFloat classProbs = trainingData.getClassProbabilities( classLabels );
895 node->initNode( parent, depth, nodeID );
898 if( trainingData.
getNumClasses() == 1 || features.size() == 0 || M < minNumSamplesPerNode || depth >= maxDepth ){
904 if( useNullRejection ){
905 nodeClusters[ nodeID ] = trainingData.
getMean();
908 std::string info =
"Reached leaf node.";
909 if( trainingData.
getNumClasses() == 1 ) info =
"Reached pure leaf node.";
910 else if( features.size() == 0 ) info =
"Reached leaf node, no remaining features.";
911 else if( M < minNumSamplesPerNode ) info =
"Reached leaf node, hit min-samples-per-node limit.";
912 else if( depth >= maxDepth ) info =
"Reached leaf node, max depth reached.";
914 trainingLog << info <<
" Depth: " << depth <<
" NumSamples: " << trainingData.
getNumSamples();
916 trainingLog <<
" Class Probabilities: ";
917 for(UINT k=0; k<classProbs.
getSize(); k++){
918 trainingLog << classProbs[k] <<
" ";
920 trainingLog << std::endl;
926 UINT featureIndex = 0;
929 if( !node->
computeBestSplit( trainingMode, numSplittingSteps, trainingData, features, classLabels, featureIndex, minError ) ){
934 trainingLog <<
"Depth: " << depth <<
" FeatureIndex: " << featureIndex <<
" MinError: " << minError;
935 trainingLog <<
" Class Probabilities: ";
936 for(UINT k=0; k<classProbs.
getSize(); k++){
937 trainingLog << classProbs[k] <<
" ";
939 trainingLog << std::endl;
942 if( removeFeaturesAtEachSplit ){
943 for(UINT i=0; i<features.
getSize(); i++){
944 if( features[i] == featureIndex ){
945 features.erase( features.begin()+i );
959 for(UINT i=0; i<M; i++){
960 if( node->
predict_( trainingData[i].getSample() ) ){
961 rhs.
addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
962 }
else lhs.
addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
966 trainingData.
clear();
969 UINT leftNodeID = ++nodeID;
970 UINT rightNodeID = ++nodeID;
973 node->setLeftChild( buildTree( lhs, node, features, classLabels, leftNodeID ) );
974 node->setRightChild( buildTree( rhs, node, features, classLabels, rightNodeID ) );
977 if( useNullRejection ){
978 nodeClusters[ leftNodeID ] = lhs.
getMean();
979 nodeClusters[ rightNodeID ] = rhs.
getMean();
985 Float DecisionTree::getNodeDistance(
const VectorFloat &x,
const UINT nodeID ){
988 std::map< UINT,VectorFloat >::iterator iter = nodeClusters.find( nodeID );
991 if( iter == nodeClusters.end() )
return NAN;
994 return getNodeDistance( x, iter->second );
1002 for(UINT i=0; i<N; i++){
1003 distance += MLBase::SQR( x[i] - y[i] );
1011 return trainingMode;
1015 return numSplittingSteps;
1019 return minNumSamplesPerNode;
1036 return removeFeaturesAtEachSplit;
1040 if( trainingMode >= Tree::BEST_ITERATIVE_SPILT && trainingMode < Tree::NUM_TRAINING_MODES ){
1041 this->trainingMode = trainingMode;
1044 warningLog << __GRT_LOG__ <<
" Unknown trainingMode: " << trainingMode << std::endl;
1049 if( numSplittingSteps > 0 ){
1050 this->numSplittingSteps = numSplittingSteps;
1053 warningLog << __GRT_LOG__ <<
" The number of splitting steps must be greater than zero!" << std::endl;
1058 if( minNumSamplesPerNode > 0 ){
1059 this->minNumSamplesPerNode = minNumSamplesPerNode;
1062 warningLog << __GRT_LOG__ <<
" The minimum number of samples per node must be greater than zero!" << std::endl;
1068 this->maxDepth = maxDepth;
1071 warningLog << __GRT_LOG__ <<
" The maximum depth must be greater than zero!" << std::endl;
1076 this->removeFeaturesAtEachSplit = removeFeaturesAtEachSplit;
1080 bool DecisionTree::setRemoveFeaturesAtEachSpilt(
const bool removeFeaturesAtEachSpilt){
1087 UINT tempTrainingMode = 0;
1090 if(word !=
"NumFeatures:"){
1091 errorLog << __GRT_LOG__ <<
" Could not find NumFeatures!" << std::endl;
1094 file >> numInputDimensions;
1097 if(word !=
"NumClasses:"){
1098 errorLog << __GRT_LOG__ <<
" Could not find NumClasses!" << std::endl;
1104 if(word !=
"UseScaling:"){
1105 errorLog << __GRT_LOG__ <<
" Could not find UseScaling!" << std::endl;
1111 if(word !=
"UseNullRejection:"){
1112 errorLog << __GRT_LOG__ <<
" Could not find UseNullRejection!" << std::endl;
1115 file >> useNullRejection;
1120 ranges.
resize( numInputDimensions );
1123 if(word !=
"Ranges:"){
1124 errorLog << __GRT_LOG__ <<
" Could not find the Ranges!" << std::endl;
1127 for(UINT n=0; n<ranges.size(); n++){
1128 file >> ranges[n].minValue;
1129 file >> ranges[n].maxValue;
1134 if(word !=
"NumSplittingSteps:"){
1135 errorLog << __GRT_LOG__ <<
" Could not find the NumSplittingSteps!" << std::endl;
1138 file >> numSplittingSteps;
1141 if(word !=
"MinNumSamplesPerNode:"){
1142 errorLog << __GRT_LOG__ <<
" Could not find the MinNumSamplesPerNode!" << std::endl;
1145 file >> minNumSamplesPerNode;
1148 if(word !=
"MaxDepth:"){
1149 errorLog << __GRT_LOG__ <<
" Could not find the MaxDepth!" << std::endl;
1155 if(word !=
"RemoveFeaturesAtEachSpilt:"){
1156 errorLog << __GRT_LOG__ <<
" Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1159 file >> removeFeaturesAtEachSplit;
1162 if(word !=
"TrainingMode:"){
1163 errorLog << __GRT_LOG__ <<
" Could not find the TrainingMode!" << std::endl;
1166 file >> tempTrainingMode;
1167 trainingMode =
static_cast<Tree::TrainingMode
>(tempTrainingMode);
1170 if(word !=
"TreeBuilt:"){
1171 errorLog << __GRT_LOG__ <<
" Could not find the TreeBuilt!" << std::endl;
1178 if(word !=
"Tree:"){
1179 errorLog << __GRT_LOG__ <<
" Could not find the Tree!" << std::endl;
1188 errorLog << __GRT_LOG__ <<
" Failed to create new DecisionTreeNode!" << std::endl;
1192 tree->setParent( NULL );
1193 if( !tree->
load( file ) ){
1195 errorLog << __GRT_LOG__ <<
" Failed to load tree from file!" << std::endl;
1203 bool DecisionTree::loadLegacyModelFromFile_v2( std::fstream &file ){
1206 UINT tempTrainingMode = 0;
1210 errorLog <<
"load(string filename) - Failed to load base settings from file!" << std::endl;
1215 if(word !=
"NumSplittingSteps:"){
1216 errorLog <<
"load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1219 file >> numSplittingSteps;
1222 if(word !=
"MinNumSamplesPerNode:"){
1223 errorLog <<
"load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1226 file >> minNumSamplesPerNode;
1229 if(word !=
"MaxDepth:"){
1230 errorLog <<
"load(string filename) - Could not find the MaxDepth!" << std::endl;
1236 if(word !=
"RemoveFeaturesAtEachSpilt:"){
1237 errorLog <<
"load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1240 file >> removeFeaturesAtEachSplit;
1243 if(word !=
"TrainingMode:"){
1244 errorLog <<
"load(string filename) - Could not find the TrainingMode!" << std::endl;
1247 file >> tempTrainingMode;
1248 trainingMode =
static_cast<Tree::TrainingMode
>(tempTrainingMode);
1251 if(word !=
"TreeBuilt:"){
1252 errorLog <<
"load(string filename) - Could not find the TreeBuilt!" << std::endl;
1259 if(word !=
"Tree:"){
1260 errorLog <<
"load(string filename) - Could not find the Tree!" << std::endl;
1269 errorLog <<
"load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1273 tree->setParent( NULL );
1274 if( !tree->
load( file ) ){
1276 errorLog <<
"load(fstream &file) - Failed to load tree from file!" << std::endl;
1285 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1287 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1293 bool DecisionTree::loadLegacyModelFromFile_v3( std::fstream &file ){
1296 UINT tempTrainingMode = 0;
1300 errorLog <<
"load(string filename) - Failed to load base settings from file!" << std::endl;
1305 if(word !=
"NumSplittingSteps:"){
1306 errorLog <<
"load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1309 file >> numSplittingSteps;
1312 if(word !=
"MinNumSamplesPerNode:"){
1313 errorLog <<
"load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1316 file >> minNumSamplesPerNode;
1319 if(word !=
"MaxDepth:"){
1320 errorLog <<
"load(string filename) - Could not find the MaxDepth!" << std::endl;
1326 if(word !=
"RemoveFeaturesAtEachSpilt:"){
1327 errorLog <<
"load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1330 file >> removeFeaturesAtEachSplit;
1333 if(word !=
"TrainingMode:"){
1334 errorLog <<
"load(string filename) - Could not find the TrainingMode!" << std::endl;
1337 file >> tempTrainingMode;
1338 trainingMode =
static_cast<Tree::TrainingMode
>(tempTrainingMode);
1341 if(word !=
"TreeBuilt:"){
1342 errorLog <<
"load(string filename) - Could not find the TreeBuilt!" << std::endl;
1349 if(word !=
"Tree:"){
1350 errorLog <<
"load(string filename) - Could not find the Tree!" << std::endl;
1359 errorLog <<
"load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1363 tree->setParent( NULL );
1364 if( !tree->
load( file ) ){
1366 errorLog <<
"load(fstream &file) - Failed to load tree from file!" << std::endl;
1371 if( useNullRejection ){
1374 classClusterMean.
resize( numClasses );
1375 classClusterStdDev.
resize( numClasses );
1378 if(word !=
"ClassClusterMean:"){
1379 errorLog <<
"load(string filename) - Could not find the ClassClusterMean header!" << std::endl;
1382 for(UINT k=0; k<numClasses; k++){
1383 file >> classClusterMean[k];
1387 if(word !=
"ClassClusterStdDev:"){
1388 errorLog <<
"load(string filename) - Could not find the ClassClusterStdDev header!" << std::endl;
1391 for(UINT k=0; k<numClasses; k++){
1392 file >> classClusterStdDev[k];
1396 if(word !=
"NumNodes:"){
1397 errorLog <<
"load(string filename) - Could not find the NumNodes header!" << std::endl;
1403 if(word !=
"NodeClusters:"){
1404 errorLog <<
"load(string filename) - Could not find the NodeClusters header!" << std::endl;
1410 for(UINT i=0; i<numNodes; i++){
1415 for(UINT j=0; j<numInputDimensions; j++){
1420 nodeClusters[ nodeID ] = cluster;
1429 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1431 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
bool saveBaseSettingsToFile(std::fstream &file) const
virtual bool computeBestSplit(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError)
bool setMinNumSamplesPerNode(const UINT minNumSamplesPerNode)
#define DEFAULT_NULL_LIKELIHOOD_VALUE
virtual bool clear() override
bool setAll(const T &value)
bool addSample(const UINT classLabel, const VectorFloat &sample)
virtual bool save(std::fstream &file) const override
std::string getClassifierType() const
const DecisionTreeNode * getTree() const
DecisionTreeNode * deepCopyDecisionTreeNode() const
bool loadLegacyModelFromFile_v1(std::fstream &file)
virtual bool resize(const unsigned int size)
UINT getNumSplittingSteps() const
std::string getNodeType() const
bool setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit)
bool setTrainingMode(const Tree::TrainingMode trainingMode)
virtual bool predict_(VectorFloat &x, VectorFloat &classLikelihoods) override
Vector< UINT > getClassLabels() const
UINT getPredictedNodeID() const
virtual bool recomputeNullRejectionThresholds() override
virtual Node * deepCopy() const override
static std::string getId()
virtual bool getModel(std::ostream &stream) const override
DecisionTreeNode * deepCopyTree() const
UINT getNumSamples() const
bool getRemoveFeaturesAtEachSplit() const
bool setMaxDepth(const UINT maxDepth)
bool setNumSplittingSteps(const UINT numSplittingSteps)
bool copyBaseVariables(const Classifier *classifier)
bool loadBaseSettingsFromFile(std::fstream &file)
virtual bool save(std::fstream &file) const override
virtual bool getModel(std::ostream &stream) const override
UINT getNumDimensions() const
UINT getNumClasses() const
DecisionTree(const DecisionTreeNode &decisionTreeNode=DecisionTreeClusterNode(), const UINT minNumSamplesPerNode=5, const UINT maxDepth=10, const bool removeFeaturesAtEachSplit=false, const Tree::TrainingMode trainingMode=Tree::TrainingMode::BEST_ITERATIVE_SPILT, const UINT numSplittingSteps=100, const bool useScaling=false)
Node * createNewInstance() const
UINT getPredictedNodeID() const
UINT getClassLabelIndexValue(const UINT classLabel) const
DecisionTree & operator=(const DecisionTree &rhs)
bool setLeafNode(const UINT nodeSize, const VectorFloat &classProbabilities)
virtual bool load(std::fstream &file) override
virtual bool deepCopyFrom(const Classifier *classifier) override
virtual bool clear() override
Vector< MinMax > getRanges() const
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
Tree::TrainingMode getTrainingMode() const
bool reserve(const UINT M)
virtual bool train_(ClassificationData &trainingData) override
static Node * createInstanceFromString(std::string const &nodeType)
static unsigned int getMaxIndex(const VectorFloat &x)
virtual bool load(std::fstream &file) override
bool setDecisionTreeNode(const DecisionTreeNode &node)
virtual ~DecisionTree(void)
bool scale(const Float minTarget, const Float maxTarget)
virtual bool predict_(VectorFloat &inputVector) override
UINT getMinNumSamplesPerNode() const
This is the main base class that all GRT Classification algorithms should inherit from...
VectorFloat getMean() const