GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
GMM.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 
20 */
21 #define GRT_DLL_EXPORTS
22 #include "GMM.h"
23 
24 GRT_BEGIN_NAMESPACE
25 
26 //Register the GMM module with the Classifier base class
27 RegisterClassifierModule< GMM > GMM::registerModule("GMM");
28 
29 GMM::GMM(UINT numMixtureModels,bool useScaling,bool useNullRejection,Float nullRejectionCoeff,UINT maxIter,Float minChange){
30  classType = "GMM";
31  classifierType = classType;
32  classifierMode = STANDARD_CLASSIFIER_MODE;
33  debugLog.setProceedingText("[DEBUG GMM]");
34  errorLog.setProceedingText("[ERROR GMM]");
35  warningLog.setProceedingText("[WARNING GMM]");
36 
37  this->numMixtureModels = numMixtureModels;
38  this->useScaling = useScaling;
39  this->useNullRejection = useNullRejection;
40  this->nullRejectionCoeff = nullRejectionCoeff;
41  this->maxIter = maxIter;
42  this->minChange = minChange;
43 }
44 
45 GMM::GMM(const GMM &rhs){
46  classType = "GMM";
47  classifierType = classType;
48  classifierMode = STANDARD_CLASSIFIER_MODE;
49  debugLog.setProceedingText("[DEBUG GMM]");
50  errorLog.setProceedingText("[ERROR GMM]");
51  warningLog.setProceedingText("[WARNING GMM]");
52  *this = rhs;
53 }
54 
56 
57 GMM& GMM::operator=(const GMM &rhs){
58  if( this != &rhs ){
59 
60  this->numMixtureModels = rhs.numMixtureModels;
61  this->maxIter = rhs.maxIter;
62  this->minChange = rhs.minChange;
63  this->models = rhs.models;
64 
65  this->debugLog = rhs.debugLog;
66  this->errorLog = rhs.errorLog;
67  this->warningLog = rhs.warningLog;
68 
69  //Classifier variables
70  copyBaseVariables( (Classifier*)&rhs );
71  }
72  return *this;
73 }
74 
75 bool GMM::deepCopyFrom(const Classifier *classifier){
76 
77  if( classifier == NULL ) return false;
78 
79  if( this->getClassifierType() == classifier->getClassifierType() ){
80 
81  GMM *ptr = (GMM*)classifier;
82  //Clone the GMM values
83  this->numMixtureModels = ptr->numMixtureModels;
84  this->maxIter = ptr->maxIter;
85  this->minChange = ptr->minChange;
86  this->models = ptr->models;
87 
88  this->debugLog = ptr->debugLog;
89  this->errorLog = ptr->errorLog;
90  this->warningLog = ptr->warningLog;
91 
92  //Clone the classifier variables
93  return copyBaseVariables( classifier );
94  }
95  return false;
96 }
97 
99 
100  predictedClassLabel = 0;
101 
102  if( classDistances.getSize() != numClasses || classLikelihoods.getSize() != numClasses ){
103  classDistances.resize(numClasses);
104  classLikelihoods.resize(numClasses);
105  }
106 
107  if( !trained ){
108  errorLog << "predict_(VectorFloat &x) - Mixture Models have not been trained!" << std::endl;
109  return false;
110  }
111 
112  if( x.getSize() != numInputDimensions ){
113  errorLog << "predict_(VectorFloat &x) - The size of the input vector (" << x.getSize() << ") does not match that of the number of features the model was trained with (" << numInputDimensions << ")." << std::endl;
114  return false;
115  }
116 
117  if( useScaling ){
118  for(UINT i=0; i<numInputDimensions; i++){
119  x[i] = grt_scale(x[i], ranges[i].minValue, ranges[i].maxValue, GMM_MIN_SCALE_VALUE, GMM_MAX_SCALE_VALUE);
120  }
121  }
122 
123  UINT bestIndex = 0;
124  maxLikelihood = 0;
125  bestDistance = 0;
126  Float sum = 0;
127  for(UINT k=0; k<numClasses; k++){
128  classDistances[k] = computeMixtureLikelihood(x,k);
129 
130  //cout << "K: " << k << " Dist: " << classDistances[k] << std::endl;
131  classLikelihoods[k] = classDistances[k];
132  sum += classLikelihoods[k];
133  if( classLikelihoods[k] > bestDistance ){
134  bestDistance = classLikelihoods[k];
135  bestIndex = k;
136  }
137  }
138 
139  //Normalize the likelihoods
140  for(unsigned int k=0; k<numClasses; k++){
141  classLikelihoods[k] /= sum;
142  }
143  maxLikelihood = classLikelihoods[bestIndex];
144 
145  if( useNullRejection ){
146 
147  //cout << "Dist: " << classDistances[bestIndex] << " RejectionThreshold: " << models[bestIndex].getRejectionThreshold() << std::endl;
148 
149  //If the best distance is below the modles rejection threshold then set the predicted class label as the best class label
150  //Otherwise set the predicted class label as the default null rejection class label of 0
151  if( classDistances[bestIndex] >= models[bestIndex].getNullRejectionThreshold() ){
152  predictedClassLabel = models[bestIndex].getClassLabel();
153  }else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
154  }else{
155  //Get the predicted class label
156  predictedClassLabel = models[bestIndex].getClassLabel();
157  }
158 
159  return true;
160 }
161 
162 bool GMM::train_(ClassificationData &trainingData){
163 
164  //Clear any old models
165  clear();
166 
167  if( trainingData.getNumSamples() == 0 ){
168  errorLog << "train_(ClassificationData &trainingData) - Training data is empty!" << std::endl;
169  return false;
170  }
171 
172  //Set the number of features and number of classes and resize the models buffer
173  numInputDimensions = trainingData.getNumDimensions();
174  numClasses = trainingData.getNumClasses();
175  models.resize(numClasses);
176 
177  if( numInputDimensions >= 6 ){
178  warningLog << "train_(ClassificationData &trainingData) - The number of features in your training data is high (" << numInputDimensions << "). The GMMClassifier does not work well with high dimensional data, you might get better results from one of the other classifiers." << std::endl;
179  }
180 
181  //Get the ranges of the training data if the training data is going to be scaled
182  ranges = trainingData.getRanges();
183  if( !trainingData.scale(GMM_MIN_SCALE_VALUE, GMM_MAX_SCALE_VALUE) ){
184  errorLog << "train_(ClassificationData &trainingData) - Failed to scale training data!" << std::endl;
185  return false;
186  }
187 
188  //Fit a Mixture Model to each class (independently)
189  for(UINT k=0; k<numClasses; k++){
190  UINT classLabel = trainingData.getClassTracker()[k].classLabel;
191  ClassificationData classData = trainingData.getClassData( classLabel );
192 
193  //Train the Mixture Model for this class
194  GaussianMixtureModels gaussianMixtureModel;
195  gaussianMixtureModel.setNumClusters( numMixtureModels );
196  gaussianMixtureModel.setMinChange( minChange );
197  gaussianMixtureModel.setMaxNumEpochs( maxIter );
198 
199  if( !gaussianMixtureModel.train( classData.getDataAsMatrixFloat() ) ){
200  errorLog << "train_(ClassificationData &trainingData) - Failed to train Mixture Model for class " << classLabel << std::endl;
201  return false;
202  }
203 
204  //Setup the model container
205  models[k].resize( numMixtureModels );
206  models[k].setClassLabel( classLabel );
207 
208  //Store the mixture model in the container
209  for(UINT j=0; j<numMixtureModels; j++){
210  models[k][j].mu = gaussianMixtureModel.getMu().getRowVector(j);
211  models[k][j].sigma = gaussianMixtureModel.getSigma()[j];
212 
213  //Compute the determinant and invSigma for the realtime prediction
214  LUDecomposition ludcmp( models[k][j].sigma );
215  if( !ludcmp.inverse( models[k][j].invSigma ) ){
216  models.clear();
217  errorLog << "train_(ClassificationData &trainingData) - Failed to invert Matrix for class " << classLabel << "!" << std::endl;
218  return false;
219  }
220  models[k][j].det = ludcmp.det();
221  }
222 
223  //Compute the normalize factor
224  models[k].recomputeNormalizationFactor();
225 
226  //Compute the rejection thresholds
227  Float mu = 0;
228  Float sigma = 0;
229  VectorFloat predictionResults(classData.getNumSamples(),0);
230  for(UINT i=0; i<classData.getNumSamples(); i++){
231  VectorFloat sample = classData[i].getSample();
232  predictionResults[i] = models[k].computeMixtureLikelihood( sample );
233  mu += predictionResults[i];
234  }
235 
236  //Update mu
237  mu /= Float( classData.getNumSamples() );
238 
239  //Calculate the standard deviation
240  for(UINT i=0; i<classData.getNumSamples(); i++)
241  sigma += grt_sqr( (predictionResults[i]-mu) );
242  sigma = grt_sqrt( sigma / (Float(classData.getNumSamples())-1.0) );
243  sigma = 0.2;
244 
245  //Set the models training mu and sigma
246  models[k].setTrainingMuAndSigma(mu,sigma);
247 
248  if( !models[k].recomputeNullRejectionThreshold(nullRejectionCoeff) && useNullRejection ){
249  warningLog << "train_(ClassificationData &trainingData) - Failed to recompute rejection threshold for class " << classLabel << " - the nullRjectionCoeff value is too high!" << std::endl;
250  }
251 
252  //cout << "Training Mu: " << mu << " TrainingSigma: " << sigma << " RejectionThreshold: " << models[k].getNullRejectionThreshold() << std::endl;
253  //models[k].printModelValues();
254  }
255 
256  //Reset the class labels
257  classLabels.resize(numClasses);
258  for(UINT k=0; k<numClasses; k++){
259  classLabels[k] = models[k].getClassLabel();
260  }
261 
262  //Resize the rejection thresholds
263  nullRejectionThresholds.resize(numClasses);
264  for(UINT k=0; k<numClasses; k++){
265  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
266  }
267 
268  //Flag that the models have been trained
269  trained = true;
270 
271  return true;
272 }
273 
274 Float GMM::computeMixtureLikelihood(const VectorFloat &x,const UINT k){
275  if( k >= numClasses ){
276  errorLog << "computeMixtureLikelihood(const VectorFloat x,const UINT k) - Invalid k value!" << std::endl;
277  return 0;
278  }
279  return models[k].computeMixtureLikelihood( x );
280 }
281 
282 bool GMM::save( std::fstream &file ) const{
283 
284  if( !trained ){
285  errorLog <<"saveGMMToFile(fstream &file) - The model has not been trained!" << std::endl;
286  return false;
287  }
288 
289  if( !file.is_open() )
290  {
291  errorLog <<"saveGMMToFile(fstream &file) - The file has not been opened!" << std::endl;
292  return false;
293  }
294 
295  //Write the header info
296  file << "GRT_GMM_MODEL_FILE_V2.0\n";
297 
298  //Write the classifier settings to the file
300  errorLog <<"save(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
301  return false;
302  }
303 
304  file << "NumMixtureModels: " << numMixtureModels << std::endl;
305 
306  if( trained ){
307  //Write each of the models
308  file << "Models:\n";
309  for(UINT k=0; k<numClasses; k++){
310  file << "ClassLabel: " << models[k].getClassLabel() << std::endl;
311  file << "K: " << models[k].getK() << std::endl;
312  file << "NormalizationFactor: " << models[k].getNormalizationFactor() << std::endl;
313  file << "TrainingMu: " << models[k].getTrainingMu() << std::endl;
314  file << "TrainingSigma: " << models[k].getTrainingSigma() << std::endl;
315  file << "NullRejectionThreshold: " << models[k].getNullRejectionThreshold() << std::endl;
316 
317  for(UINT index=0; index<models[k].getK(); index++){
318  file << "Determinant: " << models[k][index].det << std::endl;
319 
320  file << "Mu: ";
321  for(UINT j=0; j<models[k][index].mu.size(); j++) file << "\t" << models[k][index].mu[j];
322  file << std::endl;
323 
324  file << "Sigma:\n";
325  for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
326  for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
327  file << models[k][index].sigma[i][j];
328  if( j < models[k][index].sigma.getNumCols()-1 ) file << "\t";
329  }
330  file << std::endl;
331  }
332 
333  file << "InvSigma:\n";
334  for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
335  for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
336  file << models[k][index].invSigma[i][j];
337  if( j < models[k][index].invSigma.getNumCols()-1 ) file << "\t";
338  }
339  file << std::endl;
340  }
341  }
342 
343  file << std::endl;
344  }
345  }
346 
347  return true;
348 }
349 
350 bool GMM::load( std::fstream &file ){
351 
352  trained = false;
353  numInputDimensions = 0;
354  numClasses = 0;
355  models.clear();
356  classLabels.clear();
357 
358  if(!file.is_open())
359  {
360  errorLog << "load(fstream &file) - Could not open file to load model" << std::endl;
361  return false;
362  }
363 
364  std::string word;
365  file >> word;
366 
367  //Check to see if we should load a legacy file
368  if( word == "GRT_GMM_MODEL_FILE_V1.0" ){
369  return loadLegacyModelFromFile( file );
370  }
371 
372  //Find the file type header
373  if(word != "GRT_GMM_MODEL_FILE_V2.0"){
374  errorLog << "load(fstream &file) - Could not find Model File Header" << std::endl;
375  return false;
376  }
377 
378  //Load the base settings from the file
380  errorLog << "load(string filename) - Failed to load base settings from file!" << std::endl;
381  return false;
382  }
383 
384  file >> word;
385  if(word != "NumMixtureModels:"){
386  errorLog << "load(fstream &file) - Could not find NumMixtureModels" << std::endl;
387  return false;
388  }
389  file >> numMixtureModels;
390 
391  if( trained ){
392 
393  //Read the model header
394  file >> word;
395  if(word != "Models:"){
396  errorLog << "load(fstream &file) - Could not find the Models Header" << std::endl;
397  return false;
398  }
399 
400  //Resize the buffer
401  models.resize(numClasses);
402  classLabels.resize(numClasses);
403 
404  //Load each of the models
405  for(UINT k=0; k<numClasses; k++){
406  UINT classLabel = 0;
407  UINT K = 0;
408  Float normalizationFactor;
409  Float trainingMu;
410  Float trainingSigma;
411  Float rejectionThreshold;
412 
413  file >> word;
414  if(word != "ClassLabel:"){
415  errorLog << "load(fstream &file) - Could not find the ClassLabel for model " << k+1 << std::endl;
416  return false;
417  }
418  file >> classLabel;
419  models[k].setClassLabel( classLabel );
420  classLabels[k] = classLabel;
421 
422  file >> word;
423  if(word != "K:"){
424  errorLog << "load(fstream &file) - Could not find K for model " << k+1 << std::endl;
425  return false;
426  }
427  file >> K;
428 
429  file >> word;
430  if(word != "NormalizationFactor:"){
431  errorLog << "load(fstream &file) - Could not find NormalizationFactor for model " << k+1 << std::endl;
432  return false;
433  }
434  file >> normalizationFactor;
435  models[k].setNormalizationFactor(normalizationFactor);
436 
437  file >> word;
438  if(word != "TrainingMu:"){
439  errorLog << "load(fstream &file) - Could not find TrainingMu for model " << k+1 << std::endl;
440  return false;
441  }
442  file >> trainingMu;
443 
444  file >> word;
445  if(word != "TrainingSigma:"){
446  errorLog << "load(fstream &file) - Could not find TrainingSigma for model " << k+1 << std::endl;
447  return false;
448  }
449  file >> trainingSigma;
450 
451  //Set the training mu and sigma
452  models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
453 
454  file >> word;
455  if(word != "NullRejectionThreshold:"){
456  errorLog << "load(fstream &file) - Could not find NullRejectionThreshold for model " << k+1 << std::endl;
457  return false;
458  }
459  file >>rejectionThreshold;
460 
461  //Set the rejection threshold
462  models[k].setNullRejectionThreshold(rejectionThreshold);
463 
464  //Resize the buffer for the mixture models
465  models[k].resize(K);
466 
467  //Load the mixture models
468  for(UINT index=0; index<models[k].getK(); index++){
469 
470  //Resize the memory for the current mixture model
471  models[k][index].mu.resize( numInputDimensions );
472  models[k][index].sigma.resize( numInputDimensions, numInputDimensions );
473  models[k][index].invSigma.resize( numInputDimensions, numInputDimensions );
474 
475  file >> word;
476  if(word != "Determinant:"){
477  errorLog << "load(fstream &file) - Could not find the Determinant for model " << k+1 << std::endl;
478  return false;
479  }
480  file >> models[k][index].det;
481 
482 
483  file >> word;
484  if(word != "Mu:"){
485  errorLog << "load(fstream &file) - Could not find Mu for model " << k+1 << std::endl;
486  return false;
487  }
488  for(UINT j=0; j<models[k][index].mu.size(); j++){
489  file >> models[k][index].mu[j];
490  }
491 
492 
493  file >> word;
494  if(word != "Sigma:"){
495  errorLog << "load(fstream &file) - Could not find Sigma for model " << k+1 << std::endl;
496  return false;
497  }
498  for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
499  for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
500  file >> models[k][index].sigma[i][j];
501  }
502  }
503 
504  file >> word;
505  if(word != "InvSigma:"){
506  errorLog << "load(fstream &file) - Could not find InvSigma for model " << k+1 << std::endl;
507  return false;
508  }
509  for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
510  for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
511  file >> models[k][index].invSigma[i][j];
512  }
513  }
514 
515  }
516 
517  }
518 
519  //Set the null rejection thresholds
520  nullRejectionThresholds.resize(numClasses);
521  for(UINT k=0; k<numClasses; k++) {
522  models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
523  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
524  }
525 
526  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
527  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
528  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
529  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
530  }
531 
532  return true;
533 }
534 
535 bool GMM::clear(){
536 
537  //Clear the Classifier variables
539 
540  //Clear the GMM model
541  models.clear();
542 
543  return true;
544 }
545 
547 
548  if( trained ){
549  for(UINT k=0; k<numClasses; k++) {
550  models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
551  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
552  }
553  return true;
554  }
555  return false;
556 }
557 
559  return numMixtureModels;
560 }
561 
563  if( trained ){ return models; }
564  return Vector< MixtureModel >();
565 }
566 
568  if( K > 0 ){
569  numMixtureModels = K;
570  return true;
571  }
572  return false;
573 }
574 bool GMM::setMinChange(Float minChange){
575  if( minChange > 0 ){
576  this->minChange = minChange;
577  return true;
578  }
579  return false;
580 }
581 bool GMM::setMaxIter(UINT maxIter){
582  if( maxIter > 0 ){
583  this->maxIter = maxIter;
584  return true;
585  }
586  return false;
587 }
588 
589 bool GMM::loadLegacyModelFromFile( std::fstream &file ){
590 
591  std::string word;
592 
593  file >> word;
594  if(word != "NumFeatures:"){
595  errorLog << "load(fstream &file) - Could not find NumFeatures " << std::endl;
596  return false;
597  }
598  file >> numInputDimensions;
599 
600  file >> word;
601  if(word != "NumClasses:"){
602  errorLog << "load(fstream &file) - Could not find NumClasses" << std::endl;
603  return false;
604  }
605  file >> numClasses;
606 
607  file >> word;
608  if(word != "NumMixtureModels:"){
609  errorLog << "load(fstream &file) - Could not find NumMixtureModels" << std::endl;
610  return false;
611  }
612  file >> numMixtureModels;
613 
614  file >> word;
615  if(word != "MaxIter:"){
616  errorLog << "load(fstream &file) - Could not find MaxIter" << std::endl;
617  return false;
618  }
619  file >> maxIter;
620 
621  file >> word;
622  if(word != "MinChange:"){
623  errorLog << "load(fstream &file) - Could not find MinChange" << std::endl;
624  return false;
625  }
626  file >> minChange;
627 
628  file >> word;
629  if(word != "UseScaling:"){
630  errorLog << "load(fstream &file) - Could not find UseScaling" << std::endl;
631  return false;
632  }
633  file >> useScaling;
634 
635  file >> word;
636  if(word != "UseNullRejection:"){
637  errorLog << "load(fstream &file) - Could not find UseNullRejection" << std::endl;
638  return false;
639  }
640  file >> useNullRejection;
641 
642  file >> word;
643  if(word != "NullRejectionCoeff:"){
644  errorLog << "load(fstream &file) - Could not find NullRejectionCoeff" << std::endl;
645  return false;
646  }
647  file >> nullRejectionCoeff;
648 
650  if( useScaling ){
651  //Resize the ranges buffer
652  ranges.resize(numInputDimensions);
653 
654  file >> word;
655  if(word != "Ranges:"){
656  errorLog << "load(fstream &file) - Could not find the Ranges" << std::endl;
657  return false;
658  }
659  for(UINT n=0; n<ranges.size(); n++){
660  file >> ranges[n].minValue;
661  file >> ranges[n].maxValue;
662  }
663  }
664 
665  //Read the model header
666  file >> word;
667  if(word != "Models:"){
668  errorLog << "load(fstream &file) - Could not find the Models Header" << std::endl;
669  return false;
670  }
671 
672  //Resize the buffer
673  models.resize(numClasses);
674  classLabels.resize(numClasses);
675 
676  //Load each of the models
677  for(UINT k=0; k<numClasses; k++){
678  UINT classLabel = 0;
679  UINT K = 0;
680  Float normalizationFactor;
681  Float trainingMu;
682  Float trainingSigma;
683  Float rejectionThreshold;
684 
685  file >> word;
686  if(word != "ClassLabel:"){
687  errorLog << "load(fstream &file) - Could not find the ClassLabel for model " << k+1 << std::endl;
688  return false;
689  }
690  file >> classLabel;
691  models[k].setClassLabel( classLabel );
692  classLabels[k] = classLabel;
693 
694  file >> word;
695  if(word != "K:"){
696  errorLog << "load(fstream &file) - Could not find K for model " << k+1 << std::endl;
697  return false;
698  }
699  file >> K;
700 
701  file >> word;
702  if(word != "NormalizationFactor:"){
703  errorLog << "load(fstream &file) - Could not find NormalizationFactor for model " << k+1 << std::endl;
704  return false;
705  }
706  file >> normalizationFactor;
707  models[k].setNormalizationFactor(normalizationFactor);
708 
709  file >> word;
710  if(word != "TrainingMu:"){
711  errorLog << "load(fstream &file) - Could not find TrainingMu for model " << k+1 << std::endl;
712  return false;
713  }
714  file >> trainingMu;
715 
716  file >> word;
717  if(word != "TrainingSigma:"){
718  errorLog << "load(fstream &file) - Could not find TrainingSigma for model " << k+1 << std::endl;
719  return false;
720  }
721  file >> trainingSigma;
722 
723  //Set the training mu and sigma
724  models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
725 
726  file >> word;
727  if(word != "NullRejectionThreshold:"){
728  errorLog << "load(fstream &file) - Could not find NullRejectionThreshold for model " << k+1 << std::endl;
729  return false;
730  }
731  file >>rejectionThreshold;
732 
733  //Set the rejection threshold
734  models[k].setNullRejectionThreshold(rejectionThreshold);
735 
736  //Resize the buffer for the mixture models
737  models[k].resize(K);
738 
739  //Load the mixture models
740  for(UINT index=0; index<models[k].getK(); index++){
741 
742  //Resize the memory for the current mixture model
743  models[k][index].mu.resize( numInputDimensions );
744  models[k][index].sigma.resize( numInputDimensions, numInputDimensions );
745  models[k][index].invSigma.resize( numInputDimensions, numInputDimensions );
746 
747  file >> word;
748  if(word != "Determinant:"){
749  errorLog << "load(fstream &file) - Could not find the Determinant for model " << k+1 << std::endl;
750  return false;
751  }
752  file >> models[k][index].det;
753 
754 
755  file >> word;
756  if(word != "Mu:"){
757  errorLog << "load(fstream &file) - Could not find Mu for model " << k+1 << std::endl;
758  return false;
759  }
760  for(UINT j=0; j<models[k][index].mu.size(); j++){
761  file >> models[k][index].mu[j];
762  }
763 
764 
765  file >> word;
766  if(word != "Sigma:"){
767  errorLog << "load(fstream &file) - Could not find Sigma for model " << k+1 << std::endl;
768  return false;
769  }
770  for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
771  for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
772  file >> models[k][index].sigma[i][j];
773  }
774  }
775 
776  file >> word;
777  if(word != "InvSigma:"){
778  errorLog << "load(fstream &file) - Could not find InvSigma for model " << k+1 << std::endl;
779  return false;
780  }
781  for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
782  for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
783  file >> models[k][index].invSigma[i][j];
784  }
785  }
786 
787  }
788 
789  }
790 
791  //Set the null rejection thresholds
792  nullRejectionThresholds.resize(numClasses);
793  for(UINT k=0; k<numClasses; k++) {
794  models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
795  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
796  }
797 
798  //Flag that the models have been trained
799  trained = true;
800 
801  return true;
802 }
803 
804 GRT_END_NAMESPACE
bool saveBaseSettingsToFile(std::fstream &file) const
Definition: Classifier.cpp:256
#define DEFAULT_NULL_LIKELIHOOD_VALUE
Definition: Classifier.h:38
virtual bool train_(ClassificationData &trainingData)
Definition: GMM.cpp:162
Definition: GMM.h:49
virtual bool clear()
Definition: GMM.cpp:535
std::string getClassifierType() const
Definition: Classifier.cpp:161
Vector< ClassTracker > getClassTracker() const
virtual bool recomputeNullRejectionThresholds()
Definition: GMM.cpp:546
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:89
#define GMM_MIN_SCALE_VALUE
Definition: GMM.h:44
UINT getSize() const
Definition: Vector.h:191
Vector< MixtureModel > getModels()
Definition: GMM.cpp:562
bool setMinChange(const Float minChange)
Definition: MLBase.cpp:287
bool setNumMixtureModels(UINT K)
Definition: GMM.cpp:567
MatrixFloat getMu() const
This class implements the Gaussian Mixture Model Classifier algorithm. The Gaussian Mixture Model Cla...
UINT getNumSamples() const
virtual bool predict_(VectorFloat &inputVector)
Definition: GMM.cpp:98
virtual bool save(std::fstream &file) const
Definition: GMM.cpp:282
GMM(UINT numMixtureModels=2, bool useScaling=false, bool useNullRejection=false, Float nullRejectionCoeff=1.0, UINT maxIter=100, Float minChange=1.0e-5)
Definition: GMM.cpp:29
Vector< T > getRowVector(const unsigned int r) const
Definition: Matrix.h:171
UINT getNumMixtureModels()
Definition: GMM.cpp:558
bool copyBaseVariables(const Classifier *classifier)
Definition: Classifier.cpp:93
bool loadBaseSettingsFromFile(std::fstream &file)
Definition: Classifier.cpp:303
UINT getNumDimensions() const
UINT getNumClasses() const
Vector< MatrixFloat > getSigma() const
GMM & operator=(const GMM &rhs)
Definition: GMM.cpp:57
virtual ~GMM(void)
Definition: GMM.cpp:55
virtual bool load(std::fstream &file)
Definition: GMM.cpp:350
bool loadLegacyModelFromFile(std::fstream &file)
Definition: GMM.cpp:589
Vector< MinMax > getRanges() const
MatrixFloat getDataAsMatrixFloat() const
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: GMM.cpp:75
bool setMaxNumEpochs(const UINT maxNumEpochs)
Definition: MLBase.cpp:273
bool scale(const Float minTarget, const Float maxTarget)
virtual bool clear()
Definition: Classifier.cpp:142
bool setMinChange(Float minChange)
Definition: GMM.cpp:574
bool setNumClusters(const UINT numClusters)
Definition: Clusterer.cpp:266
bool setMaxIter(UINT maxIter)
Definition: GMM.cpp:581