GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
RadialBasisFunction.cpp
Go to the documentation of this file.
1 
28 #define GRT_DLL_EXPORTS
29 #include "RadialBasisFunction.h"
30 
31 GRT_BEGIN_NAMESPACE
32 
33 //Register the RadialBasisFunction module with the WeakClassifier base class
35 
36 RadialBasisFunction::RadialBasisFunction(UINT numSteps,Float positiveClassificationThreshold,Float minAlphaSearchRange,Float maxAlphaSearchRange){
37  this->numSteps = numSteps;
38  this->positiveClassificationThreshold = positiveClassificationThreshold;
39  this->minAlphaSearchRange = minAlphaSearchRange;
40  this->maxAlphaSearchRange = maxAlphaSearchRange;
41  trained = false;
43  alpha = 0;
44  gamma = 0;
45  weakClassifierType = "RadialBasisFunction";
46  trainingLog.setKey("[DEBUG RadialBasisFunction]");
47  warningLog.setKey("[WARNING RadialBasisFunction]");
48  errorLog.setKey("[ERROR RadialBasisFunction]");
49 }
50 
52 
53 }
54 
56  *this = rhs;
57 }
58 
60  if( this != &rhs ){
61  this->numSteps = rhs.numSteps;
62  this->alpha = rhs.alpha;
63  this->gamma = rhs.gamma;
64  this->positiveClassificationThreshold = rhs.positiveClassificationThreshold;
65  this->minAlphaSearchRange = rhs.minAlphaSearchRange;
66  this->maxAlphaSearchRange = rhs.maxAlphaSearchRange;
67  this->rbfCentre = rhs.rbfCentre;
68  this->copyBaseVariables( &rhs );
69  }
70  return *this;
71 }
72 
74  if( weakClassifer == NULL ) return false;
75 
76  if( this->getWeakClassifierType() == weakClassifer->getWeakClassifierType() ){
77  //Call the = operator
78  *this = *(RadialBasisFunction*)weakClassifer;
79  return true;
80  }
81  return false;
82 }
83 
85 
86  trained = false;
87  numInputDimensions = trainingData.getNumDimensions();
88  rbfCentre.clear();
89 
90  //There should only be two classes in the dataset, the positive class (classLable==1) and the negative class (classLabel==2)
91  if( trainingData.getNumClasses() != 2 ){
92  errorLog << "train(ClassificationData &trainingData, VectorFloat &weights) - There should only be 2 classes in the training data, but there are : " << trainingData.getNumClasses() << std::endl;
93  return false;
94  }
95 
96  //There should be one weight for every training sample
97  if( trainingData.getNumSamples() != weights.getSize() ){
98  errorLog << "train(ClassificationData &trainingData, VectorFloat &weights) - There number of examples in the training data (" << trainingData.getNumSamples() << ") does not match the lenght of the weights vector (" << weights.size() << ")" << std::endl;
99  return false;
100  }
101 
102  //STEP 1: Estimate the centre of the RBF function as the weighted mean of the positive examples
103  const UINT M = trainingData.getNumSamples();
104  rbfCentre.resize(numInputDimensions,0);
105 
106  //Search for the sample(s) with the maximum weight(s)
107  Float maxWeight = 0;
108  Vector< UINT > bestWeights;
109  for(UINT i=0; i<M; i++){
110  if( trainingData[i].getClassLabel() == WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL ){
111  if( weights[i] > maxWeight ){
112  maxWeight = weights[i];
113  bestWeights.clear();
114  bestWeights.push_back(i);
115  }else if( weights[i] == maxWeight ){
116  bestWeights.push_back( i );
117  }
118  }
119  }
120 
121  //Estimate the centre of the RBF function as the weighted mean of the most important sample(s)
122  const UINT N = (UINT)bestWeights.size();
123 
124  if( N == 0 ){
125  errorLog << "train(ClassificationData &trainingData, VectorFloat &weights) - There are no positive class weigts!" << std::endl;
126  return false;
127  }
128 
129  for(UINT i=0; i<N; i++){
130  for(UINT j=0; j<numInputDimensions; j++){
131  rbfCentre[j] += trainingData[ bestWeights[i] ][j];
132  }
133  }
134 
135  //Normalize the RBF centre by the positiveWeightSum so we get the weighted mean
136  for(UINT j=0; j<numInputDimensions; j++){
137  rbfCentre[j] /= Float(N);
138  }
139 
140  //STEP 2: Estimate the best value for alpha
141  Float step = (maxAlphaSearchRange-minAlphaSearchRange)/numSteps;
142  Float bestAlpha = 0;
143  Float minError = grt_numeric_limits< Float >::max();
144 
145  alpha = minAlphaSearchRange;
146  while( alpha <= maxAlphaSearchRange ){
147 
148  //Update gamma (this is used in the rbf function)
149  gamma = -1.0/(2.0*grt_sqr(alpha));
150 
151  //Compute the weighted error over all the training samples given the current alpha value
152  Float error = 0;
153  for(UINT i=0; i<M; i++){
154  bool positiveSample = trainingData[ i ].getClassLabel() == WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL;
155  Float v = rbf(trainingData[ i ].getSample(),rbfCentre);
156 
157  if( (v >= positiveClassificationThreshold && !positiveSample) || (v<positiveClassificationThreshold && positiveSample) ){
158  error += weights[i];
159  }
160  }
161 
162  //Check if the current error is the best so far
163  if( error < minError ){
164  minError = error;
165  bestAlpha = alpha;
166 
167  //If the minimum error is zero then we can stop the search
168  if( minError == 0 )
169  break;
170  }
171 
172  //Update alpha
173  alpha += step;
174  }
175 
176  alpha = bestAlpha;
177  gamma = -1.0/(2.0*grt_sqr(alpha));
178  trained = true;
179 
180  std::cout << "BestAlpha: " << bestAlpha << " Error: " << minError << std::endl;
181 
182  return true;
183 }
184 
186  if( rbf(x,rbfCentre) >= positiveClassificationThreshold ) return 1;
187  return -1;
188 }
189 
190 Float RadialBasisFunction::rbf(const VectorFloat &a,const VectorFloat &b){
191  const UINT N = (UINT)a.size();
192  //Compute the RBF distance, this uses the squared euclidean distance
193  Float r = 0;
194  for(UINT i=0; i<N; i++){
195  r += SQR(a[i]-b[i]);
196  }
197  return exp( gamma * r );
198 }
199 
200 bool RadialBasisFunction::saveModelToFile( std::fstream &file ) const{
201 
202  if(!file.is_open())
203  {
204  errorLog <<"saveModelToFile(fstream &file) - The file is not open!" << std::endl;
205  return false;
206  }
207 
208  //Write the WeakClassifierType data
209  file << "WeakClassifierType: " << weakClassifierType << std::endl;
210  file << "Trained: "<< trained << std::endl;
211  file << "NumInputDimensions: " << numInputDimensions << std::endl;
212 
213  //Write the RadialBasisFunction data
214  file << "NumSteps: " << numSteps << std::endl;
215  file << "PositiveClassificationThreshold: " << positiveClassificationThreshold << std::endl;
216  file << "Alpha: " << alpha << std::endl;
217  file << "MinAlphaSearchRange: " << minAlphaSearchRange << std::endl;
218  file << "MaxAlphaSearchRange: " << maxAlphaSearchRange << std::endl;
219  file << "RBF: ";
220  for(UINT i=0; i<numInputDimensions; i++){
221  if( trained ){
222  file << rbfCentre[i] << "\t";
223  }else file << 0 << "\t";
224  }
225  file << std::endl;
226 
227  //We don't need to close the file as the function that called this function should handle that
228  return true;
229 }
230 
231 bool RadialBasisFunction::loadModelFromFile( std::fstream &file ){
232 
233  if(!file.is_open())
234  {
235  errorLog <<"loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
236  return false;
237  }
238 
239  std::string word;
240 
241  file >> word;
242  if( word != "WeakClassifierType:" ){
243  errorLog <<"loadModelFromFile(fstream &file) - Failed to read WeakClassifierType header!" << std::endl;
244  return false;
245  }
246  file >> word;
247 
248  if( word != weakClassifierType ){
249  errorLog <<"loadModelFromFile(fstream &file) - The weakClassifierType:" << word << " does not match: " << weakClassifierType << std::endl;
250  return false;
251  }
252 
253  file >> word;
254  if( word != "Trained:" ){
255  errorLog <<"loadModelFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
256  return false;
257  }
258  file >> trained;
259 
260  file >> word;
261  if( word != "NumInputDimensions:" ){
262  errorLog <<"loadModelFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
263  return false;
264  }
265  file >> numInputDimensions;
266 
267 
268  file >> word;
269  if( word != "NumSteps:" ){
270  errorLog <<"loadModelFromFile(fstream &file) - Failed to read NumSteps header!" << std::endl;
271  return false;
272  }
273  file >> numSteps;
274 
275  file >> word;
276  if( word != "PositiveClassificationThreshold:" ){
277  errorLog <<"loadModelFromFile(fstream &file) - Failed to read PositiveClassificationThreshold header!" << std::endl;
278  return false;
279  }
280  file >> positiveClassificationThreshold;
281 
282  file >> word;
283  if( word != "Alpha:" ){
284  errorLog <<"loadModelFromFile(fstream &file) - Failed to read Alpha header!" << std::endl;
285  return false;
286  }
287  file >> alpha;
288 
289  file >> word;
290  if( word != "MinAlphaSearchRange:" ){
291  errorLog <<"loadModelFromFile(fstream &file) - Failed to read MinAlphaSearchRange header!" << std::endl;
292  return false;
293  }
294  file >> minAlphaSearchRange;
295 
296  file >> word;
297  if( word != "MaxAlphaSearchRange:" ){
298  errorLog <<"loadModelFromFile(fstream &file) - Failed to read MaxAlphaSearchRange header!" << std::endl;
299  return false;
300  }
301  file >> maxAlphaSearchRange;
302 
303  file >> word;
304  if( word != "RBF:" ){
305  errorLog <<"loadModelFromFile(fstream &file) - Failed to read RBF header!" << std::endl;
306  return false;
307  }
308  rbfCentre.resize(numInputDimensions);
309 
310  for(UINT i=0; i<numInputDimensions; i++){
311  file >> rbfCentre[i];
312  }
313 
314  //Compute gamma using alpha
315  gamma = -1.0/(2.0*SQR(alpha));
316 
317  //We don't need to close the file as the function that called this function should handle that
318  return true;
319 }
320 
322 }
323 
325  return rbfCentre;
326 }
327 
329  return numSteps;
330 }
331 
333  return positiveClassificationThreshold;
334 }
335 
337  return alpha;
338 }
339 
341  return minAlphaSearchRange;
342 }
343 
345  return maxAlphaSearchRange;
346 }
347 
348 GRT_END_NAMESPACE
349 
virtual bool train(ClassificationData &trainingData, VectorFloat &weights)
virtual Float predict(const VectorFloat &x)
virtual bool loadModelFromFile(std::fstream &file)
std::string weakClassifierType
A string that represents the weak classifier type, e.g. DecisionStump.
virtual bool deepCopyFrom(const WeakClassifier *weakClassifer)
UINT numInputDimensions
The number of input dimensions to the weak classifier.
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool setKey(const std::string &key)
sets the key that gets written at the start of each message, this will be written in the format &#39;key ...
Definition: Log.h:166
std::string getWeakClassifierType() const
UINT getSize() const
Definition: Vector.h:201
RadialBasisFunction & operator=(const RadialBasisFunction &rhs)
virtual void print() const
Float getMinAlphaSearchRange() const
static RegisterWeakClassifierModule< RadialBasisFunction > registerModule
This is used to register the DecisionStump with the WeakClassifier base class.
UINT getNumSamples() const
#define WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL
virtual bool saveModelToFile(std::fstream &file) const
bool copyBaseVariables(const WeakClassifier *weakClassifer)
UINT getNumDimensions() const
UINT getNumClasses() const
VectorFloat getRBFCentre() const
bool trained
A flag to show if the weak classifier model has been trained.
Float getMaxAlphaSearchRange() const
RadialBasisFunction(UINT numSteps=100, Float positiveClassificationThreshold=0.9, Float minAlphaSearchRange=0.001, Float maxAlphaSearchRange=1.0)
Float getPositiveClassificationThreshold() const
This class implements a Radial Basis Function Weak Classifier. The Radial Basis Function (RBF) class ...