GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
KMeans.h
Go to the documentation of this file.
1 
31 #ifndef GRT_KMEANS_HEADER
32 #define GRT_KMEANS_HEADER
33 
34 #include "../../Util/GRTCommon.h"
35 #include "../../CoreModules/Clusterer.h"
36 #include "../../DataStructures/ClassificationData.h"
37 #include "../../DataStructures/UnlabelledData.h"
38 
39 GRT_BEGIN_NAMESPACE
40 
41 class GRT_API KMeans : public Clusterer{
42 
43 public:
47  KMeans(const UINT numClusters=10,const UINT minNumEpochs=5,const UINT maxNumEpochs=1000,const Float minChange=1.0e-5,const bool computeTheta=true);
48 
54  KMeans(const KMeans &rhs);
55 
59  virtual ~KMeans();
60 
67  KMeans &operator=(const KMeans &rhs);
68 
76  virtual bool deepCopyFrom(const Clusterer *clusterer);
77 
84  virtual bool reset();
85 
91  virtual bool clear();
92 
100  bool trainModel(MatrixFloat &data);
101 
108  virtual bool train_(MatrixFloat &data);
109 
116  virtual bool train_(ClassificationData &trainingData);
117 
124  virtual bool train_(UnlabelledData &trainingData);
125 
132  virtual bool predict_(VectorFloat &inputVector);
133 
141  virtual bool saveModelToFile( std::fstream &file ) const;
142 
150  virtual bool loadModelFromFile( std::fstream &file );
151 
152  //bool predict(VectorFloat inputVector,UINT &predictedClusterLabel,Float &maxLikelihood,VectorFloat &clusterLikelihoods);
153 
154  //Getters
155  Float getTheta(){ return finalTheta; }
156  bool getModelTrained(){ return trained; }
157 
158  const VectorFloat& getTrainingThetaLog() const { return thetaTracker; }
159  const MatrixFloat& getClusters() const { return clusters; }
160  const Vector< UINT >& getClassLabelsVector() const { return assign; }
161  const Vector< UINT >& getClassCountVector() const { return count; }
162 
163  //Setters
164  bool setComputeTheta(const bool computeTheta);
165 
173  bool setClusters(const MatrixFloat &clusters);
174 
175  //Tell the compiler we are using the following functions from the MLBase class to stop hidden virtual function warnings
176  using MLBase::saveModelToFile;
177  using MLBase::loadModelFromFile;
178  using MLBase::train;
179  using MLBase::train_;
180  using MLBase::predict;
181  using MLBase::predict_;
182 
188  static std::string getId();
189 
190 protected:
191  UINT estep(const MatrixFloat &data);
192  void mstep(const MatrixFloat &data);
193  Float calculateTheta(const MatrixFloat &data);
194  inline Float SQR(const Float a) {return a*a;};
195 
196  bool computeTheta;
198  UINT nchg;
199  Float finalTheta;
200  MatrixFloat clusters;
201  Vector< UINT > assign, count;
202  VectorFloat thetaTracker;
203 
204 private:
205  static RegisterClustererModule< KMeans > registerModule;
206  static const std::string id;
207 };
208 
209 GRT_END_NAMESPACE
210 
211 #endif //GRT_KMEANS_HEADER
std::string getId() const
Definition: GRTBase.cpp:85
virtual bool predict(VectorFloat inputVector)
Definition: MLBase.cpp:135
virtual bool reset() override
Definition: Clusterer.cpp:130
virtual bool predict_(VectorFloat &inputVector)
Definition: MLBase.cpp:137
virtual bool clear() override
Definition: Clusterer.cpp:144
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:107
virtual bool deepCopyFrom(const Clusterer *clusterer)
Definition: Clusterer.h:59
virtual bool train_(MatrixFloat &trainingData) override
Definition: Clusterer.cpp:116
UINT nchg
Number of values changes.
Definition: KMeans.h:198
UINT numTrainingSamples
Number of training examples.
Definition: KMeans.h:197
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:109
Definition: KMeans.h:41