26 #ifndef GRT_SVM_HEADER 27 #define GRT_SVM_HEADER 29 #include "../../CoreModules/Classifier.h" 34 #define SVM_MIN_SCALE_RANGE -1.0 35 #define SVM_MAX_SCALE_RANGE 1.0 49 enum SVMType{ C_SVC = 0, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR };
50 enum KernelType{ LINEAR_KERNEL = 0, POLY_KERNEL, RBF_KERNEL, SIGMOID_KERNEL, PRECOMPUTED_KERNEL };
70 SVM(KernelType kernelType = LINEAR_KERNEL,SVMType svmType = C_SVC,
bool useScaling =
true,
bool useNullRejection =
false,
bool useAutoGamma =
true,Float gamma = 0.1,UINT degree = 3,Float coef0 = 0,Float nu = 0.5,Float C = 1,
bool useCrossValidation =
false,UINT kFoldValue = 10);
90 SVM &operator=(
const SVM &rhs);
122 virtual bool clear();
131 virtual bool save( std::fstream &file )
const;
140 virtual bool load( std::fstream &file );
158 bool init(KernelType kernelType,SVMType svmType,
bool useScaling,
bool useNullRejection,
bool useAutoGamma,Float gamma,UINT degree,Float coef0,Float nu,Float C,
bool useCrossValidation,UINT kFoldValue);
163 void initDefaultSVMSettings();
170 bool getIsCrossValidationTrainingEnabled()
const;
178 bool getIsAutoGammaEnabled()
const;
187 std::string getSVMType()
const;
196 std::string getKernelType()
const;
203 UINT getDegree()
const;
219 Float getGamma()
const;
233 Float getCoef0()
const;
247 Float getCrossValidationResult()
const;
263 bool setSVMType(
const SVMType svmType);
272 bool setKernelType(
const KernelType kernelType);
280 bool setGamma(
const Float gamma);
289 bool setDegree(
const UINT degree);
298 bool setNu(
const Float nu);
307 bool setCoef0(
const Float coef0);
316 bool setC(
const Float C);
324 bool setKFoldCrossValidationValue(
const UINT kFoldValue);
332 bool enableAutoGamma(
const bool useAutoGamma);
340 bool enableCrossValidationTraining(
const bool useCrossValidation);
347 static std::string
getId();
356 void deleteProblemSet();
357 bool validateProblemAndParameters();
358 bool validateSVMType(SVMType svmType);
359 bool validateKernelType(KernelType kernelType);
365 bool loadLegacyModelFromFile( std::fstream &file );
376 Float classificationThreshold;
377 Float crossValidationResult;
379 bool useCrossValidation;
383 static const std::string id;
388 #endif //GRT_SVM_HEADER std::string getId() const
virtual bool predict_(VectorFloat &inputVector)
virtual UINT getNumClasses() const
virtual bool save(const std::string &filename) const
virtual bool deepCopyFrom(const Classifier *classifier)
virtual bool train_(ClassificationData &trainingData)
virtual bool load(const std::string &filename)
This is the main base class that all GRT Classification algorithms should inherit from...