GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
MLBase.h
Go to the documentation of this file.
1 
31 #ifndef GRT_MLBASE_HEADER
32 #define GRT_MLBASE_HEADER
33 
34 #include "GRTBase.h"
35 #include "../DataStructures/UnlabelledData.h"
36 #include "../DataStructures/ClassificationData.h"
37 #include "../DataStructures/ClassificationDataStream.h"
38 #include "../DataStructures/RegressionData.h"
39 #include "../DataStructures/TimeSeriesClassificationData.h"
40 
41 GRT_BEGIN_NAMESPACE
42 
43 #define DEFAULT_NULL_LIKELIHOOD_VALUE 0
44 #define DEFAULT_NULL_DISTANCE_VALUE 0
45 
49 class GRT_API TrainingResultsObserverManager : public ObserverManager< TrainingResult >
50 {
51  public:
53  }
55 
56 };
57 
61 class GRT_API TestResultsObserverManager : public ObserverManager< TestInstanceResult >
62 {
63  public:
65  }
66  virtual ~TestResultsObserverManager(){}
67 
68 };
69 
70 class GRT_API MLBase : public GRTBase, public Observer< TrainingResult >, public Observer< TestInstanceResult >
71 {
72 public:
73  enum BaseTypes{BASE_TYPE_NOT_SET=0,CLASSIFIER,REGRESSIFIER,CLUSTERER};
74 
78  MLBase(void);
79 
83  virtual ~MLBase(void);
84 
91  bool copyMLBaseVariables(const MLBase *mlBase);
92 
100  virtual bool train(ClassificationData trainingData);
101 
108  virtual bool train_(ClassificationData &trainingData);
109 
117  virtual bool train(RegressionData trainingData);
118 
125  virtual bool train_(RegressionData &trainingData);
126 
134  virtual bool train(TimeSeriesClassificationData trainingData);
135 
142  virtual bool train_(TimeSeriesClassificationData &trainingData);
143 
151  virtual bool train(ClassificationDataStream trainingData);
152 
159  virtual bool train_(ClassificationDataStream &trainingData);
160 
168  virtual bool train(UnlabelledData trainingData);
169 
176  virtual bool train_(UnlabelledData &trainingData);
177 
185  virtual bool train(MatrixFloat data);
186 
193  virtual bool train_(MatrixFloat &data);
194 
202  virtual bool predict(VectorFloat inputVector);
203 
210  virtual bool predict_(VectorFloat &inputVector);
211 
219  virtual bool predict(MatrixFloat inputMatrix);
220 
227  virtual bool predict_(MatrixFloat &inputMatrix);
228 
236  virtual bool map(VectorFloat inputVector);
237 
244  virtual bool map_(VectorFloat &inputVector);
245 
253  virtual bool reset();
254 
261  virtual bool clear();
262 
269  virtual bool print() const;
270 
277  virtual bool save(const std::string filename) const;
278 
285  virtual bool load(const std::string filename);
286 
294  virtual bool save(std::fstream &file) const;
295 
303  virtual bool load(std::fstream &file);
304 
310  GRT_DEPRECATED_MSG( "saveModelToFile(std::string filename) is deprecated, use save(std::string filename) instead", virtual bool saveModelToFile(std::string filename) const );
311 
317  GRT_DEPRECATED_MSG( "saveModelToFile(std::fstream &file) is deprecated, use save(std::fstream &file) instead", virtual bool saveModelToFile(std::fstream &file) const );
318 
324  GRT_DEPRECATED_MSG( "loadModelFromFile(std::string filename) is deprecated, use load(std::string filename) instead",virtual bool loadModelFromFile(std::string filename) );
325 
331  GRT_DEPRECATED_MSG( "loadModelFromFile(std::fstream &file) is deprecated, use load(std::fstream &file) instead",virtual bool loadModelFromFile(std::fstream &file) );
332 
340  virtual bool getModel(std::ostream &stream) const;
341 
353  Float inline scale(const Float &x,const Float &minSource,const Float &maxSource,const Float &minTarget,const Float &maxTarget,const bool constrain=false){
354  if( constrain ){
355  if( x <= minSource ) return minTarget;
356  if( x >= maxSource ) return maxTarget;
357  }
358  if( minSource == maxSource ) return minTarget;
359  return (((x-minSource)*(maxTarget-minTarget))/(maxSource-minSource))+minTarget;
360  }
361 
367  virtual std::string getModelAsString() const;
368 
374  DataType getInputType() const;
375 
381  DataType getOutputType() const;
382 
388  UINT getBaseType() const;
389 
396  UINT getNumInputFeatures() const;
397 
403  UINT getNumInputDimensions() const;
404 
410  UINT getNumOutputDimensions() const;
411 
418  UINT getMinNumEpochs() const;
419 
426  UINT getMaxNumEpochs() const;
427 
435  UINT getValidationSetSize() const;
436 
442  UINT getNumTrainingIterationsToConverge() const;
443 
449  Float getMinChange() const;
450 
456  Float getLearningRate() const;
457 
463  Float getRootMeanSquaredTrainingError() const;
464 
470  Float getTotalSquaredTrainingError() const;
471 
477  Float getValidationSetAccuracy() const;
478 
484  VectorFloat getValidationSetPrecision() const;
485 
491  VectorFloat getValidationSetRecall() const;
492 
502  bool getUseValidationSet() const;
503 
510  bool getRandomiseTrainingOrder() const;
511 
517  bool getTrained() const;
518 
524  bool getModelTrained() const;
525 
531  bool getScalingEnabled() const;
532 
538  bool getIsBaseTypeClassifier() const;
539 
545  bool getIsBaseTypeRegressifier() const;
546 
552  bool getIsBaseTypeClusterer() const;
553 
559  bool enableScaling(const bool useScaling);
560 
568  bool setMaxNumEpochs(const UINT maxNumEpochs);
569 
576  bool setMinNumEpochs(const UINT minNumEpochs);
577 
585  bool setMinChange(const Float minChange);
586 
594  bool setLearningRate(const Float learningRate);
595 
604  bool setUseValidationSet(const bool useValidationSet);
605 
614  bool setValidationSetSize(const UINT validationSetSize);
615 
623  bool setRandomiseTrainingOrder(const bool randomiseTrainingOrder);
624 
632  bool setTrainingLoggingEnabled(const bool loggingEnabled);
633 
640  bool registerTrainingResultsObserver( Observer< TrainingResult > &observer );
641 
648  bool registerTestResultsObserver( Observer< TestInstanceResult > &observer );
649 
656  bool removeTrainingResultsObserver( const Observer< TrainingResult > &observer );
657 
664  bool removeTestResultsObserver( const Observer< TestInstanceResult > &observer );
665 
671  bool removeAllTrainingObservers();
672 
678  bool removeAllTestObservers();
679 
686  bool notifyTrainingResultsObservers( const TrainingResult &data );
687 
694  bool notifyTestResultsObservers( const TestInstanceResult &data );
695 
701  MLBase* getMLBasePointer();
702 
708  const MLBase* getMLBasePointer() const;
709 
715  Vector< TrainingResult > getTrainingResults() const;
716 
717 protected:
718 
724  bool saveBaseSettingsToFile( std::fstream &file ) const;
725 
731  bool loadBaseSettingsFromFile( std::fstream &file );
732 
733  bool trained;
734  bool useScaling;
735  DataType inputType;
736  DataType outputType;
737  UINT baseType;
738  UINT numInputDimensions;
739  UINT numOutputDimensions;
740  UINT numTrainingIterationsToConverge;
741  UINT minNumEpochs;
742  UINT maxNumEpochs;
743  UINT validationSetSize;
744  Float learningRate;
745  Float minChange;
746  Float rootMeanSquaredTrainingError;
747  Float totalSquaredTrainingError;
748  Float validationSetAccuracy;
749  bool useValidationSet;
750  bool randomiseTrainingOrder;
751  VectorFloat validationSetPrecision;
752  VectorFloat validationSetRecall;
753  Random random;
754  std::vector< TrainingResult > trainingResults;
755  TrainingResultsObserverManager trainingResultsObserverManager;
756  TestResultsObserverManager testResultsObserverManager;
757 
758 };
759 
760 GRT_END_NAMESPACE
761 
762 #endif //GRT_MLBASE_HEADER
Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)
Definition: MLBase.h:353
Definition: Random.h:40
This file contains the GRTBase class. This is the core base class for all the GRT modules...
Definition: Vector.h:41
Definition: MLBase.h:70