27 #ifndef GRT_DECISION_TREE_NODE_HEADER 28 #define GRT_DECISION_TREE_NODE_HEADER 30 #include "../../CoreAlgorithms/Tree/Node.h" 31 #include "../../CoreAlgorithms/Tree/Tree.h" 32 #include "../../DataStructures/ClassificationData.h" 88 virtual bool computeBestSplit(
const UINT &trainingMode,
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const Vector< UINT > &features,
const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError );
96 virtual bool clear()
override;
105 virtual bool getModel( std::ostream &stream )
const override;
120 UINT getNodeSize()
const;
127 UINT getNumClasses()
const;
143 bool setLeafNode(
const UINT nodeSize,
const VectorFloat &classProbabilities );
151 bool setNodeSize(
const UINT nodeSize);
159 bool setClassProbabilities(
const VectorFloat &classProbabilities);
161 static UINT getClassLabelIndexValue(UINT classLabel,
const Vector< UINT > &classLabels);
169 errorLog << __GRT_LOG__ <<
" Base class not overwritten!" << std::endl;
176 errorLog << __GRT_LOG__ <<
" Base class not overwritten!" << std::endl;
190 if( !file.is_open() )
192 errorLog << __GRT_LOG__ <<
" File is not open!" << std::endl;
197 file <<
"NodeSize: " << nodeSize << std::endl;
198 file <<
"NumClasses: " << classProbabilities.size() << std::endl;
199 file <<
"ClassProbabilities: ";
200 if( classProbabilities.size() > 0 ){
201 for(UINT i=0; i<classProbabilities.size(); i++){
202 file << classProbabilities[i];
203 if( i < classProbabilities.size()-1 ) file <<
"\t";
204 else file << std::endl;
219 if( !file.is_open() )
221 errorLog << __GRT_LOG__ <<
" File is not open!" << std::endl;
225 classProbabilities.clear();
232 if( word !=
"NodeSize:" ){
233 errorLog << __GRT_LOG__ <<
" Failed to find NodeSize header!" << std::endl;
239 if( word !=
"NumClasses:" ){
240 errorLog << __GRT_LOG__ <<
" Failed to find NumClasses header!" << std::endl;
245 classProbabilities.
resize( numClasses );
248 if( word !=
"ClassProbabilities:" ){
249 errorLog << __GRT_LOG__ <<
" Failed to find ClassProbabilities header!" << std::endl;
252 if( numClasses > 0 ){
253 for(UINT i=0; i<numClasses; i++){
254 file >> classProbabilities[i];
269 #endif //GRT_DECISION_TREE_NODE_HEADER
virtual bool predict(VectorFloat inputVector)
virtual bool getModel(std::ostream &stream) const override
virtual bool predict_(VectorFloat &x) override
virtual bool resize(const unsigned int size)
virtual bool loadParametersFromFile(std::fstream &file) override
virtual bool saveParametersToFile(std::fstream &file) const override
virtual bool clear() override
virtual Node * deepCopy() const