GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClassificationDataStream.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #define GRT_DLL_EXPORTS
23 
24 GRT_BEGIN_NAMESPACE
25 
26 //Constructors and Destructors
27 ClassificationDataStream::ClassificationDataStream(const UINT numDimensions,const std::string datasetName,const std::string infoText){
28 
29  this->numDimensions= numDimensions;
30  this->datasetName = datasetName;
31  this->infoText = infoText;
32 
33  playbackIndex = 0;
34  trackingClass = false;
35  useExternalRanges = false;
36  debugLog.setKey("[DEBUG ClassificationDataStream]");
37  errorLog.setKey("[ERROR ClassificationDataStream]");
38  warningLog.setKey("[WARNING ClassificationDataStream]");
39 
40  if( numDimensions > 0 ){
41  setNumDimensions(numDimensions);
42  }
43 }
44 
46  *this = rhs;
47 }
48 
50 
52  if( this != &rhs){
53  this->datasetName = rhs.datasetName;
54  this->infoText = rhs.infoText;
55  this->numDimensions = rhs.numDimensions;
56  this->totalNumSamples = rhs.totalNumSamples;
57  this->lastClassID = rhs.lastClassID;
58  this->playbackIndex = rhs.playbackIndex;
59  this->trackingClass = rhs.trackingClass;
61  this->externalRanges = rhs.externalRanges;
62  this->data = rhs.data;
63  this->classTracker = rhs.classTracker;
64  this->timeSeriesPositionTracker = rhs.timeSeriesPositionTracker;
65  this->debugLog = rhs.debugLog;
66  this->warningLog = rhs.warningLog;
67  this->errorLog = rhs.errorLog;
68 
69  }
70  return *this;
71 }
72 
74  totalNumSamples = 0;
75  playbackIndex = 0;
76  trackingClass = false;
77  data.clear();
78  classTracker.clear();
79  timeSeriesPositionTracker.clear();
80 }
81 
83  if( numDimensions > 0 ){
84  //Clear any previous data
85  clear();
86 
87  //Set the dimensionality of the time series data
88  this->numDimensions = numDimensions;
89 
90  return true;
91  }
92 
93  errorLog << "setNumDimensions(const UINT numDimensions) - The number of dimensions of the dataset must be greater than zero!" << std::endl;
94  return false;
95 }
96 
97 
99 
100  //Make sure there are no spaces in the std::string
101  if( datasetName.find(" ") == std::string::npos ){
102  this->datasetName = datasetName;
103  return true;
104  }
105 
106  errorLog << "setDatasetName(const std::string datasetName) - The dataset name cannot contain any spaces!" << std::endl;
107  return false;
108 }
109 
111  this->infoText = infoText;
112  return true;
113 }
114 
115 bool ClassificationDataStream::setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel){
116 
117  for(UINT i=0; i<classTracker.size(); i++){
118  if( classTracker[i].classLabel == classLabel ){
119  classTracker[i].className = className;
120  return true;
121  }
122  }
123 
124  errorLog << "setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel) - Failed to find class with label: " << classLabel << std::endl;
125  return false;
126 }
127 
128 bool ClassificationDataStream::addSample(const UINT classLabel,const VectorFloat &sample){
129 
130  if( numDimensions != sample.size() ){
131  errorLog << "addSample(const UINT classLabel, VectorFloat sample) - the size of the new sample (" << sample.size() << ") does not match the number of dimensions of the dataset (" << numDimensions << ")" << std::endl;
132  return false;
133  }
134 
135  bool searchForNewClass = true;
136  if( trackingClass ){
137  if( classLabel != lastClassID ){
138  //The class ID has changed so update the time series tracker
139  timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( totalNumSamples-1 );
140  }else searchForNewClass = false;
141  }
142 
143  if( searchForNewClass ){
144  bool newClass = true;
145  //Search to see if this class has been found before
146  for(UINT k=0; k<classTracker.size(); k++){
147  if( classTracker[k].classLabel == classLabel ){
148  newClass = false;
149  classTracker[k].counter++;
150  }
151  }
152  if( newClass ){
153  ClassTracker newCounter(classLabel,1);
154  classTracker.push_back( newCounter );
155  }
156 
157  //Set the timeSeriesPositionTracker start position
158  trackingClass = true;
159  lastClassID = classLabel;
160  TimeSeriesPositionTracker newTracker(totalNumSamples,0,classLabel);
161  timeSeriesPositionTracker.push_back( newTracker );
162  }
163 
164  ClassificationSample labelledSample(classLabel,sample);
165  data.push_back( labelledSample );
166  totalNumSamples++;
167  return true;
168 }
169 
170 bool ClassificationDataStream::addSample(const UINT classLabel,const MatrixFloat &sample){
171 
172  if( numDimensions != sample.getNumCols() ){
173  errorLog << "addSample(const UINT classLabel, const MatrixFloat &sample) - the number of columns in the sample (" << sample.getNumCols() << ") does not match the number of dimensions of the dataset (" << numDimensions << ")" << std::endl;
174  return false;
175  }
176 
177  bool searchForNewClass = true;
178  if( trackingClass ){
179  if( classLabel != lastClassID ){
180  //The class ID has changed so update the time series tracker
181  timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( totalNumSamples-1 );
182  }else searchForNewClass = false;
183  }
184 
185  if( searchForNewClass ){
186  bool newClass = true;
187  //Search to see if this class has been found before
188  for(UINT k=0; k<classTracker.size(); k++){
189  if( classTracker[k].classLabel == classLabel ){
190  newClass = false;
191  classTracker[k].counter += sample.getNumRows();
192  }
193  }
194  if( newClass ){
195  ClassTracker newCounter(classLabel,1);
196  classTracker.push_back( newCounter );
197  }
198 
199  //Set the timeSeriesPositionTracker start position
200  trackingClass = true;
201  lastClassID = classLabel;
202  TimeSeriesPositionTracker newTracker(totalNumSamples,0,classLabel);
203  timeSeriesPositionTracker.push_back( newTracker );
204  }
205 
206  ClassificationSample labelledSample( numDimensions );
207  for(UINT i=0; i<sample.getNumRows(); i++){
208  data.push_back( labelledSample );
209  data.back().setClassLabel( classLabel );
210  for(UINT j=0; j<numDimensions; j++){
211  data.back()[j] = sample[i][j];
212  }
213  }
214  totalNumSamples += sample.getNumRows();
215  return true;
216 
217 }
218 
220 
221  if( totalNumSamples > 0 ){
222 
223  //Find the corresponding class ID for the last training example
224  UINT classLabel = data[ totalNumSamples-1 ].getClassLabel();
225 
226  //Remove the training example from the buffer
227  data.erase( data.end()-1 );
228 
229  totalNumSamples = (UINT)data.size();
230 
231  //Remove the value from the counter
232  for(UINT i=0; i<classTracker.size(); i++){
233  if( classTracker[i].classLabel == classLabel ){
234  classTracker[i].counter--;
235  break;
236  }
237  }
238 
239  //If we are not tracking a class then decrement the end index of the timeseries position tracker
240  if( !trackingClass ){
241  UINT endIndex = timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].getEndIndex();
242  timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( endIndex-1 );
243  }
244 
245  return true;
246 
247  }else return false;
248 
249 }
250 
252  UINT numExamplesRemoved = 0;
253  UINT numExamplesToRemove = 0;
254 
255  //Find out how many training examples we need to remove
256  for(UINT i=0; i<classTracker.size(); i++){
257  if( classTracker[i].classLabel == classLabel ){
258  numExamplesToRemove = classTracker[i].counter;
259  classTracker.erase(classTracker.begin()+i);
260  break; //There should only be one class with this classLabel so break
261  }
262  }
263 
264  //Remove the samples with the matching class ID
265  if( numExamplesToRemove > 0 ){
266  UINT i=0;
267  while( numExamplesRemoved < numExamplesToRemove ){
268  if( data[i].getClassLabel() == classLabel ){
269  data.erase(data.begin()+i);
270  numExamplesRemoved++;
271  }else if( ++i == data.size() ) break;
272  }
273  }
274 
275  //Update the time series position tracker
276  Vector< TimeSeriesPositionTracker >::iterator iter = timeSeriesPositionTracker.begin();
277 
278  while( iter != timeSeriesPositionTracker.end() ){
279  if( iter->getClassLabel() == classLabel ){
280  UINT length = iter->getLength();
281  //Update the start and end positions of all the following position trackers
282  Vector< TimeSeriesPositionTracker >::iterator updateIter = iter + 1;
283 
284  while( updateIter != timeSeriesPositionTracker.end() ){
285  updateIter->setStartIndex( updateIter->getStartIndex() - length );
286  updateIter->setEndIndex( updateIter->getEndIndex() - length );
287  updateIter++;
288  }
289 
290  //Erase the current position tracker
291  iter = timeSeriesPositionTracker.erase( iter );
292  }else iter++;
293  }
294 
295  totalNumSamples = (UINT)data.size();
296 
297  return numExamplesRemoved;
298 }
299 
300 bool ClassificationDataStream::relabelAllSamplesWithClassLabel(const UINT oldClassLabel,const UINT newClassLabel){
301  bool oldClassLabelFound = false;
302  bool newClassLabelAllReadyExists = false;
303  UINT indexOfOldClassLabel = 0;
304  UINT indexOfNewClassLabel = 0;
305 
306  //Find out how many training examples we need to relabel
307  for(UINT i=0; i<classTracker.size(); i++){
308  if( classTracker[i].classLabel == oldClassLabel ){
309  indexOfOldClassLabel = i;
310  oldClassLabelFound = true;
311  }
312  if( classTracker[i].classLabel == newClassLabel ){
313  indexOfNewClassLabel = i;
314  newClassLabelAllReadyExists = true;
315  }
316  }
317 
318  //If the old class label was not found then we can't do anything
319  if( !oldClassLabelFound ){
320  return false;
321  }
322 
323  //Relabel the old class labels
324  for(UINT i=0; i<totalNumSamples; i++){
325  if( data[i].getClassLabel() == oldClassLabel ){
326  data[i].set(newClassLabel, data[i].getSample());
327  }
328  }
329 
330  //Update the class label counters
331  if( newClassLabelAllReadyExists ){
332  //Add the old sample count to the new sample count
333  classTracker[ indexOfNewClassLabel ].counter += classTracker[ indexOfOldClassLabel ].counter;
334 
335  //Erase the old class tracker
336  classTracker.erase( classTracker.begin() + indexOfOldClassLabel );
337  }else{
338  //Create a new class tracker
339  classTracker.push_back( ClassTracker(newClassLabel,classTracker[ indexOfOldClassLabel ].counter,classTracker[ indexOfOldClassLabel ].className) );
340  }
341 
342  //Update the timeseries position tracker
343  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
344  if( timeSeriesPositionTracker[i].getClassLabel() == oldClassLabel ){
345  timeSeriesPositionTracker[i].setClassLabel( newClassLabel );
346  }
347  }
348 
349  return true;
350 }
351 
353 
354  if( externalRanges.size() != numDimensions ) return false;
355 
356  this->externalRanges = externalRanges;
357  this->useExternalRanges = useExternalRanges;
358 
359  return true;
360 }
361 
363  if( externalRanges.size() == numDimensions ){
364  this->useExternalRanges = useExternalRanges;
365  return true;
366  }
367  return false;
368 }
369 
370 bool ClassificationDataStream::scale(const Float minTarget,const Float maxTarget){
371  Vector< MinMax > ranges = getRanges();
372  return scale(ranges,minTarget,maxTarget);
373 }
374 
375 bool ClassificationDataStream::scale(const Vector<MinMax> &ranges,const Float minTarget,const Float maxTarget){
376  if( ranges.size() != numDimensions ) return false;
377 
378  //Scale the training data
379  for(UINT i=0; i<totalNumSamples; i++){
380  for(UINT j=0; j<numDimensions; j++){
381  data[i][j] = Util::scale(data[i][j],ranges[j].minValue,ranges[j].maxValue,minTarget,maxTarget);
382  }
383  }
384  return true;
385 }
386 
387 bool ClassificationDataStream::resetPlaybackIndex(const UINT playbackIndex){
388  if( playbackIndex < totalNumSamples ){
389  this->playbackIndex = playbackIndex;
390  return true;
391  }
392  return false;
393 }
394 
396  if( totalNumSamples == 0 ) return ClassificationSample();
397 
398  UINT index = playbackIndex++ % totalNumSamples;
399  return data[ index ];
400 }
401 
404  for(UINT x=0; x<timeSeriesPositionTracker.size(); x++){
405  if( timeSeriesPositionTracker[x].getClassLabel() == classLabel && timeSeriesPositionTracker[x].getEndIndex() > 0){
406  Matrix<Float> timeSeries;
407  for(UINT i=timeSeriesPositionTracker[x].getStartIndex(); i<timeSeriesPositionTracker[x].getEndIndex(); i++){
408  timeSeries.push_back( data[ i ].getSample() );
409  }
410  classData.addSample(classLabel,timeSeries);
411  }
412  }
413  return classData;
414 }
415 
417  UINT minClassLabel = 99999;
418 
419  for(UINT i=0; i<classTracker.size(); i++){
420  if( classTracker[i].classLabel < minClassLabel ){
421  minClassLabel = classTracker[i].classLabel;
422  }
423  }
424 
425  return minClassLabel;
426 }
427 
428 
430  UINT maxClassLabel = 0;
431 
432  for(UINT i=0; i<classTracker.size(); i++){
433  if( classTracker[i].classLabel > maxClassLabel ){
434  maxClassLabel = classTracker[i].classLabel;
435  }
436  }
437 
438  return maxClassLabel;
439 }
440 
441 UINT ClassificationDataStream::getClassLabelIndexValue(const UINT classLabel) const {
442  for(UINT k=0; k<classTracker.size(); k++){
443  if( classTracker[k].classLabel == classLabel ){
444  return k;
445  }
446  }
447  warningLog << "getClassLabelIndexValue(const UINT classLabel) - Failed to find class label: " << classLabel << " in class tracker!" << std::endl;
448  return 0;
449 }
450 
452 
453  for(UINT i=0; i<classTracker.size(); i++){
454  if( classTracker[i].classLabel == classLabel ){
455  return classTracker[i].className;
456  }
457  }
458  return "CLASS_LABEL_NOT_FOUND";
459 }
460 
462 
464 
465  //If the dataset should be scaled using the external ranges then return the external ranges
466  if( useExternalRanges ) return externalRanges;
467 
468  //Otherwise return the min and max values for each column in the dataset
469  if( totalNumSamples > 0 ){
470  for(UINT j=0; j<numDimensions; j++){
471  ranges[j].minValue = data[0][0];
472  ranges[j].maxValue = data[0][0];
473  for(UINT i=0; i<totalNumSamples; i++){
474  if( data[i][j] < ranges[j].minValue ){ ranges[j].minValue = data[i][j]; } //Search for the min value
475  else if( data[i][j] > ranges[j].maxValue ){ ranges[j].maxValue = data[i][j]; } //Search for the max value
476  }
477  }
478  }
479  return ranges;
480 }
481 
482 bool ClassificationDataStream::save(const std::string &filename){
483 
484  //Check if the file should be saved as a csv file
485  if( Util::stringEndsWith( filename, ".csv" ) ){
486  return saveDatasetToCSVFile( filename );
487  }
488 
489  //Otherwise save it as a custom GRT file
490  return saveDatasetToFile( filename );
491 }
492 
493 bool ClassificationDataStream::load(const std::string &filename){
494 
495  //Check if the file should be loaded as a csv file
496  if( Util::stringEndsWith( filename, ".csv" ) ){
497  return loadDatasetFromCSVFile( filename );
498  }
499 
500  //Otherwise save it as a custom GRT file
501  return loadDatasetFromFile( filename );
502 }
503 
504 bool ClassificationDataStream::saveDatasetToFile(const std::string &filename) {
505 
506  std::fstream file;
507  file.open(filename.c_str(), std::ios::out);
508 
509  if( !file.is_open() ){
510  errorLog << "saveDatasetToFile(const std::string &filename) - Failed to open file!" << std::endl;
511  return false;
512  }
513 
514  if( trackingClass ){
515  //The class tracker was not stopped so assume the last sample is the end
516  trackingClass = false;
517  timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( totalNumSamples-1 );
518  }
519 
520  file << "GRT_LABELLED_CONTINUOUS_TIME_SERIES_CLASSIFICATION_FILE_V1.0\n";
521  file << "DatasetName: " << datasetName << std::endl;
522  file << "InfoText: " << infoText << std::endl;
523  file << "NumDimensions: " << numDimensions << std::endl;
524  file << "TotalNumSamples: " << totalNumSamples << std::endl;
525  file << "NumberOfClasses: " << classTracker.size() << std::endl;
526  file << "ClassIDsAndCounters: " << std::endl;
527  for(UINT i=0; i<classTracker.size(); i++){
528  file << classTracker[i].classLabel << "\t" << classTracker[i].counter << std::endl;
529  }
530 
531  file << "NumberOfPositionTrackers: " << timeSeriesPositionTracker.size() << std::endl;
532  file << "TimeSeriesPositionTrackers: " << std::endl;
533  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
534  file << timeSeriesPositionTracker[i].getClassLabel() << "\t" << timeSeriesPositionTracker[i].getStartIndex() << "\t" << timeSeriesPositionTracker[i].getEndIndex() << std::endl;
535  }
536 
537  file << "UseExternalRanges: " << useExternalRanges << std::endl;
538 
539  if( useExternalRanges ){
540  for(UINT i=0; i<externalRanges.size(); i++){
541  file << externalRanges[i].minValue << "\t" << externalRanges[i].maxValue << std::endl;
542  }
543  }
544 
545  file << "LabelledContinuousTimeSeriesClassificationData:\n";
546  for(UINT i=0; i<totalNumSamples; i++){
547  file << data[i].getClassLabel();
548  for(UINT j=0; j<numDimensions; j++){
549  file << "\t" << data[i][j];
550  }
551  file << std::endl;
552  }
553 
554  file.close();
555  return true;
556 }
557 
558 bool ClassificationDataStream::loadDatasetFromFile(const std::string &filename){
559 
560  std::fstream file;
561  file.open(filename.c_str(), std::ios::in);
562  UINT numClasses = 0;
563  UINT numTrackingPoints = 0;
564  clear();
565 
566  if( !file.is_open() ){
567  errorLog<< "loadDatasetFromFile(string fileName) - Failed to open file!" << std::endl;
568  return false;
569  }
570 
571  std::string word;
572 
573  //Check to make sure this is a file with the Training File Format
574  file >> word;
575  if(word != "GRT_LABELLED_CONTINUOUS_TIME_SERIES_CLASSIFICATION_FILE_V1.0"){
576  file.close();
577  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find file header!" << std::endl;
578  return false;
579  }
580 
581  //Get the name of the dataset
582  file >> word;
583  if(word != "DatasetName:"){
584  errorLog << "loadDatasetFromFile(string filename) - failed to find DatasetName!" << std::endl;
585  file.close();
586  return false;
587  }
588  file >> datasetName;
589 
590  file >> word;
591  if(word != "InfoText:"){
592  errorLog << "loadDatasetFromFile(string filename) - failed to find InfoText!" << std::endl;
593  file.close();
594  return false;
595  }
596 
597  //Load the info text
598  file >> word;
599  infoText = "";
600  while( word != "NumDimensions:" ){
601  infoText += word + " ";
602  file >> word;
603  }
604 
605  //Get the number of dimensions in the training data
606  if(word != "NumDimensions:"){
607  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find NumDimensions!" << std::endl;
608  file.close();
609  return false;
610  }
611  file >> numDimensions;
612 
613  //Get the total number of training examples in the training data
614  file >> word;
615  if(word != "TotalNumSamples:"){
616  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find TotalNumSamples!" << std::endl;
617  file.close();
618  return false;
619  }
620  file >> totalNumSamples;
621 
622  //Get the total number of classes in the training data
623  file >> word;
624  if(word != "NumberOfClasses:"){
625  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find NumberOfClasses!" << std::endl;
626  file.close();
627  return false;
628  }
629  file >> numClasses;
630 
631  //Resize the class counter buffer and load the counters
632  classTracker.resize(numClasses);
633 
634  //Get the total number of classes in the training data
635  file >> word;
636  if(word != "ClassIDsAndCounters:"){
637  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find ClassIDsAndCounters!" << std::endl;
638  file.close();
639  return false;
640  }
641 
642  for(UINT i=0; i<classTracker.size(); i++){
643  file >> classTracker[i].classLabel;
644  file >> classTracker[i].counter;
645  }
646 
647  //Get the NumberOfPositionTrackers
648  file >> word;
649  if(word != "NumberOfPositionTrackers:"){
650  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find NumberOfPositionTrackers!" << std::endl;
651  file.close();
652  return false;
653  }
654  file >> numTrackingPoints;
655  timeSeriesPositionTracker.resize( numTrackingPoints );
656 
657  //Get the TimeSeriesPositionTrackers
658  file >> word;
659  if(word != "TimeSeriesPositionTrackers:"){
660  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find TimeSeriesPositionTrackers!" << std::endl;
661  file.close();
662  return false;
663  }
664 
665  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
666  UINT classLabel;
667  UINT startIndex;
668  UINT endIndex;
669  file >> classLabel;
670  file >> startIndex;
671  file >> endIndex;
672  timeSeriesPositionTracker[i].setTracker(startIndex,endIndex,classLabel);
673  }
674 
675  //Check if the dataset should be scaled using external ranges
676  file >> word;
677  if(word != "UseExternalRanges:"){
678  errorLog << "loadDatasetFromFile(string filename) - failed to find DatasetName!" << std::endl;
679  file.close();
680  return false;
681  }
682  file >> useExternalRanges;
683 
684  //If we are using external ranges then load them
685  if( useExternalRanges ){
686  externalRanges.resize(numDimensions);
687  for(UINT i=0; i<externalRanges.size(); i++){
688  file >> externalRanges[i].minValue;
689  file >> externalRanges[i].maxValue;
690  }
691  }
692 
693  //Get the main time series data
694  file >> word;
695  if(word != "LabelledContinuousTimeSeriesClassificationData:"){
696  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find LabelledContinuousTimeSeriesClassificationData!" << std::endl;
697  file.close();
698  return false;
699  }
700 
701  //Reset the memory
702  data.resize( totalNumSamples, ClassificationSample() );
703 
704  //Load each sample
705  UINT classLabel = 0;
706  VectorFloat sample(numDimensions);
707  for(UINT i=0; i<totalNumSamples; i++){
708 
709  file >> classLabel;
710  for(UINT j=0; j<numDimensions; j++){
711  file >> sample[j];
712  }
713 
714  data[i].set(classLabel,sample);
715  }
716 
717  file.close();
718  return true;
719 }
720 
721 bool ClassificationDataStream::saveDatasetToCSVFile(const std::string &filename) {
722  std::fstream file;
723  file.open(filename.c_str(), std::ios::out );
724 
725  if( !file.is_open() ){
726  return false;
727  }
728 
729  //Write the data to the CSV file
730 
731  for(UINT i=0; i<data.size(); i++){
732  file << data[i].getClassLabel();
733  for(UINT j=0; j<numDimensions; j++){
734  file << "," << data[i][j];
735  }
736  file << std::endl;
737  }
738 
739  file.close();
740 
741  return true;
742 }
743 
744 bool ClassificationDataStream::loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex){
745 
746  datasetName = "NOT_SET";
747  infoText = "";
748 
749  //Clear any previous data
750  clear();
751 
752  //Parse the CSV file
753  FileParser parser;
754 
755  if( !parser.parseCSVFile(filename,true) ){
756  errorLog << "loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - Failed to parse CSV file!" << std::endl;
757  return false;
758  }
759 
760  if( !parser.getConsistentColumnSize() ){
761  errorLog << "loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - The CSV file does not have a consistent number of columns!" << std::endl;
762  return false;
763  }
764 
765  if( parser.getColumnSize() <= 1 ){
766  errorLog << "loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - The CSV file does not have enough columns! It should contain at least two columns!" << std::endl;
767  return false;
768  }
769 
770  //Set the number of dimensions
771  numDimensions = parser.getColumnSize()-1;
772  UINT classLabel = 0;
773  UINT j = 0;
774  UINT n = 0;
775  VectorFloat sample(numDimensions);
776  for(UINT i=0; i<parser.getRowSize(); i++){
777  //Get the class label
778  classLabel = Util::stringToInt( parser[i][classLabelColumnIndex] );
779 
780  //Get the sample data
781  j=0;
782  n=0;
783  while( j != numDimensions ){
784  if( n != classLabelColumnIndex ){
785  sample[j++] = Util::stringToFloat( parser[i][n] );
786  }
787  n++;
788  }
789 
790  //Add the labelled sample to the dataset
791  if( !addSample(classLabel, sample) ){
792  warningLog << "loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - Could not add sample " << i << " to the dataset!" << std::endl;
793  }
794  }
795 
796  return true;
797 }
798 
800 
801  std::cout << "DatasetName:\t" << datasetName << std::endl;
802  std::cout << "DatasetInfo:\t" << infoText << std::endl;
803  std::cout << "Number of Dimensions:\t" << numDimensions << std::endl;
804  std::cout << "Number of Samples:\t" << totalNumSamples << std::endl;
805  std::cout << "Number of Classes:\t" << getNumClasses() << std::endl;
806  std::cout << "ClassStats:\n";
807 
808  for(UINT k=0; k<getNumClasses(); k++){
809  std::cout << "ClassLabel:\t" << classTracker[k].classLabel;
810  std::cout << "\tNumber of Samples:\t" << classTracker[k].counter;
811  std::cout << "\tClassName:\t" << classTracker[k].className << std::endl;
812  }
813 
814  std::cout << "TimeSeriesMarkerStats:\n";
815  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
816  std::cout << "ClassLabel: " << timeSeriesPositionTracker[i].getClassLabel();
817  std::cout << "\tStartIndex: " << timeSeriesPositionTracker[i].getStartIndex();
818  std::cout << "\tEndIndex: " << timeSeriesPositionTracker[i].getEndIndex();
819  std::cout << "\tLength: " << timeSeriesPositionTracker[i].getLength() << std::endl;
820  }
821 
822  Vector< MinMax > ranges = getRanges();
823 
824  std::cout << "Dataset Ranges:\n";
825  for(UINT j=0; j<ranges.size(); j++){
826  std::cout << "[" << j+1 << "] Min:\t" << ranges[j].minValue << "\tMax: " << ranges[j].maxValue << std::endl;
827  }
828 
829  return true;
830 }
831 
832 ClassificationDataStream ClassificationDataStream::getSubset(const UINT startIndex,const UINT endIndex) const {
833 
835 
836  if( endIndex >= totalNumSamples ){
837  warningLog << "getSubset(const UINT startIndex,const UINT endIndex) - The endIndex is greater than or equal to the number of samples in the current dataset!" << std::endl;
838  return subset;
839  }
840 
841  if( startIndex >= endIndex ){
842  warningLog << "getSubset(const UINT startIndex,const UINT endIndex) - The startIndex is greater than or equal to the endIndex!" << std::endl;
843  return subset;
844  }
845 
846  //Set the header info
848  subset.setDatasetName( getDatasetName() );
849  subset.setInfoText( getInfoText() );
850 
851  //Add the data
852  for(UINT i=startIndex; i<=endIndex; i++){
853  subset.addSample(data[i].getClassLabel(), data[i].getSample());
854  }
855 
856  return subset;
857 }
858 
860 
862 
864  tsData.setAllowNullGestureClass( includeNullGestures );
865 
866  bool addSample = false;
867  const UINT numTimeseries = (UINT)timeSeriesPositionTracker.size();
868  for(UINT i=0; i<numTimeseries; i++){
869  addSample = includeNullGestures ? true : timeSeriesPositionTracker[i].getClassLabel() != GRT_DEFAULT_NULL_CLASS_LABEL;
870  if( addSample ){
871  tsData.addSample(timeSeriesPositionTracker[i].getClassLabel(), getTimeSeriesData( timeSeriesPositionTracker[i] ) );
872  }
873  }
874 
875  return tsData;
876 }
877 
879 
880  ClassificationData classificationData;
881 
882  classificationData.setNumDimensions( getNumDimensions() );
883  classificationData.setAllowNullGestureClass( includeNullGestures );
884 
885  bool addSample = false;
886  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
887  addSample = includeNullGestures ? true : timeSeriesPositionTracker[i].getClassLabel() != GRT_DEFAULT_NULL_CLASS_LABEL;
888  if( addSample ){
889  MatrixFloat dataSegment = getTimeSeriesData( timeSeriesPositionTracker[i] );
890  for(UINT j=0; j<dataSegment.getNumRows(); j++){
891  classificationData.addSample(timeSeriesPositionTracker[i].getClassLabel(), dataSegment.getRow(j) );
892  }
893  }
894  }
895 
896  return classificationData;
897 }
898 
900 
901  if( trackerInfo.getStartIndex() >= totalNumSamples || trackerInfo.getEndIndex() > totalNumSamples ){
902  warningLog << "getTimeSeriesData(TimeSeriesPositionTracker trackerInfo) - Invalid tracker indexs!" << std::endl;
903  return MatrixFloat();
904  }
905 
906  UINT startIndex = trackerInfo.getStartIndex();
907  UINT endIndex = trackerInfo.getEndIndex();
908  UINT M = endIndex > 0 ? trackerInfo.getLength() : totalNumSamples - startIndex;
909  UINT N = getNumDimensions();
910 
911  MatrixFloat tsData(M,N);
912  for(UINT i=0; i<M; i++){
913  for(UINT j=0; j<N; j++){
914  tsData[i][j] = data[ i+startIndex ][j];
915  }
916  }
917  return tsData;
918 }
919 
921  UINT M = getNumSamples();
922  UINT N = getNumDimensions();
923  MatrixFloat matrixData(M,N);
924  for(UINT i=0; i<M; i++){
925  for(UINT j=0; j<N; j++){
926  matrixData[i][j] = data[i][j];
927  }
928  }
929  return matrixData;
930 }
931 
933  const UINT K = (UINT)classTracker.size();
934  Vector< UINT > classLabels( K );
935 
936  for(UINT i=0; i<K; i++){
937  classLabels[i] = classTracker[i].classLabel;
938  }
939 
940  return classLabels;
941 }
942 
943 GRT_END_NAMESPACE
944 
bool loadDatasetFromCSVFile(const std::string &filename, const UINT classLabelColumnIndex=0)
TimeSeriesClassificationData getTimeSeriesClassificationData(const bool includeNullGestures=false) const
UINT eraseAllSamplesWithClassLabel(const UINT classLabel)
bool setAllowNullGestureClass(const bool allowNullGestureClass)
bool enableExternalRangeScaling(const bool useExternalRanges)
bool save(const std::string &filename)
bool addSample(const UINT classLabel, const VectorFloat &sample)
static Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)
Definition: Util.cpp:55
bool setDatasetName(const std::string datasetName)
static Float stringToFloat(const std::string &s)
Definition: Util.cpp:140
ClassificationDataStream(const UINT numDimensions=0, const std::string datasetName="NOT_SET", const std::string infoText="")
bool setNumDimensions(const UINT numDimensions)
ClassificationDataStream & operator=(const ClassificationDataStream &rhs)
bool resetPlaybackIndex(const UINT playbackIndex)
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
bool setNumDimensions(UINT numDimensions)
bool setExternalRanges(const Vector< MinMax > &externalRanges, const bool useExternalRanges=false)
virtual bool setKey(const std::string &key)
sets the key that gets written at the start of each message, this will be written in the format &#39;key ...
Definition: Log.h:166
std::string getClassNameForCorrespondingClassLabel(const UINT classLabel)
bool saveDatasetToCSVFile(const std::string &filename)
bool scale(const Float minTarget, const Float maxTarget)
bool load(const std::string &filename)
DebugLog debugLog
Default debugging log.
ClassificationData getClassificationData(const bool includeNullGestures=false) const
Vector< UINT > getClassLabels() const
ErrorLog errorLog
Default error log.
bool useExternalRanges
A flag to show if the dataset should be scaled using the externalRanges values.
ClassificationDataStream getSubset(const UINT startIndex, const UINT endIndex) const
bool loadDatasetFromFile(const std::string &filename)
Vector< MinMax > externalRanges
A Vector containing a set of externalRanges set by the user.
ClassificationSample getNextSample()
std::string infoText
Some infoText about the dataset.
bool setInfoText(const std::string infoText)
UINT numDimensions
The number of dimensions in the dataset.
std::string datasetName
The name of the dataset.
The ClassificationDataStream is the main data structure for recording, labeling, managing, saving, and loading datasets that can be used to test the continuous classification abilities of the GRT supervised learning algorithms.
bool relabelAllSamplesWithClassLabel(const UINT oldClassLabel, const UINT newClassLabel)
unsigned int getNumRows() const
Definition: Matrix.h:574
MatrixFloat getDataAsMatrixFloat() const
unsigned int getNumCols() const
Definition: Matrix.h:581
bool addSample(const UINT classLabel, const VectorFloat &sample)
bool saveDatasetToFile(const std::string &filename)
bool addSample(const UINT classLabel, const MatrixFloat &trainingSample)
VectorFloat getRow(const unsigned int r) const
Definition: MatrixFloat.h:107
bool setNumDimensions(const UINT numDimensions)
static bool stringEndsWith(const std::string &str, const std::string &ending)
Definition: Util.cpp:164
MatrixFloat getTimeSeriesData(const TimeSeriesPositionTracker &trackerInfo) const
Vector< MinMax > getRanges() const
This class stores the class label and raw data for a single labelled classification sample...
Definition: Vector.h:41
static int stringToInt(const std::string &s)
Definition: Util.cpp:133
WarningLog warningLog
Default warning log.
TimeSeriesClassificationData getAllTrainingExamplesWithClassLabel(const UINT classLabel) const
bool push_back(const Vector< T > &sample)
Definition: Matrix.h:431
UINT getClassLabelIndexValue(const UINT classLabel) const
bool setClassNameForCorrespondingClassLabel(const std::string className, const UINT classLabel)
bool setAllowNullGestureClass(const bool allowNullGestureClass)
std::string getDatasetName() const