26 #ifndef GRT_RANDOM_FORESTS_HEADER 27 #define GRT_RANDOM_FORESTS_HEADER 29 #include "../DecisionTree/DecisionTree.h" 59 const UINT forestSize=10,
60 const UINT numRandomSplits=100,
61 const UINT minNumSamplesPerNode=5,
62 const UINT maxDepth=10,
63 const Tree::TrainingMode trainingMode = Tree::BEST_RANDOM_SPLIT,
64 const bool removeFeaturesAtEachSplit =
true,
65 const bool useScaling=
false,
66 const Float bootstrappedDatasetWeight = 0.8);
120 virtual bool clear();
127 virtual bool print()
const;
136 virtual bool save( std::fstream &file )
const;
145 virtual bool load( std::fstream &file );
162 UINT getForestSize()
const;
169 UINT getNumRandomSplits()
const;
177 UINT getMinNumSamplesPerNode()
const;
184 UINT getMaxDepth()
const;
191 UINT getTrainingMode()
const;
207 bool getRemoveFeaturesAtEachSplit()
const;
215 Float getBootstrappedDatasetWeight()
const;
244 VectorDouble getFeatureWeights(
const bool normWeights =
true )
const;
258 MatrixDouble getLeafNodeFeatureWeights(
const bool normWeights =
true )
const;
266 bool setForestSize(
const UINT forestSize);
277 bool setNumRandomSplits(
const UINT numSplittingSteps);
287 bool setMinNumSamplesPerNode(
const UINT minNumSamplesPerNode);
296 bool setMaxDepth(
const UINT maxDepth);
306 bool setRemoveFeaturesAtEachSplit(
const bool removeFeaturesAtEachSplit);
311 GRT_DEPRECATED_MSG(
"setRemoveFeaturesAtEachSpilt(const bool removeFeaturesAtEachSpilt) is deprecated, use setRemoveFeaturesAtEachSplit(const bool removeFeaturesAtEachSplit) instead",
bool setRemoveFeaturesAtEachSpilt(
const bool removeFeaturesAtEachSpilt) );
319 bool setTrainingMode(
const Tree::TrainingMode trainingMode);
335 bool setBootstrappedDatasetWeight(
const Float bootstrappedDatasetWeight );
342 static std::string
getId();
351 UINT numRandomSplits;
352 UINT minNumSamplesPerNode;
354 Tree::TrainingMode trainingMode;
355 bool removeFeaturesAtEachSplit;
356 Float bootstrappedDatasetWeight;
362 static const std::string id;
368 #endif //GRT_RANDOM_FORESTS_HEADER
std::string getId() const
virtual bool predict_(VectorFloat &inputVector)
virtual bool save(const std::string &filename) const
virtual bool deepCopyFrom(const Classifier *classifier)
virtual bool print() const
virtual bool train_(ClassificationData &trainingData)
virtual bool load(const std::string &filename)
This is the main base class that all GRT Classification algorithms should inherit from...