GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DTW.h
Go to the documentation of this file.
1 
26 #ifndef GRT_DTW_HEADER
27 #define GRT_DTW_HEADER
28 
29 #include "../../CoreModules/Classifier.h"
30 #include "../../Util/TimeSeriesClassificationSampleTrimmer.h"
31 
32 GRT_BEGIN_NAMESPACE
33 
34 class GRT_API IndexDist{
35  public:
36  IndexDist(int x=0,int y=0,Float dist=0){
37  this->x = x;
38  this->y = y;
39  this->dist = dist;
40  }
41  ~IndexDist(){};
42  IndexDist& operator=(const IndexDist &rhs){
43  if(this!=&rhs){
44  this->x = rhs.x;
45  this->y = rhs.y;
46  this->dist = rhs.dist;
47  }
48  return (*this);
49  }
50 
51  int x;
52  int y;
53  Float dist;
54 };
55 
57 class GRT_API DTWTemplate{
58  public:
59  DTWTemplate(){
60  classLabel = 0;
61  trainingMu = 0.0;
62  trainingSigma = 0.0;
63  averageTemplateLength=0;
64  }
65  ~DTWTemplate(){};
66 
67  UINT classLabel; //The class that this template belongs to
68  MatrixFloat timeSeries; //The raw time series
69  Float trainingMu; //The mean distance value of the training data with the trained template
70  Float trainingSigma; //The sigma of the distance value of the training data with the trained template
71  UINT averageTemplateLength; //The average length of the examples used to train this template
72 };
73 
91 class GRT_API DTW : public Classifier
92 {
93 public:
94 
95  enum DistanceMethods{ABSOLUTE_DIST=0,EUCLIDEAN_DIST,NORM_ABSOLUTE_DIST};
96  enum RejectionModes{TEMPLATE_THRESHOLDS=0,CLASS_LIKELIHOODS,THRESHOLDS_AND_LIKELIHOODS};
97 
112  DTW(bool useScaling=false,bool useNullRejection=false,Float nullRejectionCoeff=3.0,UINT rejectionMode = DTW::TEMPLATE_THRESHOLDS,bool dtwConstrain=true,Float radius=0.2,bool offsetUsingFirstSample=false,bool useSmoothing = false,UINT smoothingFactor = 5, Float nullRejectionLikelihoodThreshold = 0.99);
113 
121  DTW(const DTW &rhs);
122 
126  virtual ~DTW(void);
127 
134  DTW& operator=(const DTW &rhs);
135 
143  virtual bool deepCopyFrom(const Classifier *classifier);
144 
152  virtual bool train_(TimeSeriesClassificationData &trainingData);
153 
161  virtual bool predict_(VectorFloat &inputVector);
162 
170  virtual bool predict_(MatrixFloat &timeSeries);
171 
177  virtual bool reset();
178 
185  virtual bool clear();
186 
194  virtual bool save( std::fstream &file ) const;
195 
203  virtual bool load( std::fstream &file );
204 
212  virtual bool recomputeNullRejectionThresholds();
213 
219  UINT getNumTemplates() const { return numTemplates; }
220 
226  bool setRejectionMode(UINT rejectionMode);
227 
233  bool setNullRejectionThreshold(Float nullRejectionLikelihoodThreshold);
234 
243  bool setOffsetTimeseriesUsingFirstSample(bool offsetUsingFirstSample);
244 
251  bool setContrainWarpingPath(bool constrain);
252 
261  bool setWarpingRadius(Float radius);
262 
268  UINT getRejectionMode() const { return rejectionMode; }
269 
277  bool enableZNormalization(bool useZNormalization,bool constrainZNorm = true);
278 
295  bool enableTrimTrainingData(bool trimTrainingData,Float trimThreshold,Float maximumTrimPercentage);
296 
302  Vector< DTWTemplate > getModels() const { return templatesBuffer; }
303 
309  bool setModels( Vector< DTWTemplate > newTemplates );
310 
316  Vector< VectorFloat > getInputDataBuffer() const { return continuousInputDataBuffer.getData(); }
317 
323  const Vector< MatrixFloat >& getDistanceMatrices() const { return distanceMatrices; }
324 
330  const Vector< Vector< IndexDist > >& getWarpingPaths() const { return warpPaths; }
331 
337  static std::string getId();
338 
339  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
340  using MLBase::save;
341  using MLBase::load;
342  using MLBase::train_;
343  using MLBase::predict_;
344 
345 protected:
346  //Public training and prediction methods
347  bool train_NDDTW(TimeSeriesClassificationData &trainingData,DTWTemplate &dtwTemplate,UINT &bestIndex);
348 
349  //The actual DTW function
350  Float computeDistance(MatrixFloat &timeSeriesA,MatrixFloat &timeSeriesB,MatrixFloat &distanceMatrix,Vector< IndexDist > &warpPath);
351  Float d(int m,int n,MatrixFloat &distanceMatrix,const int M,const int N);
352  Float inline MIN_(Float a,Float b, Float c);
353 
354  //Scaling and Utility Functions
355  void scaleData(TimeSeriesClassificationData &trainingData);
356  void scaleData(MatrixFloat &data,MatrixFloat &scaledData);
357  void znormData(TimeSeriesClassificationData &trainingData);
358  void znormData(MatrixFloat &data,MatrixFloat &normData);
359  void smoothData(VectorFloat &data,UINT smoothFactor,VectorFloat &resultsData);
360  void smoothData(MatrixFloat &data,UINT smoothFactor,MatrixFloat &resultsData);
361  void offsetTimeseries(MatrixFloat &timeseries);
362  bool loadLegacyModelFromFile( std::fstream &file );
363 
364  Vector< DTWTemplate > templatesBuffer; //A buffer to store the templates for each time series
365  Vector< MatrixFloat > distanceMatrices;
366  Vector< Vector< IndexDist > > warpPaths;
367  CircularBuffer< VectorFloat > continuousInputDataBuffer;
368  UINT numTemplates; //The number of templates in our buffer
369  UINT rejectionMode; //The rejection mode used to reject null gestures during the prediction phase
370 
371  //Flags
372  bool useSmoothing; //A flag to check if we need to smooth the data
373  bool useZNormalisation; //A flag to check if we need to znorm the training and prediction data
374  bool offsetUsingFirstSample; //A flag to check if each timeseries should be offset by the first sample in the time series
375  bool constrainZNorm; //A flag to check if we need to constrain zNorm (only zNorm if stdDev > zNormConstrainThreshold)
376  bool constrainWarpingPath; //A flag to check if we need to constrain the dtw cost matrix and search
377  bool trimTrainingData; //A flag to check if we need to trim the training data first before training
378 
379  Float zNormConstrainThreshold;//The threshold value to be used if constrainZNorm is turned on
380  Float radius;
381  Float trimThreshold; //Sets the threshold under which training data should be trimmed (default 0.1)
382  Float maximumTrimPercentage; //Sets the maximum amount of data that can be trimmed for each training sample (default 20)
383  Float nullRejectionLikelihoodThreshold; //Sets the threshold for null rejection based on likelihoods
384 
385  UINT smoothingFactor; //The smoothing factor if smoothing is used
386  UINT distanceMethod; //The distance method to be used (should be of enum DISTANCE_METHOD)
387  UINT averageTemplateLength; //The overall average template length (over all the templates)
388 
389 private:
390  static RegisterClassifierModule< DTW > registerModule;
391  static const std::string id;
392 };
393 
394 GRT_END_NAMESPACE
395 
396 #endif //GRT_DTW_HEADER
397 
UINT getRejectionMode() const
Definition: DTW.h:268
std::string getId() const
Definition: GRTBase.cpp:85
UINT getNumTemplates() const
Definition: DTW.h:219
virtual bool recomputeNullRejectionThresholds()
Definition: Classifier.h:255
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:137
Vector< VectorFloat > getInputDataBuffer() const
Definition: DTW.h:316
Vector< DTWTemplate > getModels() const
Definition: DTW.h:302
virtual bool save(const std::string &filename) const
Definition: MLBase.cpp:167
const Vector< MatrixFloat > & getDistanceMatrices() const
Definition: DTW.h:323
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: Classifier.h:64
Definition: DTW.h:91
virtual bool reset()
Definition: Classifier.cpp:132
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:109
T * getData()
Definition: Vector.h:208
Definition: DTW.h:34
const Vector< Vector< IndexDist > > & getWarpingPaths() const
Definition: DTW.h:330
virtual bool load(const std::string &filename)
Definition: MLBase.cpp:190
virtual bool clear()
Definition: Classifier.cpp:151
This is the main base class that all GRT Classification algorithms should inherit from...
Definition: Classifier.h:41