21 #define GRT_DLL_EXPORTS
29 GMM::GMM(UINT numMixtureModels,
bool useScaling,
bool useNullRejection,Float nullRejectionCoeff,UINT maxIter,Float minChange){
31 classifierType = classType;
32 classifierMode = STANDARD_CLASSIFIER_MODE;
33 debugLog.setProceedingText(
"[DEBUG GMM]");
34 errorLog.setProceedingText(
"[ERROR GMM]");
35 warningLog.setProceedingText(
"[WARNING GMM]");
37 this->numMixtureModels = numMixtureModels;
38 this->useScaling = useScaling;
39 this->useNullRejection = useNullRejection;
40 this->nullRejectionCoeff = nullRejectionCoeff;
41 this->maxIter = maxIter;
42 this->minChange = minChange;
47 classifierType = classType;
48 classifierMode = STANDARD_CLASSIFIER_MODE;
49 debugLog.setProceedingText(
"[DEBUG GMM]");
50 errorLog.setProceedingText(
"[ERROR GMM]");
51 warningLog.setProceedingText(
"[WARNING GMM]");
60 this->numMixtureModels = rhs.numMixtureModels;
61 this->maxIter = rhs.maxIter;
62 this->minChange = rhs.minChange;
63 this->models = rhs.models;
65 this->debugLog = rhs.debugLog;
66 this->errorLog = rhs.errorLog;
67 this->warningLog = rhs.warningLog;
77 if( classifier == NULL )
return false;
81 GMM *ptr = (
GMM*)classifier;
83 this->numMixtureModels = ptr->numMixtureModels;
84 this->maxIter = ptr->maxIter;
85 this->minChange = ptr->minChange;
86 this->models = ptr->models;
88 this->debugLog = ptr->debugLog;
89 this->errorLog = ptr->errorLog;
90 this->warningLog = ptr->warningLog;
100 predictedClassLabel = 0;
102 if( classDistances.
getSize() != numClasses || classLikelihoods.
getSize() != numClasses ){
103 classDistances.
resize(numClasses);
104 classLikelihoods.
resize(numClasses);
108 errorLog <<
"predict_(VectorFloat &x) - Mixture Models have not been trained!" << std::endl;
112 if( x.
getSize() != numInputDimensions ){
113 errorLog <<
"predict_(VectorFloat &x) - The size of the input vector (" << x.
getSize() <<
") does not match that of the number of features the model was trained with (" << numInputDimensions <<
")." << std::endl;
118 for(UINT i=0; i<numInputDimensions; i++){
119 x[i] = grt_scale(x[i], ranges[i].minValue, ranges[i].maxValue,
GMM_MIN_SCALE_VALUE, GMM_MAX_SCALE_VALUE);
127 for(UINT k=0; k<numClasses; k++){
128 classDistances[k] = computeMixtureLikelihood(x,k);
131 classLikelihoods[k] = classDistances[k];
132 sum += classLikelihoods[k];
133 if( classLikelihoods[k] > bestDistance ){
134 bestDistance = classLikelihoods[k];
140 for(
unsigned int k=0; k<numClasses; k++){
141 classLikelihoods[k] /= sum;
143 maxLikelihood = classLikelihoods[bestIndex];
145 if( useNullRejection ){
151 if( classDistances[bestIndex] >= models[bestIndex].getNullRejectionThreshold() ){
152 predictedClassLabel = models[bestIndex].getClassLabel();
153 }
else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
156 predictedClassLabel = models[bestIndex].getClassLabel();
168 errorLog <<
"train_(ClassificationData &trainingData) - Training data is empty!" << std::endl;
175 models.
resize(numClasses);
177 if( numInputDimensions >= 6 ){
178 warningLog <<
"train_(ClassificationData &trainingData) - The number of features in your training data is high (" << numInputDimensions <<
"). The GMMClassifier does not work well with high dimensional data, you might get better results from one of the other classifiers." << std::endl;
184 errorLog <<
"train_(ClassificationData &trainingData) - Failed to scale training data!" << std::endl;
189 for(UINT k=0; k<numClasses; k++){
200 errorLog <<
"train_(ClassificationData &trainingData) - Failed to train Mixture Model for class " << classLabel << std::endl;
205 models[k].
resize( numMixtureModels );
206 models[k].setClassLabel( classLabel );
209 for(UINT j=0; j<numMixtureModels; j++){
211 models[k][j].sigma = gaussianMixtureModel.
getSigma()[j];
215 if( !ludcmp.inverse( models[k][j].invSigma ) ){
217 errorLog <<
"train_(ClassificationData &trainingData) - Failed to invert Matrix for class " << classLabel <<
"!" << std::endl;
220 models[k][j].det = ludcmp.det();
224 models[k].recomputeNormalizationFactor();
232 predictionResults[i] = models[k].computeMixtureLikelihood( sample );
233 mu += predictionResults[i];
241 sigma += grt_sqr( (predictionResults[i]-mu) );
242 sigma = grt_sqrt( sigma / (Float(classData.
getNumSamples())-1.0) );
246 models[k].setTrainingMuAndSigma(mu,sigma);
248 if( !models[k].recomputeNullRejectionThreshold(nullRejectionCoeff) && useNullRejection ){
249 warningLog <<
"train_(ClassificationData &trainingData) - Failed to recompute rejection threshold for class " << classLabel <<
" - the nullRjectionCoeff value is too high!" << std::endl;
257 classLabels.
resize(numClasses);
258 for(UINT k=0; k<numClasses; k++){
259 classLabels[k] = models[k].getClassLabel();
263 nullRejectionThresholds.
resize(numClasses);
264 for(UINT k=0; k<numClasses; k++){
265 nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
274 Float GMM::computeMixtureLikelihood(
const VectorFloat &x,
const UINT k){
275 if( k >= numClasses ){
276 errorLog <<
"computeMixtureLikelihood(const VectorFloat x,const UINT k) - Invalid k value!" << std::endl;
279 return models[k].computeMixtureLikelihood( x );
285 errorLog <<
"saveGMMToFile(fstream &file) - The model has not been trained!" << std::endl;
289 if( !file.is_open() )
291 errorLog <<
"saveGMMToFile(fstream &file) - The file has not been opened!" << std::endl;
296 file <<
"GRT_GMM_MODEL_FILE_V2.0\n";
300 errorLog <<
"save(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
304 file <<
"NumMixtureModels: " << numMixtureModels << std::endl;
309 for(UINT k=0; k<numClasses; k++){
310 file <<
"ClassLabel: " << models[k].getClassLabel() << std::endl;
311 file <<
"K: " << models[k].getK() << std::endl;
312 file <<
"NormalizationFactor: " << models[k].getNormalizationFactor() << std::endl;
313 file <<
"TrainingMu: " << models[k].getTrainingMu() << std::endl;
314 file <<
"TrainingSigma: " << models[k].getTrainingSigma() << std::endl;
315 file <<
"NullRejectionThreshold: " << models[k].getNullRejectionThreshold() << std::endl;
317 for(UINT index=0; index<models[k].getK(); index++){
318 file <<
"Determinant: " << models[k][index].det << std::endl;
321 for(UINT j=0; j<models[k][index].mu.size(); j++) file <<
"\t" << models[k][index].mu[j];
325 for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
326 for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
327 file << models[k][index].sigma[i][j];
328 if( j < models[k][index].sigma.getNumCols()-1 ) file <<
"\t";
333 file <<
"InvSigma:\n";
334 for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
335 for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
336 file << models[k][index].invSigma[i][j];
337 if( j < models[k][index].invSigma.getNumCols()-1 ) file <<
"\t";
353 numInputDimensions = 0;
360 errorLog <<
"load(fstream &file) - Could not open file to load model" << std::endl;
368 if( word ==
"GRT_GMM_MODEL_FILE_V1.0" ){
373 if(word !=
"GRT_GMM_MODEL_FILE_V2.0"){
374 errorLog <<
"load(fstream &file) - Could not find Model File Header" << std::endl;
380 errorLog <<
"load(string filename) - Failed to load base settings from file!" << std::endl;
385 if(word !=
"NumMixtureModels:"){
386 errorLog <<
"load(fstream &file) - Could not find NumMixtureModels" << std::endl;
389 file >> numMixtureModels;
395 if(word !=
"Models:"){
396 errorLog <<
"load(fstream &file) - Could not find the Models Header" << std::endl;
401 models.
resize(numClasses);
402 classLabels.
resize(numClasses);
405 for(UINT k=0; k<numClasses; k++){
408 Float normalizationFactor;
411 Float rejectionThreshold;
414 if(word !=
"ClassLabel:"){
415 errorLog <<
"load(fstream &file) - Could not find the ClassLabel for model " << k+1 << std::endl;
419 models[k].setClassLabel( classLabel );
420 classLabels[k] = classLabel;
424 errorLog <<
"load(fstream &file) - Could not find K for model " << k+1 << std::endl;
430 if(word !=
"NormalizationFactor:"){
431 errorLog <<
"load(fstream &file) - Could not find NormalizationFactor for model " << k+1 << std::endl;
434 file >> normalizationFactor;
435 models[k].setNormalizationFactor(normalizationFactor);
438 if(word !=
"TrainingMu:"){
439 errorLog <<
"load(fstream &file) - Could not find TrainingMu for model " << k+1 << std::endl;
445 if(word !=
"TrainingSigma:"){
446 errorLog <<
"load(fstream &file) - Could not find TrainingSigma for model " << k+1 << std::endl;
449 file >> trainingSigma;
452 models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
455 if(word !=
"NullRejectionThreshold:"){
456 errorLog <<
"load(fstream &file) - Could not find NullRejectionThreshold for model " << k+1 << std::endl;
459 file >>rejectionThreshold;
462 models[k].setNullRejectionThreshold(rejectionThreshold);
468 for(UINT index=0; index<models[k].getK(); index++){
471 models[k][index].mu.
resize( numInputDimensions );
472 models[k][index].sigma.
resize( numInputDimensions, numInputDimensions );
473 models[k][index].invSigma.
resize( numInputDimensions, numInputDimensions );
476 if(word !=
"Determinant:"){
477 errorLog <<
"load(fstream &file) - Could not find the Determinant for model " << k+1 << std::endl;
480 file >> models[k][index].det;
485 errorLog <<
"load(fstream &file) - Could not find Mu for model " << k+1 << std::endl;
488 for(UINT j=0; j<models[k][index].mu.size(); j++){
489 file >> models[k][index].mu[j];
494 if(word !=
"Sigma:"){
495 errorLog <<
"load(fstream &file) - Could not find Sigma for model " << k+1 << std::endl;
498 for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
499 for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
500 file >> models[k][index].sigma[i][j];
505 if(word !=
"InvSigma:"){
506 errorLog <<
"load(fstream &file) - Could not find InvSigma for model " << k+1 << std::endl;
509 for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
510 for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
511 file >> models[k][index].invSigma[i][j];
520 nullRejectionThresholds.
resize(numClasses);
521 for(UINT k=0; k<numClasses; k++) {
522 models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
523 nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
527 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
529 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
549 for(UINT k=0; k<numClasses; k++) {
550 models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
551 nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
559 return numMixtureModels;
563 if( trained ){
return models; }
569 numMixtureModels = K;
576 this->minChange = minChange;
583 this->maxIter = maxIter;
594 if(word !=
"NumFeatures:"){
595 errorLog <<
"load(fstream &file) - Could not find NumFeatures " << std::endl;
598 file >> numInputDimensions;
601 if(word !=
"NumClasses:"){
602 errorLog <<
"load(fstream &file) - Could not find NumClasses" << std::endl;
608 if(word !=
"NumMixtureModels:"){
609 errorLog <<
"load(fstream &file) - Could not find NumMixtureModels" << std::endl;
612 file >> numMixtureModels;
615 if(word !=
"MaxIter:"){
616 errorLog <<
"load(fstream &file) - Could not find MaxIter" << std::endl;
622 if(word !=
"MinChange:"){
623 errorLog <<
"load(fstream &file) - Could not find MinChange" << std::endl;
629 if(word !=
"UseScaling:"){
630 errorLog <<
"load(fstream &file) - Could not find UseScaling" << std::endl;
636 if(word !=
"UseNullRejection:"){
637 errorLog <<
"load(fstream &file) - Could not find UseNullRejection" << std::endl;
640 file >> useNullRejection;
643 if(word !=
"NullRejectionCoeff:"){
644 errorLog <<
"load(fstream &file) - Could not find NullRejectionCoeff" << std::endl;
647 file >> nullRejectionCoeff;
652 ranges.
resize(numInputDimensions);
655 if(word !=
"Ranges:"){
656 errorLog <<
"load(fstream &file) - Could not find the Ranges" << std::endl;
659 for(UINT n=0; n<ranges.size(); n++){
660 file >> ranges[n].minValue;
661 file >> ranges[n].maxValue;
667 if(word !=
"Models:"){
668 errorLog <<
"load(fstream &file) - Could not find the Models Header" << std::endl;
673 models.
resize(numClasses);
674 classLabels.
resize(numClasses);
677 for(UINT k=0; k<numClasses; k++){
680 Float normalizationFactor;
683 Float rejectionThreshold;
686 if(word !=
"ClassLabel:"){
687 errorLog <<
"load(fstream &file) - Could not find the ClassLabel for model " << k+1 << std::endl;
691 models[k].setClassLabel( classLabel );
692 classLabels[k] = classLabel;
696 errorLog <<
"load(fstream &file) - Could not find K for model " << k+1 << std::endl;
702 if(word !=
"NormalizationFactor:"){
703 errorLog <<
"load(fstream &file) - Could not find NormalizationFactor for model " << k+1 << std::endl;
706 file >> normalizationFactor;
707 models[k].setNormalizationFactor(normalizationFactor);
710 if(word !=
"TrainingMu:"){
711 errorLog <<
"load(fstream &file) - Could not find TrainingMu for model " << k+1 << std::endl;
717 if(word !=
"TrainingSigma:"){
718 errorLog <<
"load(fstream &file) - Could not find TrainingSigma for model " << k+1 << std::endl;
721 file >> trainingSigma;
724 models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
727 if(word !=
"NullRejectionThreshold:"){
728 errorLog <<
"load(fstream &file) - Could not find NullRejectionThreshold for model " << k+1 << std::endl;
731 file >>rejectionThreshold;
734 models[k].setNullRejectionThreshold(rejectionThreshold);
740 for(UINT index=0; index<models[k].getK(); index++){
743 models[k][index].mu.
resize( numInputDimensions );
744 models[k][index].sigma.
resize( numInputDimensions, numInputDimensions );
745 models[k][index].invSigma.
resize( numInputDimensions, numInputDimensions );
748 if(word !=
"Determinant:"){
749 errorLog <<
"load(fstream &file) - Could not find the Determinant for model " << k+1 << std::endl;
752 file >> models[k][index].det;
757 errorLog <<
"load(fstream &file) - Could not find Mu for model " << k+1 << std::endl;
760 for(UINT j=0; j<models[k][index].mu.size(); j++){
761 file >> models[k][index].mu[j];
766 if(word !=
"Sigma:"){
767 errorLog <<
"load(fstream &file) - Could not find Sigma for model " << k+1 << std::endl;
770 for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
771 for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
772 file >> models[k][index].sigma[i][j];
777 if(word !=
"InvSigma:"){
778 errorLog <<
"load(fstream &file) - Could not find InvSigma for model " << k+1 << std::endl;
781 for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
782 for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
783 file >> models[k][index].invSigma[i][j];
792 nullRejectionThresholds.
resize(numClasses);
793 for(UINT k=0; k<numClasses; k++) {
794 models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
795 nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
bool saveBaseSettingsToFile(std::fstream &file) const
#define DEFAULT_NULL_LIKELIHOOD_VALUE
virtual bool train_(ClassificationData &trainingData)
std::string getClassifierType() const
Vector< ClassTracker > getClassTracker() const
virtual bool recomputeNullRejectionThresholds()
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
virtual bool train(ClassificationData trainingData)
#define GMM_MIN_SCALE_VALUE
Vector< MixtureModel > getModels()
bool setMinChange(const Float minChange)
bool setNumMixtureModels(UINT K)
MatrixFloat getMu() const
This class implements the Gaussian Mixture Model Classifier algorithm. The Gaussian Mixture Model Cla...
UINT getNumSamples() const
virtual bool predict_(VectorFloat &inputVector)
virtual bool save(std::fstream &file) const
GMM(UINT numMixtureModels=2, bool useScaling=false, bool useNullRejection=false, Float nullRejectionCoeff=1.0, UINT maxIter=100, Float minChange=1.0e-5)
Vector< T > getRowVector(const unsigned int r) const
UINT getNumMixtureModels()
bool copyBaseVariables(const Classifier *classifier)
bool loadBaseSettingsFromFile(std::fstream &file)
UINT getNumDimensions() const
UINT getNumClasses() const
Vector< MatrixFloat > getSigma() const
GMM & operator=(const GMM &rhs)
virtual bool load(std::fstream &file)
bool loadLegacyModelFromFile(std::fstream &file)
Vector< MinMax > getRanges() const
MatrixFloat getDataAsMatrixFloat() const
virtual bool deepCopyFrom(const Classifier *classifier)
bool setMaxNumEpochs(const UINT maxNumEpochs)
bool scale(const Float minTarget, const Float maxTarget)
bool setMinChange(Float minChange)
bool setNumClusters(const UINT numClusters)
bool setMaxIter(UINT maxIter)