21 #define GRT_DLL_EXPORTS 27 const std::string GMM::id =
"GMM";
33 GMM::GMM(UINT numMixtureModels,
bool useScaling,
bool useNullRejection,Float nullRejectionCoeff,UINT maxNumEpochs,Float minChange) :
Classifier(
GMM::
getId() )
35 classifierMode = STANDARD_CLASSIFIER_MODE;
36 this->numMixtureModels = numMixtureModels;
37 this->useScaling = useScaling;
38 this->useNullRejection = useNullRejection;
39 this->nullRejectionCoeff = nullRejectionCoeff;
40 this->maxNumEpochs = maxNumEpochs;
41 this->minChange = minChange;
47 classifierMode = STANDARD_CLASSIFIER_MODE;
56 this->numMixtureModels = rhs.numMixtureModels;
57 this->models = rhs.models;
67 if( classifier == NULL )
return false;
71 GMM *ptr = (
GMM*)classifier;
73 this->numMixtureModels = ptr->numMixtureModels;
74 this->models = ptr->models;
84 predictedClassLabel = 0;
86 if( classDistances.
getSize() != numClasses || classLikelihoods.
getSize() != numClasses ){
87 classDistances.
resize(numClasses);
88 classLikelihoods.
resize(numClasses);
92 errorLog << __GRT_LOG__ <<
" Mixture Models have not been trained!" << std::endl;
96 if( x.
getSize() != numInputDimensions ){
97 errorLog << __GRT_LOG__ <<
" The size of the input vector (" << x.
getSize() <<
") does not match that of the number of features the model was trained with (" << numInputDimensions <<
")." << std::endl;
102 for(UINT i=0; i<numInputDimensions; i++){
103 x[i] = grt_scale(x[i], ranges[i].minValue, ranges[i].maxValue,
GMM_MIN_SCALE_VALUE, GMM_MAX_SCALE_VALUE);
111 for(UINT k=0; k<numClasses; k++){
112 classDistances[k] = computeMixtureLikelihood(x,k);
115 classLikelihoods[k] = classDistances[k];
116 sum += classLikelihoods[k];
117 if( classLikelihoods[k] > bestDistance ){
118 bestDistance = classLikelihoods[k];
124 for(
unsigned int k=0; k<numClasses; k++){
125 classLikelihoods[k] /= sum;
127 maxLikelihood = classLikelihoods[bestIndex];
129 if( useNullRejection ){
135 if( classDistances[bestIndex] >= models[bestIndex].getNullRejectionThreshold() ){
136 predictedClassLabel = models[bestIndex].getClassLabel();
137 }
else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
140 predictedClassLabel = models[bestIndex].getClassLabel();
152 errorLog << __GRT_LOG__ <<
" Training data is empty!" << std::endl;
160 models.
resize(numClasses);
162 if( numInputDimensions >= 6 ){
163 warningLog << __GRT_LOG__ <<
" The number of features in your training data is high (" << numInputDimensions <<
"). The GMMClassifier does not work well with high dimensional data, you might get better results from one of the other classifiers." << std::endl;
175 errorLog << __GRT_LOG__ <<
" Failed to scale training data!" << std::endl;
180 if( useValidationSet ){
181 validationData = trainingData.
split( 100-validationSetSize );
185 for(UINT k=0; k<numClasses; k++){
197 errorLog << __GRT_LOG__ <<
" Failed to train Mixture Model for class " << classLabel << std::endl;
202 models[k].
resize( numMixtureModels );
203 models[k].setClassLabel( classLabel );
206 for(UINT j=0; j<numMixtureModels; j++){
208 models[k][j].sigma = gaussianMixtureModel.
getSigma()[j];
212 if( !ludcmp.inverse( models[k][j].invSigma ) ){
214 errorLog << __GRT_LOG__ <<
" Failed to invert Matrix for class " << classLabel <<
"!" << std::endl;
217 models[k][j].det = ludcmp.det();
221 models[k].recomputeNormalizationFactor();
229 predictionResults[i] = models[k].computeMixtureLikelihood( sample );
230 mu += predictionResults[i];
238 sigma += grt_sqr( (predictionResults[i]-mu) );
239 sigma = grt_sqrt( sigma / (Float(classData.
getNumSamples())-1.0) );
243 models[k].setTrainingMuAndSigma(mu,sigma);
245 if( !models[k].recomputeNullRejectionThreshold(nullRejectionCoeff) && useNullRejection ){
246 warningLog << __GRT_LOG__ <<
" Failed to recompute rejection threshold for class " << classLabel <<
" - the nullRjectionCoeff value is too high!" << std::endl;
254 classLabels.
resize(numClasses);
255 for(UINT k=0; k<numClasses; k++){
256 classLabels[k] = models[k].getClassLabel();
260 nullRejectionThresholds.
resize(numClasses);
261 for(UINT k=0; k<numClasses; k++){
262 nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
270 trainingSetAccuracy = 0;
271 validationSetAccuracy = 0;
274 bool scalingState = useScaling;
279 errorLog << __GRT_LOG__ <<
" Failed to compute training set accuracy! Failed to fully train model!" << std::endl;
283 if( useValidationSet ){
287 errorLog << __GRT_LOG__ <<
" Failed to compute validation set accuracy! Failed to fully train model!" << std::endl;
293 trainingLog <<
"Training set accuracy: " << trainingSetAccuracy << std::endl;
295 if( useValidationSet ){
296 trainingLog <<
"Validation set accuracy: " << validationSetAccuracy << std::endl;
300 useScaling = scalingState;
305 Float GMM::computeMixtureLikelihood(
const VectorFloat &x,
const UINT k){
306 if( k >= numClasses ){
307 errorLog << __GRT_LOG__ <<
" Invalid k value!" << std::endl;
310 return models[k].computeMixtureLikelihood( x );
316 errorLog << __GRT_LOG__ <<
" The model has not been trained!" << std::endl;
320 if( !file.is_open() )
322 errorLog << __GRT_LOG__ <<
" The file has not been opened!" << std::endl;
327 file <<
"GRT_GMM_MODEL_FILE_V2.0\n";
331 errorLog << __GRT_LOG__ <<
" Failed to save classifier base settings to file!" << std::endl;
335 file <<
"NumMixtureModels: " << numMixtureModels << std::endl;
340 for(UINT k=0; k<numClasses; k++){
341 file <<
"ClassLabel: " << models[k].getClassLabel() << std::endl;
342 file <<
"K: " << models[k].getK() << std::endl;
343 file <<
"NormalizationFactor: " << models[k].getNormalizationFactor() << std::endl;
344 file <<
"TrainingMu: " << models[k].getTrainingMu() << std::endl;
345 file <<
"TrainingSigma: " << models[k].getTrainingSigma() << std::endl;
346 file <<
"NullRejectionThreshold: " << models[k].getNullRejectionThreshold() << std::endl;
348 for(UINT index=0; index<models[k].getK(); index++){
349 file <<
"Determinant: " << models[k][index].det << std::endl;
352 for(UINT j=0; j<models[k][index].mu.size(); j++) file <<
"\t" << models[k][index].mu[j];
356 for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
357 for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
358 file << models[k][index].sigma[i][j];
359 if( j < models[k][index].sigma.getNumCols()-1 ) file <<
"\t";
364 file <<
"InvSigma:\n";
365 for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
366 for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
367 file << models[k][index].invSigma[i][j];
368 if( j < models[k][index].invSigma.getNumCols()-1 ) file <<
"\t";
384 numInputDimensions = 0;
391 errorLog << __GRT_LOG__ <<
" Could not open file to load model" << std::endl;
399 if( word ==
"GRT_GMM_MODEL_FILE_V1.0" ){
404 if(word !=
"GRT_GMM_MODEL_FILE_V2.0"){
405 errorLog << __GRT_LOG__ <<
" Could not find Model File Header" << std::endl;
411 errorLog << __GRT_LOG__ <<
" Failed to load base settings from file!" << std::endl;
416 if(word !=
"NumMixtureModels:"){
417 errorLog << __GRT_LOG__ <<
" Could not find NumMixtureModels" << std::endl;
420 file >> numMixtureModels;
426 if(word !=
"Models:"){
427 errorLog << __GRT_LOG__ <<
" Could not find the Models Header" << std::endl;
432 models.
resize(numClasses);
433 classLabels.
resize(numClasses);
436 for(UINT k=0; k<numClasses; k++){
439 Float normalizationFactor;
442 Float rejectionThreshold;
445 if(word !=
"ClassLabel:"){
446 errorLog << __GRT_LOG__ <<
" Could not find the ClassLabel for model " << k+1 << std::endl;
450 models[k].setClassLabel( classLabel );
451 classLabels[k] = classLabel;
455 errorLog << __GRT_LOG__ <<
" Could not find K for model " << k+1 << std::endl;
461 if(word !=
"NormalizationFactor:"){
462 errorLog << __GRT_LOG__ <<
" Could not find NormalizationFactor for model " << k+1 << std::endl;
465 file >> normalizationFactor;
466 models[k].setNormalizationFactor(normalizationFactor);
469 if(word !=
"TrainingMu:"){
470 errorLog << __GRT_LOG__ <<
" Could not find TrainingMu for model " << k+1 << std::endl;
476 if(word !=
"TrainingSigma:"){
477 errorLog << __GRT_LOG__ <<
" Could not find TrainingSigma for model " << k+1 << std::endl;
480 file >> trainingSigma;
483 models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
486 if(word !=
"NullRejectionThreshold:"){
487 errorLog << __GRT_LOG__ <<
" Could not find NullRejectionThreshold for model " << k+1 << std::endl;
490 file >>rejectionThreshold;
493 models[k].setNullRejectionThreshold(rejectionThreshold);
499 for(UINT index=0; index<models[k].getK(); index++){
502 models[k][index].mu.
resize( numInputDimensions );
503 models[k][index].sigma.
resize( numInputDimensions, numInputDimensions );
504 models[k][index].invSigma.
resize( numInputDimensions, numInputDimensions );
507 if(word !=
"Determinant:"){
508 errorLog << __GRT_LOG__ <<
" Could not find the Determinant for model " << k+1 << std::endl;
511 file >> models[k][index].det;
516 errorLog << __GRT_LOG__ <<
" Could not find Mu for model " << k+1 << std::endl;
519 for(UINT j=0; j<models[k][index].mu.
getSize(); j++){
520 file >> models[k][index].mu[j];
525 if(word !=
"Sigma:"){
526 errorLog << __GRT_LOG__ <<
" Could not find Sigma for model " << k+1 << std::endl;
529 for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
530 for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
531 file >> models[k][index].sigma[i][j];
536 if(word !=
"InvSigma:"){
537 errorLog << __GRT_LOG__ <<
" Could not find InvSigma for model " << k+1 << std::endl;
540 for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
541 for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
542 file >> models[k][index].invSigma[i][j];
551 nullRejectionThresholds.
resize(numClasses);
552 for(UINT k=0; k<numClasses; k++) {
553 models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
554 nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
558 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
560 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
580 for(UINT k=0; k<numClasses; k++) {
581 models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
582 nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
590 return numMixtureModels;
594 if( trained ){
return models; }
600 numMixtureModels = K;
606 bool GMM::setMaxIter(UINT maxNumEpochs){
615 if(word !=
"NumFeatures:"){
616 errorLog << __GRT_LOG__ <<
" Could not find NumFeatures " << std::endl;
619 file >> numInputDimensions;
622 if(word !=
"NumClasses:"){
623 errorLog << __GRT_LOG__ <<
" Could not find NumClasses" << std::endl;
629 if(word !=
"NumMixtureModels:"){
630 errorLog << __GRT_LOG__ <<
" Could not find NumMixtureModels" << std::endl;
633 file >> numMixtureModels;
636 if(word !=
"MaxIter:"){
637 errorLog << __GRT_LOG__ <<
" Could not find MaxIter" << std::endl;
640 file >> maxNumEpochs;
643 if(word !=
"MinChange:"){
644 errorLog << __GRT_LOG__ <<
" Could not find MinChange" << std::endl;
650 if(word !=
"UseScaling:"){
651 errorLog << __GRT_LOG__ <<
" Could not find UseScaling" << std::endl;
657 if(word !=
"UseNullRejection:"){
658 errorLog << __GRT_LOG__ <<
" Could not find UseNullRejection" << std::endl;
661 file >> useNullRejection;
664 if(word !=
"NullRejectionCoeff:"){
665 errorLog << __GRT_LOG__ <<
" Could not find NullRejectionCoeff" << std::endl;
668 file >> nullRejectionCoeff;
673 ranges.
resize(numInputDimensions);
676 if(word !=
"Ranges:"){
677 errorLog << __GRT_LOG__ <<
" Could not find the Ranges" << std::endl;
680 for(UINT n=0; n<ranges.size(); n++){
681 file >> ranges[n].minValue;
682 file >> ranges[n].maxValue;
688 if(word !=
"Models:"){
689 errorLog << __GRT_LOG__ <<
" Could not find the Models Header" << std::endl;
694 models.
resize(numClasses);
695 classLabels.
resize(numClasses);
698 for(UINT k=0; k<numClasses; k++){
701 Float normalizationFactor;
704 Float rejectionThreshold;
707 if(word !=
"ClassLabel:"){
708 errorLog << __GRT_LOG__ <<
" Could not find the ClassLabel for model " << k+1 << std::endl;
712 models[k].setClassLabel( classLabel );
713 classLabels[k] = classLabel;
717 errorLog << __GRT_LOG__ <<
" Could not find K for model " << k+1 << std::endl;
723 if(word !=
"NormalizationFactor:"){
724 errorLog << __GRT_LOG__ <<
" Could not find NormalizationFactor for model " << k+1 << std::endl;
727 file >> normalizationFactor;
728 models[k].setNormalizationFactor(normalizationFactor);
731 if(word !=
"TrainingMu:"){
732 errorLog << __GRT_LOG__ <<
" Could not find TrainingMu for model " << k+1 << std::endl;
738 if(word !=
"TrainingSigma:"){
739 errorLog << __GRT_LOG__ <<
" Could not find TrainingSigma for model " << k+1 << std::endl;
742 file >> trainingSigma;
745 models[k].setTrainingMuAndSigma(trainingMu, trainingSigma);
748 if(word !=
"NullRejectionThreshold:"){
749 errorLog << __GRT_LOG__ <<
" Could not find NullRejectionThreshold for model " << k+1 << std::endl;
752 file >>rejectionThreshold;
755 models[k].setNullRejectionThreshold(rejectionThreshold);
761 for(UINT index=0; index<models[k].getK(); index++){
764 models[k][index].mu.
resize( numInputDimensions );
765 models[k][index].sigma.
resize( numInputDimensions, numInputDimensions );
766 models[k][index].invSigma.
resize( numInputDimensions, numInputDimensions );
769 if(word !=
"Determinant:"){
770 errorLog << __GRT_LOG__ <<
" Could not find the Determinant for model " << k+1 << std::endl;
773 file >> models[k][index].det;
778 errorLog << __GRT_LOG__ <<
" Could not find Mu for model " << k+1 << std::endl;
781 for(UINT j=0; j<models[k][index].mu.
getSize(); j++){
782 file >> models[k][index].mu[j];
787 if(word !=
"Sigma:"){
788 errorLog << __GRT_LOG__ <<
" Could not find Sigma for model " << k+1 << std::endl;
791 for(UINT i=0; i<models[k][index].sigma.getNumRows(); i++){
792 for(UINT j=0; j<models[k][index].sigma.getNumCols(); j++){
793 file >> models[k][index].sigma[i][j];
798 if(word !=
"InvSigma:"){
799 errorLog << __GRT_LOG__ <<
" Could not find InvSigma for model " << k+1 << std::endl;
802 for(UINT i=0; i<models[k][index].invSigma.getNumRows(); i++){
803 for(UINT j=0; j<models[k][index].invSigma.getNumCols(); j++){
804 file >> models[k][index].invSigma[i][j];
813 nullRejectionThresholds.
resize(numClasses);
814 for(UINT k=0; k<numClasses; k++) {
815 models[k].recomputeNullRejectionThreshold(nullRejectionCoeff);
816 nullRejectionThresholds[k] = models[k].getNullRejectionThreshold();
bool saveBaseSettingsToFile(std::fstream &file) const
static std::string getId()
#define DEFAULT_NULL_LIKELIHOOD_VALUE
virtual bool train_(ClassificationData &trainingData)
std::string getClassifierType() const
Vector< ClassTracker > getClassTracker() const
virtual bool recomputeNullRejectionThresholds()
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
virtual bool train(ClassificationData trainingData)
#define GMM_MIN_SCALE_VALUE
bool setNumMixtureModels(const UINT K)
Vector< MixtureModel > getModels()
bool setMinChange(const Float minChange)
virtual bool computeAccuracy(const ClassificationData &data, Float &accuracy)
MatrixFloat getMu() const
UINT getNumSamples() const
virtual bool predict_(VectorFloat &inputVector)
virtual bool save(std::fstream &file) const
GMM(UINT numMixtureModels=2, bool useScaling=false, bool useNullRejection=false, Float nullRejectionCoeff=1.0, UINT maxIter=100, Float minChange=1.0e-5)
Vector< T > getRowVector(const unsigned int r) const
UINT getNumMixtureModels()
bool copyBaseVariables(const Classifier *classifier)
bool loadBaseSettingsFromFile(std::fstream &file)
UINT getNumDimensions() const
UINT getNumClasses() const
Vector< MatrixFloat > getSigma() const
GMM & operator=(const GMM &rhs)
virtual bool load(std::fstream &file)
bool loadLegacyModelFromFile(std::fstream &file)
Vector< MinMax > getRanges() const
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
MatrixFloat getDataAsMatrixFloat() const
virtual bool deepCopyFrom(const Classifier *classifier)
bool setMaxNumEpochs(const UINT maxNumEpochs)
bool scale(const Float minTarget, const Float maxTarget)
This is the main base class that all GRT Classification algorithms should inherit from...
bool setNumRestarts(const UINT numRestarts)
bool setNumClusters(const UINT numClusters)