28 #define GRT_DLL_EXPORTS 44 trainingLog.
setKey(
"[TRAINING DecisionStump]");
45 warningLog.setKey(
"[WARNING DecisionStump]");
46 errorLog.
setKey(
"[ERROR DecisionStump]");
69 if( weakClassifer == NULL )
return false;
85 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There should only be 2 classes in the training data, but there are : " << trainingData.
getNumClasses() << std::endl;
91 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There number of examples in the training data (" << trainingData.
getNumSamples() <<
") does not match the lenght of the weights vector (" << weights.
getSize() <<
")" << std::endl;
97 UINT bestFeatureIndex = 0;
103 Float bestThreshold = 0;
110 minRange = ranges[n].minValue;
111 maxRange = ranges[n].maxValue;
118 for(UINT i=0; i<M; i++){
120 bool rhs = trainingData[ i ][ n ] >= threshold;
121 bool lhs = trainingData[ i ][ n ] <= threshold;
122 if( (rhs && !positiveClass) || (!rhs && positiveClass) ) rhsError += weights[ i ];
123 if( (lhs && !positiveClass) || (!lhs && positiveClass) ) lhsError += weights[ i ];
127 if( rhsError < minError ){
129 bestFeatureIndex = n;
130 bestThreshold = threshold;
133 if( lhsError < minError ){
135 bestFeatureIndex = n;
136 bestThreshold = threshold;
161 errorLog <<
"saveModelToFile(fstream &file) - The file is not open!" << std::endl;
167 file <<
"Trained: "<<
trained << std::endl;
172 file <<
"Direction: "<<
direction << std::endl;
184 errorLog <<
"loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
191 if( word !=
"WeakClassifierType:" ){
192 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read WeakClassifierType header!" << std::endl;
198 errorLog <<
"loadModelFromFile(fstream &file) - The weakClassifierType:" << word <<
" does not match: " <<
weakClassifierType << std::endl;
203 if( word !=
"Trained:" ){
204 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
210 if( word !=
"NumInputDimensions:" ){
211 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
217 if( word !=
"DecisionFeatureIndex:" ){
218 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read DecisionFeatureIndex header!" << std::endl;
224 if( word !=
"Direction:" ){
225 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Direction header!" << std::endl;
231 if( word !=
"NumRandomSplits:" ){
232 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumRandomSplits header!" << std::endl;
238 if( word !=
"DecisionValue:" ){
239 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read DecisionValue header!" << std::endl;
249 std::cout <<
"Trained: " <<
trained;
252 std::cout <<
"\tDirection: " <<
direction << std::endl;
UINT getNumRandomSplits() const
static RegisterWeakClassifierModule< DecisionStump > registerModule
This is used to register the DecisionStump with the WeakClassifier base class.
std::string weakClassifierType
A string that represents the weak classifier type, e.g. DecisionStump.
UINT direction
Indicates if the decision spilt threshold is greater than (1), or less than (0)
This file contains the Random class, a useful wrapper for generating cross platform random functions...
UINT numInputDimensions
The number of input dimensions to the weak classifier.
Float decisionValue
The decision spilt threshold.
DecisionStump & operator=(const DecisionStump &rhs)
virtual bool train(ClassificationData &trainingData, VectorFloat &weights)
virtual bool setKey(const std::string &key)
sets the key that gets written at the start of each message, this will be written in the format 'key ...
std::string getWeakClassifierType() const
Float getDecisionValue() const
virtual void print() const
DecisionStump(const UINT numRandomSplits=100)
UINT getNumSamples() const
#define WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL
UINT numRandomSplits
The number of random splits used to search for the best decision spilt.
virtual bool loadModelFromFile(std::fstream &file)
virtual bool saveModelToFile(std::fstream &file) const
virtual Float predict(const VectorFloat &x)
UINT getDecisionFeatureIndex() const
bool copyBaseVariables(const WeakClassifier *weakClassifer)
UINT getNumDimensions() const
UINT getNumClasses() const
bool trained
A flag to show if the weak classifier model has been trained.
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
int getRandomNumberInt(int minRange, int maxRange)
virtual bool deepCopyFrom(const WeakClassifier *weakClassifer)
UINT decisionFeatureIndex
The dimension that the data will be spilt on.
This class implements a DecisionStump, which is a single node of a DecisionTree.
UINT getDirection() const