GestureRecognitionToolkit  Version: 0.1.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
RadialBasisFunction.cpp
Go to the documentation of this file.
1 
28 #include "RadialBasisFunction.h"
29 
30 GRT_BEGIN_NAMESPACE
31 
32 //Register the RadialBasisFunction module with the WeakClassifier base class
34 
35 RadialBasisFunction::RadialBasisFunction(UINT numSteps,Float positiveClassificationThreshold,Float minAlphaSearchRange,Float maxAlphaSearchRange){
36  this->numSteps = numSteps;
37  this->positiveClassificationThreshold = positiveClassificationThreshold;
38  this->minAlphaSearchRange = minAlphaSearchRange;
39  this->maxAlphaSearchRange = maxAlphaSearchRange;
40  trained = false;
42  alpha = 0;
43  gamma = 0;
44  weakClassifierType = "RadialBasisFunction";
45  trainingLog.setProceedingText("[DEBUG RadialBasisFunction]");
46  warningLog.setProceedingText("[WARNING RadialBasisFunction]");
47  errorLog.setProceedingText("[ERROR RadialBasisFunction]");
48 }
49 
51 
52 }
53 
55  *this = rhs;
56 }
57 
59  if( this != &rhs ){
60  this->numSteps = rhs.numSteps;
61  this->alpha = rhs.alpha;
62  this->gamma = rhs.gamma;
63  this->positiveClassificationThreshold = rhs.positiveClassificationThreshold;
64  this->minAlphaSearchRange = rhs.minAlphaSearchRange;
65  this->maxAlphaSearchRange = rhs.maxAlphaSearchRange;
66  this->rbfCentre = rhs.rbfCentre;
67  this->copyBaseVariables( &rhs );
68  }
69  return *this;
70 }
71 
73  if( weakClassifer == NULL ) return false;
74 
75  if( this->getWeakClassifierType() == weakClassifer->getWeakClassifierType() ){
76  //Call the = operator
77  *this = *(RadialBasisFunction*)weakClassifer;
78  return true;
79  }
80  return false;
81 }
82 
84 
85  trained = false;
86  numInputDimensions = trainingData.getNumDimensions();
87  rbfCentre.clear();
88 
89  //There should only be two classes in the dataset, the positive class (classLable==1) and the negative class (classLabel==2)
90  if( trainingData.getNumClasses() != 2 ){
91  errorLog << "train(ClassificationData &trainingData, VectorFloat &weights) - There should only be 2 classes in the training data, but there are : " << trainingData.getNumClasses() << std::endl;
92  return false;
93  }
94 
95  //There should be one weight for every training sample
96  if( trainingData.getNumSamples() != weights.getSize() ){
97  errorLog << "train(ClassificationData &trainingData, VectorFloat &weights) - There number of examples in the training data (" << trainingData.getNumSamples() << ") does not match the lenght of the weights vector (" << weights.size() << ")" << std::endl;
98  return false;
99  }
100 
101  //STEP 1: Estimate the centre of the RBF function as the weighted mean of the positive examples
102  const UINT M = trainingData.getNumSamples();
103  rbfCentre.resize(numInputDimensions,0);
104 
105  //Search for the sample(s) with the maximum weight(s)
106  Float maxWeight = 0;
107  Vector< UINT > bestWeights;
108  for(UINT i=0; i<M; i++){
109  if( trainingData[i].getClassLabel() == WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL ){
110  if( weights[i] > maxWeight ){
111  maxWeight = weights[i];
112  bestWeights.clear();
113  bestWeights.push_back(i);
114  }else if( weights[i] == maxWeight ){
115  bestWeights.push_back( i );
116  }
117  }
118  }
119 
120  //Estimate the centre of the RBF function as the weighted mean of the most important sample(s)
121  const UINT N = (UINT)bestWeights.size();
122 
123  if( N == 0 ){
124  errorLog << "train(ClassificationData &trainingData, VectorFloat &weights) - There are no positive class weigts!" << std::endl;
125  return false;
126  }
127 
128  for(UINT i=0; i<N; i++){
129  for(UINT j=0; j<numInputDimensions; j++){
130  rbfCentre[j] += trainingData[ bestWeights[i] ][j];
131  }
132  }
133 
134  //Normalize the RBF centre by the positiveWeightSum so we get the weighted mean
135  for(UINT j=0; j<numInputDimensions; j++){
136  rbfCentre[j] /= Float(N);
137  }
138 
139  //STEP 2: Estimate the best value for alpha
140  Float step = (maxAlphaSearchRange-minAlphaSearchRange)/numSteps;
141  Float bestAlpha = 0;
142  Float minError = grt_numeric_limits< Float >::max();
143 
144  alpha = minAlphaSearchRange;
145  while( alpha <= maxAlphaSearchRange ){
146 
147  //Update gamma (this is used in the rbf function)
148  gamma = -1.0/(2.0*grt_sqr(alpha));
149 
150  //Compute the weighted error over all the training samples given the current alpha value
151  Float error = 0;
152  for(UINT i=0; i<M; i++){
153  bool positiveSample = trainingData[ i ].getClassLabel() == WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL;
154  Float v = rbf(trainingData[ i ].getSample(),rbfCentre);
155 
156  if( (v >= positiveClassificationThreshold && !positiveSample) || (v<positiveClassificationThreshold && positiveSample) ){
157  error += weights[i];
158  }
159  }
160 
161  //Check if the current error is the best so far
162  if( error < minError ){
163  minError = error;
164  bestAlpha = alpha;
165 
166  //If the minimum error is zero then we can stop the search
167  if( minError == 0 )
168  break;
169  }
170 
171  //Update alpha
172  alpha += step;
173  }
174 
175  alpha = bestAlpha;
176  gamma = -1.0/(2.0*grt_sqr(alpha));
177  trained = true;
178 
179  std::cout << "BestAlpha: " << bestAlpha << " Error: " << minError << std::endl;
180 
181  return true;
182 }
183 
185  if( rbf(x,rbfCentre) >= positiveClassificationThreshold ) return 1;
186  return -1;
187 }
188 
189 Float RadialBasisFunction::rbf(const VectorFloat &a,const VectorFloat &b){
190  const UINT N = (UINT)a.size();
191  //Compute the RBF distance, this uses the squared euclidean distance
192  Float r = 0;
193  for(UINT i=0; i<N; i++){
194  r += SQR(a[i]-b[i]);
195  }
196  return exp( gamma * r );
197 }
198 
199 bool RadialBasisFunction::saveModelToFile( std::fstream &file ) const{
200 
201  if(!file.is_open())
202  {
203  errorLog <<"saveModelToFile(fstream &file) - The file is not open!" << std::endl;
204  return false;
205  }
206 
207  //Write the WeakClassifierType data
208  file << "WeakClassifierType: " << weakClassifierType << std::endl;
209  file << "Trained: "<< trained << std::endl;
210  file << "NumInputDimensions: " << numInputDimensions << std::endl;
211 
212  //Write the RadialBasisFunction data
213  file << "NumSteps: " << numSteps << std::endl;
214  file << "PositiveClassificationThreshold: " << positiveClassificationThreshold << std::endl;
215  file << "Alpha: " << alpha << std::endl;
216  file << "MinAlphaSearchRange: " << minAlphaSearchRange << std::endl;
217  file << "MaxAlphaSearchRange: " << maxAlphaSearchRange << std::endl;
218  file << "RBF: ";
219  for(UINT i=0; i<numInputDimensions; i++){
220  if( trained ){
221  file << rbfCentre[i] << "\t";
222  }else file << 0 << "\t";
223  }
224  file << std::endl;
225 
226  //We don't need to close the file as the function that called this function should handle that
227  return true;
228 }
229 
230 bool RadialBasisFunction::loadModelFromFile( std::fstream &file ){
231 
232  if(!file.is_open())
233  {
234  errorLog <<"loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
235  return false;
236  }
237 
238  std::string word;
239 
240  file >> word;
241  if( word != "WeakClassifierType:" ){
242  errorLog <<"loadModelFromFile(fstream &file) - Failed to read WeakClassifierType header!" << std::endl;
243  return false;
244  }
245  file >> word;
246 
247  if( word != weakClassifierType ){
248  errorLog <<"loadModelFromFile(fstream &file) - The weakClassifierType:" << word << " does not match: " << weakClassifierType << std::endl;
249  return false;
250  }
251 
252  file >> word;
253  if( word != "Trained:" ){
254  errorLog <<"loadModelFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
255  return false;
256  }
257  file >> trained;
258 
259  file >> word;
260  if( word != "NumInputDimensions:" ){
261  errorLog <<"loadModelFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
262  return false;
263  }
264  file >> numInputDimensions;
265 
266 
267  file >> word;
268  if( word != "NumSteps:" ){
269  errorLog <<"loadModelFromFile(fstream &file) - Failed to read NumSteps header!" << std::endl;
270  return false;
271  }
272  file >> numSteps;
273 
274  file >> word;
275  if( word != "PositiveClassificationThreshold:" ){
276  errorLog <<"loadModelFromFile(fstream &file) - Failed to read PositiveClassificationThreshold header!" << std::endl;
277  return false;
278  }
279  file >> positiveClassificationThreshold;
280 
281  file >> word;
282  if( word != "Alpha:" ){
283  errorLog <<"loadModelFromFile(fstream &file) - Failed to read Alpha header!" << std::endl;
284  return false;
285  }
286  file >> alpha;
287 
288  file >> word;
289  if( word != "MinAlphaSearchRange:" ){
290  errorLog <<"loadModelFromFile(fstream &file) - Failed to read MinAlphaSearchRange header!" << std::endl;
291  return false;
292  }
293  file >> minAlphaSearchRange;
294 
295  file >> word;
296  if( word != "MaxAlphaSearchRange:" ){
297  errorLog <<"loadModelFromFile(fstream &file) - Failed to read MaxAlphaSearchRange header!" << std::endl;
298  return false;
299  }
300  file >> maxAlphaSearchRange;
301 
302  file >> word;
303  if( word != "RBF:" ){
304  errorLog <<"loadModelFromFile(fstream &file) - Failed to read RBF header!" << std::endl;
305  return false;
306  }
307  rbfCentre.resize(numInputDimensions);
308 
309  for(UINT i=0; i<numInputDimensions; i++){
310  file >> rbfCentre[i];
311  }
312 
313  //Compute gamma using alpha
314  gamma = -1.0/(2.0*SQR(alpha));
315 
316  //We don't need to close the file as the function that called this function should handle that
317  return true;
318 }
319 
321 }
322 
324  return rbfCentre;
325 }
326 
328  return numSteps;
329 }
330 
332  return positiveClassificationThreshold;
333 }
334 
336  return alpha;
337 }
338 
340  return minAlphaSearchRange;
341 }
342 
344  return maxAlphaSearchRange;
345 }
346 
347 GRT_END_NAMESPACE
348 
virtual bool train(ClassificationData &trainingData, VectorFloat &weights)
virtual Float predict(const VectorFloat &x)
virtual bool loadModelFromFile(std::fstream &file)
std::string weakClassifierType
A string that represents the weak classifier type, e.g. DecisionStump.
virtual bool deepCopyFrom(const WeakClassifier *weakClassifer)
UINT numInputDimensions
The number of input dimensions to the weak classifier.
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
std::string getWeakClassifierType() const
RadialBasisFunction & operator=(const RadialBasisFunction &rhs)
virtual void print() const
Float getMinAlphaSearchRange() const
unsigned int getSize() const
Definition: Vector.h:193
static RegisterWeakClassifierModule< RadialBasisFunction > registerModule
This is used to register the DecisionStump with the WeakClassifier base class.
UINT getNumSamples() const
#define WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL
virtual bool saveModelToFile(std::fstream &file) const
bool copyBaseVariables(const WeakClassifier *weakClassifer)
UINT getNumDimensions() const
UINT getNumClasses() const
VectorFloat getRBFCentre() const
bool trained
A flag to show if the weak classifier model has been trained.
Float getMaxAlphaSearchRange() const
RadialBasisFunction(UINT numSteps=100, Float positiveClassificationThreshold=0.9, Float minAlphaSearchRange=0.001, Float maxAlphaSearchRange=1.0)
Float getPositiveClassificationThreshold() const
This class implements a Radial Basis Function Weak Classifier. The Radial Basis Function (RBF) class ...