26 #ifndef GRT_REGRESSION_TREE_HEADER 27 #define GRT_REGRESSION_TREE_HEADER 29 #include "../../CoreModules/Regressifier.h" 30 #include "../../CoreAlgorithms/Tree/Tree.h" 53 RegressionTree(
const UINT numSplittingSteps=100,
const UINT minNumSamplesPerNode=5,
const UINT maxDepth=10,
const bool removeFeaturesAtEachSpilt =
false,
const Tree::TrainingMode trainingMode = Tree::BEST_ITERATIVE_SPILT,
const bool useScaling=
false,
const Float minRMSErrorPerNode = 0.01);
108 virtual bool clear()
override;
115 virtual bool print()
const override;
124 virtual bool save( std::fstream &file )
const override;
133 virtual bool load( std::fstream &file )
override;
157 Float getMinRMSErrorPerNode()
const;
164 Tree::TrainingMode getTrainingMode()
const;
174 UINT getNumSplittingSteps()
const;
182 UINT getMinNumSamplesPerNode()
const;
189 UINT getMaxDepth()
const;
196 UINT getPredictedNodeID()
const;
203 bool getRemoveFeaturesAtEachSpilt()
const;
211 bool setTrainingMode(
const Tree::TrainingMode trainingMode);
225 bool setNumSplittingSteps(
const UINT numSplittingSteps);
235 bool setMinNumSamplesPerNode(
const UINT minNumSamplesPerNode);
244 bool setMaxDepth(
const UINT maxDepth);
254 bool setRemoveFeaturesAtEachSpilt(
const bool removeFeaturesAtEachSpilt);
262 bool setMinRMSErrorPerNode(
const Float minRMSErrorPerNode);
269 static std::string
getId();
279 UINT minNumSamplesPerNode;
281 UINT numSplittingSteps;
282 bool removeFeaturesAtEachSpilt;
283 Tree::TrainingMode trainingMode;
284 Float minRMSErrorPerNode;
287 bool computeBestSpilt(
const RegressionData &trainingData,
const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
288 bool computeBestSpiltBestIterativeSpilt(
const RegressionData &trainingData,
const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
294 static const std::string id;
300 #endif //GRT_REGRESSION_TREE_HEADER std::string getId() const
virtual bool predict_(VectorFloat &inputVector)
virtual bool clear() override
This file implements a RegressionTreeNode, which is a specific type of node used for a RegressionTree...
virtual bool save(const std::string &filename) const
This class implements a basic Regression Tree.
virtual bool deepCopyFrom(const Regressifier *regressifier)
virtual bool print() const
Node * tree
<Tell the compiler we are using the base class predict method to stop hidden virtual function warning...
virtual bool train_(ClassificationData &trainingData)
virtual bool load(const std::string &filename)