21 #define GRT_DLL_EXPORTS
26 LDA::LDA(
bool useScaling,
bool useNullRejection,Float nullRejectionCoeff)
28 this->useScaling = useScaling;
29 this->useNullRejection = useNullRejection;
30 this->nullRejectionCoeff = nullRejectionCoeff;
32 classifierType = classType;
33 classifierMode = STANDARD_CLASSIFIER_MODE;
34 debugLog.setProceedingText(
"[DEBUG LDA]");
35 errorLog.setProceedingText(
"[ERROR LDA]");
36 trainingLog.setProceedingText(
"[TRAINING LDA]");
37 warningLog.setProceedingText(
"[WARNING LDA]");
46 errorLog <<
"SORRY - this module is still under development and can't be used yet!" << std::endl;
50 numInputDimensions = 0;
57 errorLog <<
"train(LabelledClassificationData trainingData) - There is no training data to train the model!" << std::endl;
65 MatrixFloat SB = computeBetweenClassScatterMatrix( trainingData );
68 MatrixFloat SW = computeWithinClassScatterMatrix( trainingData );
203 errorLog <<
"predict(vector< Float > inputVector) - LDA Model Not Trained!" << std::endl;
207 predictedClassLabel = 0;
208 maxLikelihood = -10000;
210 if( !trained )
return false;
212 if( inputVector.
getSize() != numInputDimensions ){
213 errorLog <<
"predict(vector< Float > inputVector) - The size of the input vector (" << inputVector.
getSize() <<
") does not match the num features in the model (" << numInputDimensions << std::endl;
218 if( classLikelihoods.
getSize() != numClasses || classDistances.
getSize() != numClasses ){
219 classLikelihoods.
resize(numClasses);
220 classDistances.
resize(numClasses);
228 for(UINT k=0; k<numClasses; k++){
230 for(UINT j=0; j<numInputDimensions+1; j++){
231 if( j==0 ) classDistances[k] = models[k].weights[j];
232 else classDistances[k] += inputVector[j-1] * models[k].weights[j];
234 classLikelihoods[k] = exp( classDistances[k] );
235 sum += classLikelihoods[k];
237 if( classLikelihoods[k] > maxLikelihood ){
239 maxLikelihood = classLikelihoods[k];
244 for(UINT k=0; k<numClasses; k++){
245 classLikelihoods[k] /= sum;
248 maxLikelihood = classLikelihoods[ bestIndex ];
250 predictedClassLabel = models[ bestIndex ].classLabel;
259 errorLog <<
"saveModelToFile(fstream &file) - Could not open file to save model" << std::endl;
264 file<<
"GRT_LDA_MODEL_FILE_V1.0\n";
265 file<<
"NumFeatures: "<<numInputDimensions<< std::endl;
266 file<<
"NumClasses: "<<numClasses<< std::endl;
267 file <<
"UseScaling: " << useScaling << std::endl;
268 file<<
"UseNullRejection: " << useNullRejection << std::endl;
272 file <<
"Ranges: \n";
273 for(UINT n=0; n<ranges.size(); n++){
274 file << ranges[n].minValue <<
"\t" << ranges[n].maxValue << std::endl;
279 for(UINT k=0; k<numClasses; k++){
280 file<<
"ClassLabel: "<<models[k].classLabel<< std::endl;
281 file<<
"PriorProbability: "<<models[k].priorProb<< std::endl;
284 for(UINT j=0; j<models[k].getNumDimensions(); j++){
285 file <<
"\t" << models[k].weights[j];
295 numInputDimensions = 0;
302 errorLog <<
"loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
310 if(word !=
"GRT_LDA_MODEL_FILE_V1.0"){
311 errorLog <<
"loadModelFromFile(fstream &file) - Could not find Model File Header" << std::endl;
316 if(word !=
"NumFeatures:"){
317 errorLog <<
"loadModelFromFile(fstream &file) - Could not find NumFeatures " << std::endl;
320 file >> numInputDimensions;
323 if(word !=
"NumClasses:"){
324 errorLog <<
"loadModelFromFile(fstream &file) - Could not find NumClasses" << std::endl;
330 if(word !=
"UseScaling:"){
331 errorLog <<
"loadModelFromFile(fstream &file) - Could not find UseScaling" << std::endl;
337 if(word !=
"UseNullRejection:"){
338 errorLog <<
"loadModelFromFile(fstream &file) - Could not find UseNullRejection" << std::endl;
341 file >> useNullRejection;
346 ranges.
resize(numInputDimensions);
349 if(word !=
"Ranges:"){
350 errorLog <<
"loadModelFromFile(fstream &file) - Could not find the Ranges" << std::endl;
353 for(UINT n=0; n<ranges.size(); n++){
354 file >> ranges[n].minValue;
355 file >> ranges[n].maxValue;
360 models.
resize(numClasses);
361 classLabels.
resize(numClasses);
364 for(UINT k=0; k<numClasses; k++){
366 if(word !=
"ClassLabel:"){
367 errorLog <<
"loadModelFromFile(fstream &file) - Could not find ClassLabel for the "<<k+1<<
"th model" << std::endl;
370 file >> models[k].classLabel;
371 classLabels[k] = models[k].classLabel;
374 if(word !=
"PriorProbability:"){
375 errorLog <<
"loadModelFromFile(fstream &file) - Could not find the PriorProbability for the "<<k+1<<
"th model" << std::endl;
378 file >> models[k].priorProb;
380 models[k].weights.
resize(numInputDimensions+1);
384 if(word !=
"Weights:"){
385 errorLog <<
"loadModelFromFile(fstream &file) - Could not find the Weights vector for the "<<k+1<<
"th model" << std::endl;
390 for(UINT j=0; j<numInputDimensions+1; j++){
393 models[k].weights[j] = value;
399 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
401 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
410 MatrixFloat sb(numInputDimensions,numInputDimensions);
413 sb.setAllValues( 0 );
415 for(UINT k=0; k<numClasses; k++){
419 for(UINT m=0; m<numInputDimensions; m++){
420 for(UINT n=0; n<numInputDimensions; n++){
421 sb[m][n] += (classMean[k][m]-totalMean[m]) * (classMean[k][n]-totalMean[n]) * Float(numSamplesInClass);
431 MatrixFloat sw(numInputDimensions,numInputDimensions);
432 sw.setAllValues( 0 );
434 for(UINT k=0; k<numClasses; k++){
441 for(UINT m=0; m<numInputDimensions; m++){
442 for(UINT n=0; n<numInputDimensions; n++){
443 sw[m][n] += scatterMatrix[m][n];
#define DEFAULT_NULL_LIKELIHOOD_VALUE
Vector< ClassTracker > getClassTracker() const
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
This class implements the Linear Discriminant Analysis Classification algorithm.
MatrixFloat getClassMean() const
virtual bool predict(VectorDouble inputVector)
MatrixFloat getCovarianceMatrix() const
UINT getNumSamples() const
virtual bool train(ClassificationData trainingData)
LDA(bool useScaling=false, bool useNullRejection=true, Float nullRejectionCoeff=10.0)
UINT getNumDimensions() const
UINT getNumClasses() const
virtual bool saveModelToFile(std::fstream &file) const
virtual bool loadModelFromFile(std::fstream &file)
VectorFloat getMean() const