21 #define GRT_DLL_EXPORTS 27 const std::string ANBC::id =
"ANBC";
35 this->useScaling = useScaling;
36 this->useNullRejection = useNullRejection;
37 this->nullRejectionCoeff = nullRejectionCoeff;
38 supportsNullRejection =
true;
39 weightsDataSet =
false;
40 classifierMode = STANDARD_CLASSIFIER_MODE;
45 classifierMode = STANDARD_CLASSIFIER_MODE;
56 this->weightsDataSet = rhs.weightsDataSet;
57 this->weightsData = rhs.weightsData;
58 this->models = rhs.models;
68 if( classifier == NULL )
return false;
72 const ANBC *ptr =
dynamic_cast<const ANBC*
>(classifier);
75 this->weightsDataSet = ptr->weightsDataSet;
76 this->weightsData = ptr->weightsData;
77 this->models = ptr->models;
94 errorLog <<
"train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl;
100 errorLog <<
"train_(ClassificationData &trainingData) - The number of dimensions in the weights data (" << weightsData.
getNumDimensions() <<
") is not equal to the number of dimensions of the training data (" << N <<
")" << std::endl;
105 numInputDimensions = N;
106 numOutputDimensions = K;
116 trainingData.
scale(0, 1);
119 if( useValidationSet ){
120 validationData = trainingData.
split( 100-validationSetSize );
124 trainingLog <<
"Training Naive Bayes model, num training examples: " << M <<
", num validation examples: " << validationData.
getNumSamples() <<
", num classes: " << numClasses << std::endl;
127 for(UINT k=0; k<numClasses; k++){
133 classLabels[k] = classLabel;
137 if( weightsDataSet ){
138 bool weightsFound =
false;
140 if( weightsData[i].getClassLabel() == classLabel ){
141 weights = weightsData[i].getSample();
148 errorLog << __GRT_LOG__ <<
" Failed to find the weights for class " << classLabel << std::endl;
153 for(UINT j=0; j<numInputDimensions; j++) weights[j] = 1.0;
161 for(UINT i=0; i<data.getNumRows(); i++){
162 for(UINT j=0; j<data.getNumCols(); j++){
163 data[i][j] = classData[i][j];
168 models[k].gamma = nullRejectionCoeff;
169 if( !models[k].
train( classLabel, data, weights ) ){
170 errorLog << __GRT_LOG__ <<
" Failed to train model for class: " << classLabel << std::endl;
173 if( models[k].N == 0 ){
174 errorLog << __GRT_LOG__ <<
" N == 0!" << std::endl;
178 for(UINT j=0; j<numInputDimensions; j++){
179 if( models[k].sigma[j] == 0 ){
180 errorLog << __GRT_LOG__ <<
" The standard deviation of column " << j+1 <<
" is zero! Check the training data" << std::endl;
192 nullRejectionThresholds.
resize(numClasses);
193 for(UINT k=0; k<numClasses; k++) {
194 nullRejectionThresholds[k] = models[k].threshold;
202 trainingSetAccuracy = 0;
203 validationSetAccuracy = 0;
206 bool scalingState = useScaling;
210 errorLog << __GRT_LOG__ <<
" Failed to compute training set accuracy! Failed to fully train model!" << std::endl;
214 if( useValidationSet ){
217 errorLog << __GRT_LOG__ <<
" Failed to compute validation set accuracy! Failed to fully train model!" << std::endl;
223 trainingLog <<
"Training set accuracy: " << trainingSetAccuracy << std::endl;
225 if( useValidationSet ){
226 trainingLog <<
"Validation set accuracy: " << validationSetAccuracy << std::endl;
230 useScaling = scalingState;
238 errorLog <<
"predict_(VectorFloat &inputVector) - ANBC Model Not Trained!" << std::endl;
242 predictedClassLabel = 0;
243 maxLikelihood = -10000;
245 if( !trained )
return false;
247 if( inputVector.size() != numInputDimensions ){
248 errorLog <<
"predict_(VectorFloat &inputVector) - The size of the input vector (" << inputVector.
getSize() <<
") does not match the num features in the model (" << numInputDimensions << std::endl;
253 for(UINT n=0; n<numInputDimensions; n++){
254 inputVector[n] =
scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue,
MIN_SCALE_VALUE, MAX_SCALE_VALUE);
258 if( classLikelihoods.size() != numClasses ) classLikelihoods.
resize(numClasses,0);
259 if( classDistances.size() != numClasses ) classDistances.
resize(numClasses,0);
261 Float classLikelihoodsSum = 0;
263 for(UINT k=0; k<numClasses; k++){
264 classDistances[k] = models[k].predict( inputVector );
267 classLikelihoods[k] = classDistances[k];
270 if( grt_isinf(classLikelihoods[k]) || grt_isnan(classLikelihoods[k]) ){
271 classLikelihoods[k] = 0;
273 classLikelihoods[k] = grt_exp( classLikelihoods[k] );
274 classLikelihoodsSum += classLikelihoods[k];
277 if( classDistances[k] > minDist || k==0 ){
278 minDist = classDistances[k];
279 predictedClassLabel = k;
285 if( classLikelihoodsSum == 0 ){
286 predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
292 for(UINT k=0; k<numClasses; k++){
293 classLikelihoods[k] /= classLikelihoodsSum;
295 maxLikelihood = classLikelihoods[predictedClassLabel];
297 if( useNullRejection ){
299 if( minDist >= models[predictedClassLabel].threshold ) predictedClassLabel = models[predictedClassLabel].classLabel;
300 else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
301 }
else predictedClassLabel = models[predictedClassLabel].classLabel;
309 if( nullRejectionThresholds.size() != numClasses )
310 nullRejectionThresholds.
resize(numClasses);
311 for(UINT k=0; k<numClasses; k++) {
312 models[k].recomputeThresholdValue(nullRejectionCoeff);
313 nullRejectionThresholds[k] = models[k].threshold;
340 errorLog <<
"save(fstream &file) - The file is not open!" << std::endl;
345 file<<
"GRT_ANBC_MODEL_FILE_V2.0\n";
349 errorLog <<
"save(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
355 for(UINT k=0; k<numClasses; k++){
356 file <<
"*************_MODEL_*************\n";
357 file <<
"Model_ID: " << k+1 << std::endl;
358 file <<
"N: " << models[k].N << std::endl;
359 file <<
"ClassLabel: " << models[k].classLabel << std::endl;
360 file <<
"Threshold: " << models[k].threshold << std::endl;
361 file <<
"Gamma: " << models[k].gamma << std::endl;
362 file <<
"TrainingMu: " << models[k].trainingMu << std::endl;
363 file <<
"TrainingSigma: " << models[k].trainingSigma << std::endl;
366 for(UINT j=0; j<models[k].N; j++){
367 file <<
"\t" << models[k].mu[j];
371 for(UINT j=0; j<models[k].N; j++){
372 file <<
"\t" << models[k].sigma[j];
376 for(UINT j=0; j<models[k].N; j++){
377 file <<
"\t" << models[k].weights[j];
388 numInputDimensions = 0;
395 errorLog <<
"load(string filename) - Could not open file to load model" << std::endl;
403 if( word ==
"GRT_ANBC_MODEL_FILE_V1.0" ){
408 if(word !=
"GRT_ANBC_MODEL_FILE_V2.0"){
409 errorLog <<
"load(string filename) - Could not find Model File Header" << std::endl;
415 errorLog <<
"load(string filename) - Failed to load base settings from file!" << std::endl;
422 models.
resize(numClasses);
425 for(UINT k=0; k<numClasses; k++){
428 if(word !=
"*************_MODEL_*************"){
429 errorLog <<
"load(string filename) - Could not find header for the "<<k+1<<
"th model" << std::endl;
434 if(word !=
"Model_ID:"){
435 errorLog <<
"load(string filename) - Could not find model ID for the "<<k+1<<
"th model" << std::endl;
441 errorLog <<
"ANBC: Model ID does not match the current class ID for the "<<k+1<<
"th model" << std::endl;
447 errorLog <<
"ANBC: Could not find N for the "<<k+1<<
"th model" << std::endl;
453 if(word !=
"ClassLabel:"){
454 errorLog <<
"load(string filename) - Could not find ClassLabel for the "<<k+1<<
"th model" << std::endl;
457 file >> models[k].classLabel;
458 classLabels[k] = models[k].classLabel;
461 if(word !=
"Threshold:"){
462 errorLog <<
"load(string filename) - Could not find the threshold for the "<<k+1<<
"th model" << std::endl;
465 file >> models[k].threshold;
468 if(word !=
"Gamma:"){
469 errorLog <<
"load(string filename) - Could not find the gamma parameter for the "<<k+1<<
"th model" << std::endl;
472 file >> models[k].gamma;
475 if(word !=
"TrainingMu:"){
476 errorLog <<
"load(string filename) - Could not find the training mu parameter for the "<<k+1<<
"th model" << std::endl;
479 file >> models[k].trainingMu;
482 if(word !=
"TrainingSigma:"){
483 errorLog <<
"load(string filename) - Could not find the training sigma parameter for the "<<k+1<<
"th model" << std::endl;
486 file >> models[k].trainingSigma;
489 models[k].mu.
resize(numInputDimensions);
490 models[k].sigma.
resize(numInputDimensions);
491 models[k].weights.
resize(numInputDimensions);
496 errorLog <<
"load(string filename) - Could not find the Mu vector for the "<<k+1<<
"th model" << std::endl;
501 for(UINT j=0; j<models[k].N; j++){
504 models[k].mu[j] = value;
508 if(word !=
"Sigma:"){
509 errorLog <<
"load(string filename) - Could not find the Sigma vector for the "<<k+1<<
"th model" << std::endl;
514 for(UINT j=0; j<models[k].N; j++){
517 models[k].sigma[j] = value;
521 if(word !=
"Weights:"){
522 errorLog <<
"load(string filename) - Could not find the Weights vector for the "<<k+1<<
"th model" << std::endl;
527 for(UINT j=0; j<models[k].N; j++){
530 models[k].weights[j] = value;
539 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
541 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
549 return nullRejectionThresholds;
554 if( nullRejectionCoeff > 0 ){
555 this->nullRejectionCoeff = nullRejectionCoeff;
565 weightsDataSet =
true;
566 this->weightsData = weightsData;
577 if(word !=
"NumFeatures:"){
578 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find NumFeatures " << std::endl;
581 file >> numInputDimensions;
584 if(word !=
"NumClasses:"){
585 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find NumClasses" << std::endl;
591 if(word !=
"UseScaling:"){
592 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find UseScaling" << std::endl;
598 if(word !=
"UseNullRejection:"){
599 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find UseNullRejection" << std::endl;
602 file >> useNullRejection;
607 ranges.
resize(numInputDimensions);
610 if(word !=
"Ranges:"){
611 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Ranges" << std::endl;
614 for(UINT n=0; n<ranges.size(); n++){
615 file >> ranges[n].minValue;
616 file >> ranges[n].maxValue;
621 models.
resize(numClasses);
622 classLabels.
resize(numClasses);
625 for(UINT k=0; k<numClasses; k++){
628 if(word !=
"*************_MODEL_*************"){
629 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find header for the "<<k+1<<
"th model" << std::endl;
634 if(word !=
"Model_ID:"){
635 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find model ID for the "<<k+1<<
"th model" << std::endl;
641 errorLog <<
"ANBC: Model ID does not match the current class ID for the "<<k+1<<
"th model" << std::endl;
647 errorLog <<
"ANBC: Could not find N for the "<<k+1<<
"th model" << std::endl;
653 if(word !=
"ClassLabel:"){
654 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find ClassLabel for the "<<k+1<<
"th model" << std::endl;
657 file >> models[k].classLabel;
658 classLabels[k] = models[k].classLabel;
661 if(word !=
"Threshold:"){
662 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the threshold for the "<<k+1<<
"th model" << std::endl;
665 file >> models[k].threshold;
668 if(word !=
"Gamma:"){
669 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the gamma parameter for the "<<k+1<<
"th model" << std::endl;
672 file >> models[k].gamma;
675 if(word !=
"TrainingMu:"){
676 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the training mu parameter for the "<<k+1<<
"th model" << std::endl;
679 file >> models[k].trainingMu;
682 if(word !=
"TrainingSigma:"){
683 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the training sigma parameter for the "<<k+1<<
"th model" << std::endl;
686 file >> models[k].trainingSigma;
689 models[k].mu.
resize(numInputDimensions);
690 models[k].sigma.
resize(numInputDimensions);
691 models[k].weights.
resize(numInputDimensions);
696 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Mu vector for the "<<k+1<<
"th model" << std::endl;
701 for(UINT j=0; j<models[k].N; j++){
704 models[k].mu[j] = value;
708 if(word !=
"Sigma:"){
709 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Sigma vector for the "<<k+1<<
"th model" << std::endl;
714 for(UINT j=0; j<models[k].N; j++){
717 models[k].sigma[j] = value;
721 if(word !=
"Weights:"){
722 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Weights vector for the "<<k+1<<
"th model" << std::endl;
727 for(UINT j=0; j<models[k].N; j++){
730 models[k].weights[j] = value;
734 if(word !=
"*********************************"){
735 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the model footer for the "<<k+1<<
"th model" << std::endl;
748 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
750 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
bool saveBaseSettingsToFile(std::fstream &file) const
std::string getId() const
#define DEFAULT_NULL_LIKELIHOOD_VALUE
bool loadLegacyModelFromFile(std::fstream &file)
Classifier(const std::string &classifierId="")
Vector< ClassTracker > getClassTracker() const
virtual bool load(std::fstream &file)
virtual bool deepCopyFrom(const Classifier *classifier)
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
bool setWeights(const ClassificationData &weightsData)
virtual bool train(ClassificationData trainingData)
ANBC(bool useScaling=false, bool useNullRejection=false, double nullRejectionCoeff=10.0)
virtual bool recomputeNullRejectionThresholds()
VectorFloat getNullRejectionThresholds() const
virtual bool computeAccuracy(const ClassificationData &data, Float &accuracy)
virtual bool train_(ClassificationData &trainingData)
UINT getNumSamples() const
ANBC & operator=(const ANBC &rhs)
bool copyBaseVariables(const Classifier *classifier)
bool loadBaseSettingsFromFile(std::fstream &file)
virtual bool predict_(VectorFloat &inputVector)
UINT getNumDimensions() const
UINT getNumClasses() const
static std::string getId()
Vector< MinMax > getRanges() const
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
bool setNullRejectionCoeff(double nullRejectionCoeff)
bool scale(const Float minTarget, const Float maxTarget)
This is the main base class that all GRT Classification algorithms should inherit from...
virtual bool save(std::fstream &file) const
Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)