GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
MLBase.h
Go to the documentation of this file.
1 
27 #ifndef GRT_MLBASE_HEADER
28 #define GRT_MLBASE_HEADER
29 
30 #include "GRTBase.h"
31 #include "../Util/Metrics.h"
32 #include "../DataStructures/UnlabelledData.h"
33 #include "../DataStructures/ClassificationData.h"
34 #include "../DataStructures/ClassificationDataStream.h"
35 #include "../DataStructures/RegressionData.h"
36 #include "../DataStructures/TimeSeriesClassificationData.h"
37 
38 GRT_BEGIN_NAMESPACE
39 
40 #define DEFAULT_NULL_LIKELIHOOD_VALUE 0
41 #define DEFAULT_NULL_DISTANCE_VALUE 0
42 
46 class GRT_API TrainingResultsObserverManager : public ObserverManager< TrainingResult >
47 {
48  public:
50  }
52 
53 };
54 
58 class GRT_API TestResultsObserverManager : public ObserverManager< TestInstanceResult >
59 {
60  public:
62  }
63  virtual ~TestResultsObserverManager(){}
64 
65 };
66 
72 class GRT_API MLBase : public GRTBase, public Observer< TrainingResult >, public Observer< TestInstanceResult >
73 {
74 public:
75  enum BaseType{BASE_TYPE_NOT_SET=0,CLASSIFIER,REGRESSIFIER,CLUSTERER,PRE_PROCSSING,POST_PROCESSING,FEATURE_EXTRACTION,CONTEXT};
76 
82  MLBase( const std::string &id = "", const BaseType type = BASE_TYPE_NOT_SET );
83 
87  virtual ~MLBase(void);
88 
95  bool copyMLBaseVariables(const MLBase *mlBase);
96 
104  virtual bool train(ClassificationData trainingData);
105 
112  virtual bool train_(ClassificationData &trainingData);
113 
121  virtual bool train(RegressionData trainingData);
122 
129  virtual bool train_(RegressionData &trainingData);
130 
139  virtual bool train(RegressionData trainingData,RegressionData validationData);
140 
148  virtual bool train_(RegressionData &trainingData,RegressionData &validationData);
149 
157  virtual bool train(TimeSeriesClassificationData trainingData);
158 
165  virtual bool train_(TimeSeriesClassificationData &trainingData);
166 
174  virtual bool train(ClassificationDataStream trainingData);
175 
182  virtual bool train_(ClassificationDataStream &trainingData);
183 
191  virtual bool train(UnlabelledData trainingData);
192 
199  virtual bool train_(UnlabelledData &trainingData);
200 
208  virtual bool train(MatrixFloat data);
209 
216  virtual bool train_(MatrixFloat &data);
217 
225  virtual bool predict(VectorFloat inputVector);
226 
233  virtual bool predict_(VectorFloat &inputVector);
234 
242  virtual bool predict(MatrixFloat inputMatrix);
243 
250  virtual bool predict_(MatrixFloat &inputMatrix);
251 
259  virtual bool map(VectorFloat inputVector);
260 
267  virtual bool map_(VectorFloat &inputVector);
268 
276  virtual bool reset();
277 
284  virtual bool clear();
285 
292  virtual bool print() const;
293 
300  virtual bool save(const std::string &filename) const;
301 
308  virtual bool load(const std::string &filename);
309 
317  virtual bool save(std::fstream &file) const;
318 
326  virtual bool load(std::fstream &file);
327 
333  GRT_DEPRECATED_MSG( "saveModelToFile(std::string filename) is deprecated, use save(const std::string &filename) instead", virtual bool saveModelToFile(const std::string &filename) const );
334 
340  GRT_DEPRECATED_MSG( "saveModelToFile(std::fstream &file) is deprecated, use save(std::fstream &file) instead", virtual bool saveModelToFile(std::fstream &file) const );
341 
347  GRT_DEPRECATED_MSG( "loadModelFromFile(std::string filename) is deprecated, use load(const std::string &filename) instead",virtual bool loadModelFromFile(const std::string &filename) );
348 
354  GRT_DEPRECATED_MSG( "loadModelFromFile(std::fstream &file) is deprecated, use load(std::fstream &file) instead",virtual bool loadModelFromFile(std::fstream &file) );
355 
363  virtual bool getModel(std::ostream &stream) const;
364 
370  virtual std::string getModelAsString() const;
371 
377  DataType getInputType() const;
378 
384  DataType getOutputType() const;
385 
391  BaseType getType() const;
392 
399  UINT getNumInputFeatures() const;
400 
406  UINT getNumInputDimensions() const;
407 
413  UINT getNumOutputDimensions() const;
414 
421  UINT getMinNumEpochs() const;
422 
429  UINT getMaxNumEpochs() const;
430 
436  UINT getBatchSize() const;
437 
442  UINT getNumRestarts() const;
443 
451  UINT getValidationSetSize() const;
452 
458  UINT getNumTrainingIterationsToConverge() const;
459 
465  Float getMinChange() const;
466 
472  Float getLearningRate() const;
473 
479  Float getRMSTrainingError() const;
480 
487  GRT_DEPRECATED_MSG( "getRootMeanSquaredTrainingError() is deprecated, use getRMSTrainingError() instead", Float getRootMeanSquaredTrainingError() const );
488 
494  Float getTotalSquaredTrainingError() const;
495 
501  Float getRMSValidationError() const;
502 
508  Float getValidationSetAccuracy() const;
509 
515  VectorFloat getValidationSetPrecision() const;
516 
522  VectorFloat getValidationSetRecall() const;
523 
533  bool getUseValidationSet() const;
534 
541  bool getRandomiseTrainingOrder() const;
542 
548  bool getTrained() const;
549 
556  GRT_DEPRECATED_MSG( "getModelTrained() is deprecated, use getTrained() instead", bool getModelTrained() const );
557 
564  bool getConverged() const;
565 
571  bool getScalingEnabled() const;
572 
578  bool getIsBaseTypeClassifier() const;
579 
585  bool getIsBaseTypeRegressifier() const;
586 
592  bool getIsBaseTypeClusterer() const;
593 
599  bool getTrainingLoggingEnabled() const;
600 
606  bool getTestingLoggingEnabled() const;
607 
613  bool enableScaling(const bool useScaling);
614 
622  bool setMaxNumEpochs(const UINT maxNumEpochs);
623 
630  bool setBatchSize(const UINT batchSize);
631 
638  bool setMinNumEpochs(const UINT minNumEpochs);
639 
647  bool setNumRestarts(const UINT numRestarts);
648 
656  bool setMinChange(const Float minChange);
657 
665  bool setLearningRate(const Float learningRate);
666 
675  bool setUseValidationSet(const bool useValidationSet);
676 
685  bool setValidationSetSize(const UINT validationSetSize);
686 
694  bool setRandomiseTrainingOrder(const bool randomiseTrainingOrder);
695 
703  bool setTrainingLoggingEnabled(const bool loggingEnabled);
704 
712  bool setTestingLoggingEnabled(const bool loggingEnabled);
713 
720  bool registerTrainingResultsObserver( Observer< TrainingResult > &observer );
721 
728  bool registerTestResultsObserver( Observer< TestInstanceResult > &observer );
729 
736  bool removeTrainingResultsObserver( const Observer< TrainingResult > &observer );
737 
744  bool removeTestResultsObserver( const Observer< TestInstanceResult > &observer );
745 
751  bool removeAllTrainingObservers();
752 
758  bool removeAllTestObservers();
759 
766  bool notifyTrainingResultsObservers( const TrainingResult &data );
767 
774  bool notifyTestResultsObservers( const TestInstanceResult &data );
775 
781  MLBase* getMLBasePointer();
782 
788  const MLBase* getMLBasePointer() const;
789 
795  Vector< TrainingResult > getTrainingResults() const;
796 
797 protected:
798 
804  bool saveBaseSettingsToFile( std::fstream &file ) const;
805 
811  bool loadBaseSettingsFromFile( std::fstream &file );
812 
813  bool trained;
814  bool useScaling;
815  bool converged;
816  DataType inputType;
817  DataType outputType;
818  BaseType baseType;
819  UINT numInputDimensions;
820  UINT numOutputDimensions;
821  UINT numTrainingIterationsToConverge;
822  UINT minNumEpochs;
823  UINT maxNumEpochs;
824  UINT batchSize;
825  UINT validationSetSize;
826  UINT numRestarts;
827  Float learningRate;
828  Float minChange;
829  Float rmsTrainingError;
830  Float rmsValidationError;
831  Float totalSquaredTrainingError;
832  Float validationSetAccuracy;
833  bool useValidationSet;
834  bool randomiseTrainingOrder;
835  VectorFloat validationSetPrecision;
836  VectorFloat validationSetRecall;
837  Random random;
838  Vector< TrainingResult > trainingResults;
839  TrainingResultsObserverManager trainingResultsObserverManager;
840  TestResultsObserverManager testResultsObserverManager;
841  TrainingLog trainingLog;
842  TestingLog testingLog;
843 
844 };
845 
846 GRT_END_NAMESPACE
847 
848 #endif //GRT_MLBASE_HEADER
This file contains the Random class, a useful wrapper for generating cross platform random functions...
Definition: Random.h:46
GRT_DEPRECATED_MSG("getClassType is deprecated, use getId() instead!", std::string getClassType() const )
This file contains the GRTBase class. This is the core base class for all the GRT modules...
This is the main base class that all GRT machine learning algorithms should inherit from...
Definition: MLBase.h:72