21 #define GRT_DLL_EXPORTS 30 const std::string RegressionTree::id =
"RegressionTree";
39 this->numSplittingSteps = numSplittingSteps;
40 this->minNumSamplesPerNode = minNumSamplesPerNode;
41 this->maxDepth = maxDepth;
42 this->removeFeaturesAtEachSpilt = removeFeaturesAtEachSpilt;
43 this->trainingMode = trainingMode;
44 this->useScaling = useScaling;
45 this->minRMSErrorPerNode = minRMSErrorPerNode;
69 this->numSplittingSteps = rhs.numSplittingSteps;
70 this->minNumSamplesPerNode = rhs.minNumSamplesPerNode;
71 this->maxDepth = rhs.maxDepth;
72 this->removeFeaturesAtEachSpilt = rhs.removeFeaturesAtEachSpilt;
73 this->trainingMode = rhs.trainingMode;
74 this->minRMSErrorPerNode = rhs.minRMSErrorPerNode;
84 if( regressifier == NULL )
return false;
98 this->numSplittingSteps = ptr->numSplittingSteps;
99 this->minNumSamplesPerNode = ptr->minNumSamplesPerNode;
100 this->maxDepth = ptr->maxDepth;
101 this->removeFeaturesAtEachSpilt = ptr->removeFeaturesAtEachSpilt;
102 this->trainingMode = ptr->trainingMode;
103 this->minRMSErrorPerNode = ptr->minRMSErrorPerNode;
121 Regressifier::errorLog <<
"train_(RegressionData &trainingData) - Training data has zero samples!" << std::endl;
125 numInputDimensions = N;
126 numOutputDimensions = T;
133 trainingData.
scale(0, 1);
138 for(UINT i=0; i<N; i++){
144 tree = buildTree( trainingData, NULL, features, nodeID );
148 Regressifier::errorLog <<
"train_(RegressionData &trainingData) - Failed to build tree!" << std::endl;
161 Regressifier::errorLog <<
"predict_(VectorFloat &inputVector) - Model Not Trained!" << std::endl;
166 Regressifier::errorLog <<
"predict_(VectorFloat &inputVector) - Tree pointer is null!" << std::endl;
170 if( inputVector.size() != numInputDimensions ){
171 Regressifier::errorLog <<
"predict_(VectorFloat &inputVector) - The size of the input Vector (" << inputVector.size() <<
") does not match the num features in the model (" << numInputDimensions << std::endl;
176 for(UINT n=0; n<numInputDimensions; n++){
177 inputVector[n] =
scale(inputVector[n], inputVectorRanges[n].minValue, inputVectorRanges[n].maxValue, 0, 1);
182 Regressifier::errorLog <<
"predict_(VectorFloat &inputVector) - Failed to predict!" << std::endl;
213 Regressifier::errorLog <<
"save(fstream &file) - The file is not open!" << std::endl;
218 file <<
"GRT_REGRESSION_TREE_MODEL_FILE_V1.0\n";
222 Regressifier::errorLog <<
"save(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
226 file <<
"NumSplittingSteps: " << numSplittingSteps << std::endl;
227 file <<
"MinNumSamplesPerNode: " << minNumSamplesPerNode << std::endl;
228 file <<
"MaxDepth: " << maxDepth << std::endl;
229 file <<
"RemoveFeaturesAtEachSpilt: " << removeFeaturesAtEachSpilt << std::endl;
230 file <<
"TrainingMode: " << trainingMode << std::endl;
231 file <<
"TreeBuilt: " << (
tree != NULL ? 1 : 0) << std::endl;
236 Regressifier::errorLog <<
"save(fstream &file) - Failed to save tree to file!" << std::endl;
250 Regressifier::errorLog <<
"load(string filename) - Could not open file to load model" << std::endl;
258 if(word !=
"GRT_REGRESSION_TREE_MODEL_FILE_V1.0"){
259 Regressifier::errorLog <<
"load(string filename) - Could not find Model File Header" << std::endl;
265 Regressifier::errorLog <<
"load(string filename) - Failed to load base settings from file!" << std::endl;
270 if(word !=
"NumSplittingSteps:"){
271 Regressifier::errorLog <<
"load(string filename) - Could not find the NumSplittingSteps!" << std::endl;
274 file >> numSplittingSteps;
277 if(word !=
"MinNumSamplesPerNode:"){
278 Regressifier::errorLog <<
"load(string filename) - Could not find the MinNumSamplesPerNode!" << std::endl;
281 file >> minNumSamplesPerNode;
284 if(word !=
"MaxDepth:"){
285 Regressifier::errorLog <<
"load(string filename) - Could not find the MaxDepth!" << std::endl;
291 if(word !=
"RemoveFeaturesAtEachSpilt:"){
292 Regressifier::errorLog <<
"load(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << std::endl;
295 file >> removeFeaturesAtEachSpilt;
298 if(word !=
"TrainingMode:"){
299 Regressifier::errorLog <<
"load(string filename) - Could not find the TrainingMode!" << std::endl;
302 UINT tempTrainingMode;
303 file >> tempTrainingMode;
304 trainingMode =
static_cast< Tree::TrainingMode
>( tempTrainingMode );
307 if(word !=
"TreeBuilt:"){
308 Regressifier::errorLog <<
"load(string filename) - Could not find the TreeBuilt!" << std::endl;
316 Regressifier::errorLog <<
"load(string filename) - Could not find the Tree!" << std::endl;
325 Regressifier::errorLog <<
"load(fstream &file) - Failed to create new RegressionTreeNode!" << std::endl;
329 tree->setParent( NULL );
332 Regressifier::errorLog <<
"load(fstream &file) - Failed to load tree from file!" << std::endl;
354 return minRMSErrorPerNode;
362 return numSplittingSteps;
366 return minNumSamplesPerNode;
383 return removeFeaturesAtEachSpilt;
387 if( trainingMode >= Tree::BEST_ITERATIVE_SPILT && trainingMode < Tree::NUM_TRAINING_MODES ){
388 this->trainingMode = trainingMode;
391 warningLog <<
"Unknown trainingMode: " << trainingMode << std::endl;
396 if( numSplittingSteps > 0 ){
397 this->numSplittingSteps = numSplittingSteps;
400 warningLog <<
"setNumSplittingSteps(const UINT numSplittingSteps) - The number of splitting steps must be greater than zero!" << std::endl;
405 if( minNumSamplesPerNode > 0 ){
406 this->minNumSamplesPerNode = minNumSamplesPerNode;
409 warningLog <<
"setMinNumSamplesPerNode(const UINT minNumSamplesPerNode) - The minimum number of samples per node must be greater than zero!" << std::endl;
415 this->maxDepth = maxDepth;
418 warningLog <<
"setMaxDepth(const UINT maxDepth) - The maximum depth must be greater than zero!" << std::endl;
423 this->removeFeaturesAtEachSpilt = removeFeaturesAtEachSpilt;
428 this->minRMSErrorPerNode = minRMSErrorPerNode;
458 node->initNode( parent, depth, nodeID );
461 if( features.size() == 0 || M < minNumSamplesPerNode || depth >= maxDepth ){
464 node->setIsLeafNode(
true );
467 computeNodeRegressionData( trainingData, regressionData );
472 Regressifier::trainingLog <<
"Reached leaf node. Depth: " << depth <<
" NumSamples: " << trainingData.
getNumSamples() << std::endl;
478 UINT featureIndex = 0;
481 if( !computeBestSpilt( trainingData, features, featureIndex, threshold, minError ) ){
486 trainingLog <<
"Depth: " << depth <<
" FeatureIndex: " << featureIndex <<
" Threshold: " << threshold <<
" MinError: " << minError << std::endl;
489 if( minError <= minRMSErrorPerNode ){
491 computeNodeRegressionData( trainingData, regressionData );
494 node->
set( trainingData.
getNumSamples(), featureIndex, threshold, regressionData );
496 trainingLog <<
"Reached leaf node. Depth: " << depth <<
" NumSamples: " << M << std::endl;
502 node->
set( trainingData.
getNumSamples(), featureIndex, threshold, regressionData );
505 if( removeFeaturesAtEachSpilt ){
506 for(UINT i=0; i<features.
getSize(); i++){
507 if( features[i] == featureIndex ){
508 features.erase( features.begin()+i );
518 for(UINT i=0; i<M; i++){
519 if( node->
predict( trainingData[i].getInputVector() ) ){
520 rhs.
addSample(trainingData[i].getInputVector(), trainingData[i].getTargetVector());
521 }
else lhs.
addSample(trainingData[i].getInputVector(), trainingData[i].getTargetVector());
525 node->setLeftChild( buildTree( lhs, node, features, nodeID ) );
526 node->setRightChild( buildTree( rhs, node, features, nodeID ) );
531 bool RegressionTree::computeBestSpilt(
const RegressionData &trainingData,
const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError ){
533 switch( trainingMode ){
534 case Tree::BEST_ITERATIVE_SPILT:
535 return computeBestSpiltBestIterativeSpilt( trainingData, features, featureIndex, threshold, minError );
537 case Tree::BEST_RANDOM_SPLIT:
541 Regressifier::errorLog <<
"Uknown trainingMode!" << std::endl;
549 bool RegressionTree::computeBestSpiltBestIterativeSpilt(
const RegressionData &trainingData,
const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError ){
552 const UINT N = (UINT)features.size();
554 if( N == 0 )
return false;
557 UINT bestFeatureIndex = 0;
559 Float bestThreshold = 0;
571 for(UINT n=0; n<N; n++){
572 minRange = ranges[n].minValue;
573 maxRange = ranges[n].maxValue;
574 step = (maxRange-minRange)/Float(numSplittingSteps);
575 threshold = minRange;
576 featureIndex = features[n];
577 while( threshold <= maxRange ){
580 for(UINT i=0; i<M; i++){
581 groupID = trainingData[i].getInputVector()[featureIndex] >= threshold ? 1 : 0;
582 groupIndex[i] = groupID;
583 groupMean[ groupID ] += trainingData[i].getInputVector()[featureIndex];
584 groupCounter[ groupID ]++;
586 groupMean[0] /= groupCounter[0] > 0 ? groupCounter[0] : 1;
587 groupMean[1] /= groupCounter[1] > 0 ? groupCounter[1] : 1;
590 for(UINT i=0; i<M; i++){
591 groupMSE[ groupIndex[i] ] += grt_sqr( groupMean[ groupIndex[i] ] - trainingData[ i ].getInputVector()[features[n]] );
593 groupMSE[0] /= groupCounter[0] > 0 ? groupCounter[0] : 1;
594 groupMSE[1] /= groupCounter[1] > 0 ? groupCounter[1] : 1;
596 error = sqrt( groupMSE[0] + groupMSE[1] );
599 if( error < minError ){
601 bestThreshold = threshold;
602 bestFeatureIndex = featureIndex;
611 featureIndex = bestFeatureIndex;
612 threshold = bestThreshold;
697 Regressifier::errorLog <<
"computeNodeRegressionData(...) - Failed to compute regression data, there are zero training samples!" << std::endl;
702 regressionData.clear();
703 regressionData.
resize( T, 0 );
706 for(
unsigned int j=0; j<N; j++){
707 for(
unsigned int i=0; i<M; i++){
708 regressionData[j] += trainingData[i].getTargetVector()[j];
710 regressionData[j] /= M;
std::string getId() const
virtual bool predict(VectorFloat inputVector)
bool setMinNumSamplesPerNode(const UINT minNumSamplesPerNode)
virtual bool print() const override
bool setMaxDepth(const UINT maxDepth)
Float getMinRMSErrorPerNode() const
virtual bool train_(RegressionData &trainingData) override
static std::string getId()
virtual bool predict_(VectorFloat &x) override
virtual bool load(std::fstream &file) override
virtual bool deepCopyFrom(const Regressifier *regressifier) override
virtual bool clear() override
Vector< MinMax > getInputRanges() const
const RegressionTreeNode * getTree() const
virtual bool resize(const unsigned int size)
virtual bool predict_(VectorFloat &inputVector) override
UINT getMinNumSamplesPerNode() const
bool copyBaseVariables(const Regressifier *regressifier)
UINT getNumInputDimensions() const
bool set(const UINT nodeSize, const UINT featureIndex, const Float threshold, const VectorFloat ®ressionData)
bool getRemoveFeaturesAtEachSpilt() const
bool setTrainingMode(const Tree::TrainingMode trainingMode)
This class implements a basic Regression Tree.
UINT getPredictedNodeID() const
Vector< MinMax > getTargetRanges() const
bool saveBaseSettingsToFile(std::fstream &file) const
bool scale(const Float minTarget, const Float maxTarget)
UINT getNumTargetDimensions() const
bool setMinRMSErrorPerNode(const Float minRMSErrorPerNode)
RegressionTree & operator=(const RegressionTree &rhs)
virtual bool save(std::fstream &file) const override
bool loadBaseSettingsFromFile(std::fstream &file)
bool setNumSplittingSteps(const UINT numSplittingSteps)
UINT getPredictedNodeID() const
UINT getNumSplittingSteps() const
virtual bool clear() override
bool setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt)
virtual bool print() const override
Node * tree
<Tell the compiler we are using the base class predict method to stop hidden virtual function warning...
RegressionTreeNode * deepCopyTree() const
virtual bool save(std::fstream &file) const override
RegressionTree(const UINT numSplittingSteps=100, const UINT minNumSamplesPerNode=5, const UINT maxDepth=10, const bool removeFeaturesAtEachSpilt=false, const Tree::TrainingMode trainingMode=Tree::BEST_ITERATIVE_SPILT, const bool useScaling=false, const Float minRMSErrorPerNode=0.01)
virtual bool load(std::fstream &file) override
Tree::TrainingMode getTrainingMode() const
virtual ~RegressionTree(void)
virtual bool clear() override
bool addSample(const VectorFloat &inputVector, const VectorFloat &targetVector)
UINT getNumSamples() const
virtual Node * deepCopy() const
Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)