21 #define GRT_DLL_EXPORTS
29 ANBC::ANBC(
bool useScaling,
bool useNullRejection,Float nullRejectionCoeff)
31 this->useScaling = useScaling;
32 this->useNullRejection = useNullRejection;
33 this->nullRejectionCoeff = nullRejectionCoeff;
34 supportsNullRejection =
true;
35 weightsDataSet =
false;
37 classifierType = classType;
38 classifierMode = STANDARD_CLASSIFIER_MODE;
39 debugLog.setProceedingText(
"[DEBUG ANBC]");
40 errorLog.setProceedingText(
"[ERROR ANBC]");
41 trainingLog.setProceedingText(
"[TRAINING ANBC]");
42 warningLog.setProceedingText(
"[WARNING ANBC]");
47 classifierType = classType;
48 classifierMode = STANDARD_CLASSIFIER_MODE;
49 debugLog.setProceedingText(
"[DEBUG ANBC]");
50 errorLog.setProceedingText(
"[ERROR ANBC]");
51 trainingLog.setProceedingText(
"[TRAINING ANBC]");
52 warningLog.setProceedingText(
"[WARNING ANBC]");
63 this->weightsDataSet = rhs.weightsDataSet;
64 this->weightsData = rhs.weightsData;
65 this->models = rhs.models;
75 if( classifier == NULL )
return false;
81 this->weightsDataSet = ptr->weightsDataSet;
82 this->weightsData = ptr->weightsData;
83 this->models = ptr->models;
101 errorLog <<
"train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl;
105 if( weightsDataSet ){
107 errorLog <<
"train_(ClassificationData &trainingData) - The number of dimensions in the weights data (" << weightsData.
getNumDimensions() <<
") is not equal to the number of dimensions of the training data (" << N <<
")" << std::endl;
112 numInputDimensions = N;
121 trainingData.
scale(0, 1);
125 for(UINT k=0; k<numClasses; k++){
131 classLabels[k] = classLabel;
135 if( weightsDataSet ){
136 bool weightsFound =
false;
138 if( weightsData[i].getClassLabel() == classLabel ){
139 weights = weightsData[i].getSample();
146 errorLog <<
"train_(ClassificationData &trainingData) - Failed to find the weights for class " << classLabel << std::endl;
151 for(UINT j=0; j<numInputDimensions; j++) weights[j] = 1.0;
159 for(UINT i=0; i<data.getNumRows(); i++){
160 for(UINT j=0; j<data.getNumCols(); j++){
161 data[i][j] = classData[i][j];
166 models[k].gamma = nullRejectionCoeff;
167 if( !models[k].
train( classLabel, data, weights ) ){
168 errorLog <<
"train_(ClassificationData &trainingData) - Failed to train model for class: " << classLabel << std::endl;
171 if( models[k].N == 0 ){
172 errorLog <<
"train_(ClassificationData &trainingData) - N == 0!" << std::endl;
176 for(UINT j=0; j<numInputDimensions; j++){
177 if( models[k].sigma[j] == 0 ){
178 errorLog <<
"train_(ClassificationData &trainingData) - The standard deviation of column " << j+1 <<
" is zero! Check the training data" << std::endl;
190 nullRejectionThresholds.
resize(numClasses);
191 for(UINT k=0; k<numClasses; k++) {
192 nullRejectionThresholds[k] = models[k].threshold;
203 errorLog <<
"predict_(VectorFloat &inputVector) - ANBC Model Not Trained!" << std::endl;
207 predictedClassLabel = 0;
208 maxLikelihood = -10000;
210 if( !trained )
return false;
212 if( inputVector.size() != numInputDimensions ){
213 errorLog <<
"predict_(VectorFloat &inputVector) - The size of the input vector (" << inputVector.size() <<
") does not match the num features in the model (" << numInputDimensions << std::endl;
218 for(UINT n=0; n<numInputDimensions; n++){
219 inputVector[n] =
scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue,
MIN_SCALE_VALUE, MAX_SCALE_VALUE);
223 if( classLikelihoods.size() != numClasses ) classLikelihoods.
resize(numClasses,0);
224 if( classDistances.size() != numClasses ) classDistances.
resize(numClasses,0);
226 Float classLikelihoodsSum = 0;
227 Float minDist = -99e+99;
228 for(UINT k=0; k<numClasses; k++){
229 classDistances[k] = models[k].predict( inputVector );
232 classLikelihoods[k] = classDistances[k];
235 if( grt_isinf(classLikelihoods[k]) || grt_isnan(classLikelihoods[k]) ){
236 classLikelihoods[k] = 0;
238 classLikelihoods[k] = grt_exp( classLikelihoods[k] );
239 classLikelihoodsSum += classLikelihoods[k];
242 if( classDistances[k] > minDist ){
243 minDist = classDistances[k];
244 predictedClassLabel = k;
250 if( classLikelihoodsSum == 0 ){
251 predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
257 for(UINT k=0; k<numClasses; k++){
258 classLikelihoods[k] /= classLikelihoodsSum;
260 maxLikelihood = classLikelihoods[predictedClassLabel];
262 if( useNullRejection ){
264 if( minDist >= models[predictedClassLabel].threshold ) predictedClassLabel = models[predictedClassLabel].classLabel;
265 else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
266 }
else predictedClassLabel = models[predictedClassLabel].classLabel;
274 if( nullRejectionThresholds.size() != numClasses )
275 nullRejectionThresholds.
resize(numClasses);
276 for(UINT k=0; k<numClasses; k++) {
277 models[k].recomputeThresholdValue(nullRejectionCoeff);
278 nullRejectionThresholds[k] = models[k].threshold;
305 errorLog <<
"save(fstream &file) - The file is not open!" << std::endl;
310 file<<
"GRT_ANBC_MODEL_FILE_V2.0\n";
314 errorLog <<
"save(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
320 for(UINT k=0; k<numClasses; k++){
321 file <<
"*************_MODEL_*************\n";
322 file <<
"Model_ID: " << k+1 << std::endl;
323 file <<
"N: " << models[k].N << std::endl;
324 file <<
"ClassLabel: " << models[k].classLabel << std::endl;
325 file <<
"Threshold: " << models[k].threshold << std::endl;
326 file <<
"Gamma: " << models[k].gamma << std::endl;
327 file <<
"TrainingMu: " << models[k].trainingMu << std::endl;
328 file <<
"TrainingSigma: " << models[k].trainingSigma << std::endl;
331 for(UINT j=0; j<models[k].N; j++){
332 file <<
"\t" << models[k].mu[j];
336 for(UINT j=0; j<models[k].N; j++){
337 file <<
"\t" << models[k].sigma[j];
341 for(UINT j=0; j<models[k].N; j++){
342 file <<
"\t" << models[k].weights[j];
353 numInputDimensions = 0;
360 errorLog <<
"load(string filename) - Could not open file to load model" << std::endl;
368 if( word ==
"GRT_ANBC_MODEL_FILE_V1.0" ){
373 if(word !=
"GRT_ANBC_MODEL_FILE_V2.0"){
374 errorLog <<
"load(string filename) - Could not find Model File Header" << std::endl;
380 errorLog <<
"load(string filename) - Failed to load base settings from file!" << std::endl;
387 models.
resize(numClasses);
390 for(UINT k=0; k<numClasses; k++){
393 if(word !=
"*************_MODEL_*************"){
394 errorLog <<
"load(string filename) - Could not find header for the "<<k+1<<
"th model" << std::endl;
399 if(word !=
"Model_ID:"){
400 errorLog <<
"load(string filename) - Could not find model ID for the "<<k+1<<
"th model" << std::endl;
406 errorLog <<
"ANBC: Model ID does not match the current class ID for the "<<k+1<<
"th model" << std::endl;
412 errorLog <<
"ANBC: Could not find N for the "<<k+1<<
"th model" << std::endl;
418 if(word !=
"ClassLabel:"){
419 errorLog <<
"load(string filename) - Could not find ClassLabel for the "<<k+1<<
"th model" << std::endl;
422 file >> models[k].classLabel;
423 classLabels[k] = models[k].classLabel;
426 if(word !=
"Threshold:"){
427 errorLog <<
"load(string filename) - Could not find the threshold for the "<<k+1<<
"th model" << std::endl;
430 file >> models[k].threshold;
433 if(word !=
"Gamma:"){
434 errorLog <<
"load(string filename) - Could not find the gamma parameter for the "<<k+1<<
"th model" << std::endl;
437 file >> models[k].gamma;
440 if(word !=
"TrainingMu:"){
441 errorLog <<
"load(string filename) - Could not find the training mu parameter for the "<<k+1<<
"th model" << std::endl;
444 file >> models[k].trainingMu;
447 if(word !=
"TrainingSigma:"){
448 errorLog <<
"load(string filename) - Could not find the training sigma parameter for the "<<k+1<<
"th model" << std::endl;
451 file >> models[k].trainingSigma;
454 models[k].mu.
resize(numInputDimensions);
455 models[k].sigma.
resize(numInputDimensions);
456 models[k].weights.
resize(numInputDimensions);
461 errorLog <<
"load(string filename) - Could not find the Mu vector for the "<<k+1<<
"th model" << std::endl;
466 for(UINT j=0; j<models[k].N; j++){
469 models[k].mu[j] = value;
473 if(word !=
"Sigma:"){
474 errorLog <<
"load(string filename) - Could not find the Sigma vector for the "<<k+1<<
"th model" << std::endl;
479 for(UINT j=0; j<models[k].N; j++){
482 models[k].sigma[j] = value;
486 if(word !=
"Weights:"){
487 errorLog <<
"load(string filename) - Could not find the Weights vector for the "<<k+1<<
"th model" << std::endl;
492 for(UINT j=0; j<models[k].N; j++){
495 models[k].weights[j] = value;
504 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
506 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
514 return nullRejectionThresholds;
519 if( nullRejectionCoeff > 0 ){
520 this->nullRejectionCoeff = nullRejectionCoeff;
530 weightsDataSet =
true;
531 this->weightsData = weightsData;
542 if(word !=
"NumFeatures:"){
543 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find NumFeatures " << std::endl;
546 file >> numInputDimensions;
549 if(word !=
"NumClasses:"){
550 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find NumClasses" << std::endl;
556 if(word !=
"UseScaling:"){
557 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find UseScaling" << std::endl;
563 if(word !=
"UseNullRejection:"){
564 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find UseNullRejection" << std::endl;
567 file >> useNullRejection;
572 ranges.
resize(numInputDimensions);
575 if(word !=
"Ranges:"){
576 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Ranges" << std::endl;
579 for(UINT n=0; n<ranges.size(); n++){
580 file >> ranges[n].minValue;
581 file >> ranges[n].maxValue;
586 models.
resize(numClasses);
587 classLabels.
resize(numClasses);
590 for(UINT k=0; k<numClasses; k++){
593 if(word !=
"*************_MODEL_*************"){
594 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find header for the "<<k+1<<
"th model" << std::endl;
599 if(word !=
"Model_ID:"){
600 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find model ID for the "<<k+1<<
"th model" << std::endl;
606 errorLog <<
"ANBC: Model ID does not match the current class ID for the "<<k+1<<
"th model" << std::endl;
612 errorLog <<
"ANBC: Could not find N for the "<<k+1<<
"th model" << std::endl;
618 if(word !=
"ClassLabel:"){
619 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find ClassLabel for the "<<k+1<<
"th model" << std::endl;
622 file >> models[k].classLabel;
623 classLabels[k] = models[k].classLabel;
626 if(word !=
"Threshold:"){
627 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the threshold for the "<<k+1<<
"th model" << std::endl;
630 file >> models[k].threshold;
633 if(word !=
"Gamma:"){
634 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the gamma parameter for the "<<k+1<<
"th model" << std::endl;
637 file >> models[k].gamma;
640 if(word !=
"TrainingMu:"){
641 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the training mu parameter for the "<<k+1<<
"th model" << std::endl;
644 file >> models[k].trainingMu;
647 if(word !=
"TrainingSigma:"){
648 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the training sigma parameter for the "<<k+1<<
"th model" << std::endl;
651 file >> models[k].trainingSigma;
654 models[k].mu.
resize(numInputDimensions);
655 models[k].sigma.
resize(numInputDimensions);
656 models[k].weights.
resize(numInputDimensions);
661 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Mu vector for the "<<k+1<<
"th model" << std::endl;
666 for(UINT j=0; j<models[k].N; j++){
669 models[k].mu[j] = value;
673 if(word !=
"Sigma:"){
674 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Sigma vector for the "<<k+1<<
"th model" << std::endl;
679 for(UINT j=0; j<models[k].N; j++){
682 models[k].sigma[j] = value;
686 if(word !=
"Weights:"){
687 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Weights vector for the "<<k+1<<
"th model" << std::endl;
692 for(UINT j=0; j<models[k].N; j++){
695 models[k].weights[j] = value;
699 if(word !=
"*********************************"){
700 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the model footer for the "<<k+1<<
"th model" << std::endl;
713 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
715 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
bool saveBaseSettingsToFile(std::fstream &file) const
#define DEFAULT_NULL_LIKELIHOOD_VALUE
bool loadLegacyModelFromFile(std::fstream &file)
Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)
std::string getClassifierType() const
Vector< ClassTracker > getClassTracker() const
virtual bool load(std::fstream &file)
virtual bool deepCopyFrom(const Classifier *classifier)
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
bool setWeights(const ClassificationData &weightsData)
virtual bool train(ClassificationData trainingData)
ANBC(bool useScaling=false, bool useNullRejection=false, double nullRejectionCoeff=10.0)
virtual bool recomputeNullRejectionThresholds()
VectorFloat getNullRejectionThresholds() const
This class implements the Adaptive Naive Bayes Classifier algorithm. The Adaptive Naive Bayes Classif...
virtual bool train_(ClassificationData &trainingData)
UINT getNumSamples() const
ANBC & operator=(const ANBC &rhs)
bool copyBaseVariables(const Classifier *classifier)
bool loadBaseSettingsFromFile(std::fstream &file)
virtual bool predict_(VectorFloat &inputVector)
UINT getNumDimensions() const
UINT getNumClasses() const
Vector< MinMax > getRanges() const
bool setNullRejectionCoeff(double nullRejectionCoeff)
bool scale(const Float minTarget, const Float maxTarget)
virtual bool save(std::fstream &file) const