GestureRecognitionToolkit  Version: 0.1.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
SVM.h
Go to the documentation of this file.
1 
36 #ifndef GRT_SVM_HEADER
37 #define GRT_SVM_HEADER
38 
39 #include "../../CoreModules/Classifier.h"
40 #include "LIBSVM/libsvm.h"
41 
42 GRT_BEGIN_NAMESPACE
43 
44 using namespace LIBSVM;
45 
46 #define SVM_MIN_SCALE_RANGE -1.0
47 #define SVM_MAX_SCALE_RANGE 1.0
48 
49 class SVM : public Classifier{
50 public:
69  SVM(UINT kernelType = LINEAR_KERNEL,UINT svmType = C_SVC,bool useScaling = true,bool useNullRejection = false,bool useAutoGamma = true,Float gamma = 0.1,UINT degree = 3,Float coef0 = 0,Float nu = 0.5,Float C = 1,bool useCrossValidation = false,UINT kFoldValue = 10);
70 
76  SVM(const SVM &rhs);
77 
81  virtual ~SVM();
82 
89  SVM &operator=(const SVM &rhs);
90 
98  virtual bool deepCopyFrom(const Classifier *classifier);
99 
107  virtual bool train_(ClassificationData &trainingData);
108 
116  virtual bool predict_(VectorFloat &inputVector);
117 
121  virtual bool clear();
122 
130  virtual bool saveModelToFile( std::fstream &file ) const;
131 
139  virtual bool loadModelFromFile( std::fstream &file );
140 
157  bool init(UINT kernelType,UINT svmType,bool useScaling,bool useNullRejection,bool useAutoGamma,Float gamma,UINT degree,Float coef0,Float nu,Float C,bool useCrossValidation,UINT kFoldValue);
158 
162  void initDefaultSVMSettings();
163 
169  bool getIsCrossValidationTrainingEnabled() const;
170 
177  bool getIsAutoGammaEnabled() const;
178 
186  std::string getSVMType() const;
187 
195  std::string getKernelType() const;
196 
202  UINT getDegree() const;
203 
211  virtual UINT getNumClasses() const;
212 
218  Float getGamma() const;
219 
225  Float getNu() const;
226 
232  Float getCoef0() const;
233 
239  Float getC() const;
240 
246  Float getCrossValidationResult() const;
247 
248  struct svm_model *getModel() const { return model; }
249 
257  bool setSVMType(const UINT svmType);
258 
266  bool setKernelType(const UINT kernelType);
267 
274  bool setGamma(const Float gamma);
275 
283  bool setDegree(const UINT degree);
284 
292  bool setNu(const Float nu);
293 
301  bool setCoef0(const Float coef0);
302 
310  bool setC(const Float C);
311 
318  bool setKFoldCrossValidationValue(const UINT kFoldValue);
319 
326  bool enableAutoGamma(const bool useAutoGamma);
327 
334  bool enableCrossValidationTraining(const bool useCrossValidation);
335 
336  //Tell the compiler we are using the following functions from the MLBase class to stop hidden virtual function warnings
339  using MLBase::train;
340  using MLBase::train_;
341  using MLBase::predict;
342  using MLBase::predict_;
343 
344 protected:
345  void deleteProblemSet();
346  bool validateProblemAndParameters();
347  bool validateSVMType(UINT svmType);
348  bool validateKernelType(UINT kernelType);
349  bool convertClassificationDataToLIBSVMFormat(ClassificationData &trainingData);
350  bool trainSVM();
351 
352  bool predictSVM(VectorFloat &inputVector);
353  bool predictSVM(VectorFloat &inputVector,Float &maxProbability, VectorFloat &probabilites);
354  bool loadLegacyModelFromFile( std::fstream &file );
355 
356  struct svm_model *deepCopyModel() const;
357  bool deepCopyProblem( const struct svm_problem &source_problem, struct svm_problem &target_problem, const unsigned int numInputDimensions ) const;
358  bool deepCopyParam( const svm_parameter &source_param, svm_parameter &target_param ) const;
359 
360  bool problemSet;
361  struct svm_model *model;
362  struct svm_parameter param;
363  struct svm_problem prob;
364  UINT kFoldValue;
365  Float classificationThreshold;
366  Float crossValidationResult;
367  bool useAutoGamma;
368  bool useCrossValidation;
369 
370  static RegisterClassifierModule< SVM > registerModule;
371 
372 public:
373  enum SVMTypes{ C_SVC = 0, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR };
374  enum SVMKernelTypes{ LINEAR_KERNEL = 0, POLY_KERNEL, RBF_KERNEL, SIGMOID_KERNEL, PRECOMPUTED_KERNEL };
375 
376 };
377 
378 GRT_END_NAMESPACE
379 
380 #endif //GRT_SVM_HEADER
381 
virtual bool predict(VectorFloat inputVector)
Definition: MLBase.cpp:112
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:114
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:88
Definition: SVM.h:49
virtual bool saveModelToFile(std::string filename) const
Definition: MLBase.cpp:146
Definition: libsvm.cpp:4
virtual bool loadModelFromFile(std::string filename)
Definition: MLBase.cpp:168
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:90