GestureRecognitionToolkit  Version: 0.1.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionStump.cpp
Go to the documentation of this file.
1 
28 #include "DecisionStump.h"
29 
30 GRT_BEGIN_NAMESPACE
31 
32 //Register the DecisionStump module with the WeakClassifier base class
34 
35 DecisionStump::DecisionStump(const UINT numRandomSplits){
36  this->numRandomSplits = numRandomSplits;
37  trained = false;
40  decisionValue = 0;
41  direction = 0;
42  weakClassifierType = "DecisionStump";
43  trainingLog.setProceedingText("[TRAINING DecisionStump]");
44  warningLog.setProceedingText("[WARNING DecisionStump]");
45  errorLog.setProceedingText("[ERROR DecisionStump]");
46 }
47 
49 
50 }
51 
53  *this = rhs;
54 }
55 
57  if( this != &rhs ){
59  this->decisionValue = rhs.decisionValue;
60  this->direction = rhs.direction;
61  this->numRandomSplits = rhs.numRandomSplits;
62  this->copyBaseVariables( &rhs );
63  }
64  return *this;
65 }
66 
67 bool DecisionStump::deepCopyFrom(const WeakClassifier *weakClassifer){
68  if( weakClassifer == NULL ) return false;
69 
70  if( this->getWeakClassifierType() == weakClassifer->getWeakClassifierType() ){
71  *this = *(DecisionStump*)weakClassifer;
72  return true;
73  }
74  return false;
75 }
76 
77 bool DecisionStump::train(ClassificationData &trainingData, VectorFloat &weights){
78 
79  trained = false;
80  numInputDimensions = trainingData.getNumDimensions();
81 
82  //There should only be two classes in the dataset, the positive class (classLable==1) and the negative class (classLabel==2)
83  if( trainingData.getNumClasses() != 2 ){
84  errorLog << "train(ClassificationData &trainingData, VectorFloat &weights) - There should only be 2 classes in the training data, but there are : " << trainingData.getNumClasses() << std::endl;
85  return false;
86  }
87 
88  //There should be one weight for every training sample
89  if( trainingData.getNumSamples() != weights.size() ){
90  errorLog << "train(ClassificationData &trainingData, VectorFloat &weights) - There number of examples in the training data (" << trainingData.getNumSamples() << ") does not match the lenght of the weights vector (" << weights.getSize() << ")" << std::endl;
91  return false;
92  }
93 
94  //Pick the training sample to use as the stump feature
95  const UINT M = trainingData.getNumSamples();
96  UINT bestFeatureIndex = 0;
97  Vector< MinMax > ranges = trainingData.getRanges();
98  Float minError = grt_numeric_limits< Float >::max();
99  Float minRange = 0;
100  Float maxRange = 0;
101  Float step = 0;
102  Float threshold = 0;
103  Float bestThreshold = 0;
104  Random random;
105 
106  for(UINT k=0; k<numRandomSplits; k++){
107 
108  //Randomly select a feature and a threshold
109  UINT n = random.getRandomNumberInt(0,numInputDimensions);
110  minRange = ranges[n].minValue;
111  maxRange = ranges[n].maxValue;
112  threshold = random.getRandomNumberUniform( minRange, maxRange );
113 
114  //Compute the error using the current threshold on the current input dimension
115  //We need to check both sides of the threshold
116  Float rhsError = 0;
117  Float lhsError = 0;
118  for(UINT i=0; i<M; i++){
119  bool positiveClass = trainingData[ i ].getClassLabel() == WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL;
120  bool rhs = trainingData[ i ][ n ] >= threshold;
121  bool lhs = trainingData[ i ][ n ] <= threshold;
122  if( (rhs && !positiveClass) || (!rhs && positiveClass) ) rhsError += weights[ i ];
123  if( (lhs && !positiveClass) || (!lhs && positiveClass) ) lhsError += weights[ i ];
124  }
125 
126  //Check to see if either the rhsError or lhsError beats the minError, if so then store the results
127  if( rhsError < minError ){
128  minError = rhsError;
129  bestFeatureIndex = n;
130  bestThreshold = threshold;
131  direction = 1; //1 means rhs
132  }
133  if( lhsError < minError ){
134  minError = lhsError;
135  bestFeatureIndex = n;
136  bestThreshold = threshold;
137  direction = 0; //0 means lhs
138  }
139 
140  }
141 
142  decisionFeatureIndex = bestFeatureIndex;
143  decisionValue = bestThreshold;
144  trained = true;
145 
146  trainingLog << "Best Feature Index: " << decisionFeatureIndex << " Value: " << decisionValue << " Direction: " << direction << " Error: " << minError << std::endl;
147  return true;
148 }
149 
151  if( direction == 1){
152  if( x[ decisionFeatureIndex ] >= decisionValue ) return 1;
153  }else if( x[ decisionFeatureIndex ] <= decisionValue ) return 1;
154  return -1;
155 }
156 
157 bool DecisionStump::saveModelToFile( std::fstream &file ) const{
158 
159  if(!file.is_open())
160  {
161  errorLog <<"saveModelToFile(fstream &file) - The file is not open!" << std::endl;
162  return false;
163  }
164 
165  //Write the WeakClassifierType data
166  file << "WeakClassifierType: " << weakClassifierType << std::endl;
167  file << "Trained: "<< trained << std::endl;
168  file << "NumInputDimensions: " << numInputDimensions << std::endl;
169 
170  //Write the DecisionStump data
171  file << "DecisionFeatureIndex: " << decisionFeatureIndex << std::endl;
172  file << "Direction: "<< direction << std::endl;
173  file << "NumRandomSplits: " << numRandomSplits << std::endl;
174  file << "DecisionValue: " << decisionValue << std::endl;
175 
176  //We don't need to close the file as the function that called this function should handle that
177  return true;
178 }
179 
180 bool DecisionStump::loadModelFromFile( std::fstream &file ){
181 
182  if(!file.is_open())
183  {
184  errorLog <<"loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
185  return false;
186  }
187 
188  std::string word;
189 
190  file >> word;
191  if( word != "WeakClassifierType:" ){
192  errorLog <<"loadModelFromFile(fstream &file) - Failed to read WeakClassifierType header!" << std::endl;
193  return false;
194  }
195  file >> word;
196 
197  if( word != weakClassifierType ){
198  errorLog <<"loadModelFromFile(fstream &file) - The weakClassifierType:" << word << " does not match: " << weakClassifierType << std::endl;
199  return false;
200  }
201 
202  file >> word;
203  if( word != "Trained:" ){
204  errorLog <<"loadModelFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
205  return false;
206  }
207  file >> trained;
208 
209  file >> word;
210  if( word != "NumInputDimensions:" ){
211  errorLog <<"loadModelFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
212  return false;
213  }
214  file >> numInputDimensions;
215 
216  file >> word;
217  if( word != "DecisionFeatureIndex:" ){
218  errorLog <<"loadModelFromFile(fstream &file) - Failed to read DecisionFeatureIndex header!" << std::endl;
219  return false;
220  }
221  file >> decisionFeatureIndex;
222 
223  file >> word;
224  if( word != "Direction:" ){
225  errorLog <<"loadModelFromFile(fstream &file) - Failed to read Direction header!" << std::endl;
226  return false;
227  }
228  file >> direction;
229 
230  file >> word;
231  if( word != "NumRandomSplits:" ){
232  errorLog <<"loadModelFromFile(fstream &file) - Failed to read NumRandomSplits header!" << std::endl;
233  return false;
234  }
235  file >> numRandomSplits;
236 
237  file >> word;
238  if( word != "DecisionValue:" ){
239  errorLog <<"loadModelFromFile(fstream &file) - Failed to read DecisionValue header!" << std::endl;
240  return false;
241  }
242  file >> decisionValue;
243 
244  //We don't need to close the file as the function that called this function should handle that
245  return true;
246 }
247 
248 void DecisionStump::print() const{
249  std::cout << "Trained: " << trained;
250  std::cout << "\tDecisionValue: " << decisionValue;
251  std::cout << "\tDecisionFeatureIndex: " << decisionFeatureIndex;
252  std::cout << "\tDirection: " << direction << std::endl;
253 }
254 
256  return decisionFeatureIndex;
257 }
258 
260  return direction;
261 }
262 
264  return numRandomSplits;
265 }
266 
268  return decisionValue;
269 }
270 
271 GRT_END_NAMESPACE
272 
273 
UINT getNumRandomSplits() const
static RegisterWeakClassifierModule< DecisionStump > registerModule
This is used to register the DecisionStump with the WeakClassifier base class.
std::string weakClassifierType
A string that represents the weak classifier type, e.g. DecisionStump.
UINT direction
Indicates if the decision spilt threshold is greater than (1), or less than (0)
Definition: Random.h:40
UINT numInputDimensions
The number of input dimensions to the weak classifier.
Float decisionValue
The decision spilt threshold.
DecisionStump & operator=(const DecisionStump &rhs)
virtual ~DecisionStump()
virtual bool train(ClassificationData &trainingData, VectorFloat &weights)
std::string getWeakClassifierType() const
Float getDecisionValue() const
unsigned int getSize() const
Definition: Vector.h:193
virtual void print() const
DecisionStump(const UINT numRandomSplits=100)
UINT getNumSamples() const
#define WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL
UINT numRandomSplits
The number of random splits used to search for the best decision spilt.
virtual bool loadModelFromFile(std::fstream &file)
virtual bool saveModelToFile(std::fstream &file) const
virtual Float predict(const VectorFloat &x)
UINT getDecisionFeatureIndex() const
bool copyBaseVariables(const WeakClassifier *weakClassifer)
UINT getNumDimensions() const
UINT getNumClasses() const
bool trained
A flag to show if the weak classifier model has been trained.
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
Definition: Random.h:198
int getRandomNumberInt(int minRange, int maxRange)
Definition: Random.h:88
virtual bool deepCopyFrom(const WeakClassifier *weakClassifer)
UINT decisionFeatureIndex
The dimension that the data will be spilt on.
This class implements a DecisionStump, which is a single node of a DecisionTree.
UINT getDirection() const