28 ANBC::ANBC(
bool useScaling,
bool useNullRejection,Float nullRejectionCoeff)
30 this->useScaling = useScaling;
31 this->useNullRejection = useNullRejection;
32 this->nullRejectionCoeff = nullRejectionCoeff;
33 supportsNullRejection =
true;
34 weightsDataSet =
false;
36 classifierType = classType;
37 classifierMode = STANDARD_CLASSIFIER_MODE;
38 debugLog.setProceedingText(
"[DEBUG ANBC]");
39 errorLog.setProceedingText(
"[ERROR ANBC]");
40 trainingLog.setProceedingText(
"[TRAINING ANBC]");
41 warningLog.setProceedingText(
"[WARNING ANBC]");
46 classifierType = classType;
47 classifierMode = STANDARD_CLASSIFIER_MODE;
48 debugLog.setProceedingText(
"[DEBUG ANBC]");
49 errorLog.setProceedingText(
"[ERROR ANBC]");
50 trainingLog.setProceedingText(
"[TRAINING ANBC]");
51 warningLog.setProceedingText(
"[WARNING ANBC]");
62 this->weightsDataSet = rhs.weightsDataSet;
63 this->weightsData = rhs.weightsData;
64 this->models = rhs.models;
74 if( classifier == NULL )
return false;
80 this->weightsDataSet = ptr->weightsDataSet;
81 this->weightsData = ptr->weightsData;
82 this->models = ptr->models;
100 errorLog <<
"train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl;
104 if( weightsDataSet ){
106 errorLog <<
"train_(ClassificationData &trainingData) - The number of dimensions in the weights data (" << weightsData.
getNumDimensions() <<
") is not equal to the number of dimensions of the training data (" << N <<
")" << std::endl;
111 numInputDimensions = N;
120 trainingData.
scale(0, 1);
124 for(UINT k=0; k<numClasses; k++){
130 classLabels[k] = classLabel;
134 if( weightsDataSet ){
135 bool weightsFound =
false;
137 if( weightsData[i].getClassLabel() == classLabel ){
138 weights = weightsData[i].getSample();
145 errorLog <<
"train_(ClassificationData &trainingData) - Failed to find the weights for class " << classLabel << std::endl;
150 for(UINT j=0; j<numInputDimensions; j++) weights[j] = 1.0;
158 for(UINT i=0; i<data.getNumRows(); i++){
159 for(UINT j=0; j<data.getNumCols(); j++){
160 data[i][j] = classData[i][j];
165 models[k].gamma = nullRejectionCoeff;
166 if( !models[k].
train( classLabel, data, weights ) ){
167 errorLog <<
"train_(ClassificationData &trainingData) - Failed to train model for class: " << classLabel << std::endl;
170 if( models[k].N == 0 ){
171 errorLog <<
"train_(ClassificationData &trainingData) - N == 0!" << std::endl;
175 for(UINT j=0; j<numInputDimensions; j++){
176 if( models[k].sigma[j] == 0 ){
177 errorLog <<
"train_(ClassificationData &trainingData) - The standard deviation of column " << j+1 <<
" is zero! Check the training data" << std::endl;
189 nullRejectionThresholds.
resize(numClasses);
190 for(UINT k=0; k<numClasses; k++) {
191 nullRejectionThresholds[k] = models[k].threshold;
202 errorLog <<
"predict_(VectorFloat &inputVector) - ANBC Model Not Trained!" << std::endl;
206 predictedClassLabel = 0;
207 maxLikelihood = -10000;
209 if( !trained )
return false;
211 if( inputVector.size() != numInputDimensions ){
212 errorLog <<
"predict_(VectorFloat &inputVector) - The size of the input vector (" << inputVector.size() <<
") does not match the num features in the model (" << numInputDimensions << std::endl;
217 for(UINT n=0; n<numInputDimensions; n++){
218 inputVector[n] =
scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue,
MIN_SCALE_VALUE, MAX_SCALE_VALUE);
222 if( classLikelihoods.size() != numClasses ) classLikelihoods.
resize(numClasses,0);
223 if( classDistances.size() != numClasses ) classDistances.
resize(numClasses,0);
225 Float classLikelihoodsSum = 0;
226 Float minDist = -99e+99;
227 for(UINT k=0; k<numClasses; k++){
228 classDistances[k] = models[k].predict( inputVector );
231 classLikelihoods[k] = classDistances[k];
234 if( grt_isinf(classLikelihoods[k]) || grt_isnan(classLikelihoods[k]) ){
235 classLikelihoods[k] = 0;
237 classLikelihoods[k] = grt_exp( classLikelihoods[k] );
238 classLikelihoodsSum += classLikelihoods[k];
241 if( classDistances[k] > minDist ){
242 minDist = classDistances[k];
243 predictedClassLabel = k;
249 if( classLikelihoodsSum == 0 ){
250 predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
256 for(UINT k=0; k<numClasses; k++){
257 classLikelihoods[k] /= classLikelihoodsSum;
259 maxLikelihood = classLikelihoods[predictedClassLabel];
261 if( useNullRejection ){
263 if( minDist >= models[predictedClassLabel].threshold ) predictedClassLabel = models[predictedClassLabel].classLabel;
264 else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
265 }
else predictedClassLabel = models[predictedClassLabel].classLabel;
273 if( nullRejectionThresholds.size() != numClasses )
274 nullRejectionThresholds.
resize(numClasses);
275 for(UINT k=0; k<numClasses; k++) {
276 models[k].recomputeThresholdValue(nullRejectionCoeff);
277 nullRejectionThresholds[k] = models[k].threshold;
304 errorLog <<
"saveModelToFile(fstream &file) - The file is not open!" << std::endl;
309 file<<
"GRT_ANBC_MODEL_FILE_V2.0\n";
313 errorLog <<
"saveModelToFile(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
319 for(UINT k=0; k<numClasses; k++){
320 file <<
"*************_MODEL_*************\n";
321 file <<
"Model_ID: " << k+1 << std::endl;
322 file <<
"N: " << models[k].N << std::endl;
323 file <<
"ClassLabel: " << models[k].classLabel << std::endl;
324 file <<
"Threshold: " << models[k].threshold << std::endl;
325 file <<
"Gamma: " << models[k].gamma << std::endl;
326 file <<
"TrainingMu: " << models[k].trainingMu << std::endl;
327 file <<
"TrainingSigma: " << models[k].trainingSigma << std::endl;
330 for(UINT j=0; j<models[k].N; j++){
331 file <<
"\t" << models[k].mu[j];
335 for(UINT j=0; j<models[k].N; j++){
336 file <<
"\t" << models[k].sigma[j];
340 for(UINT j=0; j<models[k].N; j++){
341 file <<
"\t" << models[k].weights[j];
352 numInputDimensions = 0;
359 errorLog <<
"loadModelFromFile(string filename) - Could not open file to load model" << std::endl;
367 if( word ==
"GRT_ANBC_MODEL_FILE_V1.0" ){
372 if(word !=
"GRT_ANBC_MODEL_FILE_V2.0"){
373 errorLog <<
"loadModelFromFile(string filename) - Could not find Model File Header" << std::endl;
379 errorLog <<
"loadModelFromFile(string filename) - Failed to load base settings from file!" << std::endl;
386 models.
resize(numClasses);
389 for(UINT k=0; k<numClasses; k++){
392 if(word !=
"*************_MODEL_*************"){
393 errorLog <<
"loadModelFromFile(string filename) - Could not find header for the "<<k+1<<
"th model" << std::endl;
398 if(word !=
"Model_ID:"){
399 errorLog <<
"loadModelFromFile(string filename) - Could not find model ID for the "<<k+1<<
"th model" << std::endl;
405 errorLog <<
"ANBC: Model ID does not match the current class ID for the "<<k+1<<
"th model" << std::endl;
411 errorLog <<
"ANBC: Could not find N for the "<<k+1<<
"th model" << std::endl;
417 if(word !=
"ClassLabel:"){
418 errorLog <<
"loadModelFromFile(string filename) - Could not find ClassLabel for the "<<k+1<<
"th model" << std::endl;
421 file >> models[k].classLabel;
422 classLabels[k] = models[k].classLabel;
425 if(word !=
"Threshold:"){
426 errorLog <<
"loadModelFromFile(string filename) - Could not find the threshold for the "<<k+1<<
"th model" << std::endl;
429 file >> models[k].threshold;
432 if(word !=
"Gamma:"){
433 errorLog <<
"loadModelFromFile(string filename) - Could not find the gamma parameter for the "<<k+1<<
"th model" << std::endl;
436 file >> models[k].gamma;
439 if(word !=
"TrainingMu:"){
440 errorLog <<
"loadModelFromFile(string filename) - Could not find the training mu parameter for the "<<k+1<<
"th model" << std::endl;
443 file >> models[k].trainingMu;
446 if(word !=
"TrainingSigma:"){
447 errorLog <<
"loadModelFromFile(string filename) - Could not find the training sigma parameter for the "<<k+1<<
"th model" << std::endl;
450 file >> models[k].trainingSigma;
453 models[k].mu.
resize(numInputDimensions);
454 models[k].sigma.
resize(numInputDimensions);
455 models[k].weights.
resize(numInputDimensions);
460 errorLog <<
"loadModelFromFile(string filename) - Could not find the Mu vector for the "<<k+1<<
"th model" << std::endl;
465 for(UINT j=0; j<models[k].N; j++){
468 models[k].mu[j] = value;
472 if(word !=
"Sigma:"){
473 errorLog <<
"loadModelFromFile(string filename) - Could not find the Sigma vector for the "<<k+1<<
"th model" << std::endl;
478 for(UINT j=0; j<models[k].N; j++){
481 models[k].sigma[j] = value;
485 if(word !=
"Weights:"){
486 errorLog <<
"loadModelFromFile(string filename) - Could not find the Weights vector for the "<<k+1<<
"th model" << std::endl;
491 for(UINT j=0; j<models[k].N; j++){
494 models[k].weights[j] = value;
503 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
505 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
513 return nullRejectionThresholds;
518 if( nullRejectionCoeff > 0 ){
519 this->nullRejectionCoeff = nullRejectionCoeff;
529 weightsDataSet =
true;
530 this->weightsData = weightsData;
541 if(word !=
"NumFeatures:"){
542 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find NumFeatures " << std::endl;
545 file >> numInputDimensions;
548 if(word !=
"NumClasses:"){
549 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find NumClasses" << std::endl;
555 if(word !=
"UseScaling:"){
556 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find UseScaling" << std::endl;
562 if(word !=
"UseNullRejection:"){
563 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find UseNullRejection" << std::endl;
566 file >> useNullRejection;
571 ranges.
resize(numInputDimensions);
574 if(word !=
"Ranges:"){
575 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Ranges" << std::endl;
578 for(UINT n=0; n<ranges.size(); n++){
579 file >> ranges[n].minValue;
580 file >> ranges[n].maxValue;
585 models.
resize(numClasses);
586 classLabels.
resize(numClasses);
589 for(UINT k=0; k<numClasses; k++){
592 if(word !=
"*************_MODEL_*************"){
593 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find header for the "<<k+1<<
"th model" << std::endl;
598 if(word !=
"Model_ID:"){
599 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find model ID for the "<<k+1<<
"th model" << std::endl;
605 errorLog <<
"ANBC: Model ID does not match the current class ID for the "<<k+1<<
"th model" << std::endl;
611 errorLog <<
"ANBC: Could not find N for the "<<k+1<<
"th model" << std::endl;
617 if(word !=
"ClassLabel:"){
618 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find ClassLabel for the "<<k+1<<
"th model" << std::endl;
621 file >> models[k].classLabel;
622 classLabels[k] = models[k].classLabel;
625 if(word !=
"Threshold:"){
626 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the threshold for the "<<k+1<<
"th model" << std::endl;
629 file >> models[k].threshold;
632 if(word !=
"Gamma:"){
633 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the gamma parameter for the "<<k+1<<
"th model" << std::endl;
636 file >> models[k].gamma;
639 if(word !=
"TrainingMu:"){
640 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the training mu parameter for the "<<k+1<<
"th model" << std::endl;
643 file >> models[k].trainingMu;
646 if(word !=
"TrainingSigma:"){
647 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the training sigma parameter for the "<<k+1<<
"th model" << std::endl;
650 file >> models[k].trainingSigma;
653 models[k].mu.
resize(numInputDimensions);
654 models[k].sigma.
resize(numInputDimensions);
655 models[k].weights.
resize(numInputDimensions);
660 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Mu vector for the "<<k+1<<
"th model" << std::endl;
665 for(UINT j=0; j<models[k].N; j++){
668 models[k].mu[j] = value;
672 if(word !=
"Sigma:"){
673 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Sigma vector for the "<<k+1<<
"th model" << std::endl;
678 for(UINT j=0; j<models[k].N; j++){
681 models[k].sigma[j] = value;
685 if(word !=
"Weights:"){
686 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Weights vector for the "<<k+1<<
"th model" << std::endl;
691 for(UINT j=0; j<models[k].N; j++){
694 models[k].weights[j] = value;
698 if(word !=
"*********************************"){
699 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the model footer for the "<<k+1<<
"th model" << std::endl;
712 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
714 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
bool saveBaseSettingsToFile(std::fstream &file) const
#define DEFAULT_NULL_LIKELIHOOD_VALUE
bool loadLegacyModelFromFile(std::fstream &file)
Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)
std::string getClassifierType() const
Vector< ClassTracker > getClassTracker() const
virtual bool deepCopyFrom(const Classifier *classifier)
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
bool setWeights(const ClassificationData &weightsData)
virtual bool train(ClassificationData trainingData)
virtual bool loadModelFromFile(std::fstream &file)
ANBC(bool useScaling=false, bool useNullRejection=false, double nullRejectionCoeff=10.0)
virtual bool recomputeNullRejectionThresholds()
VectorFloat getNullRejectionThresholds() const
This class implements the Adaptive Naive Bayes Classifier algorithm. The Adaptive Naive Bayes Classif...
virtual bool train_(ClassificationData &trainingData)
UINT getNumSamples() const
ANBC & operator=(const ANBC &rhs)
bool copyBaseVariables(const Classifier *classifier)
bool loadBaseSettingsFromFile(std::fstream &file)
virtual bool predict_(VectorFloat &inputVector)
UINT getNumDimensions() const
UINT getNumClasses() const
virtual bool saveModelToFile(std::fstream &file) const
Vector< MinMax > getRanges() const
bool setNullRejectionCoeff(double nullRejectionCoeff)
bool scale(const Float minTarget, const Float maxTarget)