GestureRecognitionToolkit  Version: 0.1.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClassificationData.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #include "ClassificationData.h"
22 
23 GRT_BEGIN_NAMESPACE
24 
25 ClassificationData::ClassificationData(const UINT numDimensions,const std::string datasetName,const std::string infoText){
26  this->datasetName = datasetName;
27  this->numDimensions = numDimensions;
28  this->infoText = infoText;
29  totalNumSamples = 0;
30  crossValidationSetup = false;
31  useExternalRanges = false;
32  allowNullGestureClass = true;
33  if( numDimensions > 0 ) setNumDimensions( numDimensions );
34  infoLog.setProceedingText("[ClassificationData]");
35  debugLog.setProceedingText("[DEBUG ClassificationData]");
36  errorLog.setProceedingText("[ERROR ClassificationData]");
37  warningLog.setProceedingText("[WARNING ClassificationData]");
38 }
39 
41  *this = rhs;
42 }
43 
45 }
46 
48  if( this != &rhs){
49  this->datasetName = rhs.datasetName;
50  this->infoText = rhs.infoText;
51  this->numDimensions = rhs.numDimensions;
52  this->totalNumSamples = rhs.totalNumSamples;
53  this->kFoldValue = rhs.kFoldValue;
54  this->crossValidationSetup = rhs.crossValidationSetup;
55  this->useExternalRanges = rhs.useExternalRanges;
56  this->allowNullGestureClass = rhs.allowNullGestureClass;
57  this->externalRanges = rhs.externalRanges;
58  this->classTracker = rhs.classTracker;
59  this->data = rhs.data;
60  this->crossValidationIndexs = rhs.crossValidationIndexs;
61  this->infoLog = rhs.infoLog;
62  this->debugLog = rhs.debugLog;
63  this->errorLog = rhs.errorLog;
64  this->warningLog = rhs.warningLog;
65  }
66  return *this;
67 }
68 
70  totalNumSamples = 0;
71  data.clear();
72  classTracker.clear();
73  crossValidationSetup = false;
74  crossValidationIndexs.clear();
75 }
76 
77 bool ClassificationData::setNumDimensions(const UINT numDimensions){
78 
79  if( numDimensions > 0 ){
80  //Clear any previous training data
81  clear();
82 
83  //Set the dimensionality of the data
84  this->numDimensions = numDimensions;
85 
86  //Clear the external ranges
87  useExternalRanges = false;
88  externalRanges.clear();
89 
90  return true;
91  }
92 
93  errorLog << "setNumDimensions(const UINT numDimensions) - The number of dimensions of the dataset must be greater than zero!" << std::endl;
94  return false;
95 }
96 
97 bool ClassificationData::setDatasetName(const std::string datasetName){
98 
99  //Make sure there are no spaces in the std::string
100  if( datasetName.find(" ") == std::string::npos ){
101  this->datasetName = datasetName;
102  return true;
103  }
104 
105  errorLog << "setDatasetName(const std::string datasetName) - The dataset name cannot contain any spaces!" << std::endl;
106  return false;
107 }
108 
109 bool ClassificationData::setInfoText(const std::string infoText){
110  this->infoText = infoText;
111  return true;
112 }
113 
114 bool ClassificationData::setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel){
115 
116  for(UINT i=0; i<classTracker.getSize(); i++){
117  if( classTracker[i].classLabel == classLabel ){
118  classTracker[i].className = className;
119  return true;
120  }
121  }
122 
123  errorLog << "setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel) - Failed to find class with label: " << classLabel << std::endl;
124  return false;
125 }
126 
127 bool ClassificationData::setAllowNullGestureClass(const bool allowNullGestureClass){
128  this->allowNullGestureClass = allowNullGestureClass;
129  return true;
130 }
131 
132 bool ClassificationData::addSample(const UINT classLabel,const VectorFloat &sample){
133 
134  if( sample.getSize() != numDimensions ){
135  errorLog << "addSample(const UINT classLabel, VectorFloat &sample) - the size of the new sample (" << sample.getSize() << ") does not match the number of dimensions of the dataset (" << numDimensions << ")" << std::endl;
136  return false;
137  }
138 
139  //The class label must be greater than zero (as zero is used for the null rejection class label
140  if( classLabel == GRT_DEFAULT_NULL_CLASS_LABEL && !allowNullGestureClass ){
141  errorLog << "addSample(const UINT classLabel, VectorFloat &sample) - the class label can not be 0!" << std::endl;
142  return false;
143  }
144 
145  //The dataset has changed so flag that any previous cross validation setup will now not work
146  crossValidationSetup = false;
147  crossValidationIndexs.clear();
148 
149  ClassificationSample newSample(classLabel,sample);
150  data.push_back( newSample );
151  totalNumSamples++;
152 
153  if( classTracker.getSize() == 0 ){
154  ClassTracker tracker(classLabel,1);
155  classTracker.push_back(tracker);
156  }else{
157  bool labelFound = false;
158  for(UINT i=0; i<classTracker.getSize(); i++){
159  if( classLabel == classTracker[i].classLabel ){
160  classTracker[i].counter++;
161  labelFound = true;
162  break;
163  }
164  }
165  if( !labelFound ){
166  ClassTracker tracker(classLabel,1);
167  classTracker.push_back(tracker);
168  }
169  }
170 
171  //Update the class labels
172  sortClassLabels();
173 
174  return true;
175 }
176 
177 bool ClassificationData::removeSample( const UINT index ){
178 
179  if( totalNumSamples == 0 ){
180  warningLog << "removeSample( const UINT index ) - Failed to remove sample, the training dataset is empty!" << std::endl;
181  return false;
182  }
183 
184  if( index >= totalNumSamples ){
185  warningLog << "removeSample( const UINT index ) - Failed to remove sample, the index is out of bounds! Number of training samples: " << totalNumSamples << " index: " << index << std::endl;
186  return false;
187  }
188 
189  //The dataset has changed so flag that any previous cross validation setup will now not work
190  crossValidationSetup = false;
191  crossValidationIndexs.clear();
192 
193  //Find the corresponding class ID for the last training example
194  UINT classLabel = data[ index ].getClassLabel();
195 
196  //Remove the training example from the buffer
197  data.erase( data.begin()+index );
198 
199  totalNumSamples = data.getSize();
200 
201  //Remove the value from the counter
202  for(size_t i=0; i<classTracker.getSize(); i++){
203  if( classTracker[i].classLabel == classLabel ){
204  classTracker[i].counter--;
205  break;
206  }
207  }
208 
209  return true;
210 }
211 
213 
214  if( totalNumSamples == 0 ){
215  warningLog << "removeLastSample() - Failed to remove sample, the training dataset is empty!" << std::endl;
216  return false;
217  }
218 
219  return removeSample( totalNumSamples-1 );
220 }
221 
222 bool ClassificationData::reserve(const UINT N){
223 
224  data.reserve( N );
225 
226  if( data.capacity() >= N ) return true;
227 
228  return false;
229 }
230 
232  return removeClass( classLabel );
233 }
234 
235 bool ClassificationData::addClass(const UINT classLabel,const std::string className){
236 
237  //Check to make sure the class label does not exist
238  for(size_t i=0; i<classTracker.getSize(); i++){
239  if( classTracker[i].classLabel == classLabel ){
240  warningLog << "addClass(const UINT classLabel,const std::string className) - Failed to add class, it already exists! Class label: " << classLabel << std::endl;
241  return false;
242  }
243  }
244 
245  //Add the class label to the class tracker
246  classTracker.push_back( ClassTracker(classLabel,0,className) );
247 
248  //Sort the class labels
249  sortClassLabels();
250 
251  return true;
252 }
253 
254 UINT ClassificationData::removeClass(const UINT classLabel){
255 
256  UINT numExamplesRemoved = 0;
257  UINT numExamplesToRemove = 0;
258 
259  //The dataset has changed so flag that any previous cross validation setup will now not work
260  crossValidationSetup = false;
261  crossValidationIndexs.clear();
262 
263  //Find out how many training examples we need to remove
264  for(UINT i=0; i<classTracker.getSize(); i++){
265  if( classTracker[i].classLabel == classLabel ){
266  numExamplesToRemove = classTracker[i].counter;
267  classTracker.erase(classTracker.begin()+i);
268  break;
269  }
270  }
271 
272  //Remove the samples with the matching class ID
273  if( numExamplesToRemove > 0 ){
274  UINT i=0;
275  while( numExamplesRemoved < numExamplesToRemove ){
276  if( data[i].getClassLabel() == classLabel ){
277  data.erase(data.begin()+i);
278  numExamplesRemoved++;
279  }else if( ++i == data.getSize() ) break;
280  }
281  }
282 
283  totalNumSamples = data.getSize();
284 
285  return numExamplesRemoved;
286 }
287 
288 bool ClassificationData::relabelAllSamplesWithClassLabel(const UINT oldClassLabel,const UINT newClassLabel){
289  bool oldClassLabelFound = false;
290  bool newClassLabelAllReadyExists = false;
291  UINT indexOfOldClassLabel = 0;
292  UINT indexOfNewClassLabel = 0;
293 
294  //Find out how many training examples we need to relabel
295  for(UINT i=0; i<classTracker.getSize(); i++){
296  if( classTracker[i].classLabel == oldClassLabel ){
297  indexOfOldClassLabel = i;
298  oldClassLabelFound = true;
299  }
300  if( classTracker[i].classLabel == newClassLabel ){
301  indexOfNewClassLabel = i;
302  newClassLabelAllReadyExists = true;
303  }
304  }
305 
306  //If the old class label was not found then we can't do anything
307  if( !oldClassLabelFound ){
308  return false;
309  }
310 
311  //Relabel the old class labels
312  for(UINT i=0; i<totalNumSamples; i++){
313  if( data[i].getClassLabel() == oldClassLabel ){
314  data[i].setClassLabel(newClassLabel);
315  }
316  }
317 
318  //Update the class tracler
319  if( newClassLabelAllReadyExists ){
320  //Add the old sample count to the new sample count
321  classTracker[ indexOfNewClassLabel ].counter += classTracker[ indexOfOldClassLabel ].counter;
322  }else{
323  //Create a new class tracker
324  classTracker.push_back( ClassTracker(newClassLabel,classTracker[ indexOfOldClassLabel ].counter,classTracker[ indexOfOldClassLabel ].className) );
325  }
326 
327  //Erase the old class tracker
328  classTracker.erase( classTracker.begin() + indexOfOldClassLabel );
329 
330  //Sort the class labels
331  sortClassLabels();
332 
333  return true;
334 }
335 
336 bool ClassificationData::setExternalRanges(const Vector< MinMax > &externalRanges, const bool useExternalRanges){
337 
338  if( externalRanges.size() != numDimensions ) return false;
339 
340  this->externalRanges = externalRanges;
341  this->useExternalRanges = useExternalRanges;
342 
343  return true;
344 }
345 
346 bool ClassificationData::enableExternalRangeScaling(const bool useExternalRanges){
347  if( externalRanges.getSize() == numDimensions ){
348  this->useExternalRanges = useExternalRanges;
349  return true;
350  }
351  return false;
352 }
353 
354 bool ClassificationData::scale(const Float minTarget,const Float maxTarget){
355  Vector< MinMax > ranges = getRanges();
356  return scale(ranges,minTarget,maxTarget);
357 }
358 
359 bool ClassificationData::scale(const Vector<MinMax> &ranges,const Float minTarget,const Float maxTarget){
360  if( ranges.getSize() != numDimensions ) return false;
361 
362  //Scale the training data
363  for(UINT i=0; i<totalNumSamples; i++){
364  for(UINT j=0; j<numDimensions; j++){
365  data[i][j] = grt_scale(data[i][j],ranges[j].minValue,ranges[j].maxValue,minTarget,maxTarget);
366  }
367  }
368 
369  return true;
370 }
371 
372 bool ClassificationData::save(const std::string &filename) const{
373 
374  //Check if the file should be saved as a csv file
375  if( Util::stringEndsWith( filename, ".csv" ) ){
376  return saveDatasetToCSVFile( filename );
377  }
378 
379  //Otherwise save it as a custom GRT file
380  return saveDatasetToFile( filename );
381 }
382 
383 bool ClassificationData::load(const std::string &filename){
384 
385  //Check if the file should be loaded as a csv file
386  if( Util::stringEndsWith( filename, ".csv" ) ){
387  return loadDatasetFromCSVFile( filename );
388  }
389 
390  //Otherwise save it as a custom GRT file
391  return loadDatasetFromFile( filename );
392 }
393 
394 bool ClassificationData::saveDatasetToFile(const std::string &filename) const{
395 
396  std::fstream file;
397  file.open(filename.c_str(), std::ios::out);
398 
399  if( !file.is_open() ){
400  return false;
401  }
402 
403  file << "GRT_LABELLED_CLASSIFICATION_DATA_FILE_V1.0\n";
404  file << "DatasetName: " << datasetName << std::endl;
405  file << "InfoText: " << infoText << std::endl;
406  file << "NumDimensions: " << numDimensions << std::endl;
407  file << "TotalNumExamples: " << totalNumSamples << std::endl;
408  file << "NumberOfClasses: " << classTracker.size() << std::endl;
409  file << "ClassIDsAndCounters: " << std::endl;
410 
411  for(UINT i=0; i<classTracker.size(); i++){
412  file << classTracker[i].classLabel << "\t" << classTracker[i].counter << "\t" << classTracker[i].className << std::endl;
413  }
414 
415  file << "UseExternalRanges: " << useExternalRanges << std::endl;
416 
417  if( useExternalRanges ){
418  for(UINT i=0; i<externalRanges.size(); i++){
419  file << externalRanges[i].minValue << "\t" << externalRanges[i].maxValue << std::endl;
420  }
421  }
422 
423  file << "Data:\n";
424 
425  for(UINT i=0; i<totalNumSamples; i++){
426  file << data[i].getClassLabel();
427  for(UINT j=0; j<numDimensions; j++){
428  file << "\t" << data[i][j];
429  }
430  file << std::endl;
431  }
432 
433  file.close();
434  return true;
435 }
436 
437 bool ClassificationData::loadDatasetFromFile(const std::string &filename){
438 
439  std::fstream file;
440  file.open(filename.c_str(), std::ios::in);
441  UINT numClasses = 0;
442  clear();
443 
444  if( !file.is_open() ){
445  errorLog << "loadDatasetFromFile(const std::string &filename) - could not open file!" << std::endl;
446  return false;
447  }
448 
449  std::string word;
450 
451  //Check to make sure this is a file with the Training File Format
452  file >> word;
453  if(word != "GRT_LABELLED_CLASSIFICATION_DATA_FILE_V1.0"){
454  errorLog << "loadDatasetFromFile(const std::string &filename) - could not find file header!" << std::endl;
455  file.close();
456  return false;
457  }
458 
459  //Get the name of the dataset
460  file >> word;
461  if(word != "DatasetName:"){
462  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find DatasetName header!" << std::endl;
463  errorLog << word << std::endl;
464  file.close();
465  return false;
466  }
467  file >> datasetName;
468 
469  file >> word;
470  if(word != "InfoText:"){
471  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find InfoText header!" << std::endl;
472  file.close();
473  return false;
474  }
475 
476  //Load the info text
477  file >> word;
478  infoText = "";
479  while( word != "NumDimensions:" ){
480  infoText += word + " ";
481  file >> word;
482  }
483 
484  //Get the number of dimensions in the training data
485  if( word != "NumDimensions:" ){
486  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find NumDimensions header!" << std::endl;
487  file.close();
488  return false;
489  }
490  file >> numDimensions;
491 
492  //Get the total number of training examples in the training data
493  file >> word;
494  if( word != "TotalNumTrainingExamples:" && word != "TotalNumExamples:" ){
495  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find TotalNumTrainingExamples header!" << std::endl;
496  file.close();
497  return false;
498  }
499  file >> totalNumSamples;
500 
501  //Get the total number of classes in the training data
502  file >> word;
503  if(word != "NumberOfClasses:"){
504  errorLog << "loadDatasetFromFile(string filename) - failed to find NumberOfClasses header!" << std::endl;
505  file.close();
506  return false;
507  }
508  file >> numClasses;
509 
510  //Resize the class counter buffer and load the counters
511  classTracker.resize(numClasses);
512 
513  //Get the total number of classes in the training data
514  file >> word;
515  if(word != "ClassIDsAndCounters:"){
516  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find ClassIDsAndCounters header!" << std::endl;
517  file.close();
518  return false;
519  }
520 
521  for(UINT i=0; i<classTracker.getSize(); i++){
522  file >> classTracker[i].classLabel;
523  file >> classTracker[i].counter;
524  file >> classTracker[i].className;
525  }
526 
527  //Check if the dataset should be scaled using external ranges
528  file >> word;
529  if(word != "UseExternalRanges:"){
530  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find UseExternalRanges header!" << std::endl;
531  file.close();
532  return false;
533  }
534  file >> useExternalRanges;
535 
536  //If we are using external ranges then load them
537  if( useExternalRanges ){
538  externalRanges.resize(numDimensions);
539  for(UINT i=0; i<externalRanges.getSize(); i++){
540  file >> externalRanges[i].minValue;
541  file >> externalRanges[i].maxValue;
542  }
543  }
544 
545  //Get the main training data
546  file >> word;
547  if( word != "LabelledTrainingData:" && word != "Data:"){
548  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find LabelledTrainingData header!" << std::endl;
549  file.close();
550  return false;
551  }
552 
553  ClassificationSample tempSample( numDimensions );
554  data.resize( totalNumSamples, tempSample );
555 
556  for(UINT i=0; i<totalNumSamples; i++){
557  UINT classLabel = 0;
558  VectorFloat sample(numDimensions,0);
559  file >> classLabel;
560  for(UINT j=0; j<numDimensions; j++){
561  file >> sample[j];
562  }
563  data[i].set(classLabel, sample);
564  }
565 
566  file.close();
567 
568  //Sort the class labels
569  sortClassLabels();
570 
571  return true;
572 }
573 
574 bool ClassificationData::saveDatasetToCSVFile(const std::string &filename) const{
575 
576  std::fstream file;
577  file.open(filename.c_str(), std::ios::out );
578 
579  if( !file.is_open() ){
580  return false;
581  }
582 
583  //Write the data to the CSV file
584  for(UINT i=0; i<totalNumSamples; i++){
585  file << data[i].getClassLabel();
586  for(UINT j=0; j<numDimensions; j++){
587  file << "," << data[i][j];
588  }
589  file << std::endl;
590  }
591 
592  file.close();
593 
594  return true;
595 }
596 
597 bool ClassificationData::loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex){
598 
599  numDimensions = 0;
600  datasetName = "NOT_SET";
601  infoText = "";
602 
603  //Clear any previous data
604  clear();
605 
606  //Parse the CSV file
607  FileParser parser;
608 
609  Timer timer;
610 
611  timer.start();
612 
613  if( !parser.parseCSVFile(filename,true) ){
614  errorLog << "loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex) - Failed to parse CSV file!" << std::endl;
615  return false;
616  }
617 
618  if( !parser.getConsistentColumnSize() ){
619  errorLog << "loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndexe) - The CSV file does not have a consistent number of columns!" << std::endl;
620  return false;
621  }
622 
623  if( parser.getColumnSize() <= 1 ){
624  errorLog << "loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex) - The CSV file does not have enough columns! It should contain at least two columns!" << std::endl;
625  return false;
626  }
627 
628  //Set the number of dimensions
629  numDimensions = parser.getColumnSize()-1;
630 
631  timer.start();
632 
633  //Reserve the memory for the data
634  data.resize( parser.getRowSize(), ClassificationSample(numDimensions) );
635 
636  timer.start();
637 
638  //Loop over the samples and add them to the data set
639  UINT classLabel = 0;
640  UINT j = 0;
641  UINT n = 0;
642  totalNumSamples = parser.getRowSize();
643  for(UINT i=0; i<totalNumSamples; i++){
644  //Get the class label
645  classLabel = grt_from_str< UINT >( parser[i][classLabelColumnIndex] );
646 
647  //Set the class label
648  data[i].setClassLabel( classLabel );
649 
650  //Get the sample data
651  j=0;
652  n=0;
653  while( j != numDimensions ){
654  if( n != classLabelColumnIndex ){
655  data[i][j++] = grt_from_str< Float >( parser[i][n] );
656  }
657  n++;
658  }
659 
660  //Update the class tracker
661  if( classTracker.size() == 0 ){
662  ClassTracker tracker(classLabel,1);
663  classTracker.push_back(tracker);
664  }else{
665  bool labelFound = false;
666  const size_t numClasses = classTracker.size();
667  for(size_t i=0; i<numClasses; i++){
668  if( classLabel == classTracker[i].classLabel ){
669  classTracker[i].counter++;
670  labelFound = true;
671  break;
672  }
673  }
674  if( !labelFound ){
675  ClassTracker tracker(classLabel,1);
676  classTracker.push_back(tracker);
677  }
678  }
679  }
680 
681  //Sort the class labels
682  sortClassLabels();
683 
684  return true;
685 }
686 
688 
689  std::cout << getStatsAsString();
690 
691  return true;
692 }
693 
695 
696  sort(classTracker.begin(),classTracker.end(),ClassTracker::sortByClassLabelAscending);
697 
698  return true;
699 }
700 
701 ClassificationData ClassificationData::partition(const UINT trainingSizePercentage,const bool useStratifiedSampling){
702 
703  //Partitions the dataset into a training dataset (which is kept by this instance of the ClassificationData) and
704  //a testing/validation dataset (which is return as a new instance of the ClassificationData). The trainingSizePercentage
705  //therefore sets the size of the data which remains in this instance and the remaining percentage of data is then added to
706  //the testing/validation dataset
707 
708  //The dataset has changed so flag that any previous cross validation setup will now not work
709  crossValidationSetup = false;
710  crossValidationIndexs.clear();
711 
712  ClassificationData trainingSet(numDimensions);
713  ClassificationData testSet(numDimensions);
714  trainingSet.setAllowNullGestureClass( allowNullGestureClass );
715  testSet.setAllowNullGestureClass( allowNullGestureClass );
716  Vector< UINT > indexs( totalNumSamples );
717 
718  //Create the random partion indexs
719  Random random;
720  UINT randomIndex = 0;
721 
722  if( useStratifiedSampling ){
723  //Break the data into seperate classes
724  Vector< Vector< UINT > > classData( getNumClasses() );
725 
726  //Add the indexs to their respective classes
727  for(UINT i=0; i<totalNumSamples; i++){
728  classData[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
729  }
730 
731  //Randomize the order of the indexs in each of the class index buffers
732  for(UINT k=0; k<getNumClasses(); k++){
733  UINT numSamples = classData[k].getSize();
734  for(UINT x=0; x<numSamples; x++){
735  //Pick a random index
736  randomIndex = random.getRandomNumberInt(0,numSamples);
737 
738  //Swap the indexs
739  SWAP(classData[k][ x ], classData[k][ randomIndex ]);
740  }
741  }
742 
743  //Reserve the memory
744  UINT numTrainingSamples = 0;
745  UINT numTestSamples = 0;
746 
747  for(UINT k=0; k<getNumClasses(); k++){
748  UINT numTrainingExamples = (UINT) floor( Float(classData[k].size()) / 100.0 * Float(trainingSizePercentage) );
749  UINT numTestExamples = ((UINT)classData[k].size())-numTrainingExamples;
750  numTrainingSamples += numTrainingExamples;
751  numTestSamples += numTestExamples;
752  }
753 
754  trainingSet.reserve( numTrainingSamples );
755  testSet.reserve( numTestSamples );
756 
757  //Loop over each class and add the data to the trainingSet and testSet
758  for(UINT k=0; k<getNumClasses(); k++){
759  UINT numTrainingExamples = (UINT) floor( Float(classData[k].getSize()) / 100.0 * Float(trainingSizePercentage) );
760 
761  //Add the data to the training and test sets
762  for(UINT i=0; i<numTrainingExamples; i++){
763  trainingSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getSample() );
764  }
765  for(UINT i=numTrainingExamples; i<classData[k].getSize(); i++){
766  testSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getSample() );
767  }
768  }
769  }else{
770 
771  const UINT numTrainingExamples = (UINT) floor( Float(totalNumSamples) / 100.0 * Float(trainingSizePercentage) );
772  //Create the random partion indexs
773  Random random;
774  UINT randomIndex = 0;
775  for(UINT i=0; i<totalNumSamples; i++) indexs[i] = i;
776  std::random_shuffle(indexs.begin(), indexs.end());
777 
778  //Reserve the memory
779  trainingSet.reserve( numTrainingExamples );
780  testSet.reserve( totalNumSamples-numTrainingExamples );
781 
782  //Add the data to the training and test sets
783  for(UINT i=0; i<numTrainingExamples; i++){
784  trainingSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getSample() );
785  }
786  for(UINT i=numTrainingExamples; i<totalNumSamples; i++){
787  testSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getSample() );
788  }
789  }
790 
791  //Overwrite the training data in this instance with the training data of the trainingSet
792  *this = trainingSet;
793 
794  //Sort the class labels in this dataset
795  sortClassLabels();
796 
797  //Sort the class labels of the test dataset
798  testSet.sortClassLabels();
799 
800  return testSet;
801 }
802 
804 
805  if( labelledData.getNumDimensions() != numDimensions ){
806  errorLog << "merge(const ClassificationData &labelledData) - The number of dimensions in the labelledData (" << labelledData.getNumDimensions() << ") does not match the number of dimensions of this dataset (" << numDimensions << ")" << std::endl;
807  return false;
808  }
809 
810  //The dataset has changed so flag that any previous cross validation setup will now not work
811  crossValidationSetup = false;
812  crossValidationIndexs.clear();
813 
814  //Reserve the memory
815  reserve( getNumSamples() + labelledData.getNumSamples() );
816 
817  //Add the data from the labelledData to this instance
818  for(UINT i=0; i<labelledData.getNumSamples(); i++){
819  addSample(labelledData[i].getClassLabel(), labelledData[i].getSample());
820  }
821 
822  //Set the class names from the dataset
823  Vector< ClassTracker > classTracker = labelledData.getClassTracker();
824  for(UINT i=0; i<classTracker.size(); i++){
825  setClassNameForCorrespondingClassLabel(classTracker[i].className, classTracker[i].classLabel);
826  }
827 
828  //Sort the class labels
829  sortClassLabels();
830 
831  return true;
832 }
833 
834 bool ClassificationData::spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling){
835 
836  crossValidationSetup = false;
837  crossValidationIndexs.clear();
838 
839  //K can not be zero
840  if( K > totalNumSamples ){
841  errorLog << "spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling) - K can not be zero!" << std::endl;
842  return false;
843  }
844 
845  //K can not be larger than the number of examples
846  if( K > totalNumSamples ){
847  errorLog << "spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling) - K can not be larger than the total number of samples in the dataset!" << std::endl;
848  return false;
849  }
850 
851  //K can not be larger than the number of examples in a specific class if the stratified sampling option is true
852  if( useStratifiedSampling ){
853  for(UINT c=0; c<classTracker.size(); c++){
854  if( K > classTracker[c].counter ){
855  errorLog << "spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling) - K can not be larger than the number of samples in any given class!" << std::endl;
856  return false;
857  }
858  }
859  }
860 
861  //Setup the dataset for k-fold cross validation
862  kFoldValue = K;
863  Vector< UINT > indexs( totalNumSamples );
864 
865  //Work out how many samples are in each fold, the last fold might have more samples than the others
866  UINT numSamplesPerFold = (UINT) floor( totalNumSamples/Float(K) );
867 
868  //Add the random indexs to each fold
869  crossValidationIndexs.resize(K);
870 
871  //Create the random partion indexs
872  Random random;
873  UINT randomIndex = 0;
874 
875  if( useStratifiedSampling ){
876  //Break the data into seperate classes
877  Vector< Vector< UINT > > classData( getNumClasses() );
878 
879  //Add the indexs to their respective classes
880  for(UINT i=0; i<totalNumSamples; i++){
881  classData[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
882  }
883 
884  //Randomize the order of the indexs in each of the class index buffers
885  for(UINT c=0; c<getNumClasses(); c++){
886  UINT numSamples = (UINT)classData[c].size();
887  for(UINT x=0; x<numSamples; x++){
888  //Pick a random indexs
889  randomIndex = random.getRandomNumberInt(0,numSamples);
890 
891  //Swap the indexs
892  SWAP(classData[c][ x ] , classData[c][ randomIndex ]);
893  }
894  }
895 
896  //Loop over each of the k folds, at each fold add a sample from each class
898  for(UINT c=0; c<getNumClasses(); c++){
899  iter = classData[ c ].begin();
900  UINT k = 0;
901  while( iter != classData[c].end() ){
902  crossValidationIndexs[ k ].push_back( *iter );
903  iter++;
904  k++;
905  k = k % K;
906  }
907  }
908 
909  }else{
910  //Randomize the order of the data
911  for(UINT i=0; i<totalNumSamples; i++) indexs[i] = i;
912  for(UINT x=0; x<totalNumSamples; x++){
913  //Pick a random index
914  randomIndex = random.getRandomNumberInt(0,totalNumSamples);
915 
916  //Swap the indexs
917  SWAP(indexs[ x ] , indexs[ randomIndex ]);
918  }
919 
920  UINT counter = 0;
921  UINT foldIndex = 0;
922  for(UINT i=0; i<totalNumSamples; i++){
923  //Add the index to the current fold
924  crossValidationIndexs[ foldIndex ].push_back( indexs[i] );
925 
926  //Move to the next fold if ready
927  if( ++counter == numSamplesPerFold && foldIndex < K-1 ){
928  foldIndex++;
929  counter = 0;
930  }
931  }
932  }
933 
934  crossValidationSetup = true;
935  return true;
936 
937 }
938 
940 
941  ClassificationData trainingData;
942  trainingData.setNumDimensions( numDimensions );
943  trainingData.setAllowNullGestureClass( allowNullGestureClass );
944 
945  if( !crossValidationSetup ){
946  errorLog << "getTrainingFoldData(const UINT foldIndex) - Cross Validation has not been setup! You need to call the spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) function first before calling this function!" << std::endl;
947  return trainingData;
948  }
949 
950  if( foldIndex >= kFoldValue ) return trainingData;
951 
952  //Add the class labels to make sure they all exist
953  for(UINT k=0; k<getNumClasses(); k++){
954  trainingData.addClass( classTracker[k].classLabel, classTracker[k].className );
955  }
956 
957  //Add the data to the training set, this will consist of all the data that is NOT in the foldIndex
958  UINT index = 0;
959  for(UINT k=0; k<kFoldValue; k++){
960  if( k != foldIndex ){
961  for(UINT i=0; i<crossValidationIndexs[k].getSize(); i++){
962 
963  index = crossValidationIndexs[k][i];
964  trainingData.addSample( data[ index ].getClassLabel(), data[ index ].getSample() );
965  }
966  }
967  }
968 
969  //Sort the class labels
970  trainingData.sortClassLabels();
971 
972  return trainingData;
973 }
974 
976 
977  ClassificationData testData;
978  testData.setNumDimensions( numDimensions );
979  testData.setAllowNullGestureClass( allowNullGestureClass );
980 
981  if( !crossValidationSetup ) return testData;
982 
983  if( foldIndex >= kFoldValue ) return testData;
984 
985  //Add the class labels to make sure they all exist
986  for(UINT k=0; k<getNumClasses(); k++){
987  testData.addClass( classTracker[k].classLabel, classTracker[k].className );
988  }
989 
990  testData.reserve( crossValidationIndexs[ foldIndex ].getSize() );
991 
992  //Add the data to the test fold
993  UINT index = 0;
994  for(UINT i=0; i<crossValidationIndexs[ foldIndex ].getSize(); i++){
995 
996  index = crossValidationIndexs[ foldIndex ][i];
997  testData.addSample( data[ index ].getClassLabel(), data[ index ].getSample() );
998  }
999 
1000  //Sort the class labels
1001  testData.sortClassLabels();
1002 
1003  return testData;
1004 }
1005 
1007 
1008  ClassificationData classData;
1009  classData.setNumDimensions( this->numDimensions );
1010  classData.setAllowNullGestureClass( allowNullGestureClass );
1011 
1012  //Reserve the memory for the class data
1013  for(UINT i=0; i<classTracker.getSize(); i++){
1014  if( classTracker[i].classLabel == classLabel ){
1015  classData.reserve( classTracker[i].counter );
1016  break;
1017  }
1018  }
1019 
1020  for(UINT i=0; i<totalNumSamples; i++){
1021  if( data[i].getClassLabel() == classLabel ){
1022  classData.addSample(classLabel, data[i].getSample());
1023  }
1024  }
1025 
1026  return classData;
1027 }
1028 
1029 ClassificationData ClassificationData::getBootstrappedDataset(UINT numSamples,bool balanceDataset) const{
1030 
1031  Random rand;
1032  ClassificationData newDataset;
1033  newDataset.setNumDimensions( getNumDimensions() );
1034  newDataset.setAllowNullGestureClass( allowNullGestureClass );
1035  newDataset.setExternalRanges( externalRanges, useExternalRanges );
1036 
1037  if( numSamples == 0 ) numSamples = totalNumSamples;
1038 
1039  newDataset.reserve( numSamples );
1040 
1041  const UINT K = getNumClasses();
1042 
1043  //Add all the class labels to the new dataset to ensure the dataset has a list of all the labels
1044  for(UINT k=0; k<K; k++){
1045  newDataset.addClass( classTracker[k].classLabel );
1046  }
1047 
1048  if( balanceDataset ){
1049  //Group the class indexs
1050  Vector< Vector< UINT > > classIndexs( K );
1051  for(UINT i=0; i<totalNumSamples; i++){
1052  classIndexs[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
1053  }
1054 
1055  //Get the class with the minimum number of examples
1056  UINT numSamplesPerClass = (UINT)floor( numSamples / Float(K) );
1057 
1058  //Randomly select the training samples from each class
1059  UINT classIndex = 0;
1060  UINT classCounter = 0;
1061  UINT randomIndex = 0;
1062  for(UINT i=0; i<numSamples; i++){
1063  randomIndex = rand.getRandomNumberInt(0, (UINT)classIndexs[ classIndex ].size() );
1064  randomIndex = classIndexs[ classIndex ][ randomIndex ];
1065  newDataset.addSample(data[ randomIndex ].getClassLabel(), data[ randomIndex ].getSample());
1066  if( classCounter++ >= numSamplesPerClass && classIndex+1 < K ){
1067  classCounter = 0;
1068  classIndex++;
1069  }
1070  }
1071 
1072  }else{
1073  //Randomly select the training samples to add to the new data set
1074  UINT randomIndex;
1075  for(UINT i=0; i<numSamples; i++){
1076  randomIndex = rand.getRandomNumberInt(0, totalNumSamples);
1077  newDataset.addSample( data[randomIndex].getClassLabel(), data[randomIndex].getSample() );
1078  }
1079  }
1080 
1081  //Sort the class labels so they are in order
1082  newDataset.sortClassLabels();
1083 
1084  return newDataset;
1085 }
1086 
1088 
1089  //Turns the classification into a regression data to enable regression algorithms like the MLP to be used as a classifier
1090  //This sets the number of targets in the regression data equal to the number of classes in the classification data
1091  //The output of each regression training sample will then be all 0's, except for the index matching the classLabel, which will be 1
1092  //For this to work, the labelled classification data cannot have any samples with a classLabel of 0!
1093  RegressionData regressionData;
1094 
1095  if( totalNumSamples == 0 ){
1096  return regressionData;
1097  }
1098 
1099  const UINT numInputDimensions = numDimensions;
1100  const UINT numTargetDimensions = getNumClasses();
1101  regressionData.setInputAndTargetDimensions(numInputDimensions, numTargetDimensions);
1102 
1103  for(UINT i=0; i<totalNumSamples; i++){
1104  VectorFloat targetVector(numTargetDimensions,0);
1105 
1106  //Set the class index in the target Vector to 1 and all other values in the target Vector to 0
1107  UINT classLabel = data[i].getClassLabel();
1108 
1109  if( classLabel > 0 ){
1110  targetVector[ classLabel-1 ] = 1;
1111  }else{
1112  regressionData.clear();
1113  return regressionData;
1114  }
1115 
1116  regressionData.addSample(data[i].getSample(),targetVector);
1117  }
1118 
1119  return regressionData;
1120 }
1121 
1123 
1124  UnlabelledData unlabelledData;
1125 
1126  if( totalNumSamples == 0 ){
1127  return unlabelledData;
1128  }
1129 
1130  unlabelledData.setNumDimensions( numDimensions );
1131 
1132  for(UINT i=0; i<totalNumSamples; i++){
1133  unlabelledData.addSample( data[i].getSample() );
1134  }
1135 
1136  return unlabelledData;
1137 }
1138 
1140  UINT minClassLabel = grt_numeric_limits< UINT >::max();
1141 
1142  for(UINT i=0; i<classTracker.getSize(); i++){
1143  if( classTracker[i].classLabel < minClassLabel ){
1144  minClassLabel = classTracker[i].classLabel;
1145  }
1146  }
1147 
1148  return minClassLabel;
1149 }
1150 
1151 
1153  UINT maxClassLabel = 0;
1154 
1155  for(UINT i=0; i<classTracker.getSize(); i++){
1156  if( classTracker[i].classLabel > maxClassLabel ){
1157  maxClassLabel = classTracker[i].classLabel;
1158  }
1159  }
1160 
1161  return maxClassLabel;
1162 }
1163 
1165  for(UINT k=0; k<classTracker.getSize(); k++){
1166  if( classTracker[k].classLabel == classLabel ){
1167  return k;
1168  }
1169  }
1170  warningLog << "getClassLabelIndexValue(UINT classLabel) - Failed to find class label: " << classLabel << " in class tracker!" << std::endl;
1171  return 0;
1172 }
1173 
1175 
1176  for(UINT i=0; i<classTracker.getSize(); i++){
1177  if( classTracker[i].classLabel == classLabel ){
1178  return classTracker[i].className;
1179  }
1180  }
1181 
1182  return "CLASS_LABEL_NOT_FOUND";
1183 }
1184 
1186  std::string statsText;
1187  statsText += "DatasetName:\t" + datasetName + "\n";
1188  statsText += "DatasetInfo:\t" + infoText + "\n";
1189  statsText += "Number of Dimensions:\t" + Util::toString( numDimensions ) + "\n";
1190  statsText += "Number of Samples:\t" + Util::toString( totalNumSamples ) + "\n";
1191  statsText += "Number of Classes:\t" + Util::toString( getNumClasses() ) + "\n";
1192  statsText += "ClassStats:\n";
1193 
1194  for(UINT k=0; k<getNumClasses(); k++){
1195  statsText += "ClassLabel:\t" + Util::toString( classTracker[k].classLabel );
1196  statsText += "\tNumber of Samples:\t" + Util::toString(classTracker[k].counter);
1197  statsText += "\tClassName:\t" + classTracker[k].className + "\n";
1198  }
1199 
1200  Vector< MinMax > ranges = getRanges();
1201 
1202  statsText += "Dataset Ranges:\n";
1203  for(UINT j=0; j<ranges.size(); j++){
1204  statsText += "[" + Util::toString( j+1 ) + "] Min:\t" + Util::toString( ranges[j].minValue ) + "\tMax: " + Util::toString( ranges[j].maxValue ) + "\n";
1205  }
1206 
1207  return statsText;
1208 }
1209 
1211 
1212  //If the dataset should be scaled using the external ranges then return the external ranges
1213  if( useExternalRanges ) return externalRanges;
1214 
1215  Vector< MinMax > ranges(numDimensions);
1216 
1217  //Otherwise return the min and max values for each column in the dataset
1218  if( totalNumSamples > 0 ){
1219  for(UINT j=0; j<numDimensions; j++){
1220  ranges[j].minValue = data[0][j];
1221  ranges[j].maxValue = data[0][j];
1222  for(UINT i=0; i<totalNumSamples; i++){
1223  if( data[i][j] < ranges[j].minValue ){ ranges[j].minValue = data[i][j]; } //Search for the min value
1224  else if( data[i][j] > ranges[j].maxValue ){ ranges[j].maxValue = data[i][j]; } //Search for the max value
1225  }
1226  }
1227  }
1228  return ranges;
1229 }
1230 
1232  Vector< UINT > classLabels( getNumClasses(), 0 );
1233 
1234  if( getNumClasses() == 0 ) return classLabels;
1235 
1236  for(UINT i=0; i<getNumClasses(); i++){
1237  classLabels[i] = classTracker[i].classLabel;
1238  }
1239 
1240  return classLabels;
1241 }
1242 
1244  Vector< UINT > classSampleCounts( getNumClasses(), 0 );
1245 
1246  if( getNumSamples() == 0 ) return classSampleCounts;
1247 
1248  for(UINT i=0; i<getNumClasses(); i++){
1249  classSampleCounts[i] = classTracker[i].counter;
1250  }
1251 
1252  return classSampleCounts;
1253 }
1254 
1256 
1257  VectorFloat mean(numDimensions,0);
1258 
1259  for(UINT j=0; j<numDimensions; j++){
1260  for(UINT i=0; i<totalNumSamples; i++){
1261  mean[j] += data[i][j];
1262  }
1263  mean[j] /= Float(totalNumSamples);
1264  }
1265 
1266  return mean;
1267 }
1268 
1270 
1271  VectorFloat mean = getMean();
1272  VectorFloat stdDev(numDimensions,0);
1273 
1274  for(UINT j=0; j<numDimensions; j++){
1275  for(UINT i=0; i<totalNumSamples; i++){
1276  stdDev[j] += SQR(data[i][j]-mean[j]);
1277  }
1278  stdDev[j] = sqrt( stdDev[j] / Float(totalNumSamples-1) );
1279  }
1280 
1281  return stdDev;
1282 }
1283 
1284 MatrixFloat ClassificationData::getClassHistogramData(UINT classLabel,UINT numBins) const{
1285 
1286  const UINT M = getNumSamples();
1287  const UINT N = getNumDimensions();
1288 
1289  Vector< MinMax > ranges = getRanges();
1290  VectorFloat binRange(N);
1291  for(UINT i=0; i<ranges.size(); i++){
1292  binRange[i] = (ranges[i].maxValue-ranges[i].minValue)/Float(numBins);
1293  }
1294 
1295  MatrixFloat histData(N,numBins);
1296  histData.setAllValues(0);
1297 
1298  Float norm = 0;
1299  for(UINT i=0; i<M; i++){
1300  if( data[i].getClassLabel() == classLabel ){
1301  for(UINT j=0; j<N; j++){
1302  UINT binIndex = 0;
1303  bool binFound = false;
1304  for(UINT k=0; k<numBins-1; k++){
1305  if( data[i][j] >= ranges[i].minValue + (binRange[j]*k) && data[i][j] >= ranges[i].minValue + (binRange[j]*(k+1)) ){
1306  binIndex = k;
1307  binFound = true;
1308  break;
1309  }
1310  }
1311  if( !binFound ) binIndex = numBins-1;
1312  histData[j][binIndex]++;
1313  }
1314  norm++;
1315  }
1316  }
1317 
1318  if( norm == 0 ) return histData;
1319 
1320  //Is this the best way to normalize a multidimensional histogram???
1321  for(UINT i=0; i<histData.getNumRows(); i++){
1322  for(UINT j=0; j<histData.getNumCols(); j++){
1323  histData[i][j] /= norm;
1324  }
1325  }
1326 
1327  return histData;
1328 }
1329 
1331 
1332  MatrixFloat mean(getNumClasses(),numDimensions);
1333  VectorFloat counter(getNumClasses(),0);
1334 
1335  mean.setAllValues( 0 );
1336 
1337  for(UINT i=0; i<totalNumSamples; i++){
1338  UINT classIndex = getClassLabelIndexValue( data[i].getClassLabel() );
1339  for(UINT j=0; j<numDimensions; j++){
1340  mean[classIndex][j] += data[i][j];
1341  }
1342  counter[ classIndex ]++;
1343  }
1344 
1345  for(UINT k=0; k<getNumClasses(); k++){
1346  for(UINT j=0; j<numDimensions; j++){
1347  mean[k][j] = counter[k] > 0 ? mean[k][j]/counter[k] : 0;
1348  }
1349  }
1350 
1351  return mean;
1352 }
1353 
1355 
1356  MatrixFloat mean = getClassMean();
1357  MatrixFloat stdDev(getNumClasses(),numDimensions);
1358  VectorFloat counter(getNumClasses(),0);
1359 
1360  stdDev.setAllValues( 0 );
1361 
1362  for(UINT i=0; i<totalNumSamples; i++){
1363  UINT classIndex = getClassLabelIndexValue( data[i].getClassLabel() );
1364  for(UINT j=0; j<numDimensions; j++){
1365  stdDev[classIndex][j] += SQR(data[i][j]-mean[classIndex][j]);
1366  }
1367  counter[ classIndex ]++;
1368  }
1369 
1370  for(UINT k=0; k<getNumClasses(); k++){
1371  for(UINT j=0; j<numDimensions; j++){
1372  stdDev[k][j] = sqrt( stdDev[k][j] / Float(counter[k]-1) );
1373  }
1374  }
1375 
1376  return stdDev;
1377 }
1378 
1380 
1381  VectorFloat mean = getMean();
1382  MatrixFloat covariance(numDimensions,numDimensions);
1383 
1384  for(UINT j=0; j<numDimensions; j++){
1385  for(UINT k=0; k<numDimensions; k++){
1386  for(UINT i=0; i<totalNumSamples; i++){
1387  covariance[j][k] += (data[i][j]-mean[j]) * (data[i][k]-mean[k]) ;
1388  }
1389  covariance[j][k] /= Float(totalNumSamples-1);
1390  }
1391  }
1392 
1393  return covariance;
1394 }
1395 
1397  const UINT K = getNumClasses();
1398  Vector< MatrixFloat > histData(K);
1399 
1400  for(UINT k=0; k<K; k++){
1401  histData[k] = getClassHistogramData( classTracker[k].classLabel, numBins );
1402  }
1403 
1404  return histData;
1405 }
1406 
1407 VectorFloat ClassificationData::getClassProbabilities() const {
1408  return getClassProbabilities( getClassLabels() );
1409 }
1410 
1411 VectorFloat ClassificationData::getClassProbabilities( const Vector< UINT > &classLabels ) const {
1412  const UINT K = (UINT)classLabels.size();
1413  const UINT N = getNumClasses();
1414  Float sum = 0;
1415  VectorFloat x(K,0);
1416  for(UINT k=0; k<K; k++){
1417  for(UINT n=0; n<N; n++){
1418  if( classLabels[k] == classTracker[n].classLabel ){
1419  x[k] = classTracker[n].counter;
1420  sum += classTracker[n].counter;
1421  break;
1422  }
1423  }
1424  }
1425 
1426  //Normalize the class probabilities
1427  if( sum > 0 ){
1428  for(UINT k=0; k<K; k++){
1429  x[k] /= sum;
1430  }
1431  }
1432 
1433  return x;
1434 }
1435 
1437 
1438  const UINT M = getNumSamples();
1439  const UINT K = getNumClasses();
1440  UINT N = 0;
1441 
1442  //Get the number of samples in the class
1443  for(UINT k=0; k<K; k++){
1444  if( classTracker[k].classLabel == classLabel){
1445  N = classTracker[k].counter;
1446  break;
1447  }
1448  }
1449 
1450  UINT index = 0;
1451  Vector< UINT > classIndexes(N);
1452  for(UINT i=0; i<M; i++){
1453  if( data[i].getClassLabel() == classLabel ){
1454  classIndexes[index++] = i;
1455  }
1456  }
1457 
1458  return classIndexes;
1459 }
1460 
1462 
1463  const UINT M = getNumSamples();
1464  const UINT N = getNumDimensions();
1465  MatrixDouble d(M,N);
1466 
1467  for(UINT i=0; i<M; i++){
1468  for(UINT j=0; j<N; j++){
1469  d[i][j] = data[i][j];
1470  }
1471  }
1472 
1473  return d;
1474 }
1475 
1477  const UINT M = getNumSamples();
1478  const UINT N = getNumDimensions();
1479  MatrixFloat d(M,N);
1480 
1481  for(UINT i=0; i<M; i++){
1482  for(UINT j=0; j<N; j++){
1483  d[i][j] = data[i][j];
1484  }
1485  }
1486 
1487  return d;
1488 }
1489 
1490 bool ClassificationData::generateGaussDataset( const std::string filename, const UINT numSamples, const UINT numClasses, const UINT numDimensions, const Float range, const Float sigma ){
1491 
1492  Random random;
1493 
1494  //Generate a simple model that will be used to generate the main dataset
1495  MatrixFloat model(numClasses,numDimensions);
1496  for(UINT k=0; k<numClasses; k++){
1497  for(UINT j=0; j<numDimensions; j++){
1498  model[k][j] = random.getRandomNumberUniform(-range,range);
1499  }
1500  }
1501 
1502  //Use the model above to generate the main dataset
1503  ClassificationData data;
1504  data.setNumDimensions( numDimensions );
1505 
1506  for(UINT i=0; i<numSamples; i++){
1507 
1508  //Randomly select which class this sample belongs to
1509  UINT k = random.getRandomNumberInt( 0, numClasses );
1510 
1511  //Generate a sample using the model (+ some Gaussian noise)
1512  VectorFloat sample( numDimensions );
1513  for(UINT j=0; j<numDimensions; j++){
1514  sample[j] = model[k][j] + random.getRandomNumberGauss(0,sigma);
1515  }
1516 
1517  //By default in the GRT, the class label should not be 0, so add 1
1518  UINT classLabel = k + 1;
1519 
1520  //Add the labeled sample to the dataset
1521  data.addSample( classLabel, sample );
1522  }
1523 
1524  //Save the dataset to a CSV file
1525  return data.save( filename );
1526 }
1527 
1528 GRT_END_NAMESPACE
1529 
bool saveDatasetToFile(const std::string &filename) const
bool setDatasetName(std::string datasetName)
bool loadDatasetFromFile(const std::string &filename)
static std::string toString(const int &i)
Definition: Util.cpp:73
RegressionData reformatAsRegressionData() const
ClassificationData & operator=(const ClassificationData &rhs)
Definition: Timer.h:43
static bool generateGaussDataset(const std::string filename, const UINT numSamples=10000, const UINT numClasses=10, const UINT numDimensions=3, const Float range=10, const Float sigma=1)
bool addSample(const VectorFloat &sample)
The ClassificationData is the main data structure for recording, labeling, managing, saving, and loading training data for supervised learning problems.
bool relabelAllSamplesWithClassLabel(const UINT oldClassLabel, const UINT newClassLabel)
bool addSample(UINT classLabel, const VectorFloat &sample)
ClassificationData getTestFoldData(const UINT foldIndex) const
bool addClass(const UINT classLabel, const std::string className="NOT_SET")
Vector< ClassTracker > getClassTracker() const
Definition: Random.h:40
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
bool setNumDimensions(UINT numDimensions)
UINT eraseAllSamplesWithClassLabel(const UINT classLabel)
MatrixDouble getDataAsMatrixDouble() const
MatrixFloat getClassMean() const
Float getRandomNumberGauss(Float mu=0.0, Float sigma=1.0)
Definition: Random.h:209
std::string getClassNameForCorrespondingClassLabel(const UINT classLabel) const
bool setClassNameForCorrespondingClassLabel(std::string className, UINT classLabel)
Vector< UINT > getClassLabels() const
bool loadDatasetFromCSVFile(const std::string &filename, const UINT classLabelColumnIndex=0)
bool setAllowNullGestureClass(bool allowNullGestureClass)
UINT getMinimumClassLabel() const
Vector< MatrixFloat > getHistogramData(const UINT numBins) const
unsigned int getSize() const
Definition: Vector.h:193
UINT removeClass(const UINT classLabel)
ClassificationData(UINT numDimensions=0, std::string datasetName="NOT_SET", std::string infoText="")
bool setAllValues(const T &value)
Definition: Matrix.h:336
bool setInputAndTargetDimensions(const UINT numInputDimensions, const UINT numTargetDimensions)
bool setInfoText(std::string infoText)
Vector< UINT > getNumSamplesPerClass() const
MatrixFloat getCovarianceMatrix() const
UnlabelledData reformatAsUnlabelledData() const
bool removeSample(const UINT index)
UINT getNumSamples() const
bool spiltDataIntoKFolds(const UINT K, const bool useStratifiedSampling=false)
bool save(const std::string &filename) const
bool setNumDimensions(const UINT numDimensions)
bool enableExternalRangeScaling(const bool useExternalRanges)
bool setExternalRanges(const Vector< MinMax > &externalRanges, const bool useExternalRanges=false)
bool reserve(const UINT N)
bool saveDatasetToCSVFile(const std::string &filename) const
ClassificationData partition(const UINT partitionPercentage, const bool useStratifiedSampling=false)
unsigned int getNumRows() const
Definition: Matrix.h:542
UINT getNumDimensions() const
UINT getNumClasses() const
unsigned int getNumCols() const
Definition: Matrix.h:549
bool start()
Definition: Timer.h:64
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
Definition: Random.h:198
bool merge(const ClassificationData &data)
VectorFloat getStdDev() const
Vector< UINT > getClassDataIndexes(const UINT classLabel) const
int getRandomNumberInt(int minRange, int maxRange)
Definition: Random.h:88
MatrixFloat getDataAsMatrixFloat() const
static bool stringEndsWith(const std::string &str, const std::string &ending)
Definition: Util.cpp:156
UINT getClassLabelIndexValue(const UINT classLabel) const
ClassificationData getBootstrappedDataset(UINT numSamples=0, bool balanceDataset=false) const
MatrixFloat getClassHistogramData(const UINT classLabel, const UINT numBins) const
ClassificationData getTrainingFoldData(const UINT foldIndex) const
UINT getMaximumClassLabel() const
bool scale(const Float minTarget, const Float maxTarget)
bool load(const std::string &filename)
MatrixFloat getClassStdDev() const
bool addSample(const VectorFloat &inputVector, const VectorFloat &targetVector)
std::string getStatsAsString() const
VectorFloat getMean() const