26 #ifndef GRT_DECISION_TREE_HEADER 27 #define GRT_DECISION_TREE_HEADER 29 #include "../../CoreModules/Classifier.h" 30 #include "../../CoreAlgorithms/Tree/Tree.h" 61 DecisionTree(
const DecisionTreeNode &decisionTreeNode =
DecisionTreeClusterNode(),
const UINT minNumSamplesPerNode=5,
const UINT maxDepth=10,
const bool removeFeaturesAtEachSplit =
false,
const Tree::TrainingMode trainingMode = Tree::TrainingMode::BEST_ITERATIVE_SPILT,
const UINT numSplittingSteps=100,
const bool useScaling=
false );
116 virtual bool clear()
override;
133 virtual bool save( std::fstream &file )
const override;
142 virtual bool load( std::fstream &file )
override;
151 virtual bool getModel( std::ostream &stream )
const override;
181 Tree::TrainingMode getTrainingMode()
const;
191 UINT getNumSplittingSteps()
const;
199 UINT getMinNumSamplesPerNode()
const;
206 UINT getMaxDepth()
const;
213 UINT getPredictedNodeID()
const;
220 bool getRemoveFeaturesAtEachSplit()
const;
228 bool setTrainingMode(
const Tree::TrainingMode trainingMode);
242 bool setNumSplittingSteps(
const UINT numSplittingSteps);
252 bool setMinNumSamplesPerNode(
const UINT minNumSamplesPerNode);
261 bool setMaxDepth(
const UINT maxDepth);
271 bool setRemoveFeaturesAtEachSplit(
const bool removeFeaturesAtEachSplit);
276 GRT_DEPRECATED_MSG(
"setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt) is deprecated, use setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit) instead",
bool setRemoveFeaturesAtEachSpilt(
const bool removeFeaturesAtEachSpilt) );
290 static std::string
getId();
300 bool loadLegacyModelFromFile_v1( std::fstream &file );
301 bool loadLegacyModelFromFile_v2( std::fstream &file );
302 bool loadLegacyModelFromFile_v3( std::fstream &file );
306 Float getNodeDistance(
const VectorFloat &x,
const UINT nodeID );
310 std::map< UINT, VectorFloat > nodeClusters;
315 UINT minNumSamplesPerNode;
317 UINT numSplittingSteps;
318 bool removeFeaturesAtEachSplit;
319 Tree::TrainingMode trainingMode;
323 static const std::string id;
328 #endif //GRT_DECISION_TREE_HEADER
std::string getId() const
virtual bool recomputeNullRejectionThresholds()
virtual bool predict_(VectorFloat &inputVector)
virtual bool getModel(std::ostream &stream) const
virtual bool save(const std::string &filename) const
virtual bool deepCopyFrom(const Classifier *classifier)
virtual bool print() const
virtual bool train_(ClassificationData &trainingData)
virtual bool load(const std::string &filename)
This is the main base class that all GRT Classification algorithms should inherit from...