21 #define GRT_DLL_EXPORTS 31 inputType = DATA_TYPE_UNKNOWN;
32 outputType = DATA_TYPE_UNKNOWN;
33 numInputDimensions = 0;
34 numOutputDimensions = 0;
39 validationSetSize = 20;
40 validationSetAccuracy = 0;
43 useValidationSet =
false;
44 randomiseTrainingOrder =
true;
46 rmsValidationError = 0;
47 totalSquaredTrainingError = 0;
50 trainingLog.
setKey(
"[TRAINING]");
51 testingLog.
setKey(
"[TESTING]");
53 trainingLog.
setKey(
"[TRAINING " +
id +
"]");
54 testingLog.
setKey(
"[TESTING " +
id +
"]");
65 errorLog <<
"copyMLBaseVariables(MLBase *mlBase) - mlBase pointer is NULL!" << std::endl;
70 errorLog <<
"copyMLBaseVariables(MLBase *mlBase) - Failed to copy GRT Base variables!" << std::endl;
74 this->trained = mlBase->trained;
75 this->converged = mlBase->converged;
76 this->useScaling = mlBase->useScaling;
77 this->baseType = mlBase->baseType;
78 this->inputType = mlBase->inputType;
79 this->outputType = mlBase->outputType;
80 this->numInputDimensions = mlBase->numInputDimensions;
81 this->numOutputDimensions = mlBase->numOutputDimensions;
82 this->minNumEpochs = mlBase->minNumEpochs;
83 this->maxNumEpochs = mlBase->maxNumEpochs;
84 this->batchSize = mlBase->batchSize;
85 this->numRestarts = mlBase->numRestarts;
86 this->validationSetSize = mlBase->validationSetSize;
87 this->validationSetAccuracy = mlBase->validationSetAccuracy;
88 this->validationSetPrecision = mlBase->validationSetPrecision;
89 this->validationSetRecall = mlBase->validationSetRecall;
90 this->minChange = mlBase->minChange;
91 this->learningRate = mlBase->learningRate;
92 this->rmsTrainingError = mlBase->rmsTrainingError;
93 this->rmsValidationError = mlBase->rmsValidationError;
94 this->totalSquaredTrainingError = mlBase->totalSquaredTrainingError;
95 this->useValidationSet = mlBase->useValidationSet;
96 this->randomiseTrainingOrder = mlBase->randomiseTrainingOrder;
97 this->numTrainingIterationsToConverge = mlBase->numTrainingIterationsToConverge;
98 this->trainingResults = mlBase->trainingResults;
99 this->trainingResultsObserverManager = mlBase->trainingResultsObserverManager;
100 this->testResultsObserverManager = mlBase->testResultsObserverManager;
101 this->trainingLog = mlBase->trainingLog;
102 this->testingLog = mlBase->testingLog;
152 numInputDimensions = 0;
153 numOutputDimensions = 0;
154 numTrainingIterationsToConverge = 0;
155 rmsTrainingError = 0;
156 rmsValidationError = 0;
157 totalSquaredTrainingError = 0;
158 trainingResults.clear();
159 validationSetPrecision.clear();
160 validationSetRecall.clear();
161 validationSetAccuracy = 0;
170 file.open(filename.c_str(), std::ios::out);
186 bool MLBase::saveModelToFile(
const std::string &filename)
const {
return save( filename ); }
188 bool MLBase::saveModelToFile(std::fstream &file)
const {
return save( file ); }
193 file.open(filename.c_str(), std::ios::in);
209 bool MLBase::loadModelFromFile(
const std::string &filename){
return load( filename ); }
211 bool MLBase::loadModelFromFile(std::fstream &file){
return load( file ); }
216 std::stringstream stream;
237 return numTrainingIterationsToConverge;
259 return validationSetSize;
267 return rmsTrainingError;
270 Float MLBase::getRootMeanSquaredTrainingError()
const{
275 return totalSquaredTrainingError;
279 return rmsValidationError;
283 return validationSetAccuracy;
287 return validationSetPrecision;
291 return validationSetRecall;
296 bool MLBase::getModelTrained()
const{
return getTrained(); }
321 if( maxNumEpochs == 0 ){
322 warningLog <<
"setMaxNumEpochs(const UINT maxNumEpochs) - The maxNumEpochs must be greater than 0!" << std::endl;
325 this->maxNumEpochs = maxNumEpochs;
330 this->minNumEpochs = minNumEpochs;
335 this->batchSize = batchSize;
340 this->numRestarts = numRestarts;
346 warningLog <<
"setMinChange(const Float minChange) - The minChange must be greater than or equal to 0!" << std::endl;
349 this->minChange = minChange;
354 if( learningRate > 0 ){
355 this->learningRate = learningRate;
363 if( validationSetSize > 0 && validationSetSize < 100 ){
364 this->validationSetSize = validationSetSize;
368 warningLog <<
"setValidationSetSize(const UINT validationSetSize) - The validation size must be in the range [1 99]!" << std::endl;
374 this->useValidationSet = useValidationSet;
379 this->randomiseTrainingOrder = randomiseTrainingOrder;
392 return trainingResultsObserverManager.registerObserver( observer );
396 return testResultsObserverManager.registerObserver( observer );
400 return trainingResultsObserverManager.removeObserver( observer );
404 return testResultsObserverManager.removeObserver( observer );
408 return trainingResultsObserverManager.removeAllObservers();
412 return testResultsObserverManager.removeAllObservers();
416 return trainingResultsObserverManager.notifyObservers( data );
420 return testResultsObserverManager.notifyObservers( data );
432 return trainingResults;
437 if( !file.is_open() ){
438 errorLog <<
"saveBaseSettingsToFile(fstream &file) - The file is not open!" << std::endl;
442 file <<
"Trained: " << trained << std::endl;
443 file <<
"UseScaling: " << useScaling << std::endl;
444 file <<
"NumInputDimensions: " << numInputDimensions << std::endl;
445 file <<
"NumOutputDimensions: " << numOutputDimensions << std::endl;
446 file <<
"NumTrainingIterationsToConverge: " << numTrainingIterationsToConverge << std::endl;
447 file <<
"MinNumEpochs: " << minNumEpochs << std::endl;
448 file <<
"MaxNumEpochs: " << maxNumEpochs << std::endl;
449 file <<
"ValidationSetSize: " << validationSetSize << std::endl;
450 file <<
"LearningRate: " << learningRate << std::endl;
451 file <<
"MinChange: " << minChange << std::endl;
452 file <<
"UseValidationSet: " << useValidationSet << std::endl;
453 file <<
"RandomiseTrainingOrder: " << randomiseTrainingOrder << std::endl;
463 if( !file.is_open() ){
464 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - The file is not open!" << std::endl;
472 if( word !=
"Trained:" ){
473 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
480 if( word !=
"UseScaling:" ){
481 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read UseScaling header!" << std::endl;
488 if( word !=
"NumInputDimensions:" ){
489 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
492 file >> numInputDimensions;
496 if( word !=
"NumOutputDimensions:" ){
497 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumOutputDimensions header!" << std::endl;
500 file >> numOutputDimensions;
504 if( word !=
"NumTrainingIterationsToConverge:" ){
505 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumTrainingIterationsToConverge header!" << std::endl;
508 file >> numTrainingIterationsToConverge;
512 if( word !=
"MinNumEpochs:" ){
513 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MinNumEpochs header!" << std::endl;
516 file >> minNumEpochs;
520 if( word !=
"MaxNumEpochs:" ){
521 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MaxNumEpochs header!" << std::endl;
524 file >> maxNumEpochs;
528 if( word !=
"ValidationSetSize:" ){
529 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read ValidationSetSize header!" << std::endl;
532 file >> validationSetSize;
536 if( word !=
"LearningRate:" ){
537 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read LearningRate header!" << std::endl;
540 file >> learningRate;
544 if( word !=
"MinChange:" ){
545 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MinChange header!" << std::endl;
552 if( word !=
"UseValidationSet:" ){
553 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read UseValidationSet header!" << std::endl;
556 file >> useValidationSet;
560 if( word !=
"RandomiseTrainingOrder:" ){
561 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read RandomiseTrainingOrder header!" << std::endl;
564 file >> randomiseTrainingOrder;
bool saveBaseSettingsToFile(std::fstream &file) const
bool setLearningRate(const Float learningRate)
virtual bool predict(VectorFloat inputVector)
bool setRandomiseTrainingOrder(const bool randomiseTrainingOrder)
bool notifyTrainingResultsObservers(const TrainingResult &data)
MLBase(const std::string &id="", const BaseType type=BASE_TYPE_NOT_SET)
bool registerTrainingResultsObserver(Observer< TrainingResult > &observer)
Float getRMSValidationError() const
virtual bool predict_(VectorFloat &inputVector)
bool setTrainingLoggingEnabled(const bool loggingEnabled)
Float getLearningRate() const
bool getTrainingLoggingEnabled() const
bool removeAllTestObservers()
bool setNumRestarts(const UINT numRestarts)
bool enableScaling(const bool useScaling)
DataType getOutputType() const
virtual bool getModel(std::ostream &stream) const
virtual bool train(ClassificationData trainingData)
virtual bool setKey(const std::string &key)
sets the key that gets written at the start of each message, this will be written in the format 'key ...
bool getConverged() const
UINT getMinNumEpochs() const
UINT getNumOutputDimensions() const
bool getScalingEnabled() const
bool registerTestResultsObserver(Observer< TestInstanceResult > &observer)
UINT getNumRestarts() const
virtual bool save(const std::string &filename) const
bool setMinChange(const Float minChange)
UINT getValidationSetSize() const
bool getTestingLoggingEnabled() const
bool getUseValidationSet() const
Float getRMSTrainingError() const
UINT getMaxNumEpochs() const
virtual bool setInstanceLoggingEnabled(const bool loggingEnabled)
sets if logging is enabled for this specific instance
Float getTotalSquaredTrainingError() const
bool setValidationSetSize(const UINT validationSetSize)
bool copyMLBaseVariables(const MLBase *mlBase)
virtual bool print() const
virtual bool getInstanceLoggingEnabled() const
returns true if logging is enabled for this specific instance
Float getValidationSetAccuracy() const
bool getIsBaseTypeClusterer() const
virtual std::string getModelAsString() const
bool setMinNumEpochs(const UINT minNumEpochs)
bool loadBaseSettingsFromFile(std::fstream &file)
UINT getBatchSize() const
MLBase * getMLBasePointer()
bool setBatchSize(const UINT batchSize)
bool removeAllTrainingObservers()
UINT getNumInputFeatures() const
virtual bool train_(ClassificationData &trainingData)
bool notifyTestResultsObservers(const TestInstanceResult &data)
bool copyGRTBaseVariables(const GRTBase *GRTBase)
bool getIsBaseTypeClassifier() const
VectorFloat getValidationSetPrecision() const
DataType getInputType() const
bool removeTrainingResultsObserver(const Observer< TrainingResult > &observer)
UINT getNumInputDimensions() const
bool removeTestResultsObserver(const Observer< TestInstanceResult > &observer)
virtual bool map_(VectorFloat &inputVector)
UINT getNumTrainingIterationsToConverge() const
bool setUseValidationSet(const bool useValidationSet)
bool setTestingLoggingEnabled(const bool loggingEnabled)
virtual bool load(const std::string &filename)
bool setMaxNumEpochs(const UINT maxNumEpochs)
Vector< TrainingResult > getTrainingResults() const
virtual bool map(VectorFloat inputVector)
VectorFloat getValidationSetRecall() const
This is the main base class that all GRT machine learning algorithms should inherit from...
bool getIsBaseTypeRegressifier() const