GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClassificationData.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #define GRT_DLL_EXPORTS
22 #include "ClassificationData.h"
23 
24 GRT_BEGIN_NAMESPACE
25 
26 ClassificationData::ClassificationData(const UINT numDimensions,const std::string datasetName,const std::string infoText){
27  this->datasetName = datasetName;
28  this->numDimensions = numDimensions;
29  this->infoText = infoText;
30  totalNumSamples = 0;
31  crossValidationSetup = false;
32  useExternalRanges = false;
33  allowNullGestureClass = true;
34  if( numDimensions > 0 ) setNumDimensions( numDimensions );
35  infoLog.setProceedingText("[ClassificationData]");
36  debugLog.setProceedingText("[DEBUG ClassificationData]");
37  errorLog.setProceedingText("[ERROR ClassificationData]");
38  warningLog.setProceedingText("[WARNING ClassificationData]");
39 }
40 
42  *this = rhs;
43 }
44 
46 }
47 
49  if( this != &rhs){
50  this->datasetName = rhs.datasetName;
51  this->infoText = rhs.infoText;
52  this->numDimensions = rhs.numDimensions;
53  this->totalNumSamples = rhs.totalNumSamples;
54  this->kFoldValue = rhs.kFoldValue;
55  this->crossValidationSetup = rhs.crossValidationSetup;
56  this->useExternalRanges = rhs.useExternalRanges;
57  this->allowNullGestureClass = rhs.allowNullGestureClass;
58  this->externalRanges = rhs.externalRanges;
59  this->classTracker = rhs.classTracker;
60  this->data = rhs.data;
61  this->crossValidationIndexs = rhs.crossValidationIndexs;
62  this->infoLog = rhs.infoLog;
63  this->debugLog = rhs.debugLog;
64  this->errorLog = rhs.errorLog;
65  this->warningLog = rhs.warningLog;
66  }
67  return *this;
68 }
69 
71  totalNumSamples = 0;
72  data.clear();
73  classTracker.clear();
74  crossValidationSetup = false;
75  crossValidationIndexs.clear();
76 }
77 
78 bool ClassificationData::setNumDimensions(const UINT numDimensions){
79 
80  if( numDimensions > 0 ){
81  //Clear any previous training data
82  clear();
83 
84  //Set the dimensionality of the data
85  this->numDimensions = numDimensions;
86 
87  //Clear the external ranges
88  useExternalRanges = false;
89  externalRanges.clear();
90 
91  return true;
92  }
93 
94  errorLog << "setNumDimensions(const UINT numDimensions) - The number of dimensions of the dataset must be greater than zero!" << std::endl;
95  return false;
96 }
97 
98 bool ClassificationData::setDatasetName(const std::string datasetName){
99 
100  //Make sure there are no spaces in the std::string
101  if( datasetName.find(" ") == std::string::npos ){
102  this->datasetName = datasetName;
103  return true;
104  }
105 
106  errorLog << "setDatasetName(const std::string datasetName) - The dataset name cannot contain any spaces!" << std::endl;
107  return false;
108 }
109 
110 bool ClassificationData::setInfoText(const std::string infoText){
111  this->infoText = infoText;
112  return true;
113 }
114 
115 bool ClassificationData::setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel){
116 
117  for(UINT i=0; i<classTracker.getSize(); i++){
118  if( classTracker[i].classLabel == classLabel ){
119  classTracker[i].className = className;
120  return true;
121  }
122  }
123 
124  errorLog << "setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel) - Failed to find class with label: " << classLabel << std::endl;
125  return false;
126 }
127 
128 bool ClassificationData::setAllowNullGestureClass(const bool allowNullGestureClass){
129  this->allowNullGestureClass = allowNullGestureClass;
130  return true;
131 }
132 
133 bool ClassificationData::addSample(const UINT classLabel,const VectorFloat &sample){
134 
135  if( sample.getSize() != numDimensions ){
136  if( totalNumSamples == 0 ){
137  warningLog << "addSample(const UINT classLabel, VectorFloat &sample) - the size of the new sample (" << sample.getSize() << ") does not match the number of dimensions of the dataset (" << numDimensions << "), setting dimensionality to: " << numDimensions << std::endl;
138  numDimensions = sample.getSize();
139  }else{
140  errorLog << "addSample(const UINT classLabel, VectorFloat &sample) - the size of the new sample (" << sample.getSize() << ") does not match the number of dimensions of the dataset (" << numDimensions << ")" << std::endl;
141  return false;
142  }
143  }
144 
145  //The class label must be greater than zero (as zero is used for the null rejection class label
146  if( classLabel == GRT_DEFAULT_NULL_CLASS_LABEL && !allowNullGestureClass ){
147  errorLog << "addSample(const UINT classLabel, VectorFloat &sample) - the class label can not be 0!" << std::endl;
148  return false;
149  }
150 
151  //The dataset has changed so flag that any previous cross validation setup will now not work
152  crossValidationSetup = false;
153  crossValidationIndexs.clear();
154 
155  ClassificationSample newSample(classLabel,sample);
156  data.push_back( newSample );
157  totalNumSamples++;
158 
159  if( classTracker.getSize() == 0 ){
160  ClassTracker tracker(classLabel,1);
161  classTracker.push_back(tracker);
162  }else{
163  bool labelFound = false;
164  for(UINT i=0; i<classTracker.getSize(); i++){
165  if( classLabel == classTracker[i].classLabel ){
166  classTracker[i].counter++;
167  labelFound = true;
168  break;
169  }
170  }
171  if( !labelFound ){
172  ClassTracker tracker(classLabel,1);
173  classTracker.push_back(tracker);
174  }
175  }
176 
177  //Update the class labels
178  sortClassLabels();
179 
180  return true;
181 }
182 
183 bool ClassificationData::removeSample( const UINT index ){
184 
185  if( totalNumSamples == 0 ){
186  warningLog << "removeSample( const UINT index ) - Failed to remove sample, the training dataset is empty!" << std::endl;
187  return false;
188  }
189 
190  if( index >= totalNumSamples ){
191  warningLog << "removeSample( const UINT index ) - Failed to remove sample, the index is out of bounds! Number of training samples: " << totalNumSamples << " index: " << index << std::endl;
192  return false;
193  }
194 
195  //The dataset has changed so flag that any previous cross validation setup will now not work
196  crossValidationSetup = false;
197  crossValidationIndexs.clear();
198 
199  //Find the corresponding class ID for the last training example
200  UINT classLabel = data[ index ].getClassLabel();
201 
202  //Remove the training example from the buffer
203  data.erase( data.begin()+index );
204 
205  totalNumSamples = data.getSize();
206 
207  //Remove the value from the counter
208  for(size_t i=0; i<classTracker.getSize(); i++){
209  if( classTracker[i].classLabel == classLabel ){
210  classTracker[i].counter--;
211  break;
212  }
213  }
214 
215  return true;
216 }
217 
219 
220  if( totalNumSamples == 0 ){
221  warningLog << "removeLastSample() - Failed to remove sample, the training dataset is empty!" << std::endl;
222  return false;
223  }
224 
225  return removeSample( totalNumSamples-1 );
226 }
227 
228 bool ClassificationData::reserve(const UINT N){
229 
230  data.reserve( N );
231 
232  if( data.capacity() >= N ) return true;
233 
234  return false;
235 }
236 
238  return removeClass( classLabel );
239 }
240 
241 bool ClassificationData::addClass(const UINT classLabel,const std::string className){
242 
243  //Check to make sure the class label does not exist
244  for(size_t i=0; i<classTracker.getSize(); i++){
245  if( classTracker[i].classLabel == classLabel ){
246  warningLog << "addClass(const UINT classLabel,const std::string className) - Failed to add class, it already exists! Class label: " << classLabel << std::endl;
247  return false;
248  }
249  }
250 
251  //Add the class label to the class tracker
252  classTracker.push_back( ClassTracker(classLabel,0,className) );
253 
254  //Sort the class labels
255  sortClassLabels();
256 
257  return true;
258 }
259 
260 UINT ClassificationData::removeClass(const UINT classLabel){
261 
262  UINT numExamplesRemoved = 0;
263  UINT numExamplesToRemove = 0;
264 
265  //The dataset has changed so flag that any previous cross validation setup will now not work
266  crossValidationSetup = false;
267  crossValidationIndexs.clear();
268 
269  //Find out how many training examples we need to remove
270  for(UINT i=0; i<classTracker.getSize(); i++){
271  if( classTracker[i].classLabel == classLabel ){
272  numExamplesToRemove = classTracker[i].counter;
273  classTracker.erase(classTracker.begin()+i);
274  break;
275  }
276  }
277 
278  //Remove the samples with the matching class ID
279  if( numExamplesToRemove > 0 ){
280  UINT i=0;
281  while( numExamplesRemoved < numExamplesToRemove ){
282  if( data[i].getClassLabel() == classLabel ){
283  data.erase(data.begin()+i);
284  numExamplesRemoved++;
285  }else if( ++i == data.getSize() ) break;
286  }
287  }
288 
289  totalNumSamples = data.getSize();
290 
291  return numExamplesRemoved;
292 }
293 
294 bool ClassificationData::relabelAllSamplesWithClassLabel(const UINT oldClassLabel,const UINT newClassLabel){
295  bool oldClassLabelFound = false;
296  bool newClassLabelAllReadyExists = false;
297  UINT indexOfOldClassLabel = 0;
298  UINT indexOfNewClassLabel = 0;
299 
300  //Find out how many training examples we need to relabel
301  for(UINT i=0; i<classTracker.getSize(); i++){
302  if( classTracker[i].classLabel == oldClassLabel ){
303  indexOfOldClassLabel = i;
304  oldClassLabelFound = true;
305  }
306  if( classTracker[i].classLabel == newClassLabel ){
307  indexOfNewClassLabel = i;
308  newClassLabelAllReadyExists = true;
309  }
310  }
311 
312  //If the old class label was not found then we can't do anything
313  if( !oldClassLabelFound ){
314  return false;
315  }
316 
317  //Relabel the old class labels
318  for(UINT i=0; i<totalNumSamples; i++){
319  if( data[i].getClassLabel() == oldClassLabel ){
320  data[i].setClassLabel(newClassLabel);
321  }
322  }
323 
324  //Update the class tracler
325  if( newClassLabelAllReadyExists ){
326  //Add the old sample count to the new sample count
327  classTracker[ indexOfNewClassLabel ].counter += classTracker[ indexOfOldClassLabel ].counter;
328  }else{
329  //Create a new class tracker
330  classTracker.push_back( ClassTracker(newClassLabel,classTracker[ indexOfOldClassLabel ].counter,classTracker[ indexOfOldClassLabel ].className) );
331  }
332 
333  //Erase the old class tracker
334  classTracker.erase( classTracker.begin() + indexOfOldClassLabel );
335 
336  //Sort the class labels
337  sortClassLabels();
338 
339  return true;
340 }
341 
342 bool ClassificationData::setExternalRanges(const Vector< MinMax > &externalRanges, const bool useExternalRanges){
343 
344  if( externalRanges.size() != numDimensions ) return false;
345 
346  this->externalRanges = externalRanges;
347  this->useExternalRanges = useExternalRanges;
348 
349  return true;
350 }
351 
352 bool ClassificationData::enableExternalRangeScaling(const bool useExternalRanges){
353  if( externalRanges.getSize() == numDimensions ){
354  this->useExternalRanges = useExternalRanges;
355  return true;
356  }
357  return false;
358 }
359 
360 bool ClassificationData::scale(const Float minTarget,const Float maxTarget){
361  Vector< MinMax > ranges = getRanges();
362  return scale(ranges,minTarget,maxTarget);
363 }
364 
365 bool ClassificationData::scale(const Vector<MinMax> &ranges,const Float minTarget,const Float maxTarget){
366  if( ranges.getSize() != numDimensions ) return false;
367 
368  //Scale the training data
369  for(UINT i=0; i<totalNumSamples; i++){
370  for(UINT j=0; j<numDimensions; j++){
371  data[i][j] = grt_scale(data[i][j],ranges[j].minValue,ranges[j].maxValue,minTarget,maxTarget);
372  }
373  }
374 
375  return true;
376 }
377 
378 bool ClassificationData::save(const std::string &filename) const{
379 
380  //Check if the file should be saved as a csv file
381  if( Util::stringEndsWith( filename, ".csv" ) ){
382  return saveDatasetToCSVFile( filename );
383  }
384 
385  //Otherwise save it as a custom GRT file
386  return saveDatasetToFile( filename );
387 }
388 
389 bool ClassificationData::load(const std::string &filename){
390 
391  //Check if the file should be loaded as a csv file
392  if( Util::stringEndsWith( filename, ".csv" ) ){
393  return loadDatasetFromCSVFile( filename );
394  }
395 
396  //Otherwise save it as a custom GRT file
397  return loadDatasetFromFile( filename );
398 }
399 
400 bool ClassificationData::saveDatasetToFile(const std::string &filename) const{
401 
402  std::fstream file;
403  file.open(filename.c_str(), std::ios::out);
404 
405  if( !file.is_open() ){
406  return false;
407  }
408 
409  file << "GRT_LABELLED_CLASSIFICATION_DATA_FILE_V1.0\n";
410  file << "DatasetName: " << datasetName << std::endl;
411  file << "InfoText: " << infoText << std::endl;
412  file << "NumDimensions: " << numDimensions << std::endl;
413  file << "TotalNumExamples: " << totalNumSamples << std::endl;
414  file << "NumberOfClasses: " << classTracker.size() << std::endl;
415  file << "ClassIDsAndCounters: " << std::endl;
416 
417  for(UINT i=0; i<classTracker.size(); i++){
418  file << classTracker[i].classLabel << "\t" << classTracker[i].counter << "\t" << classTracker[i].className << std::endl;
419  }
420 
421  file << "UseExternalRanges: " << useExternalRanges << std::endl;
422 
423  if( useExternalRanges ){
424  for(UINT i=0; i<externalRanges.size(); i++){
425  file << externalRanges[i].minValue << "\t" << externalRanges[i].maxValue << std::endl;
426  }
427  }
428 
429  file << "Data:\n";
430 
431  for(UINT i=0; i<totalNumSamples; i++){
432  file << data[i].getClassLabel();
433  for(UINT j=0; j<numDimensions; j++){
434  file << "\t" << data[i][j];
435  }
436  file << std::endl;
437  }
438 
439  file.close();
440  return true;
441 }
442 
443 bool ClassificationData::loadDatasetFromFile(const std::string &filename){
444 
445  std::fstream file;
446  file.open(filename.c_str(), std::ios::in);
447  UINT numClasses = 0;
448  clear();
449 
450  if( !file.is_open() ){
451  errorLog << "loadDatasetFromFile(const std::string &filename) - could not open file!" << std::endl;
452  return false;
453  }
454 
455  std::string word;
456 
457  //Check to make sure this is a file with the Training File Format
458  file >> word;
459  if(word != "GRT_LABELLED_CLASSIFICATION_DATA_FILE_V1.0"){
460  errorLog << "loadDatasetFromFile(const std::string &filename) - could not find file header!" << std::endl;
461  file.close();
462  return false;
463  }
464 
465  //Get the name of the dataset
466  file >> word;
467  if(word != "DatasetName:"){
468  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find DatasetName header!" << std::endl;
469  errorLog << word << std::endl;
470  file.close();
471  return false;
472  }
473  file >> datasetName;
474 
475  file >> word;
476  if(word != "InfoText:"){
477  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find InfoText header!" << std::endl;
478  file.close();
479  return false;
480  }
481 
482  //Load the info text
483  file >> word;
484  infoText = "";
485  while( word != "NumDimensions:" ){
486  infoText += word + " ";
487  file >> word;
488  }
489 
490  //Get the number of dimensions in the training data
491  if( word != "NumDimensions:" ){
492  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find NumDimensions header!" << std::endl;
493  file.close();
494  return false;
495  }
496  file >> numDimensions;
497 
498  //Get the total number of training examples in the training data
499  file >> word;
500  if( word != "TotalNumTrainingExamples:" && word != "TotalNumExamples:" ){
501  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find TotalNumTrainingExamples header!" << std::endl;
502  file.close();
503  return false;
504  }
505  file >> totalNumSamples;
506 
507  //Get the total number of classes in the training data
508  file >> word;
509  if(word != "NumberOfClasses:"){
510  errorLog << "loadDatasetFromFile(string filename) - failed to find NumberOfClasses header!" << std::endl;
511  file.close();
512  return false;
513  }
514  file >> numClasses;
515 
516  //Resize the class counter buffer and load the counters
517  classTracker.resize(numClasses);
518 
519  //Get the total number of classes in the training data
520  file >> word;
521  if(word != "ClassIDsAndCounters:"){
522  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find ClassIDsAndCounters header!" << std::endl;
523  file.close();
524  return false;
525  }
526 
527  for(UINT i=0; i<classTracker.getSize(); i++){
528  file >> classTracker[i].classLabel;
529  file >> classTracker[i].counter;
530  file >> classTracker[i].className;
531  }
532 
533  //Check if the dataset should be scaled using external ranges
534  file >> word;
535  if(word != "UseExternalRanges:"){
536  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find UseExternalRanges header!" << std::endl;
537  file.close();
538  return false;
539  }
540  file >> useExternalRanges;
541 
542  //If we are using external ranges then load them
543  if( useExternalRanges ){
544  externalRanges.resize(numDimensions);
545  for(UINT i=0; i<externalRanges.getSize(); i++){
546  file >> externalRanges[i].minValue;
547  file >> externalRanges[i].maxValue;
548  }
549  }
550 
551  //Get the main training data
552  file >> word;
553  if( word != "LabelledTrainingData:" && word != "Data:"){
554  errorLog << "loadDatasetFromFile(const std::string &filename) - failed to find LabelledTrainingData header!" << std::endl;
555  file.close();
556  return false;
557  }
558 
559  ClassificationSample tempSample( numDimensions );
560  data.resize( totalNumSamples, tempSample );
561 
562  for(UINT i=0; i<totalNumSamples; i++){
563  UINT classLabel = 0;
564  VectorFloat sample(numDimensions,0);
565  file >> classLabel;
566  for(UINT j=0; j<numDimensions; j++){
567  file >> sample[j];
568  }
569  data[i].set(classLabel, sample);
570  }
571 
572  file.close();
573 
574  //Sort the class labels
575  sortClassLabels();
576 
577  return true;
578 }
579 
580 bool ClassificationData::saveDatasetToCSVFile(const std::string &filename) const{
581 
582  std::fstream file;
583  file.open(filename.c_str(), std::ios::out );
584 
585  if( !file.is_open() ){
586  return false;
587  }
588 
589  //Write the data to the CSV file
590  for(UINT i=0; i<totalNumSamples; i++){
591  file << data[i].getClassLabel();
592  for(UINT j=0; j<numDimensions; j++){
593  file << "," << data[i][j];
594  }
595  file << std::endl;
596  }
597 
598  file.close();
599 
600  return true;
601 }
602 
603 bool ClassificationData::loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex){
604 
605  numDimensions = 0;
606  datasetName = "NOT_SET";
607  infoText = "";
608 
609  //Clear any previous data
610  clear();
611 
612  //Parse the CSV file
613  FileParser parser;
614 
615  Timer timer;
616 
617  timer.start();
618 
619  if( !parser.parseCSVFile(filename,true) ){
620  errorLog << "loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex) - Failed to parse CSV file!" << std::endl;
621  return false;
622  }
623 
624  if( !parser.getConsistentColumnSize() ){
625  errorLog << "loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndexe) - The CSV file does not have a consistent number of columns!" << std::endl;
626  return false;
627  }
628 
629  if( parser.getColumnSize() <= 1 ){
630  errorLog << "loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex) - The CSV file does not have enough columns! It should contain at least two columns!" << std::endl;
631  return false;
632  }
633 
634  //Set the number of dimensions
635  numDimensions = parser.getColumnSize()-1;
636 
637  timer.start();
638 
639  //Reserve the memory for the data
640  data.resize( parser.getRowSize(), ClassificationSample(numDimensions) );
641 
642  timer.start();
643 
644  //Loop over the samples and add them to the data set
645  UINT classLabel = 0;
646  UINT j = 0;
647  UINT n = 0;
648  totalNumSamples = parser.getRowSize();
649  for(UINT i=0; i<totalNumSamples; i++){
650  //Get the class label
651  classLabel = grt_from_str< UINT >( parser[i][classLabelColumnIndex] );
652 
653  //Set the class label
654  data[i].setClassLabel( classLabel );
655 
656  //Get the sample data
657  j=0;
658  n=0;
659  while( j != numDimensions ){
660  if( n != classLabelColumnIndex ){
661  data[i][j++] = grt_from_str< Float >( parser[i][n] );
662  }
663  n++;
664  }
665 
666  //Update the class tracker
667  if( classTracker.size() == 0 ){
668  ClassTracker tracker(classLabel,1);
669  classTracker.push_back(tracker);
670  }else{
671  bool labelFound = false;
672  const size_t numClasses = classTracker.size();
673  for(size_t i=0; i<numClasses; i++){
674  if( classLabel == classTracker[i].classLabel ){
675  classTracker[i].counter++;
676  labelFound = true;
677  break;
678  }
679  }
680  if( !labelFound ){
681  ClassTracker tracker(classLabel,1);
682  classTracker.push_back(tracker);
683  }
684  }
685  }
686 
687  //Sort the class labels
688  sortClassLabels();
689 
690  return true;
691 }
692 
694 
695  std::cout << getStatsAsString();
696 
697  return true;
698 }
699 
701 
702  sort(classTracker.begin(),classTracker.end(),ClassTracker::sortByClassLabelAscending);
703 
704  return true;
705 }
706 
707 ClassificationData ClassificationData::partition(const UINT trainingSizePercentage,const bool useStratifiedSampling){
708  return split(trainingSizePercentage, useStratifiedSampling);
709 }
710 
711 ClassificationData ClassificationData::split(const UINT trainingSizePercentage,const bool useStratifiedSampling){
712 
713  //Partitions the dataset into a training dataset (which is kept by this instance of the ClassificationData) and
714  //a testing/validation dataset (which is return as a new instance of the ClassificationData). The trainingSizePercentage
715  //therefore sets the size of the data which remains in this instance and the remaining percentage of data is then added to
716  //the testing/validation dataset
717 
718  //The dataset has changed so flag that any previous cross validation setup will now not work
719  crossValidationSetup = false;
720  crossValidationIndexs.clear();
721 
722  ClassificationData trainingSet(numDimensions);
723  ClassificationData testSet(numDimensions);
724  trainingSet.setAllowNullGestureClass( allowNullGestureClass );
725  testSet.setAllowNullGestureClass( allowNullGestureClass );
726  Vector< UINT > indexs( totalNumSamples );
727 
728  //Create the random partion indexs
729  Random random;
730  UINT randomIndex = 0;
731 
732  if( useStratifiedSampling ){
733  //Break the data into seperate classes
734  Vector< Vector< UINT > > classData( getNumClasses() );
735 
736  //Add the indexs to their respective classes
737  for(UINT i=0; i<totalNumSamples; i++){
738  classData[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
739  }
740 
741  //Randomize the order of the indexs in each of the class index buffers
742  for(UINT k=0; k<getNumClasses(); k++){
743  UINT numSamples = classData[k].getSize();
744  for(UINT x=0; x<numSamples; x++){
745  //Pick a random index
746  randomIndex = random.getRandomNumberInt(0,numSamples);
747 
748  //Swap the indexs
749  SWAP(classData[k][ x ], classData[k][ randomIndex ]);
750  }
751  }
752 
753  //Reserve the memory
754  UINT numTrainingSamples = 0;
755  UINT numTestSamples = 0;
756 
757  for(UINT k=0; k<getNumClasses(); k++){
758  UINT numTrainingExamples = (UINT) floor( Float(classData[k].size()) / 100.0 * Float(trainingSizePercentage) );
759  UINT numTestExamples = ((UINT)classData[k].size())-numTrainingExamples;
760  numTrainingSamples += numTrainingExamples;
761  numTestSamples += numTestExamples;
762  }
763 
764  trainingSet.reserve( numTrainingSamples );
765  testSet.reserve( numTestSamples );
766 
767  //Loop over each class and add the data to the trainingSet and testSet
768  for(UINT k=0; k<getNumClasses(); k++){
769  UINT numTrainingExamples = (UINT) floor( Float(classData[k].getSize()) / 100.0 * Float(trainingSizePercentage) );
770 
771  //Add the data to the training and test sets
772  for(UINT i=0; i<numTrainingExamples; i++){
773  trainingSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getSample() );
774  }
775  for(UINT i=numTrainingExamples; i<classData[k].getSize(); i++){
776  testSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getSample() );
777  }
778  }
779  }else{
780 
781  const UINT numTrainingExamples = (UINT) floor( Float(totalNumSamples) / 100.0 * Float(trainingSizePercentage) );
782  //Create the random partion indexs
783  Random random;
784  UINT randomIndex = 0;
785  for(UINT i=0; i<totalNumSamples; i++) indexs[i] = i;
786  std::random_shuffle(indexs.begin(), indexs.end());
787 
788  //Reserve the memory
789  trainingSet.reserve( numTrainingExamples );
790  testSet.reserve( totalNumSamples-numTrainingExamples );
791 
792  //Add the data to the training and test sets
793  for(UINT i=0; i<numTrainingExamples; i++){
794  trainingSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getSample() );
795  }
796  for(UINT i=numTrainingExamples; i<totalNumSamples; i++){
797  testSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getSample() );
798  }
799  }
800 
801  //Overwrite the training data in this instance with the training data of the trainingSet
802  *this = trainingSet;
803 
804  //Sort the class labels in this dataset
805  sortClassLabels();
806 
807  //Sort the class labels of the test dataset
808  testSet.sortClassLabels();
809 
810  return testSet;
811 }
812 
814 
815  if( labelledData.getNumDimensions() != numDimensions ){
816  errorLog << "merge(const ClassificationData &labelledData) - The number of dimensions in the labelledData (" << labelledData.getNumDimensions() << ") does not match the number of dimensions of this dataset (" << numDimensions << ")" << std::endl;
817  return false;
818  }
819 
820  //The dataset has changed so flag that any previous cross validation setup will now not work
821  crossValidationSetup = false;
822  crossValidationIndexs.clear();
823 
824  //Reserve the memory
825  reserve( getNumSamples() + labelledData.getNumSamples() );
826 
827  //Add the data from the labelledData to this instance
828  for(UINT i=0; i<labelledData.getNumSamples(); i++){
829  addSample(labelledData[i].getClassLabel(), labelledData[i].getSample());
830  }
831 
832  //Set the class names from the dataset
833  Vector< ClassTracker > classTracker = labelledData.getClassTracker();
834  for(UINT i=0; i<classTracker.size(); i++){
835  setClassNameForCorrespondingClassLabel(classTracker[i].className, classTracker[i].classLabel);
836  }
837 
838  //Sort the class labels
839  sortClassLabels();
840 
841  return true;
842 }
843 
844 bool ClassificationData::spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling){
845 
846  crossValidationSetup = false;
847  crossValidationIndexs.clear();
848 
849  //K can not be zero
850  if( K == 0 ){
851  errorLog << "spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling) - K can not be zero!" << std::endl;
852  return false;
853  }
854 
855  //K can not be larger than the number of examples
856  if( K > totalNumSamples ){
857  errorLog << "spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling) - K can not be larger than the total number of samples in the dataset!" << std::endl;
858  return false;
859  }
860 
861  //K can not be larger than the number of examples in a specific class if the stratified sampling option is true
862  if( useStratifiedSampling ){
863  for(UINT c=0; c<classTracker.size(); c++){
864  if( K > classTracker[c].counter ){
865  errorLog << "spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling) - K can not be larger than the number of samples in any given class!" << std::endl;
866  return false;
867  }
868  }
869  }
870 
871  //Setup the dataset for k-fold cross validation
872  kFoldValue = K;
873  Vector< UINT > indexs( totalNumSamples );
874 
875  //Work out how many samples are in each fold, the last fold might have more samples than the others
876  UINT numSamplesPerFold = (UINT) floor( totalNumSamples/Float(K) );
877 
878  //Add the random indexs to each fold
879  crossValidationIndexs.resize(K);
880 
881  //Create the random partion indexs
882  Random random;
883  UINT randomIndex = 0;
884 
885  if( useStratifiedSampling ){
886  //Break the data into seperate classes
887  Vector< Vector< UINT > > classData( getNumClasses() );
888 
889  //Add the indexs to their respective classes
890  for(UINT i=0; i<totalNumSamples; i++){
891  classData[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
892  }
893 
894  //Randomize the order of the indexs in each of the class index buffers
895  for(UINT c=0; c<getNumClasses(); c++){
896  UINT numSamples = (UINT)classData[c].size();
897  for(UINT x=0; x<numSamples; x++){
898  //Pick a random indexs
899  randomIndex = random.getRandomNumberInt(0,numSamples);
900 
901  //Swap the indexs
902  SWAP(classData[c][ x ] , classData[c][ randomIndex ]);
903  }
904  }
905 
906  //Loop over each of the k folds, at each fold add a sample from each class
908  for(UINT c=0; c<getNumClasses(); c++){
909  iter = classData[ c ].begin();
910  UINT k = 0;
911  while( iter != classData[c].end() ){
912  crossValidationIndexs[ k ].push_back( *iter );
913  iter++;
914  k++;
915  k = k % K;
916  }
917  }
918 
919  }else{
920  //Randomize the order of the data
921  for(UINT i=0; i<totalNumSamples; i++) indexs[i] = i;
922  for(UINT x=0; x<totalNumSamples; x++){
923  //Pick a random index
924  randomIndex = random.getRandomNumberInt(0,totalNumSamples);
925 
926  //Swap the indexs
927  SWAP(indexs[ x ] , indexs[ randomIndex ]);
928  }
929 
930  UINT counter = 0;
931  UINT foldIndex = 0;
932  for(UINT i=0; i<totalNumSamples; i++){
933  //Add the index to the current fold
934  crossValidationIndexs[ foldIndex ].push_back( indexs[i] );
935 
936  //Move to the next fold if ready
937  if( ++counter == numSamplesPerFold && foldIndex < K-1 ){
938  foldIndex++;
939  counter = 0;
940  }
941  }
942  }
943 
944  crossValidationSetup = true;
945  return true;
946 
947 }
948 
950 
951  ClassificationData trainingData;
952  trainingData.setNumDimensions( numDimensions );
953  trainingData.setAllowNullGestureClass( allowNullGestureClass );
954 
955  if( !crossValidationSetup ){
956  errorLog << "getTrainingFoldData(const UINT foldIndex) - Cross Validation has not been setup! You need to call the spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) function first before calling this function!" << std::endl;
957  return trainingData;
958  }
959 
960  if( foldIndex >= kFoldValue ) return trainingData;
961 
962  //Add the class labels to make sure they all exist
963  for(UINT k=0; k<getNumClasses(); k++){
964  trainingData.addClass( classTracker[k].classLabel, classTracker[k].className );
965  }
966 
967  //Add the data to the training set, this will consist of all the data that is NOT in the foldIndex
968  UINT index = 0;
969  for(UINT k=0; k<kFoldValue; k++){
970  if( k != foldIndex ){
971  for(UINT i=0; i<crossValidationIndexs[k].getSize(); i++){
972 
973  index = crossValidationIndexs[k][i];
974  trainingData.addSample( data[ index ].getClassLabel(), data[ index ].getSample() );
975  }
976  }
977  }
978 
979  //Sort the class labels
980  trainingData.sortClassLabels();
981 
982  return trainingData;
983 }
984 
986 
987  ClassificationData testData;
988  testData.setNumDimensions( numDimensions );
989  testData.setAllowNullGestureClass( allowNullGestureClass );
990 
991  if( !crossValidationSetup ) return testData;
992 
993  if( foldIndex >= kFoldValue ) return testData;
994 
995  //Add the class labels to make sure they all exist
996  for(UINT k=0; k<getNumClasses(); k++){
997  testData.addClass( classTracker[k].classLabel, classTracker[k].className );
998  }
999 
1000  testData.reserve( crossValidationIndexs[ foldIndex ].getSize() );
1001 
1002  //Add the data to the test fold
1003  UINT index = 0;
1004  for(UINT i=0; i<crossValidationIndexs[ foldIndex ].getSize(); i++){
1005 
1006  index = crossValidationIndexs[ foldIndex ][i];
1007  testData.addSample( data[ index ].getClassLabel(), data[ index ].getSample() );
1008  }
1009 
1010  //Sort the class labels
1011  testData.sortClassLabels();
1012 
1013  return testData;
1014 }
1015 
1017 
1018  ClassificationData classData;
1019  classData.setNumDimensions( this->numDimensions );
1020  classData.setAllowNullGestureClass( allowNullGestureClass );
1021 
1022  //Reserve the memory for the class data
1023  for(UINT i=0; i<classTracker.getSize(); i++){
1024  if( classTracker[i].classLabel == classLabel ){
1025  classData.reserve( classTracker[i].counter );
1026  break;
1027  }
1028  }
1029 
1030  for(UINT i=0; i<totalNumSamples; i++){
1031  if( data[i].getClassLabel() == classLabel ){
1032  classData.addSample(classLabel, data[i].getSample());
1033  }
1034  }
1035 
1036  return classData;
1037 }
1038 
1039 ClassificationData ClassificationData::getBootstrappedDataset(UINT numSamples,bool balanceDataset) const{
1040 
1041  Random rand;
1042  ClassificationData newDataset;
1043  newDataset.setNumDimensions( getNumDimensions() );
1044  newDataset.setAllowNullGestureClass( allowNullGestureClass );
1045  newDataset.setExternalRanges( externalRanges, useExternalRanges );
1046 
1047  if( numSamples == 0 ) numSamples = totalNumSamples;
1048 
1049  newDataset.reserve( numSamples );
1050 
1051  const UINT K = getNumClasses();
1052 
1053  //Add all the class labels to the new dataset to ensure the dataset has a list of all the labels
1054  for(UINT k=0; k<K; k++){
1055  newDataset.addClass( classTracker[k].classLabel );
1056  }
1057 
1058  if( balanceDataset ){
1059  //Group the class indexs
1060  Vector< Vector< UINT > > classIndexs( K );
1061  for(UINT i=0; i<totalNumSamples; i++){
1062  classIndexs[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
1063  }
1064 
1065  //Get the class with the minimum number of examples
1066  UINT numSamplesPerClass = (UINT)floor( numSamples / Float(K) );
1067 
1068  //Randomly select the training samples from each class
1069  UINT classIndex = 0;
1070  UINT classCounter = 0;
1071  UINT randomIndex = 0;
1072  for(UINT i=0; i<numSamples; i++){
1073  randomIndex = rand.getRandomNumberInt(0, (UINT)classIndexs[ classIndex ].size() );
1074  randomIndex = classIndexs[ classIndex ][ randomIndex ];
1075  newDataset.addSample(data[ randomIndex ].getClassLabel(), data[ randomIndex ].getSample());
1076  if( classCounter++ >= numSamplesPerClass && classIndex+1 < K ){
1077  classCounter = 0;
1078  classIndex++;
1079  }
1080  }
1081 
1082  }else{
1083  //Randomly select the training samples to add to the new data set
1084  UINT randomIndex;
1085  for(UINT i=0; i<numSamples; i++){
1086  randomIndex = rand.getRandomNumberInt(0, totalNumSamples);
1087  newDataset.addSample( data[randomIndex].getClassLabel(), data[randomIndex].getSample() );
1088  }
1089  }
1090 
1091  //Sort the class labels so they are in order
1092  newDataset.sortClassLabels();
1093 
1094  return newDataset;
1095 }
1096 
1098 
1099  //Turns the classification into a regression data to enable regression algorithms like the MLP to be used as a classifier
1100  //This sets the number of targets in the regression data equal to the number of classes in the classification data
1101  //The output of each regression training sample will then be all 0's, except for the index matching the classLabel, which will be 1
1102  //For this to work, the labelled classification data cannot have any samples with a classLabel of 0!
1103  RegressionData regressionData;
1104 
1105  if( totalNumSamples == 0 ){
1106  return regressionData;
1107  }
1108 
1109  const UINT numInputDimensions = numDimensions;
1110  const UINT numTargetDimensions = getNumClasses();
1111  regressionData.setInputAndTargetDimensions(numInputDimensions, numTargetDimensions);
1112 
1113  for(UINT i=0; i<totalNumSamples; i++){
1114  VectorFloat targetVector(numTargetDimensions,0);
1115 
1116  //Set the class index in the target Vector to 1 and all other values in the target Vector to 0
1117  UINT classLabel = data[i].getClassLabel();
1118 
1119  if( classLabel > 0 ){
1120  targetVector[ classLabel-1 ] = 1;
1121  }else{
1122  regressionData.clear();
1123  return regressionData;
1124  }
1125 
1126  regressionData.addSample(data[i].getSample(),targetVector);
1127  }
1128 
1129  return regressionData;
1130 }
1131 
1133 
1134  UnlabelledData unlabelledData;
1135 
1136  if( totalNumSamples == 0 ){
1137  return unlabelledData;
1138  }
1139 
1140  unlabelledData.setNumDimensions( numDimensions );
1141 
1142  for(UINT i=0; i<totalNumSamples; i++){
1143  unlabelledData.addSample( data[i].getSample() );
1144  }
1145 
1146  return unlabelledData;
1147 }
1148 
1150  UINT minClassLabel = grt_numeric_limits< UINT >::max();
1151 
1152  for(UINT i=0; i<classTracker.getSize(); i++){
1153  if( classTracker[i].classLabel < minClassLabel ){
1154  minClassLabel = classTracker[i].classLabel;
1155  }
1156  }
1157 
1158  return minClassLabel;
1159 }
1160 
1161 
1163  UINT maxClassLabel = 0;
1164 
1165  for(UINT i=0; i<classTracker.getSize(); i++){
1166  if( classTracker[i].classLabel > maxClassLabel ){
1167  maxClassLabel = classTracker[i].classLabel;
1168  }
1169  }
1170 
1171  return maxClassLabel;
1172 }
1173 
1175  for(UINT k=0; k<classTracker.getSize(); k++){
1176  if( classTracker[k].classLabel == classLabel ){
1177  return k;
1178  }
1179  }
1180  warningLog << "getClassLabelIndexValue(UINT classLabel) - Failed to find class label: " << classLabel << " in class tracker!" << std::endl;
1181  return 0;
1182 }
1183 
1185 
1186  for(UINT i=0; i<classTracker.getSize(); i++){
1187  if( classTracker[i].classLabel == classLabel ){
1188  return classTracker[i].className;
1189  }
1190  }
1191 
1192  return "CLASS_LABEL_NOT_FOUND";
1193 }
1194 
1196  std::string statsText;
1197  statsText += "DatasetName:\t" + datasetName + "\n";
1198  statsText += "DatasetInfo:\t" + infoText + "\n";
1199  statsText += "Number of Dimensions:\t" + Util::toString( numDimensions ) + "\n";
1200  statsText += "Number of Samples:\t" + Util::toString( totalNumSamples ) + "\n";
1201  statsText += "Number of Classes:\t" + Util::toString( getNumClasses() ) + "\n";
1202  statsText += "ClassStats:\n";
1203 
1204  for(UINT k=0; k<getNumClasses(); k++){
1205  statsText += "ClassLabel:\t" + Util::toString( classTracker[k].classLabel );
1206  statsText += "\tNumber of Samples:\t" + Util::toString(classTracker[k].counter);
1207  statsText += "\tClassName:\t" + classTracker[k].className + "\n";
1208  }
1209 
1210  Vector< MinMax > ranges = getRanges();
1211 
1212  statsText += "Dataset Ranges:\n";
1213  for(UINT j=0; j<ranges.size(); j++){
1214  statsText += "[" + Util::toString( j+1 ) + "] Min:\t" + Util::toString( ranges[j].minValue ) + "\tMax: " + Util::toString( ranges[j].maxValue ) + "\n";
1215  }
1216 
1217  return statsText;
1218 }
1219 
1221 
1222  //If the dataset should be scaled using the external ranges then return the external ranges
1223  if( useExternalRanges ) return externalRanges;
1224 
1225  Vector< MinMax > ranges(numDimensions);
1226 
1227  //Otherwise return the min and max values for each column in the dataset
1228  if( totalNumSamples > 0 ){
1229  for(UINT j=0; j<numDimensions; j++){
1230  ranges[j].minValue = data[0][j];
1231  ranges[j].maxValue = data[0][j];
1232  for(UINT i=0; i<totalNumSamples; i++){
1233  if( data[i][j] < ranges[j].minValue ){ ranges[j].minValue = data[i][j]; } //Search for the min value
1234  else if( data[i][j] > ranges[j].maxValue ){ ranges[j].maxValue = data[i][j]; } //Search for the max value
1235  }
1236  }
1237  }
1238  return ranges;
1239 }
1240 
1242  Vector< UINT > classLabels( getNumClasses(), 0 );
1243 
1244  if( getNumClasses() == 0 ) return classLabels;
1245 
1246  for(UINT i=0; i<getNumClasses(); i++){
1247  classLabels[i] = classTracker[i].classLabel;
1248  }
1249 
1250  return classLabels;
1251 }
1252 
1254  Vector< UINT > classSampleCounts( getNumClasses(), 0 );
1255 
1256  if( getNumSamples() == 0 ) return classSampleCounts;
1257 
1258  for(UINT i=0; i<getNumClasses(); i++){
1259  classSampleCounts[i] = classTracker[i].counter;
1260  }
1261 
1262  return classSampleCounts;
1263 }
1264 
1266 
1267  VectorFloat mean(numDimensions,0);
1268 
1269  for(UINT j=0; j<numDimensions; j++){
1270  for(UINT i=0; i<totalNumSamples; i++){
1271  mean[j] += data[i][j];
1272  }
1273  mean[j] /= Float(totalNumSamples);
1274  }
1275 
1276  return mean;
1277 }
1278 
1280 
1281  VectorFloat mean = getMean();
1282  VectorFloat stdDev(numDimensions,0);
1283 
1284  for(UINT j=0; j<numDimensions; j++){
1285  for(UINT i=0; i<totalNumSamples; i++){
1286  stdDev[j] += SQR(data[i][j]-mean[j]);
1287  }
1288  stdDev[j] = sqrt( stdDev[j] / Float(totalNumSamples-1) );
1289  }
1290 
1291  return stdDev;
1292 }
1293 
1294 MatrixFloat ClassificationData::getClassHistogramData(UINT classLabel,UINT numBins) const{
1295 
1296  const UINT M = getNumSamples();
1297  const UINT N = getNumDimensions();
1298 
1299  Vector< MinMax > ranges = getRanges();
1300  VectorFloat binRange(N);
1301  for(UINT i=0; i<ranges.size(); i++){
1302  binRange[i] = (ranges[i].maxValue-ranges[i].minValue)/Float(numBins);
1303  }
1304 
1305  MatrixFloat histData(N,numBins);
1306  histData.setAllValues(0);
1307 
1308  Float norm = 0;
1309  for(UINT i=0; i<M; i++){
1310  if( data[i].getClassLabel() == classLabel ){
1311  for(UINT j=0; j<N; j++){
1312  UINT binIndex = 0;
1313  bool binFound = false;
1314  for(UINT k=0; k<numBins-1; k++){
1315  if( data[i][j] >= ranges[i].minValue + (binRange[j]*k) && data[i][j] >= ranges[i].minValue + (binRange[j]*(k+1)) ){
1316  binIndex = k;
1317  binFound = true;
1318  break;
1319  }
1320  }
1321  if( !binFound ) binIndex = numBins-1;
1322  histData[j][binIndex]++;
1323  }
1324  norm++;
1325  }
1326  }
1327 
1328  if( norm == 0 ) return histData;
1329 
1330  //Is this the best way to normalize a multidimensional histogram???
1331  for(UINT i=0; i<histData.getNumRows(); i++){
1332  for(UINT j=0; j<histData.getNumCols(); j++){
1333  histData[i][j] /= norm;
1334  }
1335  }
1336 
1337  return histData;
1338 }
1339 
1341 
1342  MatrixFloat mean(getNumClasses(),numDimensions);
1343  VectorFloat counter(getNumClasses(),0);
1344 
1345  mean.setAllValues( 0 );
1346 
1347  for(UINT i=0; i<totalNumSamples; i++){
1348  UINT classIndex = getClassLabelIndexValue( data[i].getClassLabel() );
1349  for(UINT j=0; j<numDimensions; j++){
1350  mean[classIndex][j] += data[i][j];
1351  }
1352  counter[ classIndex ]++;
1353  }
1354 
1355  for(UINT k=0; k<getNumClasses(); k++){
1356  for(UINT j=0; j<numDimensions; j++){
1357  mean[k][j] = counter[k] > 0 ? mean[k][j]/counter[k] : 0;
1358  }
1359  }
1360 
1361  return mean;
1362 }
1363 
1365 
1366  MatrixFloat mean = getClassMean();
1367  MatrixFloat stdDev(getNumClasses(),numDimensions);
1368  VectorFloat counter(getNumClasses(),0);
1369 
1370  stdDev.setAllValues( 0 );
1371 
1372  for(UINT i=0; i<totalNumSamples; i++){
1373  UINT classIndex = getClassLabelIndexValue( data[i].getClassLabel() );
1374  for(UINT j=0; j<numDimensions; j++){
1375  stdDev[classIndex][j] += SQR(data[i][j]-mean[classIndex][j]);
1376  }
1377  counter[ classIndex ]++;
1378  }
1379 
1380  for(UINT k=0; k<getNumClasses(); k++){
1381  for(UINT j=0; j<numDimensions; j++){
1382  stdDev[k][j] = sqrt( stdDev[k][j] / Float(counter[k]-1) );
1383  }
1384  }
1385 
1386  return stdDev;
1387 }
1388 
1390 
1391  VectorFloat mean = getMean();
1392  MatrixFloat covariance(numDimensions,numDimensions);
1393 
1394  for(UINT j=0; j<numDimensions; j++){
1395  for(UINT k=0; k<numDimensions; k++){
1396  for(UINT i=0; i<totalNumSamples; i++){
1397  covariance[j][k] += (data[i][j]-mean[j]) * (data[i][k]-mean[k]) ;
1398  }
1399  covariance[j][k] /= Float(totalNumSamples-1);
1400  }
1401  }
1402 
1403  return covariance;
1404 }
1405 
1407  const UINT K = getNumClasses();
1408  Vector< MatrixFloat > histData(K);
1409 
1410  for(UINT k=0; k<K; k++){
1411  histData[k] = getClassHistogramData( classTracker[k].classLabel, numBins );
1412  }
1413 
1414  return histData;
1415 }
1416 
1417 VectorFloat ClassificationData::getClassProbabilities() const {
1418  return getClassProbabilities( getClassLabels() );
1419 }
1420 
1421 VectorFloat ClassificationData::getClassProbabilities( const Vector< UINT > &classLabels ) const {
1422  const UINT K = (UINT)classLabels.size();
1423  const UINT N = getNumClasses();
1424  Float sum = 0;
1425  VectorFloat x(K,0);
1426  for(UINT k=0; k<K; k++){
1427  for(UINT n=0; n<N; n++){
1428  if( classLabels[k] == classTracker[n].classLabel ){
1429  x[k] = classTracker[n].counter;
1430  sum += classTracker[n].counter;
1431  break;
1432  }
1433  }
1434  }
1435 
1436  //Normalize the class probabilities
1437  if( sum > 0 ){
1438  for(UINT k=0; k<K; k++){
1439  x[k] /= sum;
1440  }
1441  }
1442 
1443  return x;
1444 }
1445 
1447 
1448  const UINT M = getNumSamples();
1449  const UINT K = getNumClasses();
1450  UINT N = 0;
1451 
1452  //Get the number of samples in the class
1453  for(UINT k=0; k<K; k++){
1454  if( classTracker[k].classLabel == classLabel){
1455  N = classTracker[k].counter;
1456  break;
1457  }
1458  }
1459 
1460  UINT index = 0;
1461  Vector< UINT > classIndexes(N);
1462  for(UINT i=0; i<M; i++){
1463  if( data[i].getClassLabel() == classLabel ){
1464  classIndexes[index++] = i;
1465  }
1466  }
1467 
1468  return classIndexes;
1469 }
1470 
1472 
1473  const UINT M = getNumSamples();
1474  const UINT N = getNumDimensions();
1475  MatrixDouble d(M,N);
1476 
1477  for(UINT i=0; i<M; i++){
1478  for(UINT j=0; j<N; j++){
1479  d[i][j] = data[i][j];
1480  }
1481  }
1482 
1483  return d;
1484 }
1485 
1487  const UINT M = getNumSamples();
1488  const UINT N = getNumDimensions();
1489  MatrixFloat d(M,N);
1490 
1491  for(UINT i=0; i<M; i++){
1492  for(UINT j=0; j<N; j++){
1493  d[i][j] = data[i][j];
1494  }
1495  }
1496 
1497  return d;
1498 }
1499 
1500 bool ClassificationData::generateGaussDataset( const std::string filename, const UINT numSamples, const UINT numClasses, const UINT numDimensions, const Float range, const Float sigma ){
1501 
1502  Random random;
1503 
1504  //Generate a simple model that will be used to generate the main dataset
1505  MatrixFloat model(numClasses,numDimensions);
1506  for(UINT k=0; k<numClasses; k++){
1507  for(UINT j=0; j<numDimensions; j++){
1508  model[k][j] = random.getRandomNumberUniform(-range,range);
1509  }
1510  }
1511 
1512  //Use the model above to generate the main dataset
1513  ClassificationData data;
1514  data.setNumDimensions( numDimensions );
1515 
1516  for(UINT i=0; i<numSamples; i++){
1517 
1518  //Randomly select which class this sample belongs to
1519  UINT k = random.getRandomNumberInt( 0, numClasses );
1520 
1521  //Generate a sample using the model (+ some Gaussian noise)
1522  VectorFloat sample( numDimensions );
1523  for(UINT j=0; j<numDimensions; j++){
1524  sample[j] = model[k][j] + random.getRandomNumberGauss(0,sigma);
1525  }
1526 
1527  //By default in the GRT, the class label should not be 0, so add 1
1528  UINT classLabel = k + 1;
1529 
1530  //Add the labeled sample to the dataset
1531  data.addSample( classLabel, sample );
1532  }
1533 
1534  //Save the dataset to a CSV file
1535  return data.save( filename );
1536 }
1537 
1538 GRT_END_NAMESPACE
1539 
bool saveDatasetToFile(const std::string &filename) const
bool setDatasetName(std::string datasetName)
bool loadDatasetFromFile(const std::string &filename)
static std::string toString(const int &i)
Definition: Util.cpp:74
RegressionData reformatAsRegressionData() const
ClassificationData & operator=(const ClassificationData &rhs)
Definition: Timer.h:43
static bool generateGaussDataset(const std::string filename, const UINT numSamples=10000, const UINT numClasses=10, const UINT numDimensions=3, const Float range=10, const Float sigma=1)
bool addSample(const VectorFloat &sample)
The ClassificationData is the main data structure for recording, labeling, managing, saving, and loading training data for supervised learning problems.
bool relabelAllSamplesWithClassLabel(const UINT oldClassLabel, const UINT newClassLabel)
bool addSample(UINT classLabel, const VectorFloat &sample)
ClassificationData getTestFoldData(const UINT foldIndex) const
bool addClass(const UINT classLabel, const std::string className="NOT_SET")
Vector< ClassTracker > getClassTracker() const
Definition: Random.h:40
ClassificationData getClassData(const UINT classLabel) const
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
bool setNumDimensions(UINT numDimensions)
UINT eraseAllSamplesWithClassLabel(const UINT classLabel)
MatrixDouble getDataAsMatrixDouble() const
MatrixFloat getClassMean() const
UINT getSize() const
Definition: Vector.h:191
Float getRandomNumberGauss(Float mu=0.0, Float sigma=1.0)
Definition: Random.h:209
std::string getClassNameForCorrespondingClassLabel(const UINT classLabel) const
bool setClassNameForCorrespondingClassLabel(std::string className, UINT classLabel)
Vector< UINT > getClassLabels() const
bool loadDatasetFromCSVFile(const std::string &filename, const UINT classLabelColumnIndex=0)
bool setAllowNullGestureClass(bool allowNullGestureClass)
UINT getMinimumClassLabel() const
Vector< MatrixFloat > getHistogramData(const UINT numBins) const
UINT removeClass(const UINT classLabel)
ClassificationData(UINT numDimensions=0, std::string datasetName="NOT_SET", std::string infoText="")
bool setAllValues(const T &value)
Definition: Matrix.h:336
bool setInputAndTargetDimensions(const UINT numInputDimensions, const UINT numTargetDimensions)
bool setInfoText(std::string infoText)
Vector< UINT > getNumSamplesPerClass() const
MatrixFloat getCovarianceMatrix() const
UnlabelledData reformatAsUnlabelledData() const
bool removeSample(const UINT index)
UINT getNumSamples() const
bool spiltDataIntoKFolds(const UINT K, const bool useStratifiedSampling=false)
bool save(const std::string &filename) const
bool setNumDimensions(const UINT numDimensions)
bool enableExternalRangeScaling(const bool useExternalRanges)
bool setExternalRanges(const Vector< MinMax > &externalRanges, const bool useExternalRanges=false)
bool reserve(const UINT N)
bool saveDatasetToCSVFile(const std::string &filename) const
unsigned int getNumRows() const
Definition: Matrix.h:542
UINT getNumDimensions() const
UINT getNumClasses() const
unsigned int getNumCols() const
Definition: Matrix.h:549
bool start()
Definition: Timer.h:64
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
Definition: Random.h:198
bool merge(const ClassificationData &data)
ClassificationData split(const UINT splitPercentage, const bool useStratifiedSampling=false)
VectorFloat getStdDev() const
Vector< UINT > getClassDataIndexes(const UINT classLabel) const
int getRandomNumberInt(int minRange, int maxRange)
Definition: Random.h:88
MatrixFloat getDataAsMatrixFloat() const
static bool stringEndsWith(const std::string &str, const std::string &ending)
Definition: Util.cpp:157
UINT getClassLabelIndexValue(const UINT classLabel) const
ClassificationData getBootstrappedDataset(UINT numSamples=0, bool balanceDataset=false) const
MatrixFloat getClassHistogramData(const UINT classLabel, const UINT numBins) const
ClassificationData getTrainingFoldData(const UINT foldIndex) const
UINT getMaximumClassLabel() const
bool scale(const Float minTarget, const Float maxTarget)
bool load(const std::string &filename)
MatrixFloat getClassStdDev() const
bool addSample(const VectorFloat &inputVector, const VectorFloat &targetVector)
std::string getStatsAsString() const
VectorFloat getMean() const