GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DTW.h
Go to the documentation of this file.
1 
43 #ifndef GRT_DTW_HEADER
44 #define GRT_DTW_HEADER
45 
46 #include "../../CoreModules/Classifier.h"
47 #include "../../Util/TimeSeriesClassificationSampleTrimmer.h"
48 
49 GRT_BEGIN_NAMESPACE
50 
51 class GRT_API IndexDist{
52  public:
53  IndexDist(int x=0,int y=0,Float dist=0){
54  this->x = x;
55  this->y = y;
56  this->dist = dist;
57  }
58  ~IndexDist(){};
59  IndexDist& operator=(const IndexDist &rhs){
60  if(this!=&rhs){
61  this->x = rhs.x;
62  this->y = rhs.y;
63  this->dist = rhs.dist;
64  }
65  return (*this);
66  }
67 
68  int x;
69  int y;
70  Float dist;
71 };
72 
74 class GRT_API DTWTemplate{
75  public:
76  DTWTemplate(){
77  classLabel = 0;
78  trainingMu = 0.0;
79  trainingSigma = 0.0;
80  averageTemplateLength=0;
81  }
82  ~DTWTemplate(){};
83 
84  UINT classLabel; //The class that this template belongs to
85  MatrixFloat timeSeries; //The raw time series
86  Float trainingMu; //The mean distance value of the training data with the trained template
87  Float trainingSigma; //The sigma of the distance value of the training data with the trained template
88  UINT averageTemplateLength; //The average length of the examples used to train this template
89 };
90 
91 class GRT_API DTW : public Classifier
92 {
93 public:
94 
109  DTW(bool useScaling=false,bool useNullRejection=false,Float nullRejectionCoeff=3.0,UINT rejectionMode = DTW::TEMPLATE_THRESHOLDS,bool dtwConstrain=true,Float radius=0.2,bool offsetUsingFirstSample=false,bool useSmoothing = false,UINT smoothingFactor = 5, Float nullRejectionLikelihoodThreshold = 0.99);
110 
118  DTW(const DTW &rhs);
119 
123  virtual ~DTW(void);
124 
131  DTW& operator=(const DTW &rhs);
132 
140  virtual bool deepCopyFrom(const Classifier *classifier);
141 
149  virtual bool train_(TimeSeriesClassificationData &trainingData);
150 
158  virtual bool predict_(VectorFloat &inputVector);
159 
167  virtual bool predict_(MatrixFloat &timeSeries);
168 
174  virtual bool reset();
175 
182  virtual bool clear();
183 
191  virtual bool save( std::fstream &file ) const;
192 
200  virtual bool load( std::fstream &file );
201 
209  virtual bool recomputeNullRejectionThresholds();
210 
216  UINT getNumTemplates() const { return numTemplates; }
217 
223  bool setRejectionMode(UINT rejectionMode);
224 
230  bool setNullRejectionThreshold(Float nullRejectionLikelihoodThreshold);
231 
240  bool setOffsetTimeseriesUsingFirstSample(bool offsetUsingFirstSample);
241 
248  bool setContrainWarpingPath(bool constrain);
249 
258  bool setWarpingRadius(Float radius);
259 
265  UINT getRejectionMode() const { return rejectionMode; }
266 
274  bool enableZNormalization(bool useZNormalization,bool constrainZNorm = true);
275 
292  bool enableTrimTrainingData(bool trimTrainingData,Float trimThreshold,Float maximumTrimPercentage);
293 
299  Vector< DTWTemplate > getModels() const { return templatesBuffer; }
300 
306  bool setModels( Vector< DTWTemplate > newTemplates );
307 
313  Vector< VectorFloat > getInputDataBuffer() const { return continuousInputDataBuffer.getData(); }
314 
320  const Vector< MatrixFloat >& getDistanceMatrices() const { return distanceMatrices; }
321 
327  const Vector< Vector< IndexDist > >& getWarpingPaths() const { return warpPaths; }
328 
329  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
330  using MLBase::save;
331  using MLBase::load;
332  using MLBase::train_;
333  using MLBase::predict_;
334 
335 private:
336  //Public training and prediction methods
337  bool train_NDDTW(TimeSeriesClassificationData &trainingData,DTWTemplate &dtwTemplate,UINT &bestIndex);
338 
339  //The actual DTW function
340  Float computeDistance(MatrixFloat &timeSeriesA,MatrixFloat &timeSeriesB,MatrixFloat &distanceMatrix,Vector< IndexDist > &warpPath);
341  Float d(int m,int n,MatrixFloat &distanceMatrix,const int M,const int N);
342  Float inline MIN_(Float a,Float b, Float c);
343 
344  //Private Scaling and Utility Functions
345  void scaleData(TimeSeriesClassificationData &trainingData);
346  void scaleData(MatrixFloat &data,MatrixFloat &scaledData);
347  void znormData(TimeSeriesClassificationData &trainingData);
348  void znormData(MatrixFloat &data,MatrixFloat &normData);
349  void smoothData(VectorFloat &data,UINT smoothFactor,VectorFloat &resultsData);
350  void smoothData(MatrixFloat &data,UINT smoothFactor,MatrixFloat &resultsData);
351  void offsetTimeseries(MatrixFloat &timeseries);
352 
353  static RegisterClassifierModule< DTW > registerModule;
354 
355 protected:
356  bool loadLegacyModelFromFile( std::fstream &file );
357 
358  Vector< DTWTemplate > templatesBuffer; //A buffer to store the templates for each time series
359  Vector< MatrixFloat > distanceMatrices;
360  Vector< Vector< IndexDist > > warpPaths;
361  CircularBuffer< VectorFloat > continuousInputDataBuffer;
362  UINT numTemplates; //The number of templates in our buffer
363  UINT rejectionMode; //The rejection mode used to reject null gestures during the prediction phase
364 
365  //Flags
366  bool useSmoothing; //A flag to check if we need to smooth the data
367  bool useZNormalisation; //A flag to check if we need to znorm the training and prediction data
368  bool offsetUsingFirstSample; //A flag to check if each timeseries should be offset by the first sample in the time series
369  bool constrainZNorm; //A flag to check if we need to constrain zNorm (only zNorm if stdDev > zNormConstrainThreshold)
370  bool constrainWarpingPath; //A flag to check if we need to constrain the dtw cost matrix and search
371  bool trimTrainingData; //A flag to check if we need to trim the training data first before training
372 
373  Float zNormConstrainThreshold;//The threshold value to be used if constrainZNorm is turned on
374  Float radius;
375  Float trimThreshold; //Sets the threshold under which training data should be trimmed (default 0.1)
376  Float maximumTrimPercentage; //Sets the maximum amount of data that can be trimmed for each training sample (default 20)
377  Float nullRejectionLikelihoodThreshold; //Sets the threshold for null rejection based on likelihoods
378 
379  UINT smoothingFactor; //The smoothing factor if smoothing is used
380  UINT distanceMethod; //The distance method to be used (should be of enum DISTANCE_METHOD)
381  UINT averageTemplateLength; //The overall average template length (over all the templates)
382 
383 public:
384  enum DistanceMethods{ABSOLUTE_DIST=0,EUCLIDEAN_DIST,NORM_ABSOLUTE_DIST};
385  enum RejectionModes{TEMPLATE_THRESHOLDS=0,CLASS_LIKELIHOODS,THRESHOLDS_AND_LIKELIHOODS};
386 
387 };
388 
389 GRT_END_NAMESPACE
390 
391 #endif //GRT_DTW_HEADER
392 
UINT getRejectionMode() const
Definition: DTW.h:265
UINT getNumTemplates() const
Definition: DTW.h:216
virtual bool recomputeNullRejectionThresholds()
Definition: Classifier.h:237
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:115
Vector< VectorFloat > getInputDataBuffer() const
Definition: DTW.h:313
Vector< DTWTemplate > getModels() const
Definition: DTW.h:299
const Vector< MatrixFloat > & getDistanceMatrices() const
Definition: DTW.h:320
virtual bool save(const std::string filename) const
Definition: MLBase.cpp:143
virtual bool load(const std::string filename)
Definition: MLBase.cpp:167
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: Classifier.h:63
Definition: DTW.h:91
virtual bool reset()
Definition: Classifier.cpp:123
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:91
T * getData()
Definition: Vector.h:198
Definition: DTW.h:51
const Vector< Vector< IndexDist > > & getWarpingPaths() const
Definition: DTW.h:327
virtual bool clear()
Definition: Classifier.cpp:142