21 #define GRT_DLL_EXPORTS 27 const std::string KNN::id =
"KNN";
37 this->useScaling = useScaling;
38 this->useNullRejection = useNullRejection;
39 this->nullRejectionCoeff = nullRejectionCoeff;
43 supportsNullRejection =
true;
44 classifierMode = STANDARD_CLASSIFIER_MODE;
50 classifierMode = STANDARD_CLASSIFIER_MODE;
78 if( classifier == NULL )
return false;
82 const KNN *ptr =
dynamic_cast<const KNN*
>(classifier);
105 errorLog << __GRT_LOG__ <<
" Training data has zero samples!" << std::endl;
117 trainingData.
scale(0, 1);
124 if( useValidationSet ){
125 validationData = trainingData.
split( 100-validationSetSize );
129 classLabels.
resize( numClasses );
130 for(UINT k=0; k<numClasses; k++){
136 if( !
train_(trainingData,K) ){
142 Float bestAccuracy = 0;
149 if( !
train_(trainingData, k) ){
150 errorLog << __GRT_LOG__ <<
" Failed to train model for a k value of " << k << std::endl;
159 if( !predict( sample , k) ){
160 errorLog << __GRT_LOG__ <<
" Failed to predict label for test sample with a k value of " << k << std::endl;
164 if( testSet[i].getClassLabel() == predictedClassLabel ){
169 accuracy = accuracy /Float( testSet.
getNumSamples() ) * 100.0;
172 trainingLog <<
"K:\t" << k <<
"\tAccuracy:\t" << accuracy << std::endl;
174 if( accuracy > bestAccuracy ){
175 bestAccuracy = accuracy;
183 if( bestAccuracy > 0 ){
185 std::sort(trainingAccuracyLog.begin(),trainingAccuracyLog.end(),IndexedDouble::sortIndexedDoubleByValueDescending);
191 tempLog.push_back( trainingAccuracyLog[0] );
194 for(UINT i=1; i<trainingAccuracyLog.size(); i++){
195 if( trainingAccuracyLog[i].value == tempLog[0].value ){
196 tempLog.push_back( trainingAccuracyLog[i] );
201 std::sort(tempLog.begin(),tempLog.end(),IndexedDouble::sortIndexedDoubleByIndexAscending);
203 trainingLog <<
"Best K Value: " << tempLog[0].index <<
"\tAccuracy:\t" << tempLog[0].value << std::endl;
207 if( !
train_(trainingData,tempLog[0].index) ){
216 trainingSetAccuracy = 0;
217 validationSetAccuracy = 0;
220 bool scalingState = useScaling;
223 if( !
predict_( trainingData[i].getSample() ) ){
225 errorLog << __GRT_LOG__ <<
" Failed to run prediction for training sample: " << i <<
"! Failed to fully train model!" << std::endl;
229 if( predictedClassLabel == trainingData[i].getClassLabel() ){
230 trainingSetAccuracy++;
234 if( useValidationSet ){
236 if( !
predict_( validationData[i].getSample() ) ){
238 errorLog << __GRT_LOG__ <<
" Failed to run prediction for validation sample: " << i <<
"! Failed to fully train model!" << std::endl;
242 if( predictedClassLabel == validationData[i].getClassLabel() ){
243 validationSetAccuracy++;
248 trainingSetAccuracy = trainingSetAccuracy / trainingData.
getNumSamples() * 100.0;
250 trainingLog <<
"Training set accuracy: " << trainingSetAccuracy << std::endl;
252 if( useValidationSet ){
253 validationSetAccuracy = validationSetAccuracy / validationData.
getNumSamples() * 100.0;
254 trainingLog <<
"Validation set accuracy: " << validationSetAccuracy << std::endl;
258 useScaling = scalingState;
274 if( useNullRejection ){
277 useNullRejection =
false;
278 nullRejectionThresholds.clear();
284 nullRejectionThresholds.
resize( numClasses, 0 );
287 const unsigned int numTrainingExamples = trainingData.
getNumSamples();
289 for(UINT i=0; i<numTrainingExamples; i++){
290 predict( trainingData[i].getSample(), K);
292 UINT classLabelIndex = 0;
293 for(UINT k=0; k<numClasses; k++){
294 if( predictedClassLabel == classLabels[k] ){
300 predictionResults[ i ].index = classLabelIndex;
301 predictionResults[ i ].value = classDistances[ classLabelIndex ];
303 trainingMu[ classLabelIndex ] += predictionResults[ i ].value;
304 counter[ classLabelIndex ]++;
307 for(UINT j=0; j<numClasses; j++){
312 for(UINT i=0; i<numTrainingExamples; i++){
313 trainingSigma[predictionResults[i].index] += SQR(predictionResults[i].value -
trainingMu[predictionResults[i].index]);
316 for(UINT j=0; j<numClasses; j++){
317 Float count = counter[j];
326 bool errorFound =
false;
327 for(UINT j=0; j<numClasses; j++){
329 warningLog << __GRT_LOG__ <<
" TrainingMu[ " << j <<
" ] is zero for a K value of " << K << std::endl;
332 warningLog << __GRT_LOG__ <<
" TrainingSigma[ " << j <<
" ] is zero for a K value of " << K << std::endl;
335 errorLog << __GRT_LOG__ <<
" TrainingMu[ " << j <<
" ] is NAN for a K value of " << K << std::endl;
339 errorLog << __GRT_LOG__ <<
" TrainingSigma[ " << j <<
" ] is NAN for a K value of " << K << std::endl;
350 for(
unsigned int j=0; j<numClasses; j++){
355 useNullRejection =
true;
359 nullRejectionThresholds.clear();
360 nullRejectionThresholds.
resize( numClasses, 0 );
369 errorLog << __GRT_LOG__ <<
" KNN model has not been trained" << std::endl;
373 if( inputVector.
getSize() != numInputDimensions ){
374 errorLog << __GRT_LOG__ <<
" The size of the input vector " << inputVector.
getSize() <<
" does not match the number of features " << numInputDimensions << std::endl;
380 for(UINT i=0; i<numInputDimensions; i++){
381 inputVector[i] =
scale(inputVector[i], ranges[i].minValue, ranges[i].maxValue, 0, 1);
386 return predict(inputVector,K);
389 bool KNN::predict(
const VectorFloat &inputVector,
const UINT K){
392 errorLog << __GRT_LOG__ <<
" KNN model has not been trained" << std::endl;
396 if( inputVector.
getSize() != numInputDimensions ){
397 errorLog << __GRT_LOG__ <<
" The size of the input vector " << inputVector.size() <<
" does not match the number of features " << numInputDimensions << std::endl;
402 errorLog << __GRT_LOG__ <<
" K Is Greater Than The Number Of Training Samples" << std::endl;
410 for(UINT i=0; i<M; i++){
412 UINT classLabel = trainingData[i].getClassLabel();
413 VectorFloat trainingSample = trainingData[i].getSample();
416 case EUCLIDEAN_DISTANCE:
417 dist = computeEuclideanDistance(inputVector,trainingSample);
419 case COSINE_DISTANCE:
420 dist = computeCosineDistance(inputVector,trainingSample);
422 case MANHATTAN_DISTANCE:
423 dist = computeManhattanDistance(inputVector, trainingSample);
426 errorLog << __GRT_LOG__ <<
" unkown distance measure!" << std::endl;
431 if( neighbours.size() < K ){
435 Float maxValue = neighbours[0].value;
437 for(UINT n=1; n<neighbours.size(); n++){
438 if( neighbours[n].value > maxValue ){
439 maxValue = neighbours[n].value;
445 if( dist < maxValue ){
452 if( classLikelihoods.size() != numClasses ) classLikelihoods.
resize(numClasses);
453 if( classDistances.size() != numClasses ) classDistances.
resize(numClasses);
455 std::fill(classLikelihoods.begin(),classLikelihoods.end(),0);
456 std::fill(classDistances.begin(),classDistances.end(),0);
459 for(UINT k=0; k<neighbours.size(); k++){
460 UINT classLabel = neighbours[k].index;
461 if( classLabel == 0 ){
462 errorLog << __GRT_LOG__ <<
" Class label of training example can not be zero!" << std::endl;
467 UINT classLabelIndex = 0;
468 for(UINT j=0; j<numClasses; j++){
469 if( classLabel == classLabels[j] ){
474 classLikelihoods[ classLabelIndex ] += 1;
475 classDistances[ classLabelIndex ] += neighbours[k].value;
479 Float maxCount = classLikelihoods[0];
481 for(UINT i=1; i<classLikelihoods.size(); i++){
482 if( classLikelihoods[i] > maxCount ){
483 maxCount = classLikelihoods[i];
489 for(UINT i=0; i<numClasses; i++){
490 if( classLikelihoods[i] > 0 ) classDistances[i] /= classLikelihoods[i];
495 for(UINT i=0; i<numClasses; i++){
496 classLikelihoods[i] /= Float( neighbours.size() );
500 maxLikelihood = classLikelihoods[ maxIndex ];
502 if( useNullRejection ){
503 if( classDistances[ maxIndex ] <= nullRejectionThresholds[ maxIndex ] ){
504 predictedClassLabel = classLabels[maxIndex];
506 predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
509 predictedClassLabel = classLabels[maxIndex];
521 trainingData.
clear();
532 errorLog << __GRT_LOG__ <<
" Could not open file to save model!" << std::endl;
537 file <<
"GRT_KNN_MODEL_FILE_V2.0\n";
541 errorLog << __GRT_LOG__ <<
" Failed to save classifier base settings to file!" << std::endl;
545 file <<
"K: " << K << std::endl;
552 if( useNullRejection ){
553 file <<
"TrainingMu: ";
558 file <<
"TrainingSigma: ";
564 file <<
"NumTrainingSamples: " << trainingData.
getNumSamples() << std::endl;
565 file <<
"TrainingData: \n";
569 file<< trainingData[i].getClassLabel() <<
"\t";
571 for(UINT j=0; j<numInputDimensions; j++){
572 file << trainingData[i][j] <<
"\t";
585 errorLog << __GRT_LOG__ <<
" Could not open file to load model!" << std::endl;
594 if( word ==
"GRT_KNN_MODEL_FILE_V1.0" ){
599 if(word !=
"GRT_KNN_MODEL_FILE_V2.0"){
600 errorLog << __GRT_LOG__ <<
" Could not find Model File Header!" << std::endl;
606 errorLog << __GRT_LOG__ <<
" Failed to load base settings from file!" << std::endl;
612 errorLog << __GRT_LOG__ <<
" Could not find K!" << std::endl;
618 if(word !=
"DistanceMethod:"){
619 errorLog << __GRT_LOG__ <<
" Could not find DistanceMethod!" << std::endl;
625 if(word !=
"SearchForBestKValue:"){
626 errorLog << __GRT_LOG__ <<
" Could not find SearchForBestKValue!" << std::endl;
632 if(word !=
"MinKSearchValue:"){
633 errorLog << __GRT_LOG__ <<
" Could not find MinKSearchValue!" << std::endl;
639 if(word !=
"MaxKSearchValue:"){
640 errorLog << __GRT_LOG__ <<
" Could not find MaxKSearchValue!" << std::endl;
651 if( useNullRejection ){
653 if(word !=
"TrainingMu:"){
654 errorLog << __GRT_LOG__ <<
" Could not find TrainingMu!" << std::endl;
659 for(UINT j=0; j<numClasses; j++){
664 if(word !=
"TrainingSigma:"){
665 errorLog << __GRT_LOG__ <<
" Could not find TrainingSigma!" << std::endl;
670 for(UINT j=0; j<numClasses; j++){
676 if(word !=
"NumTrainingSamples:"){
677 errorLog << __GRT_LOG__ <<
" Could not find NumTrainingSamples!" << std::endl;
680 unsigned int numTrainingSamples = 0;
681 file >> numTrainingSamples;
684 if(word !=
"TrainingData:"){
685 errorLog << __GRT_LOG__ <<
" Could not find TrainingData!" << std::endl;
691 unsigned int classLabel = 0;
693 for(UINT i=0; i<numTrainingSamples; i++){
698 for(UINT j=0; j<numInputDimensions; j++){
703 trainingData.
addSample(classLabel, sample);
707 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
709 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
721 nullRejectionThresholds.
resize(numClasses,0);
727 for(
unsigned int j=0; j<numClasses; j++){
758 if( nullRejectionCoeff > 0 ){
759 this->nullRejectionCoeff = nullRejectionCoeff;
767 if( distanceMethod == EUCLIDEAN_DISTANCE || distanceMethod == COSINE_DISTANCE || distanceMethod == MANHATTAN_DISTANCE ){
776 for(UINT j=0; j<numInputDimensions; j++){
777 dist += SQR( a[j] - b[j] );
789 for(UINT j=0; j<numInputDimensions; j++){
790 dotAB += a[j] * b[j];
795 dist = dotAB / (sqrt(magA) * sqrt(magB));
803 for(UINT j=0; j<numInputDimensions; j++){
804 dist += fabs( a[j] - b[j] );
816 if(word !=
"NumFeatures:"){
817 errorLog << __GRT_LOG__ <<
" Could not find NumFeatures!" << std::endl;
820 file >> numInputDimensions;
823 if(word !=
"NumClasses:"){
824 errorLog << __GRT_LOG__ <<
" Could not find NumClasses!" << std::endl;
831 errorLog << __GRT_LOG__ <<
" Could not find K!" << std::endl;
837 if(word !=
"DistanceMethod:"){
838 errorLog << __GRT_LOG__ <<
" Could not find DistanceMethod!" << std::endl;
844 if(word !=
"SearchForBestKValue:"){
845 errorLog << __GRT_LOG__ <<
" Could not find SearchForBestKValue!" << std::endl;
851 if(word !=
"MinKSearchValue:"){
852 errorLog << __GRT_LOG__ <<
" Could not find MinKSearchValue!" << std::endl;
858 if(word !=
"MaxKSearchValue:"){
859 errorLog << __GRT_LOG__ <<
" Could not find MaxKSearchValue!" << std::endl;
865 if(word !=
"UseScaling:"){
866 errorLog << __GRT_LOG__ <<
" Could not find UseScaling!" << std::endl;
872 if(word !=
"UseNullRejection:"){
873 errorLog << __GRT_LOG__ <<
" Could not find UseNullRejection!" << std::endl;
876 file >> useNullRejection;
879 if(word !=
"NullRejectionCoeff:"){
880 errorLog << __GRT_LOG__ <<
" Could not find NullRejectionCoeff!" << std::endl;
883 file >> nullRejectionCoeff;
888 ranges.
resize( numInputDimensions );
891 if(word !=
"Ranges:"){
892 errorLog << __GRT_LOG__ <<
" Could not find Ranges!" << std::endl;
895 for(UINT n=0; n<ranges.
getSize(); n++){
896 file >> ranges[n].minValue;
897 file >> ranges[n].maxValue;
906 if(word !=
"TrainingMu:"){
907 errorLog << __GRT_LOG__ <<
" Could not find TrainingMu!" << std::endl;
912 for(UINT j=0; j<numClasses; j++){
917 if(word !=
"TrainingSigma:"){
918 errorLog << __GRT_LOG__ <<
" Could not find TrainingSigma!" << std::endl;
923 for(UINT j=0; j<numClasses; j++){
928 if(word !=
"NumTrainingSamples:"){
929 errorLog << __GRT_LOG__ <<
" Could not find NumTrainingSamples!" << std::endl;
932 unsigned int numTrainingSamples = 0;
933 file >> numTrainingSamples;
936 if(word !=
"TrainingData:"){
937 errorLog << __GRT_LOG__ <<
" Could not find TrainingData!" << std::endl;
943 unsigned int classLabel = 0;
945 for(UINT i=0; i<numTrainingSamples; i++){
950 for(UINT j=0; j<numInputDimensions; j++){
955 trainingData.
addSample(classLabel, sample);
bool saveBaseSettingsToFile(std::fstream &file) const
VectorFloat trainingSigma
Holds the average max-class distance of the training data for each of classes
std::string getId() const
#define DEFAULT_NULL_LIKELIHOOD_VALUE
virtual bool save(std::fstream &file) const
virtual bool load(std::fstream &file)
bool addSample(const UINT classLabel, const VectorFloat &sample)
bool searchForBestKValue
The distance method used to compute the distance between each data point
Classifier(const std::string &classifierId="")
Vector< ClassTracker > getClassTracker() const
virtual bool resize(const unsigned int size)
bool setNumDimensions(UINT numDimensions)
static std::string getId()
virtual bool train_(ClassificationData &trainingData)
UINT distanceMethod
The number of neighbours to search for
KNN(UINT K=10, bool useScaling=false, bool useNullRejection=false, Float nullRejectionCoeff=10.0, bool searchForBestKValue=false, UINT minKSearchValue=1, UINT maxKSearchValue=10)
ClassificationData trainingData
The maximum K value to end the search at
UINT maxKSearchValue
The minimum K value to start the search from
bool setDistanceMethod(UINT distanceMethod)
virtual bool recomputeNullRejectionThresholds()
KNN & operator=(const KNN &rhs)
bool setMaxKSearchValue(UINT maxKSearchValue)
bool enableBestKValueSearch(bool searchForBestKValue)
UINT getNumSamples() const
virtual bool predict_(VectorFloat &inputVector)
virtual bool deepCopyFrom(const Classifier *classifier)
UINT minKSearchValue
Sets if the best K value should be searched for or if the model should be trained with K ...
bool copyBaseVariables(const Classifier *classifier)
bool loadBaseSettingsFromFile(std::fstream &file)
VectorFloat trainingMu
Holds the trainingData to perform the predictions
UINT getNumDimensions() const
UINT getNumClasses() const
Vector< MinMax > getRanges() const
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
bool setNullRejectionCoeff(Float nullRejectionCoeff)
bool setMinKSearchValue(UINT minKSearchValue)
bool scale(const Float minTarget, const Float maxTarget)
This is the main base class that all GRT Classification algorithms should inherit from...
bool loadLegacyModelFromFile(std::fstream &file)
Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)