GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
MeanShift.h
Go to the documentation of this file.
1 
33 #ifndef GRT_MEAN_SHIFT_HEADER
34 #define GRT_MEAN_SHIFT_HEADER
35 
36 #include "../../CoreModules/MLBase.h"
37 
38 GRT_BEGIN_NAMESPACE
39 
40 class MeanShift : public MLBase{
41 public:
42  MeanShift() {
43  classType = "MeanShift";
44  infoLog.setProceedingText("[MeanShift]");
45  debugLog.setProceedingText("[DEBUG MeanShift]");
46  errorLog.setProceedingText("[ERROR MeanShift]");
47  trainingLog.setProceedingText("[TRAINING MeanShift]");
48  warningLog.setProceedingText("[WARNING MeanShift]");
49  }
50 
51  virtual ~MeanShift(){
52 
53  }
54 
55  bool search( const VectorFloat &meanStart, const Vector< VectorFloat > &points, const Float searchRadius, const Float sigma = 20.0 ){
56 
57  //clear the results from any previous search
58  clear();
59 
60  const unsigned int numDimensions = (unsigned int)meanStart.size();
61  const unsigned int numPoints = (unsigned int)points.size();
62  const Float gamma = 1.0 / (2 * SQR(sigma) );
63  unsigned int iteration = 0;
64  VectorFloat numer(2,0);
65  VectorFloat denom(2,0);
66  VectorFloat kernelDist(2,0);
67  Float pointsWithinSearchRadius = 0;
68 
69  mean = meanStart;
70  VectorFloat lastMean = mean;
71 
72  //Start the search loop
73  while( true ){
74 
75  //Reset the counters
76  pointsWithinSearchRadius = 0;
77  std::fill(numer.begin(),numer.end(),0);
78  std::fill(denom.begin(),denom.end(),0);
79  std::fill(kernelDist.begin(),kernelDist.end(),0);
80 
81  //Update the numerator and denominator for points that are with the search radius
82  for(unsigned int i=0; i<numPoints; i++){
83 
84  //Compute the distance of the current point to the mean
85  Float distToMean = euclideanDist( mean, points[i] );
86 
87  //If the point is within the search radius then update numer and denom
88  if( distToMean < searchRadius ){
89 
90  for(unsigned int j=0; j<numDimensions; j++){
91  kernelDist[j] = gaussKernel( points[i][j], mean[j], gamma );
92  numer[j] += kernelDist[j] * points[i][j];
93  denom[j] += kernelDist[j];
94  }
95 
96  pointsWithinSearchRadius++;
97  }
98  }
99 
100  //Update the mean
101  Float change = 0;
102  for(unsigned int j=0; j<numDimensions; j++){
103 
104  mean[j] = numer[j] / denom[j];
105 
106  change += grt_sqr( mean[j] - lastMean[j] );
107 
108  lastMean[j] = mean[j];
109  }
110  change = grt_sqrt( change );
111 
112  trainingLog << "iteration: " << iteration;
113  trainingLog << " mean: ";
114  for(unsigned int j=0; j<numDimensions; j++){
115  trainingLog << mean[j] << " ";
116  }
117  trainingLog << " change: " << change << std::endl;
118 
119  if( change < minChange ){
120  trainingLog << "min changed limit reached - stopping search" << std::endl;
121  break;
122  }
123 
124  if( ++iteration >= maxNumEpochs ){
125  trainingLog << "max number of iterations reached - stopping search." << std::endl;
126  break;
127  }
128 
129  }
130  numTrainingIterationsToConverge = iteration;
131  trained = true;
132 
133  return true;
134  }
135 
136  VectorFloat getMean() const {
137  return mean;
138  }
139 
140  Float gaussKernel( const Float &x, const Float &mu, const Float gamma ){
141  return exp( gamma * grt_sqr(x-mu) );
142  }
143 
144  Float gaussKernel( const VectorFloat &x, const VectorFloat &mu, const Float gamma ){
145 
146  Float y = 0;
147  const UINT N = x.getSize();
148  for(UINT i=0; i<N; i++){
149  y += grt_sqr(x[i]-mu[i]);
150  }
151  return exp( gamma * y );
152  }
153 
154  Float euclideanDist( const VectorFloat &x, const VectorFloat &y ){
155 
156  Float z = 0;
157  const UINT N = x.getSize();
158  for(UINT i=0; i<N; i++){
159  z += grt_sqr(x[i]-y[i]);
160  }
161  return sqrt( z );
162 
163  }
164 
165 protected:
166 
167  VectorFloat mean;
168 
169 };
170 
171 GRT_END_NAMESPACE
172 
173 #endif //GRT_MEAN_SHIFT_HEADER
UINT getSize() const
Definition: Vector.h:191
virtual bool clear()
Definition: MLBase.cpp:127
Definition: MLBase.h:70