2 #define GRT_DLL_EXPORTS
11 nodeType =
"DecisionTreeNode";
27 classLikelihoods = classProbabilities;
28 predictedNodeID = nodeID;
32 if( leftChild == NULL && rightChild == NULL )
37 if( rightChild->predict( x, classLikelihoods ) ){
38 predictedNodeID = rightChild->getPredictedNodeID();
41 warningLog <<
"predict(const VectorFloat &x,VectorFloat &classLikelihoods) - Right child failed prediction!" << std::endl;
46 if( leftChild->predict( x, classLikelihoods ) ){
47 predictedNodeID = leftChild->getPredictedNodeID();
50 warningLog <<
"predict(const VectorFloat &x,VectorFloat &classLikelihoods) - Left child failed prediction!" << std::endl;
60 switch( trainingMode ){
61 case Tree::BEST_ITERATIVE_SPILT:
62 return computeBestSpiltBestIterativeSpilt( numSplittingSteps, trainingData, features, classLabels, featureIndex, minError );
64 case Tree::BEST_RANDOM_SPLIT:
65 return computeBestSpiltBestRandomSpilt( numSplittingSteps, trainingData, features, classLabels, featureIndex, minError );
68 errorLog <<
"computeBestSpilt(...) - Uknown trainingMode!" << std::endl;
82 classProbabilities.clear();
90 for(UINT i=0; i<depth; i++) tab +=
"\t";
92 stream << tab <<
"depth: " << depth <<
" nodeSize: " << nodeSize <<
" isLeafNode: " << isLeafNode << std::endl;
93 stream << tab <<
"ClassProbabilities: ";
94 for(UINT i=0; i<classProbabilities.size(); i++){
95 stream << classProbabilities[i] <<
"\t";
99 if( leftChild != NULL ){
100 stream << tab <<
"LeftChild: " << std::endl;
101 leftChild->getModel( stream );
104 if( rightChild != NULL ){
105 stream << tab <<
"RightChild: " << std::endl;
106 rightChild->getModel( stream );
122 node->isLeafNode = isLeafNode;
123 node->nodeID = nodeID;
124 node->predictedNodeID = predictedNodeID;
125 node->nodeSize = nodeSize;
126 node->classProbabilities = classProbabilities;
131 node->leftChild->setParent( node );
137 node->rightChild->setParent( node );
152 return (UINT)classProbabilities.size();
156 return classProbabilities;
160 this->nodeSize = nodeSize;
161 this->classProbabilities = classProbabilities;
162 this->isLeafNode =
true;
167 this->nodeSize = nodeSize;
172 this->classProbabilities = classProbabilities;
176 UINT DecisionTreeNode::getClassLabelIndexValue(UINT classLabel,
const Vector< UINT > &classLabels){
177 const UINT N = classLabels.
getSize();
178 for(UINT i=0; i<N; i++){
179 if( classLabel == classLabels[i] )
virtual bool computeBestSpilt(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError)
bool setClassProbabilities(const VectorFloat &classProbabilities)
UINT getNumClasses() const
virtual Node * deepCopyNode() const
This file implements a DecisionTreeNode, which is a specific base node used for a DecisionTree...
virtual bool predict(const VectorFloat &x, VectorFloat &classLikelihoods)
bool setLeafNode(const UINT nodeSize, const VectorFloat &classProbabilities)
DecisionTreeNode * deepCopy() const
static Node * createInstanceFromString(std::string const &nodeType)
virtual Node * deepCopyNode() const
virtual ~DecisionTreeNode()
virtual bool getModel(std::ostream &stream) const
VectorFloat getClassProbabilities() const
bool setNodeSize(const UINT nodeSize)