GestureRecognitionToolkit  Version: 0.1.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
GMM.cpp
1 /*
2  GRT MIT License
3  Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5  Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6  and associated documentation files (the "Software"), to deal in the Software without restriction,
7  including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9  subject to the following conditions:
10 
11  The above copyright notice and this permission notice shall be included in all copies or substantial
12  portions of the Software.
13 
14  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15  LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16  IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17  WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18  SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 
20  */
21 #include "GMM.h"
22 
23 GRT_BEGIN_NAMESPACE
24 
25 //Register the GMM module with the Classifier base class
26 RegisterClassifierModule< GMM > GMM::registerModule("GMM");
27 
28 GMM::GMM(UINT numMixtureModels,bool useScaling,bool useNullRejection,Float nullRejectionCoeff,UINT maxIter,Float minChange){
29  classType = "GMM";
30  classifierType = classType;
31  classifierMode = STANDARD_CLASSIFIER_MODE;
32  debugLog.setProceedingText("[DEBUG GMM]");
33  errorLog.setProceedingText("[ERROR GMM]");
34  warningLog.setProceedingText("[WARNING GMM]");
35 
36  this->numMixtureModels = numMixtureModels;
37  this->useScaling = useScaling;
38  this->useNullRejection = useNullRejection;
39  this->nullRejectionCoeff = nullRejectionCoeff;
40  this->maxIter = maxIter;
41  this->minChange = minChange;
42 }
43 
44 GMM::GMM(const GMM &rhs){
45  classType = "GMM";
46  classifierType = classType;
47  classifierMode = STANDARD_CLASSIFIER_MODE;
48  debugLog.setProceedingText("[DEBUG GMM]");
49  errorLog.setProceedingText("[ERROR GMM]");
50  warningLog.setProceedingText("[WARNING GMM]");
51  *this = rhs;
52 }
53 
55 
56 GMM& GMM::operator=(const GMM &rhs){
57  if( this != &rhs ){
58 
59  this->numMixtureModels = rhs.numMixtureModels;
60  this->maxIter = rhs.maxIter;
61  this->minChange = rhs.minChange;
62  this->models = rhs.models;
63 
64  this->debugLog = rhs.debugLog;
65  this->errorLog = rhs.errorLog;
66  this->warningLog = rhs.warningLog;
67 
68  //Classifier variables
69  copyBaseVariables( (Classifier*)&rhs );
70  }
71  return *this;
72 }
73 
74 bool GMM::deepCopyFrom(const Classifier *classifier){
75 
76  if( classifier == NULL ) return false;
77 
78  if( this->getClassifierType() == classifier->getClassifierType() ){
79 
80  GMM *ptr = (GMM*)classifier;
81  //Clone the GMM values
82  this->numMixtureModels = ptr->numMixtureModels;
83  this->maxIter = ptr->maxIter;
84  this->minChange = ptr->minChange;
85  this->models = ptr->models;
86 
87  this->debugLog = ptr->debugLog;
88  this->errorLog = ptr->errorLog;
89  this->warningLog = ptr->warningLog;
90 
91  //Clone the classifier variables
92  return copyBaseVariables( classifier );
93  }
94  return false;
95 }
96 
98 
99  predictedClassLabel = 0;
100 
101  if( classDistances.getSize() != numClasses || classLikelihoods.getSize() != numClasses ){
102  classDistances.resize(numClasses);
103  classLikelihoods.resize(numClasses);
104  }
105 
106  if( !trained ){
107  errorLog << "predict_(VectorFloat &x) - Mixture Models have not been trained!" << std::endl;
108  return false;
109  }
110 
111  if( x.getSize() != numInputDimensions ){
112  errorLog << "predict_(VectorFloat &x) - The size of the input vector (" << x.getSize() << ") does not match that of the number of features the model was trained with (" << numInputDimensions << ")." << std::endl;
113  return false;
114  }
115 
116  if( useScaling ){
117  for(UINT i=0; i<numInputDimensions; i++){
118  x[i] = grt_scale(x[i], ranges[i].minValue, ranges[i].maxValue, GMM_MIN_SCALE_VALUE, GMM_MAX_SCALE_VALUE);
119  }
120  }
121 
122  UINT bestIndex = 0;
123  maxLikelihood = 0;
124  bestDistance = 0;
125  Float sum = 0;
126  for(UINT k=0; k<numClasses; k++){
127  classDistances[k] = computeMixtureLikelihood(x,k);
128 
129  //cout << "K: " << k << " Dist: " << classDistances[k] << std::endl;
130  classLikelihoods[k] = classDistances[k];
131  sum += classLikelihoods[k];
132  if( classLikelihoods[k] > bestDistance ){
133  bestDistance = classLikelihoods[k];
134  bestIndex = k;
135  }
136  }
137 
138  //Normalize the likelihoods
139  for(unsigned int k=0; k<numClasses; k++){
140  classLikelihoods[k] /= sum;
141  }
142  maxLikelihood = classLikelihoods[bestIndex];
143 
144  if( useNullRejection ){
145 
146  //cout << "Dist: " << classDistances[bestIndex] << " RejectionThreshold: " << models[bestIndex].getRejectionThreshold() << std::endl;
147 
148  //If the best distance is below the modles rejection threshold then set the predicted class label as the best class label
149  //Otherwise set the predicted class label as the default null rejection class label of 0
150  if( classDistances[bestIndex] >= models[bestIndex].getNullRejectionThreshold() ){
151  predictedClassLabel = models[bestIndex].getClassLabel();
152  }else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
153  }else{
154  //Get the predicted class label
155  predictedClassLabel = models[bestIndex].getClassLabel();
156  }
157 
158  return true;
159 }
160 
161 bool GMM::train_(ClassificationData &trainingData){
162 
163  //Clear any old models
164  clear();
165 
166  if( trainingData.getNumSamples() == 0 ){
167  errorLog << "train_(ClassificationData &trainingData) - Training data is empty!" << std::endl;
168  return false;
169  }
170 
171  //Set the number of features and number of classes and resize the models buffer
172  numInputDimensions = trainingData.getNumDimensions();
173  numClasses = trainingData.getNumClasses();
174  models.resize(numClasses);
175 
176  if( numInputDimensions >= 6 ){
177  warningLog << "train_(ClassificationData &trainingData) - The number of features in your training data is high (" << numInputDimensions << "). The GMMClassifier does not work well with high dimensional data, you might get better results from one of the other classifiers." << std::endl;
178  }
179 
180  //Get the ranges of the training data if the training data is going to be scaled
181  ranges = trainingData.getRanges();
182  if( !trainingData.scale(GMM_MIN_SCALE_VALUE, GMM_MAX_SCALE_VALUE) ){
183  errorLog << "train_(ClassificationData &trainingData) - Failed to scale training data!" << std::endl;
184  return false;
185  }
186 
187  //Fit a Mixture Model to each class (independently)
188  for(UINT k=0; k<numClasses; k++){
189  UINT classLabel = trainingData.getClassTracker()[k].classLabel;
190  ClassificationData classData = trainingData.getClassData( classLabel );
191 
192  //Train the Mixture Model for this class
193  GaussianMixtureModels gaussianMixtureModel;
194  gaussianMixtureModel.setNumClusters( numMixtureModels );
195  gaussianMixtureModel.setMinChange( minChange );
196  gaussianMixtureModel.setMaxNumEpochs( maxIter );
197 
198  if( !gaussianMixtureModel.train( classData.getDataAsMatrixFloat() ) ){
199  errorLog << "train_(ClassificationData &trainingData) - Failed to train Mixture Model for class " << classLabel << std::endl;
200  return false;
201  }
202 
203  //Setup the model container
204  models[k].resize( numMixtureModels );
205  models[k].setClassLabel( classLabel );
206 
207  //Store the mixture model in the container
208  for(UINT j=0; j<numMixtureModels; j++){
209  models[k][j].mu = gaussianMixtureModel.getMu().getRowVector(j);
210  models[k][j].sigma = gaussianMixtureModel.getSigma()[j];
211 
212  //Compute the determinant and invSigma for the realtime prediction
213  LUDecomposition ludcmp( models[k][j].sigma );
214  if( !ludcmp.inverse( models[k][j].invSigma ) ){
215  models.clear();
216  errorLog << "train_(ClassificationData &trainingData) - Failed to invert Matrix for class " << classLabel << "!" << std::endl;
217  return false;
218  }
219  models[k][j].det = ludcmp.det();
220  }
221 
222  //Compute the normalize factor
223  models[k].recomputeNormalizationFactor();
224 
225  //Compute the rejection thresholds
226  Float mu = 0;
227  Float sigma = 0;
228  VectorFloat predictionResults(classData.getNumSamples(),0);
229  for(UINT i=0; i<classData.getNumSamples(); i++){
230  VectorFloat sample = classData[i].getSample();
231  predictionResults[i] = models[k].computeMixtureLikelihood( sample );
232  mu += predictionResults[i];
233  }
234 
235  //Update mu
236  mu /= Float( classData.getNumSamples() );
237 
238  //Calculate the standard deviation
239  for(UINT i=0; i<classData.getNumSamples(); i++)
240  sigma += grt_sqr( (predictionResults[i]-mu) );
241  sigma = grt_sqrt( sigma / (Float(classData.getNumSamples())-1.0) );
242  sigma = 0.2;
243 
244  //Set the models training mu and sigma
245  models[k].setTrainingMuAndSigma(mu,sigma);
246 
247  if( !models[k].recomputeNullRejectionThreshold(nullRejectionCoeff) && useNullRejection ){
248  warningLog << "train_(ClassificationData &trainingData) - Failed to recompute rejection threshold for class " << classLabel << " - the nullRjectionCoeff value is too high!" << std::endl;
249  }
250 
251  //cout << "Training Mu: " << mu << " TrainingSigma: " << sigma << " RejectionThreshold: " << models[k].getNullRejectionThreshold() << std::endl;
252  //models[k].printModelValues();
253  }
254 
255  //Reset the class labels
256  classLabels.resize(numClasses);
257  for(UINT k=0; k<numClasses; k++){
258  classLabels[k] = models[k].getClassLabel();
259  }
260 
261  //Resize the rejection thresholds
262  nullRejectionThresholds.resize(numClasses);
263  for(UINT k=0; k<numClasses; k++){
264  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
265  }
266 
267  //Flag that the models have been trained
268  trained = true;
269 
270  return true;
271 }
272 
273 Float GMM::computeMixtureLikelihood(const VectorFloat &x,const UINT k){
274  if( k >= numClasses ){
275  errorLog << "computeMixtureLikelihood(const VectorFloat x,const UINT k) - Invalid k value!" << std::endl;
276  return 0;
277  }
278  return models[k].computeMixtureLikelihood( x );
279 }
280 
281 bool GMM::saveModelToFile( std::fstream &file ) const{
282 
283  if( !trained ){
284  errorLog <<"saveGMMToFile(fstream &file) - The model has not been trained!" << std::endl;
285  return false;
286  }
287 
288  if( !file.is_open() )
289  {
290  errorLog <<"saveGMMToFile(fstream &file) - The file has not been opened!" << std::endl;
291  return false;
292  }
293 
294  //Write the header info
295  file << "GRT_GMM_MODEL_FILE_V2.0\n";
296 
297  //Write the classifier settings to the file
299  errorLog <<"saveModelToFile(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
300  return false;
301  }
302 
303  file << "NumMixtureModels: " << numMixtureModels << std::endl;
304 
305  if( trained ){
306  //Write each of the models
307  file << "Models:\n";
308  for(UINT k=0; k<numClasses; k++){
309  file << "ClassLabel: " << models[k].getClassLabel() << std::endl;
310  file << "K: " << models[k].getK() << std::endl;
311  file << "NormalizationFactor: " << models[k].getNormalizationFactor() << std::endl;
312  file << "TrainingMu: " << models[k].getTrainingMu() << std::endl;
313  file << "TrainingSigma: " << models[k].getTrainingSigma() << std::endl;
314  file << "NullRejectionThreshold: " << models[k].getNullRejectionThreshold() << std::endl;
315 
316  for(UINT index=0; index<models[k].getK(); index++){
317  file << "Determinant: " << models[k][index].det << std::endl;
318 
319  file << "Mu: ";
320  for(UINT j=0; j<models[k][index].mu.size(); j++) file << "\t" << models[k][index].mu[j];
321  file << std::endl;
322 
323  file << "Sigma:\n";
324  for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
325  for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
326  file << models[k][index].sigma[i][j];
327  if( j < models[k][index].sigma.getNumCols()-1 ) file << "\t";
328  }
329  file << std::endl;
330  }
331 
332  file << "InvSigma:\n";
333  for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
334  for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
335  file << models[k][index].invSigma[i][j];
336  if( j < models[k][index].invSigma.getNumCols()-1 ) file << "\t";
337  }
338  file << std::endl;
339  }
340  }
341 
342  file << std::endl;
343  }
344  }
345 
346  return true;
347 }
348 
349 bool GMM::loadModelFromFile( std::fstream &file ){
350 
351  trained = false;
352  numInputDimensions = 0;
353  numClasses = 0;
354  models.clear();
355  classLabels.clear();
356 
357  if(!file.is_open())
358  {
359  errorLog << "loadModelFromFile(fstream &file) - Could not open file to load model" << std::endl;
360  return false;
361  }
362 
363  std::string word;
364  file >> word;
365 
366  //Check to see if we should load a legacy file
367  if( word == "GRT_GMM_MODEL_FILE_V1.0" ){
368  return loadLegacyModelFromFile( file );
369  }
370 
371  //Find the file type header
372  if(word != "GRT_GMM_MODEL_FILE_V2.0"){
373  errorLog << "loadModelFromFile(fstream &file) - Could not find Model File Header" << std::endl;
374  return false;
375  }
376 
377  //Load the base settings from the file
379  errorLog << "loadModelFromFile(string filename) - Failed to load base settings from file!" << std::endl;
380  return false;
381  }
382 
383  file >> word;
384  if(word != "NumMixtureModels:"){
385  errorLog << "loadModelFromFile(fstream &file) - Could not find NumMixtureModels" << std::endl;
386  return false;
387  }
388  file >> numMixtureModels;
389 
390  if( trained ){
391 
392  //Read the model header
393  file >> word;
394  if(word != "Models:"){
395  errorLog << "loadModelFromFile(fstream &file) - Could not find the Models Header" << std::endl;
396  return false;
397  }
398 
399  //Resize the buffer
400  models.resize(numClasses);
401  classLabels.resize(numClasses);
402 
403  //Load each of the models
404  for(UINT k=0; k<numClasses; k++){
405  UINT classLabel = 0;
406  UINT K = 0;
407  Float normalizationFactor;
408  Float trainingMu;
409  Float trainingSigma;
410  Float rejectionThreshold;
411 
412  file >> word;
413  if(word != "ClassLabel:"){
414  errorLog << "loadModelFromFile(fstream &file) - Could not find the ClassLabel for model " << k+1 << std::endl;
415  return false;
416  }
417  file >> classLabel;
418  models[k].setClassLabel( classLabel );
419  classLabels[k] = classLabel;
420 
421  file >> word;
422  if(word != "K:"){
423  errorLog << "loadModelFromFile(fstream &file) - Could not find K for model " << k+1 << std::endl;
424  return false;
425  }
426  file >> K;
427 
428  file >> word;
429  if(word != "NormalizationFactor:"){
430  errorLog << "loadModelFromFile(fstream &file) - Could not find NormalizationFactor for model " << k+1 << std::endl;
431  return false;
432  }
433  file >> normalizationFactor;
434  models[k].setNormalizationFactor(normalizationFactor);
435 
436  file >> word;
437  if(word != "TrainingMu:"){
438  errorLog << "loadModelFromFile(fstream &file) - Could not find TrainingMu for model " << k+1 << std::endl;
439  return false;
440  }
441  file >> trainingMu;
442 
443  file >> word;
444  if(word != "TrainingSigma:"){
445  errorLog << "loadModelFromFile(fstream &file) - Could not find TrainingSigma for model " << k+1 << std::endl;
446  return false;
447  }
448  file >> trainingSigma;
449 
450  //Set the training mu and sigma
451  models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
452 
453  file >> word;
454  if(word != "NullRejectionThreshold:"){
455  errorLog << "loadModelFromFile(fstream &file) - Could not find NullRejectionThreshold for model " << k+1 << std::endl;
456  return false;
457  }
458  file >>rejectionThreshold;
459 
460  //Set the rejection threshold
461  models[k].setNullRejectionThreshold(rejectionThreshold);
462 
463  //Resize the buffer for the mixture models
464  models[k].resize(K);
465 
466  //Load the mixture models
467  for(UINT index=0; index<models[k].getK(); index++){
468 
469  //Resize the memory for the current mixture model
470  models[k][index].mu.resize( numInputDimensions );
471  models[k][index].sigma.resize( numInputDimensions, numInputDimensions );
472  models[k][index].invSigma.resize( numInputDimensions, numInputDimensions );
473 
474  file >> word;
475  if(word != "Determinant:"){
476  errorLog << "loadModelFromFile(fstream &file) - Could not find the Determinant for model " << k+1 << std::endl;
477  return false;
478  }
479  file >> models[k][index].det;
480 
481 
482  file >> word;
483  if(word != "Mu:"){
484  errorLog << "loadModelFromFile(fstream &file) - Could not find Mu for model " << k+1 << std::endl;
485  return false;
486  }
487  for(UINT j=0; j<models[k][index].mu.size(); j++){
488  file >> models[k][index].mu[j];
489  }
490 
491 
492  file >> word;
493  if(word != "Sigma:"){
494  errorLog << "loadModelFromFile(fstream &file) - Could not find Sigma for model " << k+1 << std::endl;
495  return false;
496  }
497  for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
498  for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
499  file >> models[k][index].sigma[i][j];
500  }
501  }
502 
503  file >> word;
504  if(word != "InvSigma:"){
505  errorLog << "loadModelFromFile(fstream &file) - Could not find InvSigma for model " << k+1 << std::endl;
506  return false;
507  }
508  for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
509  for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
510  file >> models[k][index].invSigma[i][j];
511  }
512  }
513 
514  }
515 
516  }
517 
518  //Set the null rejection thresholds
519  nullRejectionThresholds.resize(numClasses);
520  for(UINT k=0; k<numClasses; k++) {
521  models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
522  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
523  }
524 
525  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
526  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
527  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
528  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
529  }
530 
531  return true;
532 }
533 
534 bool GMM::clear(){
535 
536  //Clear the Classifier variables
538 
539  //Clear the GMM model
540  models.clear();
541 
542  return true;
543 }
544 
546 
547  if( trained ){
548  for(UINT k=0; k<numClasses; k++) {
549  models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
550  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
551  }
552  return true;
553  }
554  return false;
555 }
556 
558  return numMixtureModels;
559 }
560 
562  if( trained ){ return models; }
563  return Vector< MixtureModel >();
564 }
565 
567  if( K > 0 ){
568  numMixtureModels = K;
569  return true;
570  }
571  return false;
572 }
573 bool GMM::setMinChange(Float minChange){
574  if( minChange > 0 ){
575  this->minChange = minChange;
576  return true;
577  }
578  return false;
579 }
580 bool GMM::setMaxIter(UINT maxIter){
581  if( maxIter > 0 ){
582  this->maxIter = maxIter;
583  return true;
584  }
585  return false;
586 }
587 
588 bool GMM::loadLegacyModelFromFile( std::fstream &file ){
589 
590  std::string word;
591 
592  file >> word;
593  if(word != "NumFeatures:"){
594  errorLog << "loadModelFromFile(fstream &file) - Could not find NumFeatures " << std::endl;
595  return false;
596  }
597  file >> numInputDimensions;
598 
599  file >> word;
600  if(word != "NumClasses:"){
601  errorLog << "loadModelFromFile(fstream &file) - Could not find NumClasses" << std::endl;
602  return false;
603  }
604  file >> numClasses;
605 
606  file >> word;
607  if(word != "NumMixtureModels:"){
608  errorLog << "loadModelFromFile(fstream &file) - Could not find NumMixtureModels" << std::endl;
609  return false;
610  }
611  file >> numMixtureModels;
612 
613  file >> word;
614  if(word != "MaxIter:"){
615  errorLog << "loadModelFromFile(fstream &file) - Could not find MaxIter" << std::endl;
616  return false;
617  }
618  file >> maxIter;
619 
620  file >> word;
621  if(word != "MinChange:"){
622  errorLog << "loadModelFromFile(fstream &file) - Could not find MinChange" << std::endl;
623  return false;
624  }
625  file >> minChange;
626 
627  file >> word;
628  if(word != "UseScaling:"){
629  errorLog << "loadModelFromFile(fstream &file) - Could not find UseScaling" << std::endl;
630  return false;
631  }
632  file >> useScaling;
633 
634  file >> word;
635  if(word != "UseNullRejection:"){
636  errorLog << "loadModelFromFile(fstream &file) - Could not find UseNullRejection" << std::endl;
637  return false;
638  }
639  file >> useNullRejection;
640 
641  file >> word;
642  if(word != "NullRejectionCoeff:"){
643  errorLog << "loadModelFromFile(fstream &file) - Could not find NullRejectionCoeff" << std::endl;
644  return false;
645  }
646  file >> nullRejectionCoeff;
647 
649  if( useScaling ){
650  //Resize the ranges buffer
651  ranges.resize(numInputDimensions);
652 
653  file >> word;
654  if(word != "Ranges:"){
655  errorLog << "loadModelFromFile(fstream &file) - Could not find the Ranges" << std::endl;
656  return false;
657  }
658  for(UINT n=0; n<ranges.size(); n++){
659  file >> ranges[n].minValue;
660  file >> ranges[n].maxValue;
661  }
662  }
663 
664  //Read the model header
665  file >> word;
666  if(word != "Models:"){
667  errorLog << "loadModelFromFile(fstream &file) - Could not find the Models Header" << std::endl;
668  return false;
669  }
670 
671  //Resize the buffer
672  models.resize(numClasses);
673  classLabels.resize(numClasses);
674 
675  //Load each of the models
676  for(UINT k=0; k<numClasses; k++){
677  UINT classLabel = 0;
678  UINT K = 0;
679  Float normalizationFactor;
680  Float trainingMu;
681  Float trainingSigma;
682  Float rejectionThreshold;
683 
684  file >> word;
685  if(word != "ClassLabel:"){
686  errorLog << "loadModelFromFile(fstream &file) - Could not find the ClassLabel for model " << k+1 << std::endl;
687  return false;
688  }
689  file >> classLabel;
690  models[k].setClassLabel( classLabel );
691  classLabels[k] = classLabel;
692 
693  file >> word;
694  if(word != "K:"){
695  errorLog << "loadModelFromFile(fstream &file) - Could not find K for model " << k+1 << std::endl;
696  return false;
697  }
698  file >> K;
699 
700  file >> word;
701  if(word != "NormalizationFactor:"){
702  errorLog << "loadModelFromFile(fstream &file) - Could not find NormalizationFactor for model " << k+1 << std::endl;
703  return false;
704  }
705  file >> normalizationFactor;
706  models[k].setNormalizationFactor(normalizationFactor);
707 
708  file >> word;
709  if(word != "TrainingMu:"){
710  errorLog << "loadModelFromFile(fstream &file) - Could not find TrainingMu for model " << k+1 << std::endl;
711  return false;
712  }
713  file >> trainingMu;
714 
715  file >> word;
716  if(word != "TrainingSigma:"){
717  errorLog << "loadModelFromFile(fstream &file) - Could not find TrainingSigma for model " << k+1 << std::endl;
718  return false;
719  }
720  file >> trainingSigma;
721 
722  //Set the training mu and sigma
723  models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
724 
725  file >> word;
726  if(word != "NullRejectionThreshold:"){
727  errorLog << "loadModelFromFile(fstream &file) - Could not find NullRejectionThreshold for model " << k+1 << std::endl;
728  return false;
729  }
730  file >>rejectionThreshold;
731 
732  //Set the rejection threshold
733  models[k].setNullRejectionThreshold(rejectionThreshold);
734 
735  //Resize the buffer for the mixture models
736  models[k].resize(K);
737 
738  //Load the mixture models
739  for(UINT index=0; index<models[k].getK(); index++){
740 
741  //Resize the memory for the current mixture model
742  models[k][index].mu.resize( numInputDimensions );
743  models[k][index].sigma.resize( numInputDimensions, numInputDimensions );
744  models[k][index].invSigma.resize( numInputDimensions, numInputDimensions );
745 
746  file >> word;
747  if(word != "Determinant:"){
748  errorLog << "loadModelFromFile(fstream &file) - Could not find the Determinant for model " << k+1 << std::endl;
749  return false;
750  }
751  file >> models[k][index].det;
752 
753 
754  file >> word;
755  if(word != "Mu:"){
756  errorLog << "loadModelFromFile(fstream &file) - Could not find Mu for model " << k+1 << std::endl;
757  return false;
758  }
759  for(UINT j=0; j<models[k][index].mu.size(); j++){
760  file >> models[k][index].mu[j];
761  }
762 
763 
764  file >> word;
765  if(word != "Sigma:"){
766  errorLog << "loadModelFromFile(fstream &file) - Could not find Sigma for model " << k+1 << std::endl;
767  return false;
768  }
769  for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
770  for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
771  file >> models[k][index].sigma[i][j];
772  }
773  }
774 
775  file >> word;
776  if(word != "InvSigma:"){
777  errorLog << "loadModelFromFile(fstream &file) - Could not find InvSigma for model " << k+1 << std::endl;
778  return false;
779  }
780  for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
781  for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
782  file >> models[k][index].invSigma[i][j];
783  }
784  }
785 
786  }
787 
788  }
789 
790  //Set the null rejection thresholds
791  nullRejectionThresholds.resize(numClasses);
792  for(UINT k=0; k<numClasses; k++) {
793  models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
794  nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
795  }
796 
797  //Flag that the models have been trained
798  trained = true;
799 
800  return true;
801 }
802 
803 GRT_END_NAMESPACE
bool saveBaseSettingsToFile(std::fstream &file) const
Definition: Classifier.cpp:255
#define DEFAULT_NULL_LIKELIHOOD_VALUE
Definition: Classifier.h:38
virtual bool train_(ClassificationData &trainingData)
Definition: GMM.cpp:161
Definition: GMM.h:49
virtual bool loadModelFromFile(std::fstream &file)
Definition: GMM.cpp:349
virtual bool clear()
Definition: GMM.cpp:534
std::string getClassifierType() const
Definition: Classifier.cpp:160
Vector< ClassTracker > getClassTracker() const
virtual bool recomputeNullRejectionThresholds()
Definition: GMM.cpp:545
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:88
#define GMM_MIN_SCALE_VALUE
Definition: GMM.h:44
Vector< MixtureModel > getModels()
Definition: GMM.cpp:561
bool setMinChange(const Float minChange)
Definition: MLBase.cpp:282
bool setNumMixtureModels(UINT K)
Definition: GMM.cpp:566
unsigned int getSize() const
Definition: Vector.h:193
virtual bool saveModelToFile(std::fstream &file) const
Definition: GMM.cpp:281
MatrixFloat getMu() const
This class implements the Gaussian Mixture Model Classifier algorithm. The Gaussian Mixture Model Cla...
UINT getNumSamples() const
virtual bool predict_(VectorFloat &inputVector)
Definition: GMM.cpp:97
GMM(UINT numMixtureModels=2, bool useScaling=false, bool useNullRejection=false, Float nullRejectionCoeff=1.0, UINT maxIter=100, Float minChange=1.0e-5)
Definition: GMM.cpp:28
Vector< T > getRowVector(const unsigned int r) const
Definition: Matrix.h:171
UINT getNumMixtureModels()
Definition: GMM.cpp:557
bool copyBaseVariables(const Classifier *classifier)
Definition: Classifier.cpp:92
bool loadBaseSettingsFromFile(std::fstream &file)
Definition: Classifier.cpp:302
UINT getNumDimensions() const
UINT getNumClasses() const
Vector< MatrixFloat > getSigma() const
GMM & operator=(const GMM &rhs)
Definition: GMM.cpp:56
virtual ~GMM(void)
Definition: GMM.cpp:54
bool loadLegacyModelFromFile(std::fstream &file)
Definition: GMM.cpp:588
Vector< MinMax > getRanges() const
MatrixFloat getDataAsMatrixFloat() const
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: GMM.cpp:74
bool setMaxNumEpochs(const UINT maxNumEpochs)
Definition: MLBase.cpp:268
bool scale(const Float minTarget, const Float maxTarget)
virtual bool clear()
Definition: Classifier.cpp:141
bool setMinChange(Float minChange)
Definition: GMM.cpp:573
bool setNumClusters(const UINT numClusters)
Definition: Clusterer.cpp:265
bool setMaxIter(UINT maxIter)
Definition: GMM.cpp:580