31 #ifndef GRT_CLUSTER_TREE_HEADER 32 #define GRT_CLUSTER_TREE_HEADER 34 #include "../../CoreModules/Clusterer.h" 35 #include "../../CoreAlgorithms/Tree/Tree.h" 54 ClusterTree(
const UINT numSplittingSteps=100,
const UINT minNumSamplesPerNode=5,
const UINT maxDepth=10,
const bool removeFeaturesAtEachSplit =
false,
const Tree::TrainingMode trainingMode = Tree::BEST_ITERATIVE_SPILT,
const bool useScaling=
false,
const Float minRMSErrorPerNode = 0.01);
109 virtual bool clear()
override;
116 virtual bool print()
const override;
125 virtual bool saveModelToFile( std::fstream &file )
const override;
134 virtual bool loadModelFromFile( std::fstream &file )
override;
166 Float getMinRMSErrorPerNode()
const;
173 Tree::TrainingMode getTrainingMode()
const;
183 UINT getNumSplittingSteps()
const;
191 UINT getMinNumSamplesPerNode()
const;
198 UINT getMaxDepth()
const;
205 UINT getPredictedNodeID()
const;
212 bool getRemoveFeaturesAtEachSplit()
const;
220 bool setTrainingMode(
const Tree::TrainingMode trainingMode);
234 bool setNumSplittingSteps(
const UINT numSplittingSteps);
244 bool setMinNumSamplesPerNode(
const UINT minNumSamplesPerNode);
253 bool setMaxDepth(
const UINT maxDepth);
263 bool setRemoveFeaturesAtEachSplit(
const bool removeFeaturesAtEachSplit);
270 bool setMinRMSErrorPerNode(
const Float minRMSErrorPerNode);
273 using MLBase::saveModelToFile;
274 using MLBase::loadModelFromFile;
283 static std::string
getId();
287 UINT minNumSamplesPerNode;
289 UINT numSplittingSteps;
290 bool removeFeaturesAtEachSplit;
291 Tree::TrainingMode trainingMode;
292 Float minRMSErrorPerNode;
295 bool computeBestSplit(
const MatrixFloat &trainingData,
const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
296 bool computeBestSplitBestIterativeSplit(
const MatrixFloat &trainingData,
const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
297 bool computeBestSplitBestRandomSplit(
const MatrixFloat &trainingData,
const Vector< UINT > &features, UINT &featureIndex, Float &threshold, Float &minError );
301 static const std::string id;
306 #endif //GRT_CLUSTER_TREE_HEADER std::string getId() const
virtual bool predict(VectorFloat inputVector)
virtual bool predict_(VectorFloat &inputVector)
virtual bool clear() override
virtual bool train(ClassificationData trainingData)
virtual bool deepCopyFrom(const Clusterer *clusterer)
virtual bool train_(MatrixFloat &trainingData) override
This file implements a ClusterTreeNode, which is a specific type of node used for a ClusterTree...
virtual bool print() const
UINT getPredictedClusterLabel() const