GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
SVM.h
Go to the documentation of this file.
1 
26 #ifndef GRT_SVM_HEADER
27 #define GRT_SVM_HEADER
28 
29 #include "../../CoreModules/Classifier.h"
30 #include "LIBSVM/libsvm.h"
31 
32 GRT_BEGIN_NAMESPACE
33 
34 #define SVM_MIN_SCALE_RANGE -1.0
35 #define SVM_MAX_SCALE_RANGE 1.0
36 
47 class GRT_API SVM : public Classifier{
48 public:
49  enum SVMType{ C_SVC = 0, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR };
50  enum KernelType{ LINEAR_KERNEL = 0, POLY_KERNEL, RBF_KERNEL, SIGMOID_KERNEL, PRECOMPUTED_KERNEL };
51 
70  SVM(KernelType kernelType = LINEAR_KERNEL,SVMType svmType = C_SVC,bool useScaling = true,bool useNullRejection = false,bool useAutoGamma = true,Float gamma = 0.1,UINT degree = 3,Float coef0 = 0,Float nu = 0.5,Float C = 1,bool useCrossValidation = false,UINT kFoldValue = 10);
71 
77  SVM(const SVM &rhs);
78 
82  virtual ~SVM();
83 
90  SVM &operator=(const SVM &rhs);
91 
99  virtual bool deepCopyFrom(const Classifier *classifier);
100 
108  virtual bool train_(ClassificationData &trainingData);
109 
117  virtual bool predict_(VectorFloat &inputVector);
118 
122  virtual bool clear();
123 
131  virtual bool save( std::fstream &file ) const;
132 
140  virtual bool load( std::fstream &file );
141 
158  bool init(KernelType kernelType,SVMType svmType,bool useScaling,bool useNullRejection,bool useAutoGamma,Float gamma,UINT degree,Float coef0,Float nu,Float C,bool useCrossValidation,UINT kFoldValue);
159 
163  void initDefaultSVMSettings();
164 
170  bool getIsCrossValidationTrainingEnabled() const;
171 
178  bool getIsAutoGammaEnabled() const;
179 
187  std::string getSVMType() const;
188 
196  std::string getKernelType() const;
197 
203  UINT getDegree() const;
204 
212  virtual UINT getNumClasses() const;
213 
219  Float getGamma() const;
220 
226  Float getNu() const;
227 
233  Float getCoef0() const;
234 
240  Float getC() const;
241 
247  Float getCrossValidationResult() const;
248 
254  const struct LIBSVM::svm_model *getLIBSVMModel() const;
255 
263  bool setSVMType(const SVMType svmType);
264 
272  bool setKernelType(const KernelType kernelType);
273 
280  bool setGamma(const Float gamma);
281 
289  bool setDegree(const UINT degree);
290 
298  bool setNu(const Float nu);
299 
307  bool setCoef0(const Float coef0);
308 
316  bool setC(const Float C);
317 
324  bool setKFoldCrossValidationValue(const UINT kFoldValue);
325 
332  bool enableAutoGamma(const bool useAutoGamma);
333 
340  bool enableCrossValidationTraining(const bool useCrossValidation);
341 
347  static std::string getId();
348 
349  //Tell the compiler we are using the following functions from the MLBase class to stop hidden virtual function warnings
350  using MLBase::save;
351  using MLBase::load;
352  using MLBase::train_;
353  using MLBase::predict_;
354 
355 protected:
356  void deleteProblemSet();
357  bool validateProblemAndParameters();
358  bool validateSVMType(SVMType svmType);
359  bool validateKernelType(KernelType kernelType);
360  bool convertClassificationDataToLIBSVMFormat(ClassificationData &trainingData);
361  bool trainSVM();
362 
363  bool predictSVM(VectorFloat &inputVector);
364  bool predictSVM(VectorFloat &inputVector,Float &maxProbability, VectorFloat &probabilites);
365  bool loadLegacyModelFromFile( std::fstream &file );
366 
367  struct LIBSVM::svm_model *deepCopyModel() const;
368  bool deepCopyProblem( const struct LIBSVM::svm_problem &source_problem, struct LIBSVM::svm_problem &target_problem, const unsigned int numInputDimensions ) const;
369  bool deepCopyParam( const LIBSVM::svm_parameter &source_param, LIBSVM::svm_parameter &target_param ) const;
370 
371  bool problemSet;
372  struct LIBSVM::svm_model *model;
373  struct LIBSVM::svm_parameter param;
374  struct LIBSVM::svm_problem prob;
375  UINT kFoldValue;
376  Float classificationThreshold;
377  Float crossValidationResult;
378  bool useAutoGamma;
379  bool useCrossValidation;
380 
381 private:
382  static RegisterClassifierModule< SVM > registerModule;
383  static const std::string id;
384 };
385 
386 GRT_END_NAMESPACE
387 
388 #endif //GRT_SVM_HEADER
389 
std::string getId() const
Definition: GRTBase.cpp:85
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:137
virtual UINT getNumClasses() const
Definition: Classifier.cpp:209
virtual bool save(const std::string &filename) const
Definition: MLBase.cpp:167
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: Classifier.h:64
Definition: SVM.h:47
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:109
virtual bool load(const std::string &filename)
Definition: MLBase.cpp:190
virtual bool clear()
Definition: Classifier.cpp:151
This is the main base class that all GRT Classification algorithms should inherit from...
Definition: Classifier.h:41