28 #define GRT_DLL_EXPORTS
44 trainingLog.setProceedingText(
"[TRAINING DecisionStump]");
45 warningLog.setProceedingText(
"[WARNING DecisionStump]");
46 errorLog.setProceedingText(
"[ERROR DecisionStump]");
69 if( weakClassifer == NULL )
return false;
85 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There should only be 2 classes in the training data, but there are : " << trainingData.
getNumClasses() << std::endl;
91 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There number of examples in the training data (" << trainingData.
getNumSamples() <<
") does not match the lenght of the weights vector (" << weights.
getSize() <<
")" << std::endl;
97 UINT bestFeatureIndex = 0;
104 Float bestThreshold = 0;
111 minRange = ranges[n].minValue;
112 maxRange = ranges[n].maxValue;
119 for(UINT i=0; i<M; i++){
121 bool rhs = trainingData[ i ][ n ] >= threshold;
122 bool lhs = trainingData[ i ][ n ] <= threshold;
123 if( (rhs && !positiveClass) || (!rhs && positiveClass) ) rhsError += weights[ i ];
124 if( (lhs && !positiveClass) || (!lhs && positiveClass) ) lhsError += weights[ i ];
128 if( rhsError < minError ){
130 bestFeatureIndex = n;
131 bestThreshold = threshold;
134 if( lhsError < minError ){
136 bestFeatureIndex = n;
137 bestThreshold = threshold;
162 errorLog <<
"saveModelToFile(fstream &file) - The file is not open!" << std::endl;
168 file <<
"Trained: "<<
trained << std::endl;
173 file <<
"Direction: "<<
direction << std::endl;
185 errorLog <<
"loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
192 if( word !=
"WeakClassifierType:" ){
193 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read WeakClassifierType header!" << std::endl;
199 errorLog <<
"loadModelFromFile(fstream &file) - The weakClassifierType:" << word <<
" does not match: " <<
weakClassifierType << std::endl;
204 if( word !=
"Trained:" ){
205 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
211 if( word !=
"NumInputDimensions:" ){
212 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
218 if( word !=
"DecisionFeatureIndex:" ){
219 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read DecisionFeatureIndex header!" << std::endl;
225 if( word !=
"Direction:" ){
226 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Direction header!" << std::endl;
232 if( word !=
"NumRandomSplits:" ){
233 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumRandomSplits header!" << std::endl;
239 if( word !=
"DecisionValue:" ){
240 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read DecisionValue header!" << std::endl;
250 std::cout <<
"Trained: " <<
trained;
253 std::cout <<
"\tDirection: " <<
direction << std::endl;
UINT getNumRandomSplits() const
static RegisterWeakClassifierModule< DecisionStump > registerModule
This is used to register the DecisionStump with the WeakClassifier base class.
std::string weakClassifierType
A string that represents the weak classifier type, e.g. DecisionStump.
UINT direction
Indicates if the decision spilt threshold is greater than (1), or less than (0)
UINT numInputDimensions
The number of input dimensions to the weak classifier.
Float decisionValue
The decision spilt threshold.
DecisionStump & operator=(const DecisionStump &rhs)
virtual bool train(ClassificationData &trainingData, VectorFloat &weights)
std::string getWeakClassifierType() const
Float getDecisionValue() const
virtual void print() const
DecisionStump(const UINT numRandomSplits=100)
UINT getNumSamples() const
#define WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL
UINT numRandomSplits
The number of random splits used to search for the best decision spilt.
virtual bool loadModelFromFile(std::fstream &file)
virtual bool saveModelToFile(std::fstream &file) const
virtual Float predict(const VectorFloat &x)
UINT getDecisionFeatureIndex() const
bool copyBaseVariables(const WeakClassifier *weakClassifer)
UINT getNumDimensions() const
UINT getNumClasses() const
bool trained
A flag to show if the weak classifier model has been trained.
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
int getRandomNumberInt(int minRange, int maxRange)
virtual bool deepCopyFrom(const WeakClassifier *weakClassifer)
UINT decisionFeatureIndex
The dimension that the data will be spilt on.
This class implements a DecisionStump, which is a single node of a DecisionTree.
UINT getDirection() const