GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
KMeans.h
Go to the documentation of this file.
1 
31 #ifndef GRT_KMEANS_HEADER
32 #define GRT_KMEANS_HEADER
33 
34 #include "../../Util/GRTCommon.h"
35 #include "../../CoreModules/Clusterer.h"
36 #include "../../DataStructures/ClassificationData.h"
37 #include "../../DataStructures/UnlabelledData.h"
38 
39 GRT_BEGIN_NAMESPACE
40 
41 class GRT_API KMeans : public Clusterer{
42 
43 public:
47  KMeans(const UINT numClusters=10,const UINT minNumEpochs=5,const UINT maxNumEpochs=1000,const Float minChange=1.0e-5,const bool computeTheta=true);
48 
54  KMeans(const KMeans &rhs);
55 
59  virtual ~KMeans();
60 
67  KMeans &operator=(const KMeans &rhs);
68 
76  virtual bool deepCopyFrom(const Clusterer *clusterer);
77 
84  virtual bool reset();
85 
91  virtual bool clear();
92 
100  bool trainModel(MatrixFloat &data);
101 
108  virtual bool train_(MatrixFloat &data);
109 
116  virtual bool train_(ClassificationData &trainingData);
117 
124  virtual bool train_(UnlabelledData &trainingData);
125 
132  virtual bool predict_(VectorFloat &inputVector);
133 
141  virtual bool saveModelToFile( std::fstream &file ) const;
142 
150  virtual bool loadModelFromFile( std::fstream &file );
151 
152  //bool predict(VectorFloat inputVector,UINT &predictedClusterLabel,Float &maxLikelihood,VectorFloat &clusterLikelihoods);
153 
154  //Getters
155  Float getTheta(){ return finalTheta; }
156  bool getModelTrained(){ return trained; }
157 
158  const VectorFloat& getTrainingThetaLog() const { return thetaTracker; }
159  const MatrixFloat& getClusters() const { return clusters; }
160  const Vector< UINT >& getClassLabelsVector() const { return assign; }
161  const Vector< UINT >& getClassCountVector() const { return count; }
162 
163  //Setters
164  bool setComputeTheta(const bool computeTheta);
165 
173  bool setClusters(const MatrixFloat &clusters);
174 
175  //Tell the compiler we are using the following functions from the MLBase class to stop hidden virtual function warnings
176  using MLBase::saveModelToFile;
177  using MLBase::loadModelFromFile;
178  using MLBase::train;
179  using MLBase::train_;
180  using MLBase::predict;
181  using MLBase::predict_;
182 
183 protected:
184  UINT estep(const MatrixFloat &data);
185  void mstep(const MatrixFloat &data);
186  Float calculateTheta(const MatrixFloat &data);
187  inline Float SQR(const Float a) {return a*a;};
188 
189  bool computeTheta;
191  UINT nchg;
192  Float finalTheta;
193  MatrixFloat clusters;
194  Vector< UINT > assign, count;
195  VectorFloat thetaTracker;
196 
197 private:
198  static RegisterClustererModule< KMeans > registerModule;
199 };
200 
201 GRT_END_NAMESPACE
202 
203 #endif //GRT_KMEANS_HEADER
virtual bool predict(VectorFloat inputVector)
Definition: MLBase.cpp:113
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:115
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:89
virtual bool deepCopyFrom(const Clusterer *clusterer)
Definition: Clusterer.h:58
UINT nchg
Number of values changes.
Definition: KMeans.h:191
bool getModelTrained() const
Definition: MLBase.cpp:261
UINT numTrainingSamples
Number of training examples.
Definition: KMeans.h:190
virtual bool reset()
Definition: Clusterer.cpp:128
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:91
virtual bool train_(MatrixFloat &trainingData)
Definition: Clusterer.cpp:114
Definition: KMeans.h:41
virtual bool clear()
Definition: Clusterer.cpp:142