GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
MLBase.cpp
1 
21 #define GRT_DLL_EXPORTS
22 #include "MLBase.h"
23 
24 GRT_BEGIN_NAMESPACE
25 
26 MLBase::MLBase( const std::string &id, const BaseType type ) : GRTBase( id ){
27  baseType = type;
28  trained = false;
29  converged = false;
30  useScaling = false;
31  inputType = DATA_TYPE_UNKNOWN;
32  outputType = DATA_TYPE_UNKNOWN;
33  numInputDimensions = 0;
34  numOutputDimensions = 0;
35  minNumEpochs = 0;
36  maxNumEpochs = 100;
37  batchSize = 1;
38  numRestarts = 1;
39  validationSetSize = 20;
40  validationSetAccuracy = 0;
41  minChange = 1.0e-5;
42  learningRate = 0.1;
43  useValidationSet = false;
44  randomiseTrainingOrder = true;
45  rmsTrainingError = 0;
46  rmsValidationError = 0;
47  totalSquaredTrainingError = 0;
48 
49  if( id == "" ){
50  trainingLog.setKey("[TRAINING]");
51  testingLog.setKey("[TESTING]");
52  }else{
53  trainingLog.setKey("[TRAINING " + id + "]");
54  testingLog.setKey("[TESTING " + id + "]");
55  }
56 }
57 
59  clear();
60 }
61 
62 bool MLBase::copyMLBaseVariables(const MLBase *mlBase){
63 
64  if( mlBase == NULL ){
65  errorLog << "copyMLBaseVariables(MLBase *mlBase) - mlBase pointer is NULL!" << std::endl;
66  return false;
67  }
68 
69  if( !copyGRTBaseVariables( mlBase ) ){
70  errorLog << "copyMLBaseVariables(MLBase *mlBase) - Failed to copy GRT Base variables!" << std::endl;
71  return false;
72  }
73 
74  this->trained = mlBase->trained;
75  this->converged = mlBase->converged;
76  this->useScaling = mlBase->useScaling;
77  this->baseType = mlBase->baseType;
78  this->inputType = mlBase->inputType;
79  this->outputType = mlBase->outputType;
80  this->numInputDimensions = mlBase->numInputDimensions;
81  this->numOutputDimensions = mlBase->numOutputDimensions;
82  this->minNumEpochs = mlBase->minNumEpochs;
83  this->maxNumEpochs = mlBase->maxNumEpochs;
84  this->batchSize = mlBase->batchSize;
85  this->numRestarts = mlBase->numRestarts;
86  this->validationSetSize = mlBase->validationSetSize;
87  this->validationSetAccuracy = mlBase->validationSetAccuracy;
88  this->validationSetPrecision = mlBase->validationSetPrecision;
89  this->validationSetRecall = mlBase->validationSetRecall;
90  this->minChange = mlBase->minChange;
91  this->learningRate = mlBase->learningRate;
92  this->rmsTrainingError = mlBase->rmsTrainingError;
93  this->rmsValidationError = mlBase->rmsValidationError;
94  this->totalSquaredTrainingError = mlBase->totalSquaredTrainingError;
95  this->useValidationSet = mlBase->useValidationSet;
96  this->randomiseTrainingOrder = mlBase->randomiseTrainingOrder;
97  this->numTrainingIterationsToConverge = mlBase->numTrainingIterationsToConverge;
98  this->trainingResults = mlBase->trainingResults;
99  this->trainingResultsObserverManager = mlBase->trainingResultsObserverManager;
100  this->testResultsObserverManager = mlBase->testResultsObserverManager;
101  this->trainingLog = mlBase->trainingLog;
102  this->testingLog = mlBase->testingLog;
103 
104  return true;
105 }
106 
107 bool MLBase::train(ClassificationData trainingData){ return train_( trainingData ); }
108 
109 bool MLBase::train_(ClassificationData &trainingData){ return false; }
110 
111 bool MLBase::train(RegressionData trainingData){ return train_( trainingData ); }
112 
113 bool MLBase::train_(RegressionData &trainingData){ return false; }
114 
115 bool MLBase::train(RegressionData trainingData,RegressionData validationData){ return train_( trainingData, validationData ); }
116 
117 bool MLBase::train_(RegressionData &trainingData,RegressionData &validationData){ return false; }
118 
119 bool MLBase::train(TimeSeriesClassificationData trainingData){ return train_( trainingData ); }
120 
121 bool MLBase::train_(TimeSeriesClassificationData &trainingData){ return false; }
122 
123 bool MLBase::train(ClassificationDataStream trainingData){ return train_( trainingData ); }
124 
125 bool MLBase::train_(ClassificationDataStream &trainingData){ return false; }
126 
127 bool MLBase::train(UnlabelledData trainingData){ return train_( trainingData ); }
128 
129 bool MLBase::train_(UnlabelledData &trainingData){ return false; }
130 
131 bool MLBase::train(MatrixFloat data){ return train_( data ); }
132 
133 bool MLBase::train_(MatrixFloat &data){ return false; }
134 
135 bool MLBase::predict(VectorFloat inputVector){ return predict_( inputVector ); }
136 
137 bool MLBase::predict_(VectorFloat &inputVector){ return false; }
138 
139 bool MLBase::predict(MatrixFloat inputMatrix){ return predict_( inputMatrix ); }
140 
141 bool MLBase::predict_(MatrixFloat &inputMatrix){ return false; }
142 
143 bool MLBase::map(VectorFloat inputVector){ return map_( inputVector ); }
144 
145 bool MLBase::map_(VectorFloat &inputVector){ return false; }
146 
147 bool MLBase::reset(){ return true; }
148 
150  trained = false;
151  converged = false;
152  numInputDimensions = 0;
153  numOutputDimensions = 0;
154  numTrainingIterationsToConverge = 0;
155  rmsTrainingError = 0;
156  rmsValidationError = 0;
157  totalSquaredTrainingError = 0;
158  trainingResults.clear();
159  validationSetPrecision.clear();
160  validationSetRecall.clear();
161  validationSetAccuracy = 0;
162  return true;
163 }
164 
165 bool MLBase::print() const { std::cout << getModelAsString(); return true; }
166 
167 bool MLBase::save(const std::string &filename) const {
168 
169  std::fstream file;
170  file.open(filename.c_str(), std::ios::out);
171 
172  if( !save( file ) )
173  {
174  return false;
175  }
176 
177  file.close();
178 
179  return true;
180 }
181 
182 bool MLBase::save(std::fstream &file) const {
183  return false; //The base class returns false, as this should be overwritten by the inheriting class
184 }
185 
186 bool MLBase::saveModelToFile(const std::string &filename) const { return save( filename ); }
187 
188 bool MLBase::saveModelToFile(std::fstream &file) const { return save( file ); }
189 
190 bool MLBase::load(const std::string &filename){
191 
192  std::fstream file;
193  file.open(filename.c_str(), std::ios::in);
194 
195  if( !load( file ) ){
196  return false;
197  }
198 
199  //Close the file
200  file.close();
201 
202  return true;
203 }
204 
205 bool MLBase::load(std::fstream &file) {
206  return false; //The base class returns false, as this should be overwritten by the inheriting class
207 }
208 
209 bool MLBase::loadModelFromFile(const std::string &filename){ return load( filename ); }
210 
211 bool MLBase::loadModelFromFile(std::fstream &file){ return load( file ); }
212 
213 bool MLBase::getModel(std::ostream &stream) const { return true; }
214 
215 std::string MLBase::getModelAsString() const{
216  std::stringstream stream;
217  if( getModel( stream ) ){
218  return stream.str();
219  }
220  return "";
221 }
222 
223 DataType MLBase::getInputType() const { return inputType; }
224 
225 DataType MLBase::getOutputType() const { return outputType; }
226 
227 MLBase::BaseType MLBase::getType() const{ return baseType; }
228 
230 
231 UINT MLBase::getNumInputDimensions() const{ return numInputDimensions; }
232 
233 UINT MLBase::getNumOutputDimensions() const{ return numOutputDimensions; }
234 
236  if( trained ){
237  return numTrainingIterationsToConverge;
238  }
239  return 0;
240 }
241 
243  return minNumEpochs;
244 }
245 
247  return maxNumEpochs;
248 }
249 
250 UINT MLBase::getBatchSize() const{
251  return batchSize;
252 }
253 
255  return numRestarts;
256 }
257 
259  return validationSetSize;
260 }
261 
263  return learningRate;
264 }
265 
267  return rmsTrainingError;
268 }
269 
270 Float MLBase::getRootMeanSquaredTrainingError() const{
271  return getRMSTrainingError();
272 }
273 
275  return totalSquaredTrainingError;
276 }
277 
279  return rmsValidationError;
280 }
281 
283  return validationSetAccuracy;
284 }
285 
287  return validationSetPrecision;
288 }
289 
291  return validationSetRecall;
292 }
293 
294 bool MLBase::getTrained() const{ return trained; }
295 
296 bool MLBase::getModelTrained() const{ return getTrained(); }
297 
298 bool MLBase::getConverged() const { return converged; }
299 
300 bool MLBase::getScalingEnabled() const{ return useScaling; }
301 
302 bool MLBase::getIsBaseTypeClassifier() const{ return baseType==CLASSIFIER; }
303 
304 bool MLBase::getIsBaseTypeRegressifier() const{ return baseType==REGRESSIFIER; }
305 
306 bool MLBase::getIsBaseTypeClusterer() const{ return baseType==CLUSTERER; }
307 
308 bool MLBase::enableScaling(bool useScaling){ this->useScaling = useScaling; return true; }
309 
310 bool MLBase::getUseValidationSet() const { return useValidationSet; }
311 
313  return trainingLog.getInstanceLoggingEnabled();
314 }
315 
317  return testingLog.getInstanceLoggingEnabled();
318 }
319 
320 bool MLBase::setMaxNumEpochs(const UINT maxNumEpochs){
321  if( maxNumEpochs == 0 ){
322  warningLog << "setMaxNumEpochs(const UINT maxNumEpochs) - The maxNumEpochs must be greater than 0!" << std::endl;
323  return false;
324  }
325  this->maxNumEpochs = maxNumEpochs;
326  return true;
327 }
328 
329 bool MLBase::setMinNumEpochs(const UINT minNumEpochs){
330  this->minNumEpochs = minNumEpochs;
331  return true;
332 }
333 
334 bool MLBase::setBatchSize(const UINT batchSize){
335  this->batchSize = batchSize;
336  return true;
337 }
338 
339 bool MLBase::setNumRestarts(const UINT numRestarts){
340  this->numRestarts = numRestarts;
341  return true;
342 }
343 
344 bool MLBase::setMinChange(const Float minChange){
345  if( minChange < 0 ){
346  warningLog << "setMinChange(const Float minChange) - The minChange must be greater than or equal to 0!" << std::endl;
347  return false;
348  }
349  this->minChange = minChange;
350  return true;
351 }
352 
353 bool MLBase::setLearningRate(const Float learningRate){
354  if( learningRate > 0 ){
355  this->learningRate = learningRate;
356  return true;
357  }
358  return false;
359 }
360 
361 bool MLBase::setValidationSetSize(const UINT validationSetSize){
362 
363  if( validationSetSize > 0 && validationSetSize < 100 ){
364  this->validationSetSize = validationSetSize;
365  return true;
366  }
367 
368  warningLog << "setValidationSetSize(const UINT validationSetSize) - The validation size must be in the range [1 99]!" << std::endl;
369 
370  return false;
371 }
372 
373 bool MLBase::setUseValidationSet(const bool useValidationSet){
374  this->useValidationSet = useValidationSet;
375  return true;
376 }
377 
378 bool MLBase::setRandomiseTrainingOrder(const bool randomiseTrainingOrder){
379  this->randomiseTrainingOrder = randomiseTrainingOrder;
380  return true;
381 }
382 
383 bool MLBase::setTrainingLoggingEnabled(const bool loggingEnabled){
384  return this->trainingLog.setInstanceLoggingEnabled( loggingEnabled );
385 }
386 
387 bool MLBase::setTestingLoggingEnabled(const bool loggingEnabled){
388  return this->testingLog.setInstanceLoggingEnabled( loggingEnabled );
389 }
390 
392  return trainingResultsObserverManager.registerObserver( observer );
393 }
394 
396  return testResultsObserverManager.registerObserver( observer );
397 }
398 
400  return trainingResultsObserverManager.removeObserver( observer );
401 }
402 
404  return testResultsObserverManager.removeObserver( observer );
405 }
406 
408  return trainingResultsObserverManager.removeAllObservers();
409 }
410 
412  return testResultsObserverManager.removeAllObservers();
413 }
414 
416  return trainingResultsObserverManager.notifyObservers( data );
417 }
418 
420  return testResultsObserverManager.notifyObservers( data );
421 }
422 
424  return this;
425 }
426 
428  return this;
429 }
430 
432  return trainingResults;
433 }
434 
435 bool MLBase::saveBaseSettingsToFile( std::fstream &file ) const{
436 
437  if( !file.is_open() ){
438  errorLog << "saveBaseSettingsToFile(fstream &file) - The file is not open!" << std::endl;
439  return false;
440  }
441 
442  file << "Trained: " << trained << std::endl;
443  file << "UseScaling: " << useScaling << std::endl;
444  file << "NumInputDimensions: " << numInputDimensions << std::endl;
445  file << "NumOutputDimensions: " << numOutputDimensions << std::endl;
446  file << "NumTrainingIterationsToConverge: " << numTrainingIterationsToConverge << std::endl;
447  file << "MinNumEpochs: " << minNumEpochs << std::endl;
448  file << "MaxNumEpochs: " << maxNumEpochs << std::endl;
449  file << "ValidationSetSize: " << validationSetSize << std::endl;
450  file << "LearningRate: " << learningRate << std::endl;
451  file << "MinChange: " << minChange << std::endl;
452  file << "UseValidationSet: " << useValidationSet << std::endl;
453  file << "RandomiseTrainingOrder: " << randomiseTrainingOrder << std::endl;
454 
455  return true;
456 }
457 
458 bool MLBase::loadBaseSettingsFromFile( std::fstream &file ){
459 
460  //Clear any previous setup
461  clear();
462 
463  if( !file.is_open() ){
464  errorLog << "loadBaseSettingsFromFile(fstream &file) - The file is not open!" << std::endl;
465  return false;
466  }
467 
468  std::string word;
469 
470  //Load the trained state
471  file >> word;
472  if( word != "Trained:" ){
473  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
474  return false;
475  }
476  file >> trained;
477 
478  //Load the scaling state
479  file >> word;
480  if( word != "UseScaling:" ){
481  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read UseScaling header!" << std::endl;
482  return false;
483  }
484  file >> useScaling;
485 
486  //Load the NumInputDimensions
487  file >> word;
488  if( word != "NumInputDimensions:" ){
489  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
490  return false;
491  }
492  file >> numInputDimensions;
493 
494  //Load the NumOutputDimensions
495  file >> word;
496  if( word != "NumOutputDimensions:" ){
497  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read NumOutputDimensions header!" << std::endl;
498  return false;
499  }
500  file >> numOutputDimensions;
501 
502  //Load the numTrainingIterationsToConverge
503  file >> word;
504  if( word != "NumTrainingIterationsToConverge:" ){
505  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read NumTrainingIterationsToConverge header!" << std::endl;
506  return false;
507  }
508  file >> numTrainingIterationsToConverge;
509 
510  //Load the MinNumEpochs
511  file >> word;
512  if( word != "MinNumEpochs:" ){
513  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read MinNumEpochs header!" << std::endl;
514  return false;
515  }
516  file >> minNumEpochs;
517 
518  //Load the maxNumEpochs
519  file >> word;
520  if( word != "MaxNumEpochs:" ){
521  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read MaxNumEpochs header!" << std::endl;
522  return false;
523  }
524  file >> maxNumEpochs;
525 
526  //Load the ValidationSetSize
527  file >> word;
528  if( word != "ValidationSetSize:" ){
529  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read ValidationSetSize header!" << std::endl;
530  return false;
531  }
532  file >> validationSetSize;
533 
534  //Load the LearningRate
535  file >> word;
536  if( word != "LearningRate:" ){
537  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read LearningRate header!" << std::endl;
538  return false;
539  }
540  file >> learningRate;
541 
542  //Load the MinChange
543  file >> word;
544  if( word != "MinChange:" ){
545  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read MinChange header!" << std::endl;
546  return false;
547  }
548  file >> minChange;
549 
550  //Load the UseValidationSet
551  file >> word;
552  if( word != "UseValidationSet:" ){
553  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read UseValidationSet header!" << std::endl;
554  return false;
555  }
556  file >> useValidationSet;
557 
558  //Load the RandomiseTrainingOrder
559  file >> word;
560  if( word != "RandomiseTrainingOrder:" ){
561  errorLog << "loadBaseSettingsFromFile(fstream &file) - Failed to read RandomiseTrainingOrder header!" << std::endl;
562  return false;
563  }
564  file >> randomiseTrainingOrder;
565 
566  return true;
567 }
568 
569 GRT_END_NAMESPACE
bool saveBaseSettingsToFile(std::fstream &file) const
Definition: MLBase.cpp:435
bool setLearningRate(const Float learningRate)
Definition: MLBase.cpp:353
virtual bool predict(VectorFloat inputVector)
Definition: MLBase.cpp:135
bool setRandomiseTrainingOrder(const bool randomiseTrainingOrder)
Definition: MLBase.cpp:378
bool notifyTrainingResultsObservers(const TrainingResult &data)
Definition: MLBase.cpp:415
MLBase(const std::string &id="", const BaseType type=BASE_TYPE_NOT_SET)
Definition: MLBase.cpp:26
bool registerTrainingResultsObserver(Observer< TrainingResult > &observer)
Definition: MLBase.cpp:391
Float getRMSValidationError() const
Definition: MLBase.cpp:278
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:137
virtual bool reset()
Definition: MLBase.cpp:147
bool setTrainingLoggingEnabled(const bool loggingEnabled)
Definition: MLBase.cpp:383
Float getLearningRate() const
Definition: MLBase.cpp:262
bool getTrainingLoggingEnabled() const
Definition: MLBase.cpp:312
bool removeAllTestObservers()
Definition: MLBase.cpp:411
bool setNumRestarts(const UINT numRestarts)
Definition: MLBase.cpp:339
bool enableScaling(const bool useScaling)
Definition: MLBase.cpp:308
DataType getOutputType() const
Definition: MLBase.cpp:225
virtual bool getModel(std::ostream &stream) const
Definition: MLBase.cpp:213
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:107
bool getTrained() const
Definition: MLBase.cpp:294
virtual bool setKey(const std::string &key)
sets the key that gets written at the start of each message, this will be written in the format &#39;key ...
Definition: Log.h:166
bool getConverged() const
Definition: MLBase.cpp:298
UINT getMinNumEpochs() const
Definition: MLBase.cpp:242
UINT getNumOutputDimensions() const
Definition: MLBase.cpp:233
bool getScalingEnabled() const
Definition: MLBase.cpp:300
bool registerTestResultsObserver(Observer< TestInstanceResult > &observer)
Definition: MLBase.cpp:395
UINT getNumRestarts() const
Definition: MLBase.cpp:254
virtual bool save(const std::string &filename) const
Definition: MLBase.cpp:167
bool setMinChange(const Float minChange)
Definition: MLBase.cpp:344
UINT getValidationSetSize() const
Definition: MLBase.cpp:258
bool getTestingLoggingEnabled() const
Definition: MLBase.cpp:316
bool getUseValidationSet() const
Definition: MLBase.cpp:310
Float getRMSTrainingError() const
Definition: MLBase.cpp:266
UINT getMaxNumEpochs() const
Definition: MLBase.cpp:246
virtual bool setInstanceLoggingEnabled(const bool loggingEnabled)
sets if logging is enabled for this specific instance
Definition: Log.h:201
Float getTotalSquaredTrainingError() const
Definition: MLBase.cpp:274
bool setValidationSetSize(const UINT validationSetSize)
Definition: MLBase.cpp:361
bool copyMLBaseVariables(const MLBase *mlBase)
Definition: MLBase.cpp:62
virtual bool print() const
Definition: MLBase.cpp:165
virtual bool getInstanceLoggingEnabled() const
returns true if logging is enabled for this specific instance
Definition: Log.h:141
Float getValidationSetAccuracy() const
Definition: MLBase.cpp:282
BaseType getType() const
Definition: MLBase.cpp:227
bool getIsBaseTypeClusterer() const
Definition: MLBase.cpp:306
virtual std::string getModelAsString() const
Definition: MLBase.cpp:215
bool setMinNumEpochs(const UINT minNumEpochs)
Definition: MLBase.cpp:329
bool loadBaseSettingsFromFile(std::fstream &file)
Definition: MLBase.cpp:458
UINT getBatchSize() const
Definition: MLBase.cpp:250
MLBase * getMLBasePointer()
Definition: MLBase.cpp:423
bool setBatchSize(const UINT batchSize)
Definition: MLBase.cpp:334
bool removeAllTrainingObservers()
Definition: MLBase.cpp:407
UINT getNumInputFeatures() const
Definition: MLBase.cpp:229
virtual bool clear()
Definition: MLBase.cpp:149
virtual ~MLBase(void)
Definition: MLBase.cpp:58
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:109
bool notifyTestResultsObservers(const TestInstanceResult &data)
Definition: MLBase.cpp:419
bool copyGRTBaseVariables(const GRTBase *GRTBase)
Definition: GRTBase.cpp:43
bool getIsBaseTypeClassifier() const
Definition: MLBase.cpp:302
VectorFloat getValidationSetPrecision() const
Definition: MLBase.cpp:286
DataType getInputType() const
Definition: MLBase.cpp:223
bool removeTrainingResultsObserver(const Observer< TrainingResult > &observer)
Definition: MLBase.cpp:399
UINT getNumInputDimensions() const
Definition: MLBase.cpp:231
bool removeTestResultsObserver(const Observer< TestInstanceResult > &observer)
Definition: MLBase.cpp:403
virtual bool map_(VectorFloat &inputVector)
Definition: MLBase.cpp:145
UINT getNumTrainingIterationsToConverge() const
Definition: MLBase.cpp:235
bool setUseValidationSet(const bool useValidationSet)
Definition: MLBase.cpp:373
bool setTestingLoggingEnabled(const bool loggingEnabled)
Definition: MLBase.cpp:387
virtual bool load(const std::string &filename)
Definition: MLBase.cpp:190
bool setMaxNumEpochs(const UINT maxNumEpochs)
Definition: MLBase.cpp:320
Vector< TrainingResult > getTrainingResults() const
Definition: MLBase.cpp:431
virtual bool map(VectorFloat inputVector)
Definition: MLBase.cpp:143
VectorFloat getValidationSetRecall() const
Definition: MLBase.cpp:290
This is the main base class that all GRT machine learning algorithms should inherit from...
Definition: MLBase.h:72
bool getIsBaseTypeRegressifier() const
Definition: MLBase.cpp:304