33 #ifndef GRT_BERNOULLI_RBM_HEADER 34 #define GRT_BERNOULLI_RBM_HEADER 36 #include "../../Util/GRTTypedefs.h" 37 #include "../../DataStructures/MatrixFloat.h" 38 #include "../../CoreModules/MLBase.h" 45 BernoulliRBM(
const UINT numHiddenUnits = 100,
const UINT maxNumEpochs = 1000,
const Float learningRate = 1,
const Float learningRateUpdate = 1,
const Float momentum = 0.5,
const bool useScaling =
true,
const bool randomiseTrainingOrder =
true);
103 virtual bool clear();
111 virtual bool save( std::fstream &file )
const;
119 virtual bool load( std::fstream &file );
123 virtual bool print()
const;
125 bool getRandomizeWeightsForTraining()
const;
126 UINT getNumVisibleUnits()
const;
127 UINT getNumHiddenUnits()
const;
131 bool setNumHiddenUnits(
const UINT numHiddenUnits);
132 bool setMomentum(
const Float momentum);
133 bool setLearningRateUpdate(
const Float learningRateUpdate);
134 bool setRandomizeWeightsForTraining(
const bool randomizeWeightsForTraining);
136 bool setBatchStepSize(
const UINT batchStepSize);
147 bool loadLegacyModelFromFile( std::fstream &file );
149 inline Float sigmoidRandom(
const Float &x){
150 return (1.0 / (1.0 + exp(-x)) > rand.getRandomNumberUniform(0.0,1.0)) ? 1.0 : 0.0;
153 bool randomizeWeightsForTraining;
154 UINT numVisibleUnits;
159 Float learningRateUpdate;
184 #endif //GRT_BERNOULLI_RBM_HEADER virtual bool predict(VectorFloat inputVector)
virtual bool predict_(VectorFloat &inputVector)
This file contains the Random class, a useful wrapper for generating cross platform random functions...
virtual bool train(ClassificationData trainingData)
virtual bool save(const std::string &filename) const
virtual bool print() const
bool setBatchSize(const UINT batchSize)
virtual bool train_(ClassificationData &trainingData)
virtual bool load(const std::string &filename)
This is the main base class that all GRT machine learning algorithms should inherit from...