21 #define GRT_DLL_EXPORTS
27 std::string DecisionTree::id =
"DecisionTree";
33 DecisionTree::DecisionTree(
const DecisionTreeNode &decisionTreeNode,
const UINT minNumSamplesPerNode,
const UINT maxDepth,
const bool removeFeaturesAtEachSpilt,
const UINT trainingMode,
const UINT numSplittingSteps,
const bool useScaling){
36 this->decisionTreeNode = NULL;
37 this->minNumSamplesPerNode = minNumSamplesPerNode;
38 this->maxDepth = maxDepth;
39 this->removeFeaturesAtEachSpilt = removeFeaturesAtEachSpilt;
40 this->trainingMode = trainingMode;
41 this->numSplittingSteps = numSplittingSteps;
42 this->useScaling = useScaling;
43 this->supportsNullRejection =
true;
45 classifierType = Classifier::classType;
46 classifierMode = STANDARD_CLASSIFIER_MODE;
52 this->decisionTreeNode = decisionTreeNode.
deepCopy();
58 decisionTreeNode = NULL;
60 classifierType = Classifier::classType;
61 classifierMode = STANDARD_CLASSIFIER_MODE;
73 if( decisionTreeNode != NULL ){
74 delete decisionTreeNode;
75 decisionTreeNode = NULL;
90 if( this->decisionTreeNode != NULL ){
91 delete decisionTreeNode;
92 decisionTreeNode = NULL;
96 this->minNumSamplesPerNode = rhs.minNumSamplesPerNode;
97 this->maxDepth = rhs.maxDepth;
98 this->removeFeaturesAtEachSpilt = rhs.removeFeaturesAtEachSpilt;
99 this->trainingMode = rhs.trainingMode;
100 this->numSplittingSteps = rhs.numSplittingSteps;
101 this->nodeClusters = rhs.nodeClusters;
111 if( classifier == NULL )
return false;
126 if( this->decisionTreeNode != NULL ){
127 delete decisionTreeNode;
128 decisionTreeNode = NULL;
132 this->minNumSamplesPerNode = ptr->minNumSamplesPerNode;
133 this->maxDepth = ptr->maxDepth;
134 this->removeFeaturesAtEachSpilt = ptr->removeFeaturesAtEachSpilt;
135 this->trainingMode = ptr->trainingMode;
136 this->numSplittingSteps = ptr->numSplittingSteps;
137 this->nodeClusters = ptr->nodeClusters;
150 if( decisionTreeNode == NULL ){
151 Classifier::errorLog <<
"train_(ClassificationData &trainingData) - The decision tree node has not been set! You must set this first before training a model." << std::endl;
160 Classifier::errorLog <<
"train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl;
164 numInputDimensions = N;
171 if( useValidationSet ){
172 validationData = trainingData.
split( validationSetSize );
173 validationSetAccuracy = 0;
174 validationSetPrecision.
resize( useNullRejection ? K+1 : K, 0 );
175 validationSetRecall.
resize( useNullRejection ? K+1 : K, 0 );
181 trainingData.
scale(0, 1);
186 if( useNullRejection ){
187 trainingDataCopy = trainingData;
192 for(UINT i=0; i<N; i++){
198 tree = buildTree( trainingData, NULL, features, classLabels, nodeID );
202 Classifier::errorLog <<
"train_(ClassificationData &trainingData) - Failed to build tree!" << std::endl;
210 if( useNullRejection ){
217 for(UINT i=0; i<M; i++){
219 if( !tree->
predict( trainingDataCopy[i].getSample(), classLikelihoods ) ){
220 Classifier::errorLog <<
"train_(ClassificationData &trainingData) - Failed to predict training sample while building null rejection model!" << std::endl;
226 distances[i] = getNodeDistance(trainingDataCopy[i].getSample(), tree->
getPredictedNodeID() );
228 classCounter[ predictions[i] ]++;
232 classClusterMean.clear();
233 classClusterStdDev.clear();
234 classClusterMean.
resize( numClasses, 0 );
235 classClusterStdDev.
resize( numClasses, 0.01 );
237 for(UINT i=0; i<M; i++){
238 classClusterMean[ predictions[i] ] += distances[ i ];
240 for(UINT k=0; k<numClasses; k++){
241 classClusterMean[k] /= grt_max( classCounter[k], 1 );
245 for(UINT i=0; i<M; i++){
246 classClusterStdDev[ predictions[i] ] += MLBase::SQR( distances[ i ] - classClusterMean[ predictions[i] ] );
248 for(UINT k=0; k<numClasses; k++){
249 classClusterStdDev[k] = sqrt( classClusterStdDev[k] / grt_max( classCounter[k], 1 ) );
256 if( useValidationSet ){
258 double numCorrect = 0;
261 VectorFloat validationSetPrecisionCounter( validationSetPrecision.size(), 0.0 );
262 VectorFloat validationSetRecallCounter( validationSetRecall.size(), 0.0 );
263 Classifier::trainingLog <<
"Testing model with validation set..." << std::endl;
264 for(UINT i=0; i<numTestSamples; i++){
265 testLabel = validationData[i].getClassLabel();
266 testSample = validationData[i].getSample();
268 if( predictedClassLabel == testLabel ){
277 validationSetAccuracy = (numCorrect / numTestSamples) * 100.0;
278 for(UINT i=0; i<validationSetPrecision.
getSize(); i++){
279 validationSetPrecision[i] /= validationSetPrecisionCounter[i] > 0 ? validationSetPrecisionCounter[i] : 1;
281 for(UINT i=0; i<validationSetRecall.
getSize(); i++){
282 validationSetRecall[i] /= validationSetRecallCounter[i] > 0 ? validationSetRecallCounter[i] : 1;
285 Classifier::trainingLog <<
"Validation set accuracy: " << validationSetAccuracy << std::endl;
287 Classifier::trainingLog <<
"Validation set precision: ";
288 for(UINT i=0; i<validationSetPrecision.
getSize(); i++){
289 Classifier::trainingLog << validationSetPrecision[i] <<
" ";
291 Classifier::trainingLog << std::endl;
293 Classifier::trainingLog <<
"Validation set recall: ";
294 for(UINT i=0; i<validationSetRecall.
getSize(); i++){
295 Classifier::trainingLog << validationSetRecall[i] <<
" ";
297 Classifier::trainingLog << std::endl;
305 predictedClassLabel = 0;
310 Classifier::errorLog <<
"predict_(VectorFloat &inputVector) - Model Not Trained!" << std::endl;
315 Classifier::errorLog <<
"predict_(VectorFloat &inputVector) - DecisionTree pointer is null!" << std::endl;
319 if( inputVector.
getSize() != numInputDimensions ){
320 Classifier::errorLog <<
"predict_(VectorFloat &inputVector) - The size of the input Vector (" << inputVector.
getSize() <<
") does not match the num features in the model (" << numInputDimensions << std::endl;
326 for(UINT n=0; n<numInputDimensions; n++){
327 inputVector[n] = grt_scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue, 0.0, 1.0);
331 if( classLikelihoods.size() != numClasses ) classLikelihoods.
resize(numClasses,0);
332 if( classDistances.size() != numClasses ) classDistances.
resize(numClasses,0);
335 if( !tree->
predict( inputVector, classLikelihoods ) ){
336 Classifier::errorLog <<
"predict_(VectorFloat &inputVector) - Failed to predict!" << std::endl;
344 for(UINT k=0; k<numClasses; k++){
345 if( classLikelihoods[k] > maxLikelihood ){
346 maxLikelihood = classLikelihoods[k];
352 if( useNullRejection ){
357 if( grt_isnan(leafDistance) ){
358 Classifier::errorLog <<
"predict_(VectorFloat &inputVector) - Failed to match leaf node ID to compute node distance!" << std::endl;
363 classDistances.
setAll(0.0);
364 classDistances[ maxIndex ] = leafDistance;
367 if( leafDistance <= nullRejectionThresholds[ maxIndex ] ){
368 predictedClassLabel = classLabels[ maxIndex ];
369 }
else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
373 predictedClassLabel = classLabels[ maxIndex ];
385 nodeClusters.clear();
402 Classifier::warningLog <<
"recomputeNullRejectionThresholds() - Failed to recompute null rejection thresholds, the model has not been trained!" << std::endl;
406 if( !useNullRejection ){
407 Classifier::warningLog <<
"recomputeNullRejectionThresholds() - Failed to recompute null rejection thresholds, null rejection is not enabled!" << std::endl;
411 nullRejectionThresholds.
resize( numClasses );
414 for(UINT k=0; k<numClasses; k++){
415 nullRejectionThresholds[k] = classClusterMean[k] + (classClusterStdDev[k]*nullRejectionCoeff);
425 Classifier::errorLog <<
"save(fstream &file) - The file is not open!" << std::endl;
430 file <<
"GRT_DECISION_TREE_MODEL_FILE_V4.0\n";
434 Classifier::errorLog <<
"save(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
438 if( decisionTreeNode != NULL ){
439 file <<
"DecisionTreeNodeType: " << decisionTreeNode->
getNodeType() << std::endl;
440 if( !decisionTreeNode->
save( file ) ){
441 Classifier::errorLog <<
"save(fstream &file) - Failed to save decisionTreeNode settings to file!" << std::endl;
445 file <<
"DecisionTreeNodeType: " <<
"NULL" << std::endl;
448 file <<
"MinNumSamplesPerNode: " << minNumSamplesPerNode << std::endl;
449 file <<
"MaxDepth: " << maxDepth << std::endl;
450 file <<
"RemoveFeaturesAtEachSpilt: " << removeFeaturesAtEachSpilt << std::endl;
451 file <<
"TrainingMode: " << trainingMode << std::endl;
452 file <<
"NumSplittingSteps: " << numSplittingSteps << std::endl;
453 file <<
"TreeBuilt: " << (tree != NULL ? 1 : 0) << std::endl;
457 if( !tree->
save( file ) ){
458 Classifier::errorLog <<
"save(fstream &file) - Failed to save tree to file!" << std::endl;
463 if( useNullRejection ){
465 file <<
"ClassClusterMean:";
466 for(UINT k=0; k<numClasses; k++){
467 file <<
" " << classClusterMean[k];
471 file <<
"ClassClusterStdDev:";
472 for(UINT k=0; k<numClasses; k++){
473 file <<
" " << classClusterStdDev[k];
477 file <<
"NumNodes: " << nodeClusters.size() << std::endl;
478 file <<
"NodeClusters:\n";
480 std::map< UINT, VectorFloat >::const_iterator iter = nodeClusters.begin();
482 while( iter != nodeClusters.end() ){
488 for(UINT j=0; j<numInputDimensions; j++){
489 file <<
" " << iter->second[j];
506 if( decisionTreeNode != NULL ){
507 delete decisionTreeNode;
508 decisionTreeNode = NULL;
511 if( !file.is_open() )
513 Classifier::errorLog <<
"load(std::fstream &file) - Could not open file to load model" << std::endl;
521 if( word ==
"GRT_DECISION_TREE_MODEL_FILE_V1.0" ){
525 if( word ==
"GRT_DECISION_TREE_MODEL_FILE_V2.0" ){
526 return loadLegacyModelFromFile_v2( file );
529 if( word ==
"GRT_DECISION_TREE_MODEL_FILE_V3.0" ){
530 return loadLegacyModelFromFile_v3( file );
534 if( word !=
"GRT_DECISION_TREE_MODEL_FILE_V4.0" ){
535 Classifier::errorLog <<
"load(string filename) - Could not find Model File Header" << std::endl;
541 Classifier::errorLog <<
"load(string filename) - Failed to load base settings from file!" << std::endl;
546 if(word !=
"DecisionTreeNodeType:"){
547 Classifier::errorLog <<
"load(string filename) - Could not find the DecisionTreeNodeType!" << std::endl;
552 if( word !=
"NULL" ){
556 if( decisionTreeNode == NULL ){
557 Classifier::errorLog <<
"load(string filename) - Could not create new DecisionTreeNode from type: " << word << std::endl;
561 if( !decisionTreeNode->
load( file ) ){
562 Classifier::errorLog <<
"load(fstream &file) - Failed to load decisionTreeNode settings from file!" << std::endl;
566 Classifier::errorLog <<
"load(fstream &file) - Failed to load decisionTreeNode! DecisionTreeNodeType is NULL!" << std::endl;
571 if(word !=
"MinNumSamplesPerNode:"){
572 Classifier::errorLog <<
"load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
575 file >> minNumSamplesPerNode;
578 if(word !=
"MaxDepth:"){
579 Classifier::errorLog <<
"load(string filename) - Could not find the MaxDepth!" << std::endl;
585 if(word !=
"RemoveFeaturesAtEachSpilt:"){
586 Classifier::errorLog <<
"load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
589 file >> removeFeaturesAtEachSpilt;
592 if(word !=
"TrainingMode:"){
593 Classifier::errorLog <<
"load(string filename) - Could not find the TrainingMode!" << std::endl;
596 file >> trainingMode;
599 if(word !=
"NumSplittingSteps:"){
600 Classifier::errorLog <<
"load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
603 file >> numSplittingSteps;
606 if(word !=
"TreeBuilt:"){
607 Classifier::errorLog <<
"load(string filename) - Could not find the TreeBuilt!" << std::endl;
615 Classifier::errorLog <<
"load(string filename) - Could not find the Tree!" << std::endl;
624 Classifier::errorLog <<
"load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
628 tree->setParent( NULL );
629 if( !tree->
load( file ) ){
631 Classifier::errorLog <<
"load(fstream &file) - Failed to load tree from file!" << std::endl;
636 if( useNullRejection ){
639 classClusterMean.
resize( numClasses );
640 classClusterStdDev.
resize( numClasses );
643 if(word !=
"ClassClusterMean:"){
644 Classifier::errorLog <<
"load(string filename) - Could not find the ClassClusterMean header!" << std::endl;
647 for(UINT k=0; k<numClasses; k++){
648 file >> classClusterMean[k];
652 if(word !=
"ClassClusterStdDev:"){
653 Classifier::errorLog <<
"load(string filename) - Could not find the ClassClusterStdDev header!" << std::endl;
656 for(UINT k=0; k<numClasses; k++){
657 file >> classClusterStdDev[k];
661 if(word !=
"NumNodes:"){
662 Classifier::errorLog <<
"load(string filename) - Could not find the NumNodes header!" << std::endl;
668 if(word !=
"NodeClusters:"){
669 Classifier::errorLog <<
"load(string filename) - Could not find the NodeClusters header!" << std::endl;
675 for(UINT i=0; i<numNodes; i++){
680 for(UINT j=0; j<numInputDimensions; j++){
685 nodeClusters[ nodeID ] = cluster;
694 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
696 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
721 if( decisionTreeNode == NULL ){
725 return decisionTreeNode->
deepCopy();
734 if( decisionTreeNode != NULL ){
735 delete decisionTreeNode;
736 decisionTreeNode = NULL;
738 this->decisionTreeNode = node.
deepCopy();
768 VectorFloat classProbs = trainingData.getClassProbabilities( classLabels );
771 node->initNode( parent, depth, nodeID );
774 if( trainingData.
getNumClasses() == 1 || features.size() == 0 || M < minNumSamplesPerNode || depth >= maxDepth ){
780 if( useNullRejection ){
781 nodeClusters[ nodeID ] = trainingData.
getMean();
784 std::string info =
"Reached leaf node.";
785 if( trainingData.
getNumClasses() == 1 ) info =
"Reached pure leaf node.";
786 else if( features.size() == 0 ) info =
"Reached leaf node, no remaining features.";
787 else if( M < minNumSamplesPerNode ) info =
"Reached leaf node, hit min-samples-per-node limit.";
788 else if( depth >= maxDepth ) info =
"Reached leaf node, max depth reached.";
790 Classifier::trainingLog << info <<
" Depth: " << depth <<
" NumSamples: " << trainingData.
getNumSamples();
792 Classifier::trainingLog <<
" Class Probabilities: ";
793 for(UINT k=0; k<classProbs.
getSize(); k++){
794 Classifier::trainingLog << classProbs[k] <<
" ";
796 Classifier::trainingLog << std::endl;
802 UINT featureIndex = 0;
805 if( !node->
computeBestSpilt( trainingMode, numSplittingSteps, trainingData, features, classLabels, featureIndex, minError ) ){
810 Classifier::trainingLog <<
"Depth: " << depth <<
" FeatureIndex: " << featureIndex <<
" MinError: " << minError;
811 Classifier::trainingLog <<
" Class Probabilities: ";
812 for(
size_t k=0; k<classProbs.size(); k++){
813 Classifier::trainingLog << classProbs[k] <<
" ";
815 Classifier::trainingLog << std::endl;
818 if( removeFeaturesAtEachSpilt ){
819 for(
size_t i=0; i<features.size(); i++){
820 if( features[i] == featureIndex ){
821 features.erase( features.begin()+i );
835 for(UINT i=0; i<M; i++){
836 if( node->
predict( trainingData[i].getSample() ) ){
837 rhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
838 }
else lhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
842 trainingData.
clear();
845 UINT leftNodeID = ++nodeID;
846 UINT rightNodeID = ++nodeID;
849 node->setLeftChild( buildTree( lhs, node, features, classLabels, leftNodeID ) );
850 node->setRightChild( buildTree( rhs, node, features, classLabels, rightNodeID ) );
853 if( useNullRejection ){
854 nodeClusters[ leftNodeID ] = lhs.getMean();
855 nodeClusters[ rightNodeID ] = rhs.getMean();
861 Float DecisionTree::getNodeDistance(
const VectorFloat &x,
const UINT nodeID ){
864 std::map< UINT,VectorFloat >::iterator iter = nodeClusters.find( nodeID );
867 if( iter == nodeClusters.end() )
return NAN;
870 return getNodeDistance( x, iter->second );
876 const size_t N = x.size();
878 for(
size_t i=0; i<N; i++){
879 distance += MLBase::SQR( x[i] - y[i] );
891 if(word !=
"NumFeatures:"){
892 Classifier::errorLog <<
"load(string filename) - Could not find NumFeatures!" << std::endl;
895 file >> numInputDimensions;
898 if(word !=
"NumClasses:"){
899 Classifier::errorLog <<
"load(string filename) - Could not find NumClasses!" << std::endl;
905 if(word !=
"UseScaling:"){
906 Classifier::errorLog <<
"load(string filename) - Could not find UseScaling!" << std::endl;
912 if(word !=
"UseNullRejection:"){
913 Classifier::errorLog <<
"load(string filename) - Could not find UseNullRejection!" << std::endl;
916 file >> useNullRejection;
921 ranges.
resize( numInputDimensions );
924 if(word !=
"Ranges:"){
925 Classifier::errorLog <<
"load(string filename) - Could not find the Ranges!" << std::endl;
928 for(UINT n=0; n<ranges.size(); n++){
929 file >> ranges[n].minValue;
930 file >> ranges[n].maxValue;
935 if(word !=
"NumSplittingSteps:"){
936 Classifier::errorLog <<
"load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
939 file >> numSplittingSteps;
942 if(word !=
"MinNumSamplesPerNode:"){
943 Classifier::errorLog <<
"load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
946 file >> minNumSamplesPerNode;
949 if(word !=
"MaxDepth:"){
950 Classifier::errorLog <<
"load(string filename) - Could not find the MaxDepth!" << std::endl;
956 if(word !=
"RemoveFeaturesAtEachSpilt:"){
957 Classifier::errorLog <<
"load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
960 file >> removeFeaturesAtEachSpilt;
963 if(word !=
"TrainingMode:"){
964 Classifier::errorLog <<
"load(string filename) - Could not find the TrainingMode!" << std::endl;
967 file >> trainingMode;
970 if(word !=
"TreeBuilt:"){
971 Classifier::errorLog <<
"load(string filename) - Could not find the TreeBuilt!" << std::endl;
979 Classifier::errorLog <<
"load(string filename) - Could not find the Tree!" << std::endl;
988 Classifier::errorLog <<
"load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
992 tree->setParent( NULL );
993 if( !tree->
load( file ) ){
995 Classifier::errorLog <<
"load(fstream &file) - Failed to load tree from file!" << std::endl;
1003 bool DecisionTree::loadLegacyModelFromFile_v2( std::fstream &file ){
1009 Classifier::errorLog <<
"load(string filename) - Failed to load base settings from file!" << std::endl;
1014 if(word !=
"NumSplittingSteps:"){
1015 Classifier::errorLog <<
"load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1018 file >> numSplittingSteps;
1021 if(word !=
"MinNumSamplesPerNode:"){
1022 Classifier::errorLog <<
"load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1025 file >> minNumSamplesPerNode;
1028 if(word !=
"MaxDepth:"){
1029 Classifier::errorLog <<
"load(string filename) - Could not find the MaxDepth!" << std::endl;
1035 if(word !=
"RemoveFeaturesAtEachSpilt:"){
1036 Classifier::errorLog <<
"load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1039 file >> removeFeaturesAtEachSpilt;
1042 if(word !=
"TrainingMode:"){
1043 Classifier::errorLog <<
"load(string filename) - Could not find the TrainingMode!" << std::endl;
1046 file >> trainingMode;
1049 if(word !=
"TreeBuilt:"){
1050 Classifier::errorLog <<
"load(string filename) - Could not find the TreeBuilt!" << std::endl;
1057 if(word !=
"Tree:"){
1058 Classifier::errorLog <<
"load(string filename) - Could not find the Tree!" << std::endl;
1067 Classifier::errorLog <<
"load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1071 tree->setParent( NULL );
1072 if( !tree->
load( file ) ){
1074 Classifier::errorLog <<
"load(fstream &file) - Failed to load tree from file!" << std::endl;
1083 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1085 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1091 bool DecisionTree::loadLegacyModelFromFile_v3( std::fstream &file ){
1097 Classifier::errorLog <<
"load(string filename) - Failed to load base settings from file!" << std::endl;
1102 if(word !=
"NumSplittingSteps:"){
1103 Classifier::errorLog <<
"load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
1106 file >> numSplittingSteps;
1109 if(word !=
"MinNumSamplesPerNode:"){
1110 Classifier::errorLog <<
"load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
1113 file >> minNumSamplesPerNode;
1116 if(word !=
"MaxDepth:"){
1117 Classifier::errorLog <<
"load(string filename) - Could not find the MaxDepth!" << std::endl;
1123 if(word !=
"RemoveFeaturesAtEachSpilt:"){
1124 Classifier::errorLog <<
"load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
1127 file >> removeFeaturesAtEachSpilt;
1130 if(word !=
"TrainingMode:"){
1131 Classifier::errorLog <<
"load(string filename) - Could not find the TrainingMode!" << std::endl;
1134 file >> trainingMode;
1137 if(word !=
"TreeBuilt:"){
1138 Classifier::errorLog <<
"load(string filename) - Could not find the TreeBuilt!" << std::endl;
1145 if(word !=
"Tree:"){
1146 Classifier::errorLog <<
"load(string filename) - Could not find the Tree!" << std::endl;
1155 Classifier::errorLog <<
"load(fstream &file) - Failed to create new DecisionTreeNode!" << std::endl;
1159 tree->setParent( NULL );
1160 if( !tree->
load( file ) ){
1162 Classifier::errorLog <<
"load(fstream &file) - Failed to load tree from file!" << std::endl;
1167 if( useNullRejection ){
1170 classClusterMean.
resize( numClasses );
1171 classClusterStdDev.
resize( numClasses );
1174 if(word !=
"ClassClusterMean:"){
1175 Classifier::errorLog <<
"load(string filename) - Could not find the ClassClusterMean header!" << std::endl;
1178 for(UINT k=0; k<numClasses; k++){
1179 file >> classClusterMean[k];
1183 if(word !=
"ClassClusterStdDev:"){
1184 Classifier::errorLog <<
"load(string filename) - Could not find the ClassClusterStdDev header!" << std::endl;
1187 for(UINT k=0; k<numClasses; k++){
1188 file >> classClusterStdDev[k];
1192 if(word !=
"NumNodes:"){
1193 Classifier::errorLog <<
"load(string filename) - Could not find the NumNodes header!" << std::endl;
1199 if(word !=
"NodeClusters:"){
1200 Classifier::errorLog <<
"load(string filename) - Could not find the NodeClusters header!" << std::endl;
1206 for(UINT i=0; i<numNodes; i++){
1211 for(UINT j=0; j<numInputDimensions; j++){
1216 nodeClusters[ nodeID ] = cluster;
1225 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1227 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
bool saveBaseSettingsToFile(std::fstream &file) const
This class implements a basic Decision Tree classifier. Decision Trees are conceptually simple classi...
virtual bool load(std::fstream &file)
#define DEFAULT_NULL_LIKELIHOOD_VALUE
virtual bool computeBestSpilt(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError)
bool setAll(const T &value)
std::string getClassifierType() const
const DecisionTreeNode * getTree() const
DecisionTreeNode * deepCopyDecisionTreeNode() const
bool loadLegacyModelFromFile_v1(std::fstream &file)
virtual bool resize(const unsigned int size)
virtual bool train_(ClassificationData &trainingData)
std::string getNodeType() const
virtual bool getModel(std::ostream &stream) const
virtual bool predict_(VectorFloat &inputVector)
UINT getClassLabelIndexValue(UINT classLabel) const
Vector< UINT > getClassLabels() const
DecisionTree(const DecisionTreeNode &decisionTreeNode=DecisionTreeClusterNode(), const UINT minNumSamplesPerNode=5, const UINT maxDepth=10, const bool removeFeaturesAtEachSpilt=false, const UINT trainingMode=BEST_ITERATIVE_SPILT, const UINT numSplittingSteps=100, const bool useScaling=false)
virtual bool deepCopyFrom(const Classifier *classifier)
static std::string getId()
DecisionTreeNode * deepCopyTree() const
virtual bool save(std::fstream &file) const
virtual bool recomputeNullRejectionThresholds()
virtual bool save(std::fstream &file) const
UINT getNumSamples() const
virtual Node * deepCopyNode() const
bool copyBaseVariables(const Classifier *classifier)
bool loadBaseSettingsFromFile(std::fstream &file)
UINT getNumDimensions() const
UINT getNumClasses() const
Node * createNewInstance() const
UINT getPredictedNodeID() const
DecisionTree & operator=(const DecisionTree &rhs)
virtual bool predict(const VectorFloat &x, VectorFloat &classLikelihoods)
bool setLeafNode(const UINT nodeSize, const VectorFloat &classProbabilities)
DecisionTreeNode * deepCopy() const
Vector< MinMax > getRanges() const
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
static Node * createInstanceFromString(std::string const &nodeType)
static unsigned int getMaxIndex(const VectorFloat &x)
bool setDecisionTreeNode(const DecisionTreeNode &node)
virtual ~DecisionTree(void)
bool scale(const Float minTarget, const Float maxTarget)
virtual bool getModel(std::ostream &stream) const
virtual bool load(std::fstream &file)
virtual bool predict(const VectorFloat &x)
VectorFloat getMean() const