27 #ifndef GRT_MLBASE_HEADER 28 #define GRT_MLBASE_HEADER 31 #include "../Util/Metrics.h" 32 #include "../DataStructures/UnlabelledData.h" 33 #include "../DataStructures/ClassificationData.h" 34 #include "../DataStructures/ClassificationDataStream.h" 35 #include "../DataStructures/RegressionData.h" 36 #include "../DataStructures/TimeSeriesClassificationData.h" 40 #define DEFAULT_NULL_LIKELIHOOD_VALUE 0 41 #define DEFAULT_NULL_DISTANCE_VALUE 0 75 enum BaseType{BASE_TYPE_NOT_SET=0,CLASSIFIER,REGRESSIFIER,CLUSTERER,PRE_PROCSSING,POST_PROCESSING,FEATURE_EXTRACTION,CONTEXT};
82 MLBase(
const std::string &
id =
"",
const BaseType type = BASE_TYPE_NOT_SET );
95 bool copyMLBaseVariables(
const MLBase *mlBase);
276 virtual bool reset();
284 virtual bool clear();
292 virtual bool print()
const;
300 virtual bool save(
const std::string &filename)
const;
308 virtual bool load(
const std::string &filename);
317 virtual bool save(std::fstream &file)
const;
326 virtual bool load(std::fstream &file);
333 GRT_DEPRECATED_MSG(
"saveModelToFile(std::string filename) is deprecated, use save(const std::string &filename) instead",
virtual bool saveModelToFile(
const std::string &filename)
const );
340 GRT_DEPRECATED_MSG(
"saveModelToFile(std::fstream &file) is deprecated, use save(std::fstream &file) instead",
virtual bool saveModelToFile(std::fstream &file)
const );
347 GRT_DEPRECATED_MSG(
"loadModelFromFile(std::string filename) is deprecated, use load(const std::string &filename) instead",
virtual bool loadModelFromFile(
const std::string &filename) );
354 GRT_DEPRECATED_MSG(
"loadModelFromFile(std::fstream &file) is deprecated, use load(std::fstream &file) instead",
virtual bool loadModelFromFile(std::fstream &file) );
363 virtual bool getModel(std::ostream &stream)
const;
370 virtual std::string getModelAsString()
const;
377 DataType getInputType()
const;
384 DataType getOutputType()
const;
391 BaseType getType()
const;
399 UINT getNumInputFeatures()
const;
406 UINT getNumInputDimensions()
const;
413 UINT getNumOutputDimensions()
const;
421 UINT getMinNumEpochs()
const;
429 UINT getMaxNumEpochs()
const;
436 UINT getBatchSize()
const;
442 UINT getNumRestarts()
const;
451 UINT getValidationSetSize()
const;
458 UINT getNumTrainingIterationsToConverge()
const;
465 Float getMinChange()
const;
472 Float getLearningRate()
const;
479 Float getRMSTrainingError()
const;
487 GRT_DEPRECATED_MSG(
"getRootMeanSquaredTrainingError() is deprecated, use getRMSTrainingError() instead", Float getRootMeanSquaredTrainingError()
const );
494 Float getTotalSquaredTrainingError()
const;
501 Float getRMSValidationError()
const;
508 Float getValidationSetAccuracy()
const;
533 bool getUseValidationSet()
const;
541 bool getRandomiseTrainingOrder()
const;
548 bool getTrained()
const;
556 GRT_DEPRECATED_MSG(
"getModelTrained() is deprecated, use getTrained() instead",
bool getModelTrained()
const );
564 bool getConverged()
const;
571 bool getScalingEnabled()
const;
578 bool getIsBaseTypeClassifier()
const;
585 bool getIsBaseTypeRegressifier()
const;
592 bool getIsBaseTypeClusterer()
const;
599 bool getTrainingLoggingEnabled()
const;
606 bool getTestingLoggingEnabled()
const;
613 bool enableScaling(
const bool useScaling);
622 bool setMaxNumEpochs(
const UINT maxNumEpochs);
630 bool setBatchSize(
const UINT batchSize);
638 bool setMinNumEpochs(
const UINT minNumEpochs);
647 bool setNumRestarts(
const UINT numRestarts);
656 bool setMinChange(
const Float minChange);
665 bool setLearningRate(
const Float learningRate);
675 bool setUseValidationSet(
const bool useValidationSet);
685 bool setValidationSetSize(
const UINT validationSetSize);
694 bool setRandomiseTrainingOrder(
const bool randomiseTrainingOrder);
703 bool setTrainingLoggingEnabled(
const bool loggingEnabled);
712 bool setTestingLoggingEnabled(
const bool loggingEnabled);
751 bool removeAllTrainingObservers();
758 bool removeAllTestObservers();
766 bool notifyTrainingResultsObservers(
const TrainingResult &data );
781 MLBase* getMLBasePointer();
788 const MLBase* getMLBasePointer()
const;
804 bool saveBaseSettingsToFile( std::fstream &file )
const;
811 bool loadBaseSettingsFromFile( std::fstream &file );
819 UINT numInputDimensions;
820 UINT numOutputDimensions;
821 UINT numTrainingIterationsToConverge;
825 UINT validationSetSize;
829 Float rmsTrainingError;
830 Float rmsValidationError;
831 Float totalSquaredTrainingError;
832 Float validationSetAccuracy;
833 bool useValidationSet;
834 bool randomiseTrainingOrder;
848 #endif //GRT_MLBASE_HEADER
This file contains the Random class, a useful wrapper for generating cross platform random functions...
GRT_DEPRECATED_MSG("getClassType is deprecated, use getId() instead!", std::string getClassType() const )
This file contains the GRTBase class. This is the core base class for all the GRT modules...
This is the main base class that all GRT machine learning algorithms should inherit from...