21 #define GRT_DLL_EXPORTS
29 baseType = BASE_TYPE_NOT_SET;
30 inputType = DATA_TYPE_UNKNOWN;
31 outputType = DATA_TYPE_UNKNOWN;
32 numInputDimensions = 0;
33 numOutputDimensions = 0;
36 validationSetSize = 20;
37 validationSetAccuracy = 0;
40 useValidationSet =
false;
41 randomiseTrainingOrder =
true;
42 rootMeanSquaredTrainingError = 0;
43 totalSquaredTrainingError = 0;
53 errorLog <<
"copyMLBaseVariables(MLBase *mlBase) - mlBase pointer is NULL!" << std::endl;
58 errorLog <<
"copyMLBaseVariables(MLBase *mlBase) - Failed to copy GRT Base variables!" << std::endl;
62 this->trained = mlBase->trained;
63 this->useScaling = mlBase->useScaling;
64 this->baseType = mlBase->baseType;
65 this->inputType = mlBase->inputType;
66 this->outputType = mlBase->outputType;
67 this->numInputDimensions = mlBase->numInputDimensions;
68 this->numOutputDimensions = mlBase->numOutputDimensions;
69 this->minNumEpochs = mlBase->minNumEpochs;
70 this->maxNumEpochs = mlBase->maxNumEpochs;
71 this->validationSetSize = mlBase->validationSetSize;
72 this->validationSetAccuracy = mlBase->validationSetAccuracy;
73 this->validationSetPrecision = mlBase->validationSetPrecision;
74 this->validationSetRecall = mlBase->validationSetRecall;
75 this->minChange = mlBase->minChange;
76 this->learningRate = mlBase->learningRate;
77 this->rootMeanSquaredTrainingError = mlBase->rootMeanSquaredTrainingError;
78 this->totalSquaredTrainingError = mlBase->totalSquaredTrainingError;
79 this->useValidationSet = mlBase->useValidationSet;
80 this->randomiseTrainingOrder = mlBase->randomiseTrainingOrder;
81 this->numTrainingIterationsToConverge = mlBase->numTrainingIterationsToConverge;
82 this->trainingResults = mlBase->trainingResults;
83 this->trainingResultsObserverManager = mlBase->trainingResultsObserverManager;
84 this->testResultsObserverManager = mlBase->testResultsObserverManager;
129 numInputDimensions = 0;
130 numOutputDimensions = 0;
131 numTrainingIterationsToConverge = 0;
132 rootMeanSquaredTrainingError = 0;
133 totalSquaredTrainingError = 0;
134 trainingResults.clear();
135 validationSetPrecision.clear();
136 validationSetRecall.clear();
137 validationSetAccuracy = 0;
145 if( !trained )
return false;
148 file.open(filename.c_str(), std::ios::out);
163 bool MLBase::saveModelToFile(std::string filename)
const {
return save( filename ); }
165 bool MLBase::saveModelToFile(std::fstream &file)
const {
return save( file ); }
170 file.open(filename.c_str(), std::ios::in);
186 bool MLBase::loadModelFromFile(std::string filename){
return load( filename ); }
188 bool MLBase::loadModelFromFile(std::fstream &file){
return load( file ); }
193 std::stringstream stream;
218 return numTrainingIterationsToConverge;
232 return validationSetSize;
240 return rootMeanSquaredTrainingError;
244 return totalSquaredTrainingError;
248 return validationSetAccuracy;
252 return validationSetPrecision;
256 return validationSetRecall;
274 if( maxNumEpochs == 0 ){
275 warningLog <<
"setMaxNumEpochs(const UINT maxNumEpochs) - The maxNumEpochs must be greater than 0!" << std::endl;
278 this->maxNumEpochs = maxNumEpochs;
283 this->minNumEpochs = minNumEpochs;
289 warningLog <<
"setMinChange(const Float minChange) - The minChange must be greater than or equal to 0!" << std::endl;
292 this->minChange = minChange;
297 if( learningRate > 0 ){
298 this->learningRate = learningRate;
306 if( validationSetSize > 0 && validationSetSize < 100 ){
307 this->validationSetSize = validationSetSize;
311 warningLog <<
"setValidationSetSize(const UINT validationSetSize) - The validation size must be in the range [1 99]!" << std::endl;
317 this->useValidationSet = useValidationSet;
322 this->randomiseTrainingOrder = randomiseTrainingOrder;
327 this->trainingLog.setEnableInstanceLogging( loggingEnabled );
332 return trainingResultsObserverManager.registerObserver( observer );
336 return testResultsObserverManager.registerObserver( observer );
340 return trainingResultsObserverManager.removeObserver( observer );
344 return testResultsObserverManager.removeObserver( observer );
348 return trainingResultsObserverManager.removeAllObservers();
352 return testResultsObserverManager.removeAllObservers();
356 return trainingResultsObserverManager.notifyObservers( data );
360 return testResultsObserverManager.notifyObservers( data );
372 return trainingResults;
377 if( !file.is_open() ){
378 errorLog <<
"saveBaseSettingsToFile(fstream &file) - The file is not open!" << std::endl;
382 file <<
"Trained: " << trained << std::endl;
383 file <<
"UseScaling: " << useScaling << std::endl;
384 file <<
"NumInputDimensions: " << numInputDimensions << std::endl;
385 file <<
"NumOutputDimensions: " << numOutputDimensions << std::endl;
386 file <<
"NumTrainingIterationsToConverge: " << numTrainingIterationsToConverge << std::endl;
387 file <<
"MinNumEpochs: " << minNumEpochs << std::endl;
388 file <<
"MaxNumEpochs: " << maxNumEpochs << std::endl;
389 file <<
"ValidationSetSize: " << validationSetSize << std::endl;
390 file <<
"LearningRate: " << learningRate << std::endl;
391 file <<
"MinChange: " << minChange << std::endl;
392 file <<
"UseValidationSet: " << useValidationSet << std::endl;
393 file <<
"RandomiseTrainingOrder: " << randomiseTrainingOrder << std::endl;
403 if( !file.is_open() ){
404 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - The file is not open!" << std::endl;
412 if( word !=
"Trained:" ){
413 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
420 if( word !=
"UseScaling:" ){
421 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read UseScaling header!" << std::endl;
428 if( word !=
"NumInputDimensions:" ){
429 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
432 file >> numInputDimensions;
436 if( word !=
"NumOutputDimensions:" ){
437 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumOutputDimensions header!" << std::endl;
440 file >> numOutputDimensions;
444 if( word !=
"NumTrainingIterationsToConverge:" ){
445 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumTrainingIterationsToConverge header!" << std::endl;
448 file >> numTrainingIterationsToConverge;
452 if( word !=
"MinNumEpochs:" ){
453 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MinNumEpochs header!" << std::endl;
456 file >> minNumEpochs;
460 if( word !=
"MaxNumEpochs:" ){
461 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MaxNumEpochs header!" << std::endl;
464 file >> maxNumEpochs;
468 if( word !=
"ValidationSetSize:" ){
469 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read ValidationSetSize header!" << std::endl;
472 file >> validationSetSize;
476 if( word !=
"LearningRate:" ){
477 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read LearningRate header!" << std::endl;
480 file >> learningRate;
484 if( word !=
"MinChange:" ){
485 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MinChange header!" << std::endl;
492 if( word !=
"UseValidationSet:" ){
493 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read UseValidationSet header!" << std::endl;
496 file >> useValidationSet;
500 if( word !=
"RandomiseTrainingOrder:" ){
501 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read RandomiseTrainingOrder header!" << std::endl;
504 file >> randomiseTrainingOrder;
bool saveBaseSettingsToFile(std::fstream &file) const
bool setLearningRate(const Float learningRate)
virtual bool predict(VectorFloat inputVector)
bool setRandomiseTrainingOrder(const bool randomiseTrainingOrder)
bool notifyTrainingResultsObservers(const TrainingResult &data)
bool registerTrainingResultsObserver(Observer< TrainingResult > &observer)
virtual bool predict_(VectorFloat &inputVector)
bool setTrainingLoggingEnabled(const bool loggingEnabled)
Float getLearningRate() const
bool removeAllTestObservers()
bool enableScaling(const bool useScaling)
DataType getOutputType() const
virtual bool getModel(std::ostream &stream) const
virtual bool train(ClassificationData trainingData)
Float getRootMeanSquaredTrainingError() const
UINT getMinNumEpochs() const
UINT getNumOutputDimensions() const
bool getScalingEnabled() const
bool registerTestResultsObserver(Observer< TestInstanceResult > &observer)
bool setMinChange(const Float minChange)
UINT getValidationSetSize() const
virtual bool save(const std::string filename) const
virtual bool load(const std::string filename)
UINT getMaxNumEpochs() const
This is the main base class that all GRT machine learning algorithms should inherit from...
bool getModelTrained() const
Float getTotalSquaredTrainingError() const
bool setValidationSetSize(const UINT validationSetSize)
bool copyMLBaseVariables(const MLBase *mlBase)
virtual bool print() const
Float getValidationSetAccuracy() const
bool getIsBaseTypeClusterer() const
virtual std::string getModelAsString() const
bool setMinNumEpochs(const UINT minNumEpochs)
bool loadBaseSettingsFromFile(std::fstream &file)
MLBase * getMLBasePointer()
bool removeAllTrainingObservers()
UINT getNumInputFeatures() const
virtual bool train_(ClassificationData &trainingData)
bool notifyTestResultsObservers(const TestInstanceResult &data)
bool copyGRTBaseVariables(const GRTBase *GRTBase)
bool getIsBaseTypeClassifier() const
VectorFloat getValidationSetPrecision() const
DataType getInputType() const
bool removeTrainingResultsObserver(const Observer< TrainingResult > &observer)
UINT getNumInputDimensions() const
bool removeTestResultsObserver(const Observer< TestInstanceResult > &observer)
virtual bool map_(VectorFloat &inputVector)
UINT getNumTrainingIterationsToConverge() const
bool setUseValidationSet(const bool useValidationSet)
bool setMaxNumEpochs(const UINT maxNumEpochs)
Vector< TrainingResult > getTrainingResults() const
virtual bool map(VectorFloat inputVector)
VectorFloat getValidationSetRecall() const
bool getIsBaseTypeRegressifier() const