GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
GMM.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 
20 */
21 #define GRT_DLL_EXPORTS
22 #include "GMM.h"
23 
24 GRT_BEGIN_NAMESPACE
25 
26 //Define the string that will be used to identify the object
27 const std::string GMM::id = "GMM";
28 std::string GMM::getId() { return GMM::id; }
29 
30 //Register the GMM module with the Classifier base class
31 RegisterClassifierModule< GMM > GMM::registerModule( GMM::getId() );
32 
33 GMM::GMM(UINT numMixtureModels,bool useScaling,bool useNullRejection,Float nullRejectionCoeff,UINT maxNumEpochs,Float minChange) : Classifier( GMM::getId() )
34 {
35  classifierMode = STANDARD_CLASSIFIER_MODE;
36  this->numMixtureModels = numMixtureModels;
37  this->useScaling = useScaling;
38  this->useNullRejection = useNullRejection;
39  this->nullRejectionCoeff = nullRejectionCoeff;
40  this->maxNumEpochs = maxNumEpochs;
41  this->minChange = minChange;
42  numRestarts = 10; //The learning algorithm can attempt to train a model N times
43 }
44 
45 GMM::GMM(const GMM &rhs) : Classifier( GMM::getId() )
46 {
47  classifierMode = STANDARD_CLASSIFIER_MODE;
48  *this = rhs;
49 }
50 
52 
53 GMM& GMM::operator=(const GMM &rhs){
54  if( this != &rhs ){
55 
56  this->numMixtureModels = rhs.numMixtureModels;
57  this->models = rhs.models;
58 
59  //Classifier variables
60  copyBaseVariables( (Classifier*)&rhs );
61  }
62  return *this;
63 }
64 
65 bool GMM::deepCopyFrom(const Classifier *classifier){
66 
67  if( classifier == NULL ) return false;
68 
69  if( this->getClassifierType() == classifier->getClassifierType() ){
70 
71  GMM *ptr = (GMM*)classifier;
72  //Clone the GMM values
73  this->numMixtureModels = ptr->numMixtureModels;
74  this->models = ptr->models;
75 
76  //Clone the classifier variables
77  return copyBaseVariables( classifier );
78  }
79  return false;
80 }
81 
83 
84  predictedClassLabel = 0;
85 
86  if( classDistances.getSize() != numClasses || classLikelihoods.getSize() != numClasses ){
87  classDistances.resize(numClasses);
88  classLikelihoods.resize(numClasses);
89  }
90 
91  if( !trained ){
92  errorLog << __GRT_LOG__ << " Mixture Models have not been trained!" << std::endl;
93  return false;
94  }
95 
96  if( x.getSize() != numInputDimensions ){
97  errorLog << __GRT_LOG__ << " The size of the input vector (" << x.getSize() << ") does not match that of the number of features the model was trained with (" << numInputDimensions << ")." << std::endl;
98  return false;
99  }
100 
101  if( useScaling ){
102  for(UINT i=0; i<numInputDimensions; i++){
103  x[i] = grt_scale(x[i], ranges[i].minValue, ranges[i].maxValue, GMM_MIN_SCALE_VALUE, GMM_MAX_SCALE_VALUE);
104  }
105  }
106 
107  UINT bestIndex = 0;
108  maxLikelihood = 0;
109  bestDistance = 0;
110  Float sum = 0;
111  for(UINT k=0; k<numClasses; k++){
112  classDistances[k] = computeMixtureLikelihood(x,k);
113 
114  //cout << "K: " << k << " Dist: " << classDistances[k] << std::endl;
115  classLikelihoods[k] = classDistances[k];
116  sum += classLikelihoods[k];
117  if( classLikelihoods[k] > bestDistance ){
118  bestDistance = classLikelihoods[k];
119  bestIndex = k;
120  }
121  }
122 
123  //Normalize the likelihoods
124  for(unsigned int k=0; k<numClasses; k++){
125  classLikelihoods[k] /= sum;
126  }
127  maxLikelihood = classLikelihoods[bestIndex];
128 
129  if( useNullRejection ){
130 
131  //cout << "Dist: " << classDistances[bestIndex] << " RejectionThreshold: " << models[bestIndex].getRejectionThreshold() << std::endl;
132 
133  //If the best distance is below the modles rejection threshold then set the predicted class label as the best class label
134  //Otherwise set the predicted class label as the default null rejection class label of 0
135  if( classDistances[bestIndex] >= models[bestIndex].getNullRejectionThreshold() ){
136  predictedClassLabel = models[bestIndex].getClassLabel();
137  }else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
138  }else{
139  //Get the predicted class label
140  predictedClassLabel = models[bestIndex].getClassLabel();
141  }
142 
143  return true;
144 }
145 
146 bool GMM::train_(ClassificationData &trainingData){
147 
148  //Clear any old models
149  clear();
150 
151  if( trainingData.getNumSamples() == 0 ){
152  errorLog << __GRT_LOG__ << " Training data is empty!" << std::endl;
153  return false;
154  }
155 
156  //Set the number of features and number of classes and resize the models buffer
157  numInputDimensions = trainingData.getNumDimensions();
158  numOutputDimensions = trainingData.getNumClasses();
159  numClasses = trainingData.getNumClasses();
160  models.resize(numClasses);
161 
162  if( numInputDimensions >= 6 ){
163  warningLog << __GRT_LOG__ << " The number of features in your training data is high (" << numInputDimensions << "). The GMMClassifier does not work well with high dimensional data, you might get better results from one of the other classifiers." << std::endl;
164  }
165 
166  //Get the ranges of the training data if the training data is going to be scaled
167  ranges = trainingData.getRanges();
168 
169  ClassificationData validationData;
170 
171  //Scale the training data if needed
172  if( useScaling ){
173  //Scale the training data between 0 and 1
174  if( !trainingData.scale(GMM_MIN_SCALE_VALUE, GMM_MAX_SCALE_VALUE) ){
175  errorLog << __GRT_LOG__ << " Failed to scale training data!" << std::endl;
176  return false;
177  }
178  }
179 
180  if( useValidationSet ){
181  validationData = trainingData.split( 100-validationSetSize );
182  }
183 
184  //Fit a Mixture Model to each class (independently)
185  for(UINT k=0; k<numClasses; k++){
186  UINT classLabel = trainingData.getClassTracker()[k].classLabel;
187  ClassificationData classData = trainingData.getClassData( classLabel );
188 
189  //Train the Mixture Model for this class
190  GaussianMixtureModels gaussianMixtureModel;
191  gaussianMixtureModel.setNumClusters( numMixtureModels );
192  gaussianMixtureModel.setMinChange( minChange );
193  gaussianMixtureModel.setMaxNumEpochs( maxNumEpochs );
194  gaussianMixtureModel.setNumRestarts( numRestarts ); //The learning algorithm can retry building a model up to N times
195 
196  if( !gaussianMixtureModel.train( classData.getDataAsMatrixFloat() ) ){
197  errorLog << __GRT_LOG__ << " Failed to train Mixture Model for class " << classLabel << std::endl;
198  return false;
199  }
200 
201  //Setup the model container
202  models[k].resize( numMixtureModels );
203  models[k].setClassLabel( classLabel );
204 
205  //Store the mixture model in the container
206  for(UINT j=0; j<numMixtureModels; j++){
207  models[k][j].mu = gaussianMixtureModel.getMu().getRowVector(j);
208  models[k][j].sigma = gaussianMixtureModel.getSigma()[j];
209 
210  //Compute the determinant and invSigma for the realtime prediction
211  LUDecomposition ludcmp( models[k][j].sigma );
212  if( !ludcmp.inverse( models[k][j].invSigma ) ){
213  models.clear();
214  errorLog << __GRT_LOG__ << " Failed to invert Matrix for class " << classLabel << "!" << std::endl;
215  return false;
216  }
217  models[k][j].det = ludcmp.det();
218  }
219 
220  //Compute the normalize factor
221  models[k].recomputeNormalizationFactor();
222 
223  //Compute the rejection thresholds
224  Float mu = 0;
225  Float sigma = 0;
226  VectorFloat predictionResults(classData.getNumSamples(),0);
227  for(UINT i=0; i<classData.getNumSamples(); i++){
228  VectorFloat sample = classData[i].getSample();
229  predictionResults[i] = models[k].computeMixtureLikelihood( sample );
230  mu += predictionResults[i];
231  }
232 
233  //Update mu
234  mu /= Float( classData.getNumSamples() );
235 
236  //Calculate the standard deviation
237  for(UINT i=0; i<classData.getNumSamples(); i++)
238  sigma += grt_sqr( (predictionResults[i]-mu) );
239  sigma = grt_sqrt( sigma / (Float(classData.getNumSamples())-1.0) );
240  sigma = 0.2;
241 
242  //Set the models training mu and sigma
243  models[k].setTrainingMuAndSigma(mu,sigma);
244 
245  if( !models[k].recomputeNullRejectionThreshold(nullRejectionCoeff) && useNullRejection ){
246  warningLog << __GRT_LOG__ << " Failed to recompute rejection threshold for class " << classLabel << " - the nullRjectionCoeff value is too high!" << std::endl;
247  }
248 
249  //cout << "Training Mu: " << mu << " TrainingSigma: " << sigma << " RejectionThreshold: " << models[k].getNullRejectionThreshold() << std::endl;
250  //models[k].printModelValues();
251  }
252 
253  //Reset the class labels
254  classLabels.resize(numClasses);
255  for(UINT k=0; k<numClasses; k++){
256  classLabels[k] = models[k].getClassLabel();
257  }
258 
259  //Resize the rejection thresholds
260  nullRejectionThresholds.resize(numClasses);
261  for(UINT k=0; k<numClasses; k++){
262  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
263  }
264 
265  //Flag that the model has been trained
266  trained = true;
267  converged = true;
268 
269  //Compute the final training stats
270  trainingSetAccuracy = 0;
271  validationSetAccuracy = 0;
272 
273  //If scaling was on, then the data will already be scaled, so turn it off temporially so we can test the model accuracy
274  bool scalingState = useScaling;
275  useScaling = false;
276  if( !computeAccuracy( trainingData, trainingSetAccuracy ) ){
277  trained = false;
278  converged = false;
279  errorLog << __GRT_LOG__ << " Failed to compute training set accuracy! Failed to fully train model!" << std::endl;
280  return false;
281  }
282 
283  if( useValidationSet ){
284  if( !computeAccuracy( validationData, validationSetAccuracy ) ){
285  trained = false;
286  converged = false;
287  errorLog << __GRT_LOG__ << " Failed to compute validation set accuracy! Failed to fully train model!" << std::endl;
288  return false;
289  }
290 
291  }
292 
293  trainingLog << "Training set accuracy: " << trainingSetAccuracy << std::endl;
294 
295  if( useValidationSet ){
296  trainingLog << "Validation set accuracy: " << validationSetAccuracy << std::endl;
297  }
298 
299  //Reset the scaling state for future prediction
300  useScaling = scalingState;
301 
302  return trained;
303 }
304 
305 Float GMM::computeMixtureLikelihood(const VectorFloat &x,const UINT k){
306  if( k >= numClasses ){
307  errorLog << __GRT_LOG__ << " Invalid k value!" << std::endl;
308  return 0;
309  }
310  return models[k].computeMixtureLikelihood( x );
311 }
312 
313 bool GMM::save( std::fstream &file ) const{
314 
315  if( !trained ){
316  errorLog << __GRT_LOG__ << " The model has not been trained!" << std::endl;
317  return false;
318  }
319 
320  if( !file.is_open() )
321  {
322  errorLog << __GRT_LOG__ << " The file has not been opened!" << std::endl;
323  return false;
324  }
325 
326  //Write the header info
327  file << "GRT_GMM_MODEL_FILE_V2.0\n";
328 
329  //Write the classifier settings to the file
331  errorLog << __GRT_LOG__ << " Failed to save classifier base settings to file!" << std::endl;
332  return false;
333  }
334 
335  file << "NumMixtureModels: " << numMixtureModels << std::endl;
336 
337  if( trained ){
338  //Write each of the models
339  file << "Models:\n";
340  for(UINT k=0; k<numClasses; k++){
341  file << "ClassLabel: " << models[k].getClassLabel() << std::endl;
342  file << "K: " << models[k].getK() << std::endl;
343  file << "NormalizationFactor: " << models[k].getNormalizationFactor() << std::endl;
344  file << "TrainingMu: " << models[k].getTrainingMu() << std::endl;
345  file << "TrainingSigma: " << models[k].getTrainingSigma() << std::endl;
346  file << "NullRejectionThreshold: " << models[k].getNullRejectionThreshold() << std::endl;
347 
348  for(UINT index=0; index<models[k].getK(); index++){
349  file << "Determinant: " << models[k][index].det << std::endl;
350 
351  file << "Mu: ";
352  for(UINT j=0; j<models[k][index].mu.size(); j++) file << "\t" << models[k][index].mu[j];
353  file << std::endl;
354 
355  file << "Sigma:\n";
356  for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
357  for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
358  file << models[k][index].sigma[i][j];
359  if( j < models[k][index].sigma.getNumCols()-1 ) file << "\t";
360  }
361  file << std::endl;
362  }
363 
364  file << "InvSigma:\n";
365  for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
366  for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
367  file << models[k][index].invSigma[i][j];
368  if( j < models[k][index].invSigma.getNumCols()-1 ) file << "\t";
369  }
370  file << std::endl;
371  }
372  }
373 
374  file << std::endl;
375  }
376  }
377 
378  return true;
379 }
380 
381 bool GMM::load( std::fstream &file ){
382 
383  trained = false;
384  numInputDimensions = 0;
385  numClasses = 0;
386  models.clear();
387  classLabels.clear();
388 
389  if(!file.is_open())
390  {
391  errorLog << __GRT_LOG__ << " Could not open file to load model" << std::endl;
392  return false;
393  }
394 
395  std::string word;
396  file >> word;
397 
398  //Check to see if we should load a legacy file
399  if( word == "GRT_GMM_MODEL_FILE_V1.0" ){
400  return loadLegacyModelFromFile( file );
401  }
402 
403  //Find the file type header
404  if(word != "GRT_GMM_MODEL_FILE_V2.0"){
405  errorLog << __GRT_LOG__ << " Could not find Model File Header" << std::endl;
406  return false;
407  }
408 
409  //Load the base settings from the file
411  errorLog << __GRT_LOG__ << " Failed to load base settings from file!" << std::endl;
412  return false;
413  }
414 
415  file >> word;
416  if(word != "NumMixtureModels:"){
417  errorLog << __GRT_LOG__ << " Could not find NumMixtureModels" << std::endl;
418  return false;
419  }
420  file >> numMixtureModels;
421 
422  if( trained ){
423 
424  //Read the model header
425  file >> word;
426  if(word != "Models:"){
427  errorLog << __GRT_LOG__ << " Could not find the Models Header" << std::endl;
428  return false;
429  }
430 
431  //Resize the buffer
432  models.resize(numClasses);
433  classLabels.resize(numClasses);
434 
435  //Load each of the models
436  for(UINT k=0; k<numClasses; k++){
437  UINT classLabel = 0;
438  UINT K = 0;
439  Float normalizationFactor;
440  Float trainingMu;
441  Float trainingSigma;
442  Float rejectionThreshold;
443 
444  file >> word;
445  if(word != "ClassLabel:"){
446  errorLog << __GRT_LOG__ << " Could not find the ClassLabel for model " << k+1 << std::endl;
447  return false;
448  }
449  file >> classLabel;
450  models[k].setClassLabel( classLabel );
451  classLabels[k] = classLabel;
452 
453  file >> word;
454  if(word != "K:"){
455  errorLog << __GRT_LOG__ << " Could not find K for model " << k+1 << std::endl;
456  return false;
457  }
458  file >> K;
459 
460  file >> word;
461  if(word != "NormalizationFactor:"){
462  errorLog << __GRT_LOG__ << " Could not find NormalizationFactor for model " << k+1 << std::endl;
463  return false;
464  }
465  file >> normalizationFactor;
466  models[k].setNormalizationFactor(normalizationFactor);
467 
468  file >> word;
469  if(word != "TrainingMu:"){
470  errorLog << __GRT_LOG__ << " Could not find TrainingMu for model " << k+1 << std::endl;
471  return false;
472  }
473  file >> trainingMu;
474 
475  file >> word;
476  if(word != "TrainingSigma:"){
477  errorLog << __GRT_LOG__ << " Could not find TrainingSigma for model " << k+1 << std::endl;
478  return false;
479  }
480  file >> trainingSigma;
481 
482  //Set the training mu and sigma
483  models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
484 
485  file >> word;
486  if(word != "NullRejectionThreshold:"){
487  errorLog << __GRT_LOG__ << " Could not find NullRejectionThreshold for model " << k+1 << std::endl;
488  return false;
489  }
490  file >>rejectionThreshold;
491 
492  //Set the rejection threshold
493  models[k].setNullRejectionThreshold(rejectionThreshold);
494 
495  //Resize the buffer for the mixture models
496  models[k].resize(K);
497 
498  //Load the mixture models
499  for(UINT index=0; index<models[k].getK(); index++){
500 
501  //Resize the memory for the current mixture model
502  models[k][index].mu.resize( numInputDimensions );
503  models[k][index].sigma.resize( numInputDimensions, numInputDimensions );
504  models[k][index].invSigma.resize( numInputDimensions, numInputDimensions );
505 
506  file >> word;
507  if(word != "Determinant:"){
508  errorLog << __GRT_LOG__ << " Could not find the Determinant for model " << k+1 << std::endl;
509  return false;
510  }
511  file >> models[k][index].det;
512 
513 
514  file >> word;
515  if(word != "Mu:"){
516  errorLog << __GRT_LOG__ << " Could not find Mu for model " << k+1 << std::endl;
517  return false;
518  }
519  for(UINT j=0; j<models[k][index].mu.getSize(); j++){
520  file >> models[k][index].mu[j];
521  }
522 
523 
524  file >> word;
525  if(word != "Sigma:"){
526  errorLog << __GRT_LOG__ << " Could not find Sigma for model " << k+1 << std::endl;
527  return false;
528  }
529  for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
530  for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
531  file >> models[k][index].sigma[i][j];
532  }
533  }
534 
535  file >> word;
536  if(word != "InvSigma:"){
537  errorLog << __GRT_LOG__ << " Could not find InvSigma for model " << k+1 << std::endl;
538  return false;
539  }
540  for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
541  for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
542  file >> models[k][index].invSigma[i][j];
543  }
544  }
545 
546  }
547 
548  }
549 
550  //Set the null rejection thresholds
551  nullRejectionThresholds.resize(numClasses);
552  for(UINT k=0; k<numClasses; k++) {
553  models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
554  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
555  }
556 
557  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
558  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
559  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
560  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
561  }
562 
563  return true;
564 }
565 
566 bool GMM::clear(){
567 
568  //Clear the Classifier variables
570 
571  //Clear the GMM model
572  models.clear();
573 
574  return true;
575 }
576 
578 
579  if( trained ){
580  for(UINT k=0; k<numClasses; k++) {
581  models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
582  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
583  }
584  return true;
585  }
586  return false;
587 }
588 
590  return numMixtureModels;
591 }
592 
594  if( trained ){ return models; }
595  return Vector< MixtureModel >();
596 }
597 
598 bool GMM::setNumMixtureModels(const UINT K){
599  if( K > 0 ){
600  numMixtureModels = K;
601  return true;
602  }
603  return false;
604 }
605 
606 bool GMM::setMaxIter(UINT maxNumEpochs){
607  return setMaxNumEpochs( maxNumEpochs );
608 }
609 
610 bool GMM::loadLegacyModelFromFile( std::fstream &file ){
611 
612  std::string word;
613 
614  file >> word;
615  if(word != "NumFeatures:"){
616  errorLog << __GRT_LOG__ << " Could not find NumFeatures " << std::endl;
617  return false;
618  }
619  file >> numInputDimensions;
620 
621  file >> word;
622  if(word != "NumClasses:"){
623  errorLog << __GRT_LOG__ << " Could not find NumClasses" << std::endl;
624  return false;
625  }
626  file >> numClasses;
627 
628  file >> word;
629  if(word != "NumMixtureModels:"){
630  errorLog << __GRT_LOG__ << " Could not find NumMixtureModels" << std::endl;
631  return false;
632  }
633  file >> numMixtureModels;
634 
635  file >> word;
636  if(word != "MaxIter:"){
637  errorLog << __GRT_LOG__ << " Could not find MaxIter" << std::endl;
638  return false;
639  }
640  file >> maxNumEpochs;
641 
642  file >> word;
643  if(word != "MinChange:"){
644  errorLog << __GRT_LOG__ << " Could not find MinChange" << std::endl;
645  return false;
646  }
647  file >> minChange;
648 
649  file >> word;
650  if(word != "UseScaling:"){
651  errorLog << __GRT_LOG__ << " Could not find UseScaling" << std::endl;
652  return false;
653  }
654  file >> useScaling;
655 
656  file >> word;
657  if(word != "UseNullRejection:"){
658  errorLog << __GRT_LOG__ << " Could not find UseNullRejection" << std::endl;
659  return false;
660  }
661  file >> useNullRejection;
662 
663  file >> word;
664  if(word != "NullRejectionCoeff:"){
665  errorLog << __GRT_LOG__ << " Could not find NullRejectionCoeff" << std::endl;
666  return false;
667  }
668  file >> nullRejectionCoeff;
669 
671  if( useScaling ){
672  //Resize the ranges buffer
673  ranges.resize(numInputDimensions);
674 
675  file >> word;
676  if(word != "Ranges:"){
677  errorLog << __GRT_LOG__ << " Could not find the Ranges" << std::endl;
678  return false;
679  }
680  for(UINT n=0; n<ranges.size(); n++){
681  file >> ranges[n].minValue;
682  file >> ranges[n].maxValue;
683  }
684  }
685 
686  //Read the model header
687  file >> word;
688  if(word != "Models:"){
689  errorLog << __GRT_LOG__ << " Could not find the Models Header" << std::endl;
690  return false;
691  }
692 
693  //Resize the buffer
694  models.resize(numClasses);
695  classLabels.resize(numClasses);
696 
697  //Load each of the models
698  for(UINT k=0; k<numClasses; k++){
699  UINT classLabel = 0;
700  UINT K = 0;
701  Float normalizationFactor;
702  Float trainingMu;
703  Float trainingSigma;
704  Float rejectionThreshold;
705 
706  file >> word;
707  if(word != "ClassLabel:"){
708  errorLog << __GRT_LOG__ << " Could not find the ClassLabel for model " << k+1 << std::endl;
709  return false;
710  }
711  file >> classLabel;
712  models[k].setClassLabel( classLabel );
713  classLabels[k] = classLabel;
714 
715  file >> word;
716  if(word != "K:"){
717  errorLog << __GRT_LOG__ << " Could not find K for model " << k+1 << std::endl;
718  return false;
719  }
720  file >> K;
721 
722  file >> word;
723  if(word != "NormalizationFactor:"){
724  errorLog << __GRT_LOG__ << " Could not find NormalizationFactor for model " << k+1 << std::endl;
725  return false;
726  }
727  file >> normalizationFactor;
728  models[k].setNormalizationFactor(normalizationFactor);
729 
730  file >> word;
731  if(word != "TrainingMu:"){
732  errorLog << __GRT_LOG__ << " Could not find TrainingMu for model " << k+1 << std::endl;
733  return false;
734  }
735  file >> trainingMu;
736 
737  file >> word;
738  if(word != "TrainingSigma:"){
739  errorLog << __GRT_LOG__ << " Could not find TrainingSigma for model " << k+1 << std::endl;
740  return false;
741  }
742  file >> trainingSigma;
743 
744  //Set the training mu and sigma
745  models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
746 
747  file >> word;
748  if(word != "NullRejectionThreshold:"){
749  errorLog << __GRT_LOG__ << " Could not find NullRejectionThreshold for model " << k+1 << std::endl;
750  return false;
751  }
752  file >>rejectionThreshold;
753 
754  //Set the rejection threshold
755  models[k].setNullRejectionThreshold(rejectionThreshold);
756 
757  //Resize the buffer for the mixture models
758  models[k].resize(K);
759 
760  //Load the mixture models
761  for(UINT index=0; index<models[k].getK(); index++){
762 
763  //Resize the memory for the current mixture model
764  models[k][index].mu.resize( numInputDimensions );
765  models[k][index].sigma.resize( numInputDimensions, numInputDimensions );
766  models[k][index].invSigma.resize( numInputDimensions, numInputDimensions );
767 
768  file >> word;
769  if(word != "Determinant:"){
770  errorLog << __GRT_LOG__ << " Could not find the Determinant for model " << k+1 << std::endl;
771  return false;
772  }
773  file >> models[k][index].det;
774 
775 
776  file >> word;
777  if(word != "Mu:"){
778  errorLog << __GRT_LOG__ << " Could not find Mu for model " << k+1 << std::endl;
779  return false;
780  }
781  for(UINT j=0; j<models[k][index].mu.getSize(); j++){
782  file >> models[k][index].mu[j];
783  }
784 
785 
786  file >> word;
787  if(word != "Sigma:"){
788  errorLog << __GRT_LOG__ << " Could not find Sigma for model " << k+1 << std::endl;
789  return false;
790  }
791  for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
792  for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
793  file >> models[k][index].sigma[i][j];
794  }
795  }
796 
797  file >> word;
798  if(word != "InvSigma:"){
799  errorLog << __GRT_LOG__ << " Could not find InvSigma for model " << k+1 << std::endl;
800  return false;
801  }
802  for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
803  for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
804  file >> models[k][index].invSigma[i][j];
805  }
806  }
807 
808  }
809 
810  }
811 
812  //Set the null rejection thresholds
813  nullRejectionThresholds.resize(numClasses);
814  for(UINT k=0; k<numClasses; k++) {
815  models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
816  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
817  }
818 
819  //Flag that the models have been trained
820  trained = true;
821  converged = true;
822 
823  return trained;
824 }
825 
826 GRT_END_NAMESPACE
bool saveBaseSettingsToFile(std::fstream &file) const
Definition: Classifier.cpp:274
static std::string getId()
Definition: GMM.cpp:28
#define DEFAULT_NULL_LIKELIHOOD_VALUE
Definition: Classifier.h:33
virtual bool train_(ClassificationData &trainingData)
Definition: GMM.cpp:146
Definition: GMM.h:48
virtual bool clear()
Definition: GMM.cpp:566
std::string getClassifierType() const
Definition: Classifier.cpp:175
Vector< ClassTracker > getClassTracker() const
virtual bool recomputeNullRejectionThresholds()
Definition: GMM.cpp:577
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:107
#define GMM_MIN_SCALE_VALUE
Definition: GMM.h:33
bool setNumMixtureModels(const UINT K)
Definition: GMM.cpp:598
UINT getSize() const
Definition: Vector.h:201
Vector< MixtureModel > getModels()
Definition: GMM.cpp:593
bool setMinChange(const Float minChange)
Definition: MLBase.cpp:344
virtual bool computeAccuracy(const ClassificationData &data, Float &accuracy)
Definition: Classifier.cpp:171
MatrixFloat getMu() const
UINT getNumSamples() const
virtual bool predict_(VectorFloat &inputVector)
Definition: GMM.cpp:82
virtual bool save(std::fstream &file) const
Definition: GMM.cpp:313
GMM(UINT numMixtureModels=2, bool useScaling=false, bool useNullRejection=false, Float nullRejectionCoeff=1.0, UINT maxIter=100, Float minChange=1.0e-5)
Definition: GMM.cpp:33
Vector< T > getRowVector(const unsigned int r) const
Definition: Matrix.h:184
UINT getNumMixtureModels()
Definition: GMM.cpp:589
bool copyBaseVariables(const Classifier *classifier)
Definition: Classifier.cpp:101
bool loadBaseSettingsFromFile(std::fstream &file)
Definition: Classifier.cpp:321
UINT getNumDimensions() const
UINT getNumClasses() const
Vector< MatrixFloat > getSigma() const
GMM & operator=(const GMM &rhs)
Definition: GMM.cpp:53
virtual ~GMM(void)
Definition: GMM.cpp:51
virtual bool load(std::fstream &file)
Definition: GMM.cpp:381
bool loadLegacyModelFromFile(std::fstream &file)
Definition: GMM.cpp:610
Vector< MinMax > getRanges() const
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
MatrixFloat getDataAsMatrixFloat() const
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: GMM.cpp:65
bool setMaxNumEpochs(const UINT maxNumEpochs)
Definition: MLBase.cpp:320
bool scale(const Float minTarget, const Float maxTarget)
virtual bool clear()
Definition: Classifier.cpp:151
This is the main base class that all GRT Classification algorithms should inherit from...
Definition: Classifier.h:41
bool setNumRestarts(const UINT numRestarts)
bool setNumClusters(const UINT numClusters)
Definition: Clusterer.cpp:262