GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
BernoulliRBM.h
Go to the documentation of this file.
1 
33 #ifndef GRT_BERNOULLI_RBM_HEADER
34 #define GRT_BERNOULLI_RBM_HEADER
35 
36 #include "../../Util/GRTTypedefs.h"
37 #include "../../DataStructures/MatrixFloat.h"
38 #include "../../CoreModules/MLBase.h"
39 
40 GRT_BEGIN_NAMESPACE
41 
42 class GRT_API BernoulliRBM : public MLBase{
43 
44  public:
45  BernoulliRBM(const UINT numHiddenUnits = 100,const UINT maxNumEpochs = 1000,const Float learningRate = 1,const Float learningRateUpdate = 1,const Float momentum = 0.5,const bool useScaling = true,const bool randomiseTrainingOrder = true);
46 
47  virtual ~BernoulliRBM();
48 
57  bool predict_(VectorFloat &inputData);
58 
68  bool predict_(VectorFloat &inputData,VectorFloat &outputData);
69 
80  bool predict_(const MatrixFloat &inputData,MatrixFloat &outputData,const UINT rowIndex);
81 
88  virtual bool train_(MatrixFloat &data);
89 
96  virtual bool reset();
97 
103  virtual bool clear();
104 
111  virtual bool save( std::fstream &file ) const;
112 
119  virtual bool load( std::fstream &file );
120 
121  bool reconstruct(const VectorFloat &input,VectorFloat &output);
122 
123  virtual bool print() const;
124 
125  bool getRandomizeWeightsForTraining() const;
126  UINT getNumVisibleUnits() const;
127  UINT getNumHiddenUnits() const;
128  VectorFloat getOutputData() const;
129  const MatrixFloat& getWeights() const;
130 
131  bool setNumHiddenUnits(const UINT numHiddenUnits);
132  bool setMomentum(const Float momentum);
133  bool setLearningRateUpdate(const Float learningRateUpdate);
134  bool setRandomizeWeightsForTraining(const bool randomizeWeightsForTraining);
135  bool setBatchSize(const UINT batchSize);
136  bool setBatchStepSize(const UINT batchStepSize);
137 
138  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
139  using MLBase::save;
140  using MLBase::load;
141  using MLBase::train;
142  using MLBase::predict;
143  using MLBase::train_;
144  using MLBase::predict_;
145 
146 protected:
147  bool loadLegacyModelFromFile( std::fstream &file );
148 
149  inline Float sigmoidRandom(const Float &x){
150  return (1.0 / (1.0 + exp(-x)) > rand.getRandomNumberUniform(0.0,1.0)) ? 1.0 : 0.0;
151  }
152 
153  bool randomizeWeightsForTraining;
154  UINT numVisibleUnits;
155  UINT numHiddenUnits;
156  UINT batchSize;
157  UINT batchStepSize;
158  Float momentum;
159  Float learningRateUpdate;
160  MatrixFloat weightsMatrix;
161  VectorFloat visibleLayerBias;
162  VectorFloat hiddenLayerBias;
163  VectorFloat ph_mean;
164  VectorFloat ph_sample;
165  VectorFloat nv_means;
166  VectorFloat nv_samples;
167  VectorFloat nh_means;
168  VectorFloat nh_samples;
169  VectorFloat outputData;
170  Vector<MinMax> ranges;
171  Random rand;
172 
173  struct BatchIndexs{
174  UINT startIndex;
175  UINT endIndex;
176  UINT batchSize;
177  };
178  typedef struct BatchIndexs BatchIndexs;
179 
180 };
181 
182 GRT_END_NAMESPACE
183 
184 #endif //GRT_BERNOULLI_RBM_HEADER
virtual bool predict(VectorFloat inputVector)
Definition: MLBase.cpp:135
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:137
virtual bool reset()
Definition: MLBase.cpp:147
This file contains the Random class, a useful wrapper for generating cross platform random functions...
Definition: Random.h:46
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:107
virtual bool save(const std::string &filename) const
Definition: MLBase.cpp:167
virtual bool print() const
Definition: MLBase.cpp:165
bool setBatchSize(const UINT batchSize)
Definition: MLBase.cpp:334
virtual bool clear()
Definition: MLBase.cpp:149
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:109
virtual bool load(const std::string &filename)
Definition: MLBase.cpp:190
This is the main base class that all GRT machine learning algorithms should inherit from...
Definition: MLBase.h:72