25 LDA::LDA(
bool useScaling,
bool useNullRejection,Float nullRejectionCoeff)
27 this->useScaling = useScaling;
28 this->useNullRejection = useNullRejection;
29 this->nullRejectionCoeff = nullRejectionCoeff;
31 classifierType = classType;
32 classifierMode = STANDARD_CLASSIFIER_MODE;
33 debugLog.setProceedingText(
"[DEBUG LDA]");
34 errorLog.setProceedingText(
"[ERROR LDA]");
35 trainingLog.setProceedingText(
"[TRAINING LDA]");
36 warningLog.setProceedingText(
"[WARNING LDA]");
45 errorLog <<
"SORRY - this module is still under development and can't be used yet!" << std::endl;
49 numInputDimensions = 0;
56 errorLog <<
"train(LabelledClassificationData trainingData) - There is no training data to train the model!" << std::endl;
64 MatrixFloat SB = computeBetweenClassScatterMatrix( trainingData );
67 MatrixFloat SW = computeWithinClassScatterMatrix( trainingData );
202 errorLog <<
"predict(vector< Float > inputVector) - LDA Model Not Trained!" << std::endl;
206 predictedClassLabel = 0;
207 maxLikelihood = -10000;
209 if( !trained )
return false;
211 if( inputVector.
getSize() != numInputDimensions ){
212 errorLog <<
"predict(vector< Float > inputVector) - The size of the input vector (" << inputVector.
getSize() <<
") does not match the num features in the model (" << numInputDimensions << std::endl;
217 if( classLikelihoods.
getSize() != numClasses || classDistances.
getSize() != numClasses ){
218 classLikelihoods.
resize(numClasses);
219 classDistances.
resize(numClasses);
227 for(UINT k=0; k<numClasses; k++){
229 for(UINT j=0; j<numInputDimensions+1; j++){
230 if( j==0 ) classDistances[k] = models[k].weights[j];
231 else classDistances[k] += inputVector[j-1] * models[k].weights[j];
233 classLikelihoods[k] = exp( classDistances[k] );
234 sum += classLikelihoods[k];
236 if( classLikelihoods[k] > maxLikelihood ){
238 maxLikelihood = classLikelihoods[k];
243 for(UINT k=0; k<numClasses; k++){
244 classLikelihoods[k] /= sum;
247 maxLikelihood = classLikelihoods[ bestIndex ];
249 predictedClassLabel = models[ bestIndex ].classLabel;
258 errorLog <<
"saveModelToFile(fstream &file) - Could not open file to save model" << std::endl;
263 file<<
"GRT_LDA_MODEL_FILE_V1.0\n";
264 file<<
"NumFeatures: "<<numInputDimensions<< std::endl;
265 file<<
"NumClasses: "<<numClasses<< std::endl;
266 file <<
"UseScaling: " << useScaling << std::endl;
267 file<<
"UseNullRejection: " << useNullRejection << std::endl;
271 file <<
"Ranges: \n";
272 for(UINT n=0; n<ranges.size(); n++){
273 file << ranges[n].minValue <<
"\t" << ranges[n].maxValue << std::endl;
278 for(UINT k=0; k<numClasses; k++){
279 file<<
"ClassLabel: "<<models[k].classLabel<< std::endl;
280 file<<
"PriorProbability: "<<models[k].priorProb<< std::endl;
283 for(UINT j=0; j<models[k].getNumDimensions(); j++){
284 file <<
"\t" << models[k].weights[j];
294 numInputDimensions = 0;
301 errorLog <<
"loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
309 if(word !=
"GRT_LDA_MODEL_FILE_V1.0"){
310 errorLog <<
"loadModelFromFile(fstream &file) - Could not find Model File Header" << std::endl;
315 if(word !=
"NumFeatures:"){
316 errorLog <<
"loadModelFromFile(fstream &file) - Could not find NumFeatures " << std::endl;
319 file >> numInputDimensions;
322 if(word !=
"NumClasses:"){
323 errorLog <<
"loadModelFromFile(fstream &file) - Could not find NumClasses" << std::endl;
329 if(word !=
"UseScaling:"){
330 errorLog <<
"loadModelFromFile(fstream &file) - Could not find UseScaling" << std::endl;
336 if(word !=
"UseNullRejection:"){
337 errorLog <<
"loadModelFromFile(fstream &file) - Could not find UseNullRejection" << std::endl;
340 file >> useNullRejection;
345 ranges.
resize(numInputDimensions);
348 if(word !=
"Ranges:"){
349 errorLog <<
"loadModelFromFile(fstream &file) - Could not find the Ranges" << std::endl;
352 for(UINT n=0; n<ranges.size(); n++){
353 file >> ranges[n].minValue;
354 file >> ranges[n].maxValue;
359 models.
resize(numClasses);
360 classLabels.
resize(numClasses);
363 for(UINT k=0; k<numClasses; k++){
365 if(word !=
"ClassLabel:"){
366 errorLog <<
"loadModelFromFile(fstream &file) - Could not find ClassLabel for the "<<k+1<<
"th model" << std::endl;
369 file >> models[k].classLabel;
370 classLabels[k] = models[k].classLabel;
373 if(word !=
"PriorProbability:"){
374 errorLog <<
"loadModelFromFile(fstream &file) - Could not find the PriorProbability for the "<<k+1<<
"th model" << std::endl;
377 file >> models[k].priorProb;
379 models[k].weights.
resize(numInputDimensions+1);
383 if(word !=
"Weights:"){
384 errorLog <<
"loadModelFromFile(fstream &file) - Could not find the Weights vector for the "<<k+1<<
"th model" << std::endl;
389 for(UINT j=0; j<numInputDimensions+1; j++){
392 models[k].weights[j] = value;
398 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
400 classDistances.
resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
409 MatrixFloat sb(numInputDimensions,numInputDimensions);
412 sb.setAllValues( 0 );
414 for(UINT k=0; k<numClasses; k++){
418 for(UINT m=0; m<numInputDimensions; m++){
419 for(UINT n=0; n<numInputDimensions; n++){
420 sb[m][n] += (classMean[k][m]-totalMean[m]) * (classMean[k][n]-totalMean[n]) * Float(numSamplesInClass);
430 MatrixFloat sw(numInputDimensions,numInputDimensions);
431 sw.setAllValues( 0 );
433 for(UINT k=0; k<numClasses; k++){
440 for(UINT m=0; m<numInputDimensions; m++){
441 for(UINT n=0; n<numInputDimensions; n++){
442 sw[m][n] += scatterMatrix[m][n];
#define DEFAULT_NULL_LIKELIHOOD_VALUE
Vector< ClassTracker > getClassTracker() const
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
This class implements the Linear Discriminant Analysis Classification algorithm.
MatrixFloat getClassMean() const
virtual bool predict(VectorDouble inputVector)
unsigned int getSize() const
MatrixFloat getCovarianceMatrix() const
UINT getNumSamples() const
virtual bool train(ClassificationData trainingData)
LDA(bool useScaling=false, bool useNullRejection=true, Float nullRejectionCoeff=10.0)
UINT getNumDimensions() const
UINT getNumClasses() const
virtual bool saveModelToFile(std::fstream &file) const
virtual bool loadModelFromFile(std::fstream &file)
VectorFloat getMean() const