GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
SVM.h
Go to the documentation of this file.
1 
36 #ifndef GRT_SVM_HEADER
37 #define GRT_SVM_HEADER
38 
39 #include "../../CoreModules/Classifier.h"
40 #include "LIBSVM/libsvm.h"
41 
42 GRT_BEGIN_NAMESPACE
43 
44 using namespace LIBSVM;
45 
46 #define SVM_MIN_SCALE_RANGE -1.0
47 #define SVM_MAX_SCALE_RANGE 1.0
48 
49 class GRT_API SVM : public Classifier{
50 public:
51  enum SVMTypes{ C_SVC = 0, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR };
52  enum SVMKernelTypes{ LINEAR_KERNEL = 0, POLY_KERNEL, RBF_KERNEL, SIGMOID_KERNEL, PRECOMPUTED_KERNEL };
53 
72  SVM(UINT kernelType = LINEAR_KERNEL,UINT svmType = C_SVC,bool useScaling = true,bool useNullRejection = false,bool useAutoGamma = true,Float gamma = 0.1,UINT degree = 3,Float coef0 = 0,Float nu = 0.5,Float C = 1,bool useCrossValidation = false,UINT kFoldValue = 10);
73 
79  SVM(const SVM &rhs);
80 
84  virtual ~SVM();
85 
92  SVM &operator=(const SVM &rhs);
93 
101  virtual bool deepCopyFrom(const Classifier *classifier);
102 
110  virtual bool train_(ClassificationData &trainingData);
111 
119  virtual bool predict_(VectorFloat &inputVector);
120 
124  virtual bool clear();
125 
133  virtual bool save( std::fstream &file ) const;
134 
142  virtual bool load( std::fstream &file );
143 
160  bool init(UINT kernelType,UINT svmType,bool useScaling,bool useNullRejection,bool useAutoGamma,Float gamma,UINT degree,Float coef0,Float nu,Float C,bool useCrossValidation,UINT kFoldValue);
161 
165  void initDefaultSVMSettings();
166 
172  bool getIsCrossValidationTrainingEnabled() const;
173 
180  bool getIsAutoGammaEnabled() const;
181 
189  std::string getSVMType() const;
190 
198  std::string getKernelType() const;
199 
205  UINT getDegree() const;
206 
214  virtual UINT getNumClasses() const;
215 
221  Float getGamma() const;
222 
228  Float getNu() const;
229 
235  Float getCoef0() const;
236 
242  Float getC() const;
243 
249  Float getCrossValidationResult() const;
250 
251  struct svm_model *getModel() const { return model; }
252 
260  bool setSVMType(const UINT svmType);
261 
269  bool setKernelType(const UINT kernelType);
270 
277  bool setGamma(const Float gamma);
278 
286  bool setDegree(const UINT degree);
287 
295  bool setNu(const Float nu);
296 
304  bool setCoef0(const Float coef0);
305 
313  bool setC(const Float C);
314 
321  bool setKFoldCrossValidationValue(const UINT kFoldValue);
322 
329  bool enableAutoGamma(const bool useAutoGamma);
330 
337  bool enableCrossValidationTraining(const bool useCrossValidation);
338 
339  //Tell the compiler we are using the following functions from the MLBase class to stop hidden virtual function warnings
340  using MLBase::save;
341  using MLBase::load;
342  using MLBase::train_;
343  using MLBase::predict_;
344 
345 protected:
346  void deleteProblemSet();
347  bool validateProblemAndParameters();
348  bool validateSVMType(UINT svmType);
349  bool validateKernelType(UINT kernelType);
350  bool convertClassificationDataToLIBSVMFormat(ClassificationData &trainingData);
351  bool trainSVM();
352 
353  bool predictSVM(VectorFloat &inputVector);
354  bool predictSVM(VectorFloat &inputVector,Float &maxProbability, VectorFloat &probabilites);
355  bool loadLegacyModelFromFile( std::fstream &file );
356 
357  struct svm_model *deepCopyModel() const;
358  bool deepCopyProblem( const struct svm_problem &source_problem, struct svm_problem &target_problem, const unsigned int numInputDimensions ) const;
359  bool deepCopyParam( const svm_parameter &source_param, svm_parameter &target_param ) const;
360 
361  bool problemSet;
362  struct svm_model *model;
363  struct svm_parameter param;
364  struct svm_problem prob;
365  UINT kFoldValue;
366  Float classificationThreshold;
367  Float crossValidationResult;
368  bool useAutoGamma;
369  bool useCrossValidation;
370 
371  static RegisterClassifierModule< SVM > registerModule;
372 };
373 
374 GRT_END_NAMESPACE
375 
376 #endif //GRT_SVM_HEADER
377 
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:115
virtual bool save(const std::string filename) const
Definition: MLBase.cpp:143
virtual bool load(const std::string filename)
Definition: MLBase.cpp:167
Definition: SVM.h:49
Definition: libsvm.cpp:4
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:91