36 #ifndef GRT_SVM_HEADER
37 #define GRT_SVM_HEADER
39 #include "../../CoreModules/Classifier.h"
46 #define SVM_MIN_SCALE_RANGE -1.0
47 #define SVM_MAX_SCALE_RANGE 1.0
69 SVM(UINT kernelType = LINEAR_KERNEL,UINT svmType = C_SVC,
bool useScaling =
true,
bool useNullRejection =
false,
bool useAutoGamma =
true,Float gamma = 0.1,UINT degree = 3,Float coef0 = 0,Float nu = 0.5,Float C = 1,
bool useCrossValidation =
false,UINT kFoldValue = 10);
89 SVM &operator=(
const SVM &rhs);
98 virtual bool deepCopyFrom(
const Classifier *classifier);
121 virtual bool clear();
130 virtual bool saveModelToFile( std::fstream &file )
const;
139 virtual bool loadModelFromFile( std::fstream &file );
157 bool init(UINT kernelType,UINT svmType,
bool useScaling,
bool useNullRejection,
bool useAutoGamma,Float gamma,UINT degree,Float coef0,Float nu,Float C,
bool useCrossValidation,UINT kFoldValue);
162 void initDefaultSVMSettings();
169 bool getIsCrossValidationTrainingEnabled()
const;
177 bool getIsAutoGammaEnabled()
const;
186 std::string getSVMType()
const;
195 std::string getKernelType()
const;
202 UINT getDegree()
const;
211 virtual UINT getNumClasses()
const;
218 Float getGamma()
const;
232 Float getCoef0()
const;
246 Float getCrossValidationResult()
const;
248 struct svm_model *getModel()
const {
return model; }
257 bool setSVMType(
const UINT svmType);
266 bool setKernelType(
const UINT kernelType);
274 bool setGamma(
const Float gamma);
283 bool setDegree(
const UINT degree);
292 bool setNu(
const Float nu);
301 bool setCoef0(
const Float coef0);
310 bool setC(
const Float C);
318 bool setKFoldCrossValidationValue(
const UINT kFoldValue);
326 bool enableAutoGamma(
const bool useAutoGamma);
334 bool enableCrossValidationTraining(
const bool useCrossValidation);
345 void deleteProblemSet();
346 bool validateProblemAndParameters();
347 bool validateSVMType(UINT svmType);
348 bool validateKernelType(UINT kernelType);
354 bool loadLegacyModelFromFile( std::fstream &file );
357 bool deepCopyProblem(
const struct svm_problem &source_problem,
struct svm_problem &target_problem,
const unsigned int numInputDimensions )
const;
365 Float classificationThreshold;
366 Float crossValidationResult;
368 bool useCrossValidation;
373 enum SVMTypes{ C_SVC = 0, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR };
374 enum SVMKernelTypes{ LINEAR_KERNEL = 0, POLY_KERNEL, RBF_KERNEL, SIGMOID_KERNEL, PRECOMPUTED_KERNEL };
380 #endif //GRT_SVM_HEADER
virtual bool predict(VectorFloat inputVector)
virtual bool predict_(VectorFloat &inputVector)
virtual bool train(ClassificationData trainingData)
virtual bool saveModelToFile(std::string filename) const
virtual bool loadModelFromFile(std::string filename)
virtual bool train_(ClassificationData &trainingData)