31 #ifndef GRT_MLBASE_HEADER
32 #define GRT_MLBASE_HEADER
35 #include "../DataStructures/UnlabelledData.h"
36 #include "../DataStructures/ClassificationData.h"
37 #include "../DataStructures/ClassificationDataStream.h"
38 #include "../DataStructures/RegressionData.h"
39 #include "../DataStructures/TimeSeriesClassificationData.h"
43 #define DEFAULT_NULL_LIKELIHOOD_VALUE 0
44 #define DEFAULT_NULL_DISTANCE_VALUE 0
73 enum BaseTypes{BASE_TYPE_NOT_SET=0,CLASSIFIER,REGRESSIFIER,CLUSTERER};
91 bool copyMLBaseVariables(
const MLBase *mlBase);
253 virtual bool reset();
261 virtual bool clear();
269 virtual bool print()
const;
277 virtual bool save(
const std::string filename)
const;
285 virtual bool load(
const std::string filename);
294 virtual bool save(std::fstream &file)
const;
303 virtual bool load(std::fstream &file);
310 GRT_DEPRECATED_MSG(
"saveModelToFile(std::string filename) is deprecated, use save(std::string filename) instead",
virtual bool saveModelToFile(std::string filename)
const );
317 GRT_DEPRECATED_MSG(
"saveModelToFile(std::fstream &file) is deprecated, use save(std::fstream &file) instead",
virtual bool saveModelToFile(std::fstream &file)
const );
324 GRT_DEPRECATED_MSG(
"loadModelFromFile(std::string filename) is deprecated, use load(std::string filename) instead",
virtual bool loadModelFromFile(std::string filename) );
331 GRT_DEPRECATED_MSG(
"loadModelFromFile(std::fstream &file) is deprecated, use load(std::fstream &file) instead",
virtual bool loadModelFromFile(std::fstream &file) );
340 virtual bool getModel(std::ostream &stream)
const;
353 Float
inline scale(
const Float &x,
const Float &minSource,
const Float &maxSource,
const Float &minTarget,
const Float &maxTarget,
const bool constrain=
false){
355 if( x <= minSource )
return minTarget;
356 if( x >= maxSource )
return maxTarget;
358 if( minSource == maxSource )
return minTarget;
359 return (((x-minSource)*(maxTarget-minTarget))/(maxSource-minSource))+minTarget;
367 virtual std::string getModelAsString()
const;
374 DataType getInputType()
const;
381 DataType getOutputType()
const;
388 UINT getBaseType()
const;
396 UINT getNumInputFeatures()
const;
403 UINT getNumInputDimensions()
const;
410 UINT getNumOutputDimensions()
const;
418 UINT getMinNumEpochs()
const;
426 UINT getMaxNumEpochs()
const;
435 UINT getValidationSetSize()
const;
442 UINT getNumTrainingIterationsToConverge()
const;
449 Float getMinChange()
const;
456 Float getLearningRate()
const;
463 Float getRootMeanSquaredTrainingError()
const;
470 Float getTotalSquaredTrainingError()
const;
477 Float getValidationSetAccuracy()
const;
502 bool getUseValidationSet()
const;
510 bool getRandomiseTrainingOrder()
const;
517 bool getTrained()
const;
524 bool getModelTrained()
const;
531 bool getScalingEnabled()
const;
538 bool getIsBaseTypeClassifier()
const;
545 bool getIsBaseTypeRegressifier()
const;
552 bool getIsBaseTypeClusterer()
const;
559 bool enableScaling(
const bool useScaling);
568 bool setMaxNumEpochs(
const UINT maxNumEpochs);
576 bool setMinNumEpochs(
const UINT minNumEpochs);
585 bool setMinChange(
const Float minChange);
594 bool setLearningRate(
const Float learningRate);
604 bool setUseValidationSet(
const bool useValidationSet);
614 bool setValidationSetSize(
const UINT validationSetSize);
623 bool setRandomiseTrainingOrder(
const bool randomiseTrainingOrder);
632 bool setTrainingLoggingEnabled(
const bool loggingEnabled);
671 bool removeAllTrainingObservers();
678 bool removeAllTestObservers();
686 bool notifyTrainingResultsObservers(
const TrainingResult &data );
694 bool notifyTestResultsObservers(
const TestInstanceResult &data );
701 MLBase* getMLBasePointer();
708 const MLBase* getMLBasePointer()
const;
724 bool saveBaseSettingsToFile( std::fstream &file )
const;
731 bool loadBaseSettingsFromFile( std::fstream &file );
738 UINT numInputDimensions;
739 UINT numOutputDimensions;
740 UINT numTrainingIterationsToConverge;
743 UINT validationSetSize;
746 Float rootMeanSquaredTrainingError;
747 Float totalSquaredTrainingError;
748 Float validationSetAccuracy;
749 bool useValidationSet;
750 bool randomiseTrainingOrder;
754 std::vector< TrainingResult > trainingResults;
762 #endif //GRT_MLBASE_HEADER
Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)
This file contains the GRTBase class. This is the core base class for all the GRT modules...