28 baseType = BASE_TYPE_NOT_SET;
29 inputType = DATA_TYPE_UNKNOWN;
30 outputType = DATA_TYPE_UNKNOWN;
31 numInputDimensions = 0;
32 numOutputDimensions = 0;
35 validationSetSize = 20;
36 validationSetAccuracy = 0;
39 useValidationSet =
false;
40 randomiseTrainingOrder =
true;
41 rootMeanSquaredTrainingError = 0;
42 totalSquaredTrainingError = 0;
52 errorLog <<
"copyMLBaseVariables(MLBase *mlBase) - mlBase pointer is NULL!" << std::endl;
57 errorLog <<
"copyMLBaseVariables(MLBase *mlBase) - Failed to copy GRT Base variables!" << std::endl;
61 this->trained = mlBase->trained;
62 this->useScaling = mlBase->useScaling;
63 this->baseType = mlBase->baseType;
64 this->inputType = mlBase->inputType;
65 this->outputType = mlBase->outputType;
66 this->numInputDimensions = mlBase->numInputDimensions;
67 this->numOutputDimensions = mlBase->numOutputDimensions;
68 this->minNumEpochs = mlBase->minNumEpochs;
69 this->maxNumEpochs = mlBase->maxNumEpochs;
70 this->validationSetSize = mlBase->validationSetSize;
71 this->validationSetAccuracy = mlBase->validationSetAccuracy;
72 this->validationSetPrecision = mlBase->validationSetPrecision;
73 this->validationSetRecall = mlBase->validationSetRecall;
74 this->minChange = mlBase->minChange;
75 this->learningRate = mlBase->learningRate;
76 this->rootMeanSquaredTrainingError = mlBase->rootMeanSquaredTrainingError;
77 this->totalSquaredTrainingError = mlBase->totalSquaredTrainingError;
78 this->useValidationSet = mlBase->useValidationSet;
79 this->randomiseTrainingOrder = mlBase->randomiseTrainingOrder;
80 this->numTrainingIterationsToConverge = mlBase->numTrainingIterationsToConverge;
81 this->trainingResults = mlBase->trainingResults;
82 this->trainingResultsObserverManager = mlBase->trainingResultsObserverManager;
83 this->testResultsObserverManager = mlBase->testResultsObserverManager;
128 numInputDimensions = 0;
129 numOutputDimensions = 0;
130 numTrainingIterationsToConverge = 0;
131 rootMeanSquaredTrainingError = 0;
132 totalSquaredTrainingError = 0;
133 trainingResults.clear();
134 validationSetPrecision.clear();
135 validationSetRecall.clear();
136 validationSetAccuracy = 0;
148 if( !trained )
return false;
151 file.open(filename.c_str(), std::ios::out);
171 file.open(filename.c_str(), std::ios::in);
188 std::stringstream stream;
213 return numTrainingIterationsToConverge;
227 return validationSetSize;
235 return rootMeanSquaredTrainingError;
239 return totalSquaredTrainingError;
243 return validationSetAccuracy;
247 return validationSetPrecision;
251 return validationSetRecall;
269 if( maxNumEpochs == 0 ){
270 warningLog <<
"setMaxNumEpochs(const UINT maxNumEpochs) - The maxNumEpochs must be greater than 0!" << std::endl;
273 this->maxNumEpochs = maxNumEpochs;
278 this->minNumEpochs = minNumEpochs;
284 warningLog <<
"setMinChange(const Float minChange) - The minChange must be greater than or equal to 0!" << std::endl;
287 this->minChange = minChange;
292 if( learningRate > 0 ){
293 this->learningRate = learningRate;
301 if( validationSetSize > 0 && validationSetSize < 100 ){
302 this->validationSetSize = validationSetSize;
306 warningLog <<
"setValidationSetSize(const UINT validationSetSize) - The validation size must be in the range [1 99]!" << std::endl;
312 this->useValidationSet = useValidationSet;
317 this->randomiseTrainingOrder = randomiseTrainingOrder;
322 this->trainingLog.setEnableInstanceLogging( loggingEnabled );
327 return trainingResultsObserverManager.registerObserver( observer );
331 return testResultsObserverManager.registerObserver( observer );
335 return trainingResultsObserverManager.removeObserver( observer );
339 return testResultsObserverManager.removeObserver( observer );
343 return trainingResultsObserverManager.removeAllObservers();
347 return testResultsObserverManager.removeAllObservers();
351 return trainingResultsObserverManager.notifyObservers( data );
355 return testResultsObserverManager.notifyObservers( data );
367 return trainingResults;
372 if( !file.is_open() ){
373 errorLog <<
"saveBaseSettingsToFile(fstream &file) - The file is not open!" << std::endl;
377 file <<
"Trained: " << trained << std::endl;
378 file <<
"UseScaling: " << useScaling << std::endl;
379 file <<
"NumInputDimensions: " << numInputDimensions << std::endl;
380 file <<
"NumOutputDimensions: " << numOutputDimensions << std::endl;
381 file <<
"NumTrainingIterationsToConverge: " << numTrainingIterationsToConverge << std::endl;
382 file <<
"MinNumEpochs: " << minNumEpochs << std::endl;
383 file <<
"MaxNumEpochs: " << maxNumEpochs << std::endl;
384 file <<
"ValidationSetSize: " << validationSetSize << std::endl;
385 file <<
"LearningRate: " << learningRate << std::endl;
386 file <<
"MinChange: " << minChange << std::endl;
387 file <<
"UseValidationSet: " << useValidationSet << std::endl;
388 file <<
"RandomiseTrainingOrder: " << randomiseTrainingOrder << std::endl;
398 if( !file.is_open() ){
399 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - The file is not open!" << std::endl;
407 if( word !=
"Trained:" ){
408 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
415 if( word !=
"UseScaling:" ){
416 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read UseScaling header!" << std::endl;
423 if( word !=
"NumInputDimensions:" ){
424 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
427 file >> numInputDimensions;
431 if( word !=
"NumOutputDimensions:" ){
432 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumOutputDimensions header!" << std::endl;
435 file >> numOutputDimensions;
439 if( word !=
"NumTrainingIterationsToConverge:" ){
440 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumTrainingIterationsToConverge header!" << std::endl;
443 file >> numTrainingIterationsToConverge;
447 if( word !=
"MinNumEpochs:" ){
448 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MinNumEpochs header!" << std::endl;
451 file >> minNumEpochs;
455 if( word !=
"MaxNumEpochs:" ){
456 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MaxNumEpochs header!" << std::endl;
459 file >> maxNumEpochs;
463 if( word !=
"ValidationSetSize:" ){
464 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read ValidationSetSize header!" << std::endl;
467 file >> validationSetSize;
471 if( word !=
"LearningRate:" ){
472 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read LearningRate header!" << std::endl;
475 file >> learningRate;
479 if( word !=
"MinChange:" ){
480 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MinChange header!" << std::endl;
487 if( word !=
"UseValidationSet:" ){
488 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read UseValidationSet header!" << std::endl;
491 file >> useValidationSet;
495 if( word !=
"RandomiseTrainingOrder:" ){
496 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read RandomiseTrainingOrder header!" << std::endl;
499 file >> randomiseTrainingOrder;
bool saveBaseSettingsToFile(std::fstream &file) const
bool setLearningRate(const Float learningRate)
virtual bool predict(VectorFloat inputVector)
bool setRandomiseTrainingOrder(const bool randomiseTrainingOrder)
bool notifyTrainingResultsObservers(const TrainingResult &data)
bool registerTrainingResultsObserver(Observer< TrainingResult > &observer)
virtual bool predict_(VectorFloat &inputVector)
bool setTrainingLoggingEnabled(const bool loggingEnabled)
Float getLearningRate() const
bool removeAllTestObservers()
bool enableScaling(const bool useScaling)
DataType getOutputType() const
virtual bool getModel(std::ostream &stream) const
virtual bool train(ClassificationData trainingData)
Float getRootMeanSquaredTrainingError() const
UINT getMinNumEpochs() const
UINT getNumOutputDimensions() const
bool getScalingEnabled() const
bool registerTestResultsObserver(Observer< TestInstanceResult > &observer)
bool setMinChange(const Float minChange)
UINT getValidationSetSize() const
virtual bool save(const std::string filename) const
virtual bool load(const std::string filename)
UINT getMaxNumEpochs() const
This is the main base class that all GRT machine learning algorithms should inherit from...
bool getModelTrained() const
Float getTotalSquaredTrainingError() const
bool setValidationSetSize(const UINT validationSetSize)
virtual bool saveModelToFile(std::string filename) const
bool copyMLBaseVariables(const MLBase *mlBase)
virtual bool print() const
Float getValidationSetAccuracy() const
bool getIsBaseTypeClusterer() const
virtual std::string getModelAsString() const
virtual bool loadModelFromFile(std::string filename)
bool setMinNumEpochs(const UINT minNumEpochs)
bool loadBaseSettingsFromFile(std::fstream &file)
MLBase * getMLBasePointer()
bool removeAllTrainingObservers()
UINT getNumInputFeatures() const
virtual bool train_(ClassificationData &trainingData)
bool notifyTestResultsObservers(const TestInstanceResult &data)
bool copyGRTBaseVariables(const GRTBase *GRTBase)
bool getIsBaseTypeClassifier() const
VectorFloat getValidationSetPrecision() const
DataType getInputType() const
bool removeTrainingResultsObserver(const Observer< TrainingResult > &observer)
UINT getNumInputDimensions() const
bool removeTestResultsObserver(const Observer< TestInstanceResult > &observer)
virtual bool map_(VectorFloat &inputVector)
UINT getNumTrainingIterationsToConverge() const
bool setUseValidationSet(const bool useValidationSet)
bool setMaxNumEpochs(const UINT maxNumEpochs)
Vector< TrainingResult > getTrainingResults() const
virtual bool map(VectorFloat inputVector)
VectorFloat getValidationSetRecall() const
bool getIsBaseTypeRegressifier() const