GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ANBC.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #define GRT_DLL_EXPORTS
22 #include "ANBC.h"
23 
24 GRT_BEGIN_NAMESPACE
25 
26 //Define the string that will be used to identify the object
27 const std::string ANBC::id = "ANBC";
28 std::string ANBC::getId() { return ANBC::id; }
29 
30 //Register the ANBC module with the Classifier base class
31 RegisterClassifierModule< ANBC > ANBC::registerModule( ANBC::getId() );
32 
33 ANBC::ANBC(bool useScaling,bool useNullRejection,Float nullRejectionCoeff) : Classifier( ANBC::getId() )
34 {
35  this->useScaling = useScaling;
36  this->useNullRejection = useNullRejection;
37  this->nullRejectionCoeff = nullRejectionCoeff;
38  supportsNullRejection = true;
39  weightsDataSet = false;
40  classifierMode = STANDARD_CLASSIFIER_MODE;
41 }
42 
43 ANBC::ANBC(const ANBC &rhs) : Classifier( ANBC::getId() )
44 {
45  classifierMode = STANDARD_CLASSIFIER_MODE;
46  *this = rhs;
47 }
48 
50 {
51 }
52 
53 ANBC& ANBC::operator=(const ANBC &rhs){
54  if( this != &rhs ){
55  //ANBC variables
56  this->weightsDataSet = rhs.weightsDataSet;
57  this->weightsData = rhs.weightsData;
58  this->models = rhs.models;
59 
60  //Classifier variables
61  copyBaseVariables( (Classifier*)&rhs );
62  }
63  return *this;
64 }
65 
66 bool ANBC::deepCopyFrom(const Classifier *classifier){
67 
68  if( classifier == NULL ) return false;
69 
70  if( this->getId() == classifier->getId() ){
71 
72  const ANBC *ptr = dynamic_cast<const ANBC*>(classifier);
73 
74  //Clone the ANBC values
75  this->weightsDataSet = ptr->weightsDataSet;
76  this->weightsData = ptr->weightsData;
77  this->models = ptr->models;
78 
79  //Clone the classifier variables
80  return copyBaseVariables( classifier );
81  }
82  return false;
83 }
84 
85 bool ANBC::train_(ClassificationData &trainingData){
86 
87  //Clear any previous model
88  clear();
89 
90  const unsigned int N = trainingData.getNumDimensions();
91  const unsigned int K = trainingData.getNumClasses();
92 
93  if( trainingData.getNumSamples() == 0 ){
94  errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl;
95  return false;
96  }
97 
98  if( weightsDataSet ){
99  if( weightsData.getNumDimensions() != N ){
100  errorLog << "train_(ClassificationData &trainingData) - The number of dimensions in the weights data (" << weightsData.getNumDimensions() << ") is not equal to the number of dimensions of the training data (" << N << ")" << std::endl;
101  return false;
102  }
103  }
104 
105  numInputDimensions = N;
106  numOutputDimensions = K;
107  numClasses = K;
108  models.resize(K);
109  classLabels.resize(K);
110  ranges = trainingData.getRanges();
111  ClassificationData validationData;
112 
113  //Scale the training data if needed
114  if( useScaling ){
115  //Scale the training data between 0 and 1
116  trainingData.scale(0, 1);
117  }
118 
119  if( useValidationSet ){
120  validationData = trainingData.split( 100-validationSetSize );
121  }
122 
123  const UINT M = trainingData.getNumSamples();
124  trainingLog << "Training Naive Bayes model, num training examples: " << M << ", num validation examples: " << validationData.getNumSamples() << ", num classes: " << numClasses << std::endl;
125 
126  //Train each of the models
127  for(UINT k=0; k<numClasses; k++){
128 
129  //Get the class label for the kth class
130  UINT classLabel = trainingData.getClassTracker()[k].classLabel;
131 
132  //Set the kth class label
133  classLabels[k] = classLabel;
134 
135  //Get the weights for this class
136  VectorFloat weights(numInputDimensions);
137  if( weightsDataSet ){
138  bool weightsFound = false;
139  for(UINT i=0; i<weightsData.getNumSamples(); i++){
140  if( weightsData[i].getClassLabel() == classLabel ){
141  weights = weightsData[i].getSample();
142  weightsFound = true;
143  break;
144  }
145  }
146 
147  if( !weightsFound ){
148  errorLog << __GRT_LOG__ << " Failed to find the weights for class " << classLabel << std::endl;
149  return false;
150  }
151  }else{
152  //If the weights data has not been set then all the weights are 1
153  for(UINT j=0; j<numInputDimensions; j++) weights[j] = 1.0;
154  }
155 
156  //Get all the training data for this class
157  ClassificationData classData = trainingData.getClassData(classLabel);
158  MatrixFloat data(classData.getNumSamples(),N);
159 
160  //Copy the training data into a matrix
161  for(UINT i=0; i<data.getNumRows(); i++){
162  for(UINT j=0; j<data.getNumCols(); j++){
163  data[i][j] = classData[i][j];
164  }
165  }
166 
167  //Train the model for this class
168  models[k].gamma = nullRejectionCoeff;
169  if( !models[k].train( classLabel, data, weights ) ){
170  errorLog << __GRT_LOG__ << " Failed to train model for class: " << classLabel << std::endl;
171 
172  //Try and work out why the training failed
173  if( models[k].N == 0 ){
174  errorLog << __GRT_LOG__ << " N == 0!" << std::endl;
175  models.clear();
176  return false;
177  }
178  for(UINT j=0; j<numInputDimensions; j++){
179  if( models[k].sigma[j] == 0 ){
180  errorLog << __GRT_LOG__ << " The standard deviation of column " << j+1 << " is zero! Check the training data" << std::endl;
181  models.clear();
182  return false;
183  }
184  }
185  models.clear();
186  return false;
187  }
188 
189  }
190 
191  //Store the null rejection thresholds
192  nullRejectionThresholds.resize(numClasses);
193  for(UINT k=0; k<numClasses; k++) {
194  nullRejectionThresholds[k] = models[k].threshold;
195  }
196 
197  //Flag that the model has been trained
198  trained = true;
199  converged = true;
200 
201  //Compute the final training stats
202  trainingSetAccuracy = 0;
203  validationSetAccuracy = 0;
204 
205  //If scaling was on, then the data will already be scaled, so turn it off temporially so we can test the model accuracy
206  bool scalingState = useScaling;
207  useScaling = false;
208  if( !computeAccuracy( trainingData, trainingSetAccuracy ) ){
209  trained = false;
210  errorLog << __GRT_LOG__ << " Failed to compute training set accuracy! Failed to fully train model!" << std::endl;
211  return false;
212  }
213 
214  if( useValidationSet ){
215  if( !computeAccuracy( validationData, validationSetAccuracy ) ){
216  trained = false;
217  errorLog << __GRT_LOG__ << " Failed to compute validation set accuracy! Failed to fully train model!" << std::endl;
218  return false;
219  }
220 
221  }
222 
223  trainingLog << "Training set accuracy: " << trainingSetAccuracy << std::endl;
224 
225  if( useValidationSet ){
226  trainingLog << "Validation set accuracy: " << validationSetAccuracy << std::endl;
227  }
228 
229  //Reset the scaling state for future prediction
230  useScaling = scalingState;
231 
232  return trained;
233 }
234 
235 bool ANBC::predict_(VectorFloat &inputVector){
236 
237  if( !trained ){
238  errorLog << "predict_(VectorFloat &inputVector) - ANBC Model Not Trained!" << std::endl;
239  return false;
240  }
241 
242  predictedClassLabel = 0;
243  maxLikelihood = -10000;
244 
245  if( !trained ) return false;
246 
247  if( inputVector.size() != numInputDimensions ){
248  errorLog << "predict_(VectorFloat &inputVector) - The size of the input vector (" << inputVector.getSize() << ") does not match the num features in the model (" << numInputDimensions << std::endl;
249  return false;
250  }
251 
252  if( useScaling ){
253  for(UINT n=0; n<numInputDimensions; n++){
254  inputVector[n] = scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue, MIN_SCALE_VALUE, MAX_SCALE_VALUE);
255  }
256  }
257 
258  if( classLikelihoods.size() != numClasses ) classLikelihoods.resize(numClasses,0);
259  if( classDistances.size() != numClasses ) classDistances.resize(numClasses,0);
260 
261  Float classLikelihoodsSum = 0;
262  Float minDist = 0;
263  for(UINT k=0; k<numClasses; k++){
264  classDistances[k] = models[k].predict( inputVector );
265 
266  //At this point the class likelihoods and class distances are the same thing
267  classLikelihoods[k] = classDistances[k];
268 
269  //If the distances are very far away then they could be -inf or nan so catch this so the sum still works
270  if( grt_isinf(classLikelihoods[k]) || grt_isnan(classLikelihoods[k]) ){
271  classLikelihoods[k] = 0;
272  }else{
273  classLikelihoods[k] = grt_exp( classLikelihoods[k] );
274  classLikelihoodsSum += classLikelihoods[k];
275 
276  //The loglikelihood values are negative so we want the values closest to 0
277  if( classDistances[k] > minDist || k==0 ){
278  minDist = classDistances[k];
279  predictedClassLabel = k;
280  }
281  }
282  }
283 
284  //If the class likelihoods sum is zero then all classes are -INF
285  if( classLikelihoodsSum == 0 ){
286  predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
287  maxLikelihood = 0;
288  return true;
289  }
290 
291  //Normalize the classlikelihoods
292  for(UINT k=0; k<numClasses; k++){
293  classLikelihoods[k] /= classLikelihoodsSum;
294  }
295  maxLikelihood = classLikelihoods[predictedClassLabel];
296 
297  if( useNullRejection ){
298  //Check to see if the best result is greater than the models threshold
299  if( minDist >= models[predictedClassLabel].threshold ) predictedClassLabel = models[predictedClassLabel].classLabel;
300  else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
301  }else predictedClassLabel = models[predictedClassLabel].classLabel;
302 
303  return true;
304 }
305 
307 
308  if( trained ){
309  if( nullRejectionThresholds.size() != numClasses )
310  nullRejectionThresholds.resize(numClasses);
311  for(UINT k=0; k<numClasses; k++) {
312  models[k].recomputeThresholdValue(nullRejectionCoeff);
313  nullRejectionThresholds[k] = models[k].threshold;
314  }
315  return true;
316  }
317  return false;
318 }
319 
320 bool ANBC::reset(){
321  return true;
322 }
323 
324 bool ANBC::clear(){
325 
326  //Clear the Classifier variables
328 
329  //Clear the ANBC model
330  weightsData.clear();
331  models.clear();
332 
333  return true;
334 }
335 
336 bool ANBC::save( std::fstream &file ) const{
337 
338  if(!file.is_open())
339  {
340  errorLog <<"save(fstream &file) - The file is not open!" << std::endl;
341  return false;
342  }
343 
344  //Write the header info
345  file<<"GRT_ANBC_MODEL_FILE_V2.0\n";
346 
347  //Write the classifier settings to the file
349  errorLog <<"save(fstream &file) - Failed to save classifier base settings to file!" << std::endl;
350  return false;
351  }
352 
353  if( trained ){
354  //Write each of the models
355  for(UINT k=0; k<numClasses; k++){
356  file << "*************_MODEL_*************\n";
357  file << "Model_ID: " << k+1 << std::endl;
358  file << "N: " << models[k].N << std::endl;
359  file << "ClassLabel: " << models[k].classLabel << std::endl;
360  file << "Threshold: " << models[k].threshold << std::endl;
361  file << "Gamma: " << models[k].gamma << std::endl;
362  file << "TrainingMu: " << models[k].trainingMu << std::endl;
363  file << "TrainingSigma: " << models[k].trainingSigma << std::endl;
364 
365  file<<"Mu:";
366  for(UINT j=0; j<models[k].N; j++){
367  file << "\t" << models[k].mu[j];
368  }file << std::endl;
369 
370  file<<"Sigma:";
371  for(UINT j=0; j<models[k].N; j++){
372  file << "\t" << models[k].sigma[j];
373  }file << std::endl;
374 
375  file<<"Weights:";
376  for(UINT j=0; j<models[k].N; j++){
377  file << "\t" << models[k].weights[j];
378  }file << std::endl;
379  }
380  }
381 
382  return true;
383 }
384 
385 bool ANBC::load( std::fstream &file ){
386 
387  trained = false;
388  numInputDimensions = 0;
389  numClasses = 0;
390  models.clear();
391  classLabels.clear();
392 
393  if(!file.is_open())
394  {
395  errorLog << "load(string filename) - Could not open file to load model" << std::endl;
396  return false;
397  }
398 
399  std::string word;
400  file >> word;
401 
402  //Check to see if we should load a legacy file
403  if( word == "GRT_ANBC_MODEL_FILE_V1.0" ){
404  return loadLegacyModelFromFile( file );
405  }
406 
407  //Find the file type header
408  if(word != "GRT_ANBC_MODEL_FILE_V2.0"){
409  errorLog << "load(string filename) - Could not find Model File Header" << std::endl;
410  return false;
411  }
412 
413  //Load the base settings from the file
415  errorLog << "load(string filename) - Failed to load base settings from file!" << std::endl;
416  return false;
417  }
418 
419  if( trained ){
420 
421  //Resize the buffer
422  models.resize(numClasses);
423 
424  //Load each of the K models
425  for(UINT k=0; k<numClasses; k++){
426  UINT modelID;
427  file >> word;
428  if(word != "*************_MODEL_*************"){
429  errorLog << "load(string filename) - Could not find header for the "<<k+1<<"th model" << std::endl;
430  return false;
431  }
432 
433  file >> word;
434  if(word != "Model_ID:"){
435  errorLog << "load(string filename) - Could not find model ID for the "<<k+1<<"th model" << std::endl;
436  return false;
437  }
438  file >> modelID;
439 
440  if(modelID-1!=k){
441  errorLog << "ANBC: Model ID does not match the current class ID for the "<<k+1<<"th model" << std::endl;
442  return false;
443  }
444 
445  file >> word;
446  if(word != "N:"){
447  errorLog << "ANBC: Could not find N for the "<<k+1<<"th model" << std::endl;
448  return false;
449  }
450  file >> models[k].N;
451 
452  file >> word;
453  if(word != "ClassLabel:"){
454  errorLog << "load(string filename) - Could not find ClassLabel for the "<<k+1<<"th model" << std::endl;
455  return false;
456  }
457  file >> models[k].classLabel;
458  classLabels[k] = models[k].classLabel;
459 
460  file >> word;
461  if(word != "Threshold:"){
462  errorLog << "load(string filename) - Could not find the threshold for the "<<k+1<<"th model" << std::endl;
463  return false;
464  }
465  file >> models[k].threshold;
466 
467  file >> word;
468  if(word != "Gamma:"){
469  errorLog << "load(string filename) - Could not find the gamma parameter for the "<<k+1<<"th model" << std::endl;
470  return false;
471  }
472  file >> models[k].gamma;
473 
474  file >> word;
475  if(word != "TrainingMu:"){
476  errorLog << "load(string filename) - Could not find the training mu parameter for the "<<k+1<<"th model" << std::endl;
477  return false;
478  }
479  file >> models[k].trainingMu;
480 
481  file >> word;
482  if(word != "TrainingSigma:"){
483  errorLog << "load(string filename) - Could not find the training sigma parameter for the "<<k+1<<"th model" << std::endl;
484  return false;
485  }
486  file >> models[k].trainingSigma;
487 
488  //Resize the buffers
489  models[k].mu.resize(numInputDimensions);
490  models[k].sigma.resize(numInputDimensions);
491  models[k].weights.resize(numInputDimensions);
492 
493  //Load Mu, Sigma and Weights
494  file >> word;
495  if(word != "Mu:"){
496  errorLog << "load(string filename) - Could not find the Mu vector for the "<<k+1<<"th model" << std::endl;
497  return false;
498  }
499 
500  //Load Mu
501  for(UINT j=0; j<models[k].N; j++){
502  Float value;
503  file >> value;
504  models[k].mu[j] = value;
505  }
506 
507  file >> word;
508  if(word != "Sigma:"){
509  errorLog << "load(string filename) - Could not find the Sigma vector for the "<<k+1<<"th model" << std::endl;
510  return false;
511  }
512 
513  //Load Sigma
514  for(UINT j=0; j<models[k].N; j++){
515  Float value;
516  file >> value;
517  models[k].sigma[j] = value;
518  }
519 
520  file >> word;
521  if(word != "Weights:"){
522  errorLog << "load(string filename) - Could not find the Weights vector for the "<<k+1<<"th model" << std::endl;
523  return false;
524  }
525 
526  //Load Weights
527  for(UINT j=0; j<models[k].N; j++){
528  Float value;
529  file >> value;
530  models[k].weights[j] = value;
531  }
532  }
533 
534  //Recompute the null rejection thresholds
536 
537  //Resize the prediction results to make sure it is setup for realtime prediction
538  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
539  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
540  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
541  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
542  }
543 
544  return true;
545 }
546 
548  if( !trained ) return VectorFloat();
549  return nullRejectionThresholds;
550 }
551 
552 bool ANBC::setNullRejectionCoeff(Float nullRejectionCoeff){
553 
554  if( nullRejectionCoeff > 0 ){
555  this->nullRejectionCoeff = nullRejectionCoeff;
557  return true;
558  }
559  return false;
560 }
561 
562 bool ANBC::setWeights(const ClassificationData &weightsData){
563 
564  if( weightsData.getNumSamples() > 0 ){
565  weightsDataSet = true;
566  this->weightsData = weightsData;
567  return true;
568  }
569  return false;
570 }
571 
572 bool ANBC::loadLegacyModelFromFile( std::fstream &file ){
573 
574  std::string word;
575 
576  file >> word;
577  if(word != "NumFeatures:"){
578  errorLog << "loadANBCModelFromFile(string filename) - Could not find NumFeatures " << std::endl;
579  return false;
580  }
581  file >> numInputDimensions;
582 
583  file >> word;
584  if(word != "NumClasses:"){
585  errorLog << "loadANBCModelFromFile(string filename) - Could not find NumClasses" << std::endl;
586  return false;
587  }
588  file >> numClasses;
589 
590  file >> word;
591  if(word != "UseScaling:"){
592  errorLog << "loadANBCModelFromFile(string filename) - Could not find UseScaling" << std::endl;
593  return false;
594  }
595  file >> useScaling;
596 
597  file >> word;
598  if(word != "UseNullRejection:"){
599  errorLog << "loadANBCModelFromFile(string filename) - Could not find UseNullRejection" << std::endl;
600  return false;
601  }
602  file >> useNullRejection;
603 
605  if( useScaling ){
606  //Resize the ranges buffer
607  ranges.resize(numInputDimensions);
608 
609  file >> word;
610  if(word != "Ranges:"){
611  errorLog << "loadANBCModelFromFile(string filename) - Could not find the Ranges" << std::endl;
612  return false;
613  }
614  for(UINT n=0; n<ranges.size(); n++){
615  file >> ranges[n].minValue;
616  file >> ranges[n].maxValue;
617  }
618  }
619 
620  //Resize the buffer
621  models.resize(numClasses);
622  classLabels.resize(numClasses);
623 
624  //Load each of the K models
625  for(UINT k=0; k<numClasses; k++){
626  UINT modelID;
627  file >> word;
628  if(word != "*************_MODEL_*************"){
629  errorLog << "loadANBCModelFromFile(string filename) - Could not find header for the "<<k+1<<"th model" << std::endl;
630  return false;
631  }
632 
633  file >> word;
634  if(word != "Model_ID:"){
635  errorLog << "loadANBCModelFromFile(string filename) - Could not find model ID for the "<<k+1<<"th model" << std::endl;
636  return false;
637  }
638  file >> modelID;
639 
640  if(modelID-1!=k){
641  errorLog << "ANBC: Model ID does not match the current class ID for the "<<k+1<<"th model" << std::endl;
642  return false;
643  }
644 
645  file >> word;
646  if(word != "N:"){
647  errorLog << "ANBC: Could not find N for the "<<k+1<<"th model" << std::endl;
648  return false;
649  }
650  file >> models[k].N;
651 
652  file >> word;
653  if(word != "ClassLabel:"){
654  errorLog << "loadANBCModelFromFile(string filename) - Could not find ClassLabel for the "<<k+1<<"th model" << std::endl;
655  return false;
656  }
657  file >> models[k].classLabel;
658  classLabels[k] = models[k].classLabel;
659 
660  file >> word;
661  if(word != "Threshold:"){
662  errorLog << "loadANBCModelFromFile(string filename) - Could not find the threshold for the "<<k+1<<"th model" << std::endl;
663  return false;
664  }
665  file >> models[k].threshold;
666 
667  file >> word;
668  if(word != "Gamma:"){
669  errorLog << "loadANBCModelFromFile(string filename) - Could not find the gamma parameter for the "<<k+1<<"th model" << std::endl;
670  return false;
671  }
672  file >> models[k].gamma;
673 
674  file >> word;
675  if(word != "TrainingMu:"){
676  errorLog << "loadANBCModelFromFile(string filename) - Could not find the training mu parameter for the "<<k+1<<"th model" << std::endl;
677  return false;
678  }
679  file >> models[k].trainingMu;
680 
681  file >> word;
682  if(word != "TrainingSigma:"){
683  errorLog << "loadANBCModelFromFile(string filename) - Could not find the training sigma parameter for the "<<k+1<<"th model" << std::endl;
684  return false;
685  }
686  file >> models[k].trainingSigma;
687 
688  //Resize the buffers
689  models[k].mu.resize(numInputDimensions);
690  models[k].sigma.resize(numInputDimensions);
691  models[k].weights.resize(numInputDimensions);
692 
693  //Load Mu, Sigma and Weights
694  file >> word;
695  if(word != "Mu:"){
696  errorLog << "loadANBCModelFromFile(string filename) - Could not find the Mu vector for the "<<k+1<<"th model" << std::endl;
697  return false;
698  }
699 
700  //Load Mu
701  for(UINT j=0; j<models[k].N; j++){
702  Float value;
703  file >> value;
704  models[k].mu[j] = value;
705  }
706 
707  file >> word;
708  if(word != "Sigma:"){
709  errorLog << "loadANBCModelFromFile(string filename) - Could not find the Sigma vector for the "<<k+1<<"th model" << std::endl;
710  return false;
711  }
712 
713  //Load Sigma
714  for(UINT j=0; j<models[k].N; j++){
715  Float value;
716  file >> value;
717  models[k].sigma[j] = value;
718  }
719 
720  file >> word;
721  if(word != "Weights:"){
722  errorLog << "loadANBCModelFromFile(string filename) - Could not find the Weights vector for the "<<k+1<<"th model" << std::endl;
723  return false;
724  }
725 
726  //Load Weights
727  for(UINT j=0; j<models[k].N; j++){
728  Float value;
729  file >> value;
730  models[k].weights[j] = value;
731  }
732 
733  file >> word;
734  if(word != "*********************************"){
735  errorLog << "loadANBCModelFromFile(string filename) - Could not find the model footer for the "<<k+1<<"th model" << std::endl;
736  return false;
737  }
738  }
739 
740  //Flag that the model is trained
741  trained = true;
742 
743  //Recompute the null rejection thresholds
745 
746  //Resize the prediction results to make sure it is setup for realtime prediction
747  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
748  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
749  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
750  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
751 
752  return true;
753 
754 }
755 
756 GRT_END_NAMESPACE
bool saveBaseSettingsToFile(std::fstream &file) const
Definition: Classifier.cpp:274
std::string getId() const
Definition: GRTBase.cpp:85
virtual bool reset()
Definition: ANBC.cpp:320
#define DEFAULT_NULL_LIKELIHOOD_VALUE
Definition: Classifier.h:33
bool loadLegacyModelFromFile(std::fstream &file)
Definition: ANBC.cpp:572
virtual ~ANBC(void)
Definition: ANBC.cpp:49
Classifier(const std::string &classifierId="")
Definition: Classifier.cpp:77
Vector< ClassTracker > getClassTracker() const
virtual bool load(std::fstream &file)
Definition: ANBC.cpp:385
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: ANBC.cpp:66
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
bool setWeights(const ClassificationData &weightsData)
Definition: ANBC.cpp:562
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:107
#define MIN_SCALE_VALUE
Definition: ANBC.h:34
ANBC(bool useScaling=false, bool useNullRejection=false, double nullRejectionCoeff=10.0)
UINT getSize() const
Definition: Vector.h:201
virtual bool recomputeNullRejectionThresholds()
Definition: ANBC.cpp:306
VectorFloat getNullRejectionThresholds() const
Definition: ANBC.cpp:547
virtual bool computeAccuracy(const ClassificationData &data, Float &accuracy)
Definition: Classifier.cpp:171
virtual bool train_(ClassificationData &trainingData)
Definition: ANBC.cpp:85
UINT getNumSamples() const
ANBC & operator=(const ANBC &rhs)
Definition: ANBC.cpp:53
virtual bool clear()
Definition: ANBC.cpp:324
Definition: ANBC.h:50
bool copyBaseVariables(const Classifier *classifier)
Definition: Classifier.cpp:101
bool loadBaseSettingsFromFile(std::fstream &file)
Definition: Classifier.cpp:321
virtual bool predict_(VectorFloat &inputVector)
Definition: ANBC.cpp:235
UINT getNumDimensions() const
UINT getNumClasses() const
static std::string getId()
Definition: ANBC.cpp:28
Vector< MinMax > getRanges() const
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
bool setNullRejectionCoeff(double nullRejectionCoeff)
Definition: ANBC.cpp:552
bool scale(const Float minTarget, const Float maxTarget)
virtual bool clear()
Definition: Classifier.cpp:151
This is the main base class that all GRT Classification algorithms should inherit from...
Definition: Classifier.h:41
virtual bool save(std::fstream &file) const
Definition: ANBC.cpp:336
Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)
Definition: GRTBase.h:184