GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DiscreteHiddenMarkovModel.h
Go to the documentation of this file.
1 
31 #ifndef GRT_DISCRETE_HIDDEN_MARKOV_MODEL_HEADER
32 #define GRT_DISCRETE_HIDDEN_MARKOV_MODEL_HEADER
33 
34 #include "HMMEnums.h"
35 #include "../../Util/GRTCommon.h"
36 #include "../../CoreModules/MLBase.h"
37 
38 GRT_BEGIN_NAMESPACE
39 
40 //This class is used for the HMM batch training
41 class GRT_API HMMTrainingObject{
42  public:
44  pk = 0.0;
45  }
47  MatrixFloat alpha; //The forward estimate matrix
48  MatrixFloat beta; //The backward estimate matrix
49  VectorFloat c; //The scaling coefficient Vector
50  Float pk; //P( O | Model )
51 };
52 
53 class GRT_API DiscreteHiddenMarkovModel : public MLBase {
54 
55 public:
57 
58  DiscreteHiddenMarkovModel(const UINT numStates,const UINT numSymbols,const UINT modelType,const UINT delta);
59 
60  DiscreteHiddenMarkovModel(const MatrixFloat &a,const MatrixFloat &b,const VectorFloat &pi,const UINT modelType,const UINT delta);
61 
63 
64  virtual ~DiscreteHiddenMarkovModel();
65 
66  Float predict(const UINT newSample);
67  Float predict(const Vector<UINT> &obs);
68 
69  bool resetModel(const UINT numStates,const UINT numSymbols,const UINT modelType,const UINT delta);
70  bool train(const Vector< Vector<UINT> > &trainingData);
71 
72  virtual bool reset();
73 
80  virtual bool save( std::fstream &file ) const;
81 
88  virtual bool load( std::fstream &file );
89 
90  bool randomizeMatrices(const UINT numStates,const UINT numSymbols);
91  Float predictLogLikelihood(const Vector<UINT> &obs);
92  bool forwardBackward(HMMTrainingObject &trainingObject,const Vector<UINT> &obs);
93  bool train_(const Vector< Vector<UINT> > &obs,const UINT maxIter, UINT &currentIter,Float &newLoglikelihood);
94  virtual bool print() const;
95 
96  VectorFloat getTrainingIterationLog() const;
97 
98  using MLBase::save;
99  using MLBase::load;
100 
101 protected:
102  UINT numStates; //The number of states for this model
103  UINT numSymbols; //The number of symbols for this model
104  MatrixFloat a; //The transitions probability matrix
105  MatrixFloat b; //The emissions probability matrix
106  VectorFloat pi; //The state start probability Vector
107  VectorFloat trainingIterationLog; //Stores the loglikelihood at each iteration the BaumWelch algorithm
108 
109  UINT modelType;
110  UINT delta; //The number of states a model can move to in a LeftRight model
111  UINT numRandomTrainingIterations; //The number of training loops to find the best starting values
112  Float logLikelihood; //The log likelihood of an observation sequence given the modal, calculated by the forward method
113  Float cThreshold; //The classification threshold for this model
114  CircularBuffer<UINT> observationSequence;
115  Vector< UINT > estimatedStates;
116 };
117 
118 GRT_END_NAMESPACE
119 
120 #endif //GRT_HIDDEN_MARKOV_MODEL_HEADER
virtual bool predict(VectorFloat inputVector)
Definition: MLBase.cpp:113
virtual bool reset()
Definition: MLBase.cpp:125
This class acts as the main interface for using a Hidden Markov Model.
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:89
virtual bool save(const std::string filename) const
Definition: MLBase.cpp:143
virtual bool load(const std::string filename)
Definition: MLBase.cpp:167
virtual bool print() const
Definition: MLBase.cpp:141
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:91
Definition: MLBase.h:70