GestureRecognitionToolkit  Version: 0.1.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClassificationDataStream.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
22 
23 GRT_BEGIN_NAMESPACE
24 
25 //Constructors and Destructors
26 ClassificationDataStream::ClassificationDataStream(const UINT numDimensions,const std::string datasetName,const std::string infoText){
27 
28  this->numDimensions= numDimensions;
29  this->datasetName = datasetName;
30  this->infoText = infoText;
31 
32  playbackIndex = 0;
33  trackingClass = false;
34  useExternalRanges = false;
35  debugLog.setProceedingText("[DEBUG ClassificationDataStream]");
36  errorLog.setProceedingText("[ERROR ClassificationDataStream]");
37  warningLog.setProceedingText("[WARNING ClassificationDataStream]");
38 
39  if( numDimensions > 0 ){
40  setNumDimensions(numDimensions);
41  }
42 }
43 
45  *this = rhs;
46 }
47 
49 
51  if( this != &rhs){
52  this->datasetName = rhs.datasetName;
53  this->infoText = rhs.infoText;
54  this->numDimensions = rhs.numDimensions;
55  this->totalNumSamples = rhs.totalNumSamples;
56  this->lastClassID = rhs.lastClassID;
57  this->playbackIndex = rhs.playbackIndex;
58  this->trackingClass = rhs.trackingClass;
60  this->externalRanges = rhs.externalRanges;
61  this->data = rhs.data;
62  this->classTracker = rhs.classTracker;
63  this->timeSeriesPositionTracker = rhs.timeSeriesPositionTracker;
64  this->debugLog = rhs.debugLog;
65  this->warningLog = rhs.warningLog;
66  this->errorLog = rhs.errorLog;
67 
68  }
69  return *this;
70 }
71 
73  totalNumSamples = 0;
74  playbackIndex = 0;
75  trackingClass = false;
76  data.clear();
77  classTracker.clear();
78  timeSeriesPositionTracker.clear();
79 }
80 
81 bool ClassificationDataStream::setNumDimensions(const UINT numDimensions){
82  if( numDimensions > 0 ){
83  //Clear any previous data
84  clear();
85 
86  //Set the dimensionality of the time series data
87  this->numDimensions = numDimensions;
88 
89  return true;
90  }
91 
92  errorLog << "setNumDimensions(const UINT numDimensions) - The number of dimensions of the dataset must be greater than zero!" << std::endl;
93  return false;
94 }
95 
96 
97 bool ClassificationDataStream::setDatasetName(const std::string datasetName){
98 
99  //Make sure there are no spaces in the std::string
100  if( datasetName.find(" ") == std::string::npos ){
101  this->datasetName = datasetName;
102  return true;
103  }
104 
105  errorLog << "setDatasetName(const std::string datasetName) - The dataset name cannot contain any spaces!" << std::endl;
106  return false;
107 }
108 
109 bool ClassificationDataStream::setInfoText(std::string infoText){
110  this->infoText = infoText;
111  return true;
112 }
113 
114 bool ClassificationDataStream::setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel){
115 
116  for(UINT i=0; i<classTracker.size(); i++){
117  if( classTracker[i].classLabel == classLabel ){
118  classTracker[i].className = className;
119  return true;
120  }
121  }
122 
123  errorLog << "setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel) - Failed to find class with label: " << classLabel << std::endl;
124  return false;
125 }
126 
127 bool ClassificationDataStream::addSample(const UINT classLabel,const VectorFloat &sample){
128 
129  if( numDimensions != sample.size() ){
130  errorLog << "addSample(const UINT classLabel, VectorFloat sample) - the size of the new sample (" << sample.size() << ") does not match the number of dimensions of the dataset (" << numDimensions << ")" << std::endl;
131  return false;
132  }
133 
134  bool searchForNewClass = true;
135  if( trackingClass ){
136  if( classLabel != lastClassID ){
137  //The class ID has changed so update the time series tracker
138  timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( totalNumSamples-1 );
139  }else searchForNewClass = false;
140  }
141 
142  if( searchForNewClass ){
143  bool newClass = true;
144  //Search to see if this class has been found before
145  for(UINT k=0; k<classTracker.size(); k++){
146  if( classTracker[k].classLabel == classLabel ){
147  newClass = false;
148  classTracker[k].counter++;
149  }
150  }
151  if( newClass ){
152  ClassTracker newCounter(classLabel,1);
153  classTracker.push_back( newCounter );
154  }
155 
156  //Set the timeSeriesPositionTracker start position
157  trackingClass = true;
158  lastClassID = classLabel;
159  TimeSeriesPositionTracker newTracker(totalNumSamples,0,classLabel);
160  timeSeriesPositionTracker.push_back( newTracker );
161  }
162 
163  ClassificationSample labelledSample(classLabel,sample);
164  data.push_back( labelledSample );
165  totalNumSamples++;
166  return true;
167 }
168 
169 bool ClassificationDataStream::addSample(const UINT classLabel,const MatrixFloat &sample){
170 
171  if( numDimensions != sample.getNumCols() ){
172  errorLog << "addSample(const UINT classLabel, const MatrixFloat &sample) - the number of columns in the sample (" << sample.getNumCols() << ") does not match the number of dimensions of the dataset (" << numDimensions << ")" << std::endl;
173  return false;
174  }
175 
176  bool searchForNewClass = true;
177  if( trackingClass ){
178  if( classLabel != lastClassID ){
179  //The class ID has changed so update the time series tracker
180  timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( totalNumSamples-1 );
181  }else searchForNewClass = false;
182  }
183 
184  if( searchForNewClass ){
185  bool newClass = true;
186  //Search to see if this class has been found before
187  for(UINT k=0; k<classTracker.size(); k++){
188  if( classTracker[k].classLabel == classLabel ){
189  newClass = false;
190  classTracker[k].counter += sample.getNumRows();
191  }
192  }
193  if( newClass ){
194  ClassTracker newCounter(classLabel,1);
195  classTracker.push_back( newCounter );
196  }
197 
198  //Set the timeSeriesPositionTracker start position
199  trackingClass = true;
200  lastClassID = classLabel;
201  TimeSeriesPositionTracker newTracker(totalNumSamples,0,classLabel);
202  timeSeriesPositionTracker.push_back( newTracker );
203  }
204 
205  ClassificationSample labelledSample( numDimensions );
206  for(UINT i=0; i<sample.getNumRows(); i++){
207  data.push_back( labelledSample );
208  data.back().setClassLabel( classLabel );
209  for(UINT j=0; j<numDimensions; j++){
210  data.back()[j] = sample[i][j];
211  }
212  }
213  totalNumSamples += sample.getNumRows();
214  return true;
215 
216 }
217 
219 
220  if( totalNumSamples > 0 ){
221 
222  //Find the corresponding class ID for the last training example
223  UINT classLabel = data[ totalNumSamples-1 ].getClassLabel();
224 
225  //Remove the training example from the buffer
226  data.erase( data.end()-1 );
227 
228  totalNumSamples = (UINT)data.size();
229 
230  //Remove the value from the counter
231  for(UINT i=0; i<classTracker.size(); i++){
232  if( classTracker[i].classLabel == classLabel ){
233  classTracker[i].counter--;
234  break;
235  }
236  }
237 
238  //If we are not tracking a class then decrement the end index of the timeseries position tracker
239  if( !trackingClass ){
240  UINT endIndex = timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].getEndIndex();
241  timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( endIndex-1 );
242  }
243 
244  return true;
245 
246  }else return false;
247 
248 }
249 
251  UINT numExamplesRemoved = 0;
252  UINT numExamplesToRemove = 0;
253 
254  //Find out how many training examples we need to remove
255  for(UINT i=0; i<classTracker.size(); i++){
256  if( classTracker[i].classLabel == classLabel ){
257  numExamplesToRemove = classTracker[i].counter;
258  classTracker.erase(classTracker.begin()+i);
259  break; //There should only be one class with this classLabel so break
260  }
261  }
262 
263  //Remove the samples with the matching class ID
264  if( numExamplesToRemove > 0 ){
265  UINT i=0;
266  while( numExamplesRemoved < numExamplesToRemove ){
267  if( data[i].getClassLabel() == classLabel ){
268  data.erase(data.begin()+i);
269  numExamplesRemoved++;
270  }else if( ++i == data.size() ) break;
271  }
272  }
273 
274  //Update the time series position tracker
275  Vector< TimeSeriesPositionTracker >::iterator iter = timeSeriesPositionTracker.begin();
276 
277  while( iter != timeSeriesPositionTracker.end() ){
278  if( iter->getClassLabel() == classLabel ){
279  UINT length = iter->getLength();
280  //Update the start and end positions of all the following position trackers
281  Vector< TimeSeriesPositionTracker >::iterator updateIter = iter + 1;
282 
283  while( updateIter != timeSeriesPositionTracker.end() ){
284  updateIter->setStartIndex( updateIter->getStartIndex() - length );
285  updateIter->setEndIndex( updateIter->getEndIndex() - length );
286  updateIter++;
287  }
288 
289  //Erase the current position tracker
290  iter = timeSeriesPositionTracker.erase( iter );
291  }else iter++;
292  }
293 
294  totalNumSamples = (UINT)data.size();
295 
296  return numExamplesRemoved;
297 }
298 
299 bool ClassificationDataStream::relabelAllSamplesWithClassLabel(const UINT oldClassLabel,const UINT newClassLabel){
300  bool oldClassLabelFound = false;
301  bool newClassLabelAllReadyExists = false;
302  UINT indexOfOldClassLabel = 0;
303  UINT indexOfNewClassLabel = 0;
304 
305  //Find out how many training examples we need to relabel
306  for(UINT i=0; i<classTracker.size(); i++){
307  if( classTracker[i].classLabel == oldClassLabel ){
308  indexOfOldClassLabel = i;
309  oldClassLabelFound = true;
310  }
311  if( classTracker[i].classLabel == newClassLabel ){
312  indexOfNewClassLabel = i;
313  newClassLabelAllReadyExists = true;
314  }
315  }
316 
317  //If the old class label was not found then we can't do anything
318  if( !oldClassLabelFound ){
319  return false;
320  }
321 
322  //Relabel the old class labels
323  for(UINT i=0; i<totalNumSamples; i++){
324  if( data[i].getClassLabel() == oldClassLabel ){
325  data[i].set(newClassLabel, data[i].getSample());
326  }
327  }
328 
329  //Update the class label counters
330  if( newClassLabelAllReadyExists ){
331  //Add the old sample count to the new sample count
332  classTracker[ indexOfNewClassLabel ].counter += classTracker[ indexOfOldClassLabel ].counter;
333 
334  //Erase the old class tracker
335  classTracker.erase( classTracker.begin() + indexOfOldClassLabel );
336  }else{
337  //Create a new class tracker
338  classTracker.push_back( ClassTracker(newClassLabel,classTracker[ indexOfOldClassLabel ].counter,classTracker[ indexOfOldClassLabel ].className) );
339  }
340 
341  //Update the timeseries position tracker
342  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
343  if( timeSeriesPositionTracker[i].getClassLabel() == oldClassLabel ){
344  timeSeriesPositionTracker[i].setClassLabel( newClassLabel );
345  }
346  }
347 
348  return true;
349 }
350 
351 bool ClassificationDataStream::setExternalRanges(const Vector< MinMax > &externalRanges, const bool useExternalRanges){
352 
353  if( externalRanges.size() != numDimensions ) return false;
354 
355  this->externalRanges = externalRanges;
356  this->useExternalRanges = useExternalRanges;
357 
358  return true;
359 }
360 
361 bool ClassificationDataStream::enableExternalRangeScaling(const bool useExternalRanges){
362  if( externalRanges.size() == numDimensions ){
363  this->useExternalRanges = useExternalRanges;
364  return true;
365  }
366  return false;
367 }
368 
369 bool ClassificationDataStream::scale(const Float minTarget,const Float maxTarget){
370  Vector< MinMax > ranges = getRanges();
371  return scale(ranges,minTarget,maxTarget);
372 }
373 
374 bool ClassificationDataStream::scale(const Vector<MinMax> &ranges,const Float minTarget,const Float maxTarget){
375  if( ranges.size() != numDimensions ) return false;
376 
377  //Scale the training data
378  for(UINT i=0; i<totalNumSamples; i++){
379  for(UINT j=0; j<numDimensions; j++){
380  data[i][j] = Util::scale(data[i][j],ranges[j].minValue,ranges[j].maxValue,minTarget,maxTarget);
381  }
382  }
383  return true;
384 }
385 
386 bool ClassificationDataStream::resetPlaybackIndex(const UINT playbackIndex){
387  if( playbackIndex < totalNumSamples ){
388  this->playbackIndex = playbackIndex;
389  return true;
390  }
391  return false;
392 }
393 
395  if( totalNumSamples == 0 ) return ClassificationSample();
396 
397  UINT index = playbackIndex++ % totalNumSamples;
398  return data[ index ];
399 }
400 
403  for(UINT x=0; x<timeSeriesPositionTracker.size(); x++){
404  if( timeSeriesPositionTracker[x].getClassLabel() == classLabel && timeSeriesPositionTracker[x].getEndIndex() > 0){
405  Matrix<Float> timeSeries;
406  for(UINT i=timeSeriesPositionTracker[x].getStartIndex(); i<timeSeriesPositionTracker[x].getEndIndex(); i++){
407  timeSeries.push_back( data[ i ].getSample() );
408  }
409  classData.addSample(classLabel,timeSeries);
410  }
411  }
412  return classData;
413 }
414 
416  UINT minClassLabel = 99999;
417 
418  for(UINT i=0; i<classTracker.size(); i++){
419  if( classTracker[i].classLabel < minClassLabel ){
420  minClassLabel = classTracker[i].classLabel;
421  }
422  }
423 
424  return minClassLabel;
425 }
426 
427 
429  UINT maxClassLabel = 0;
430 
431  for(UINT i=0; i<classTracker.size(); i++){
432  if( classTracker[i].classLabel > maxClassLabel ){
433  maxClassLabel = classTracker[i].classLabel;
434  }
435  }
436 
437  return maxClassLabel;
438 }
439 
440 UINT ClassificationDataStream::getClassLabelIndexValue(const UINT classLabel) const {
441  for(UINT k=0; k<classTracker.size(); k++){
442  if( classTracker[k].classLabel == classLabel ){
443  return k;
444  }
445  }
446  warningLog << "getClassLabelIndexValue(const UINT classLabel) - Failed to find class label: " << classLabel << " in class tracker!" << std::endl;
447  return 0;
448 }
449 
451 
452  for(UINT i=0; i<classTracker.size(); i++){
453  if( classTracker[i].classLabel == classLabel ){
454  return classTracker[i].className;
455  }
456  }
457  return "CLASS_LABEL_NOT_FOUND";
458 }
459 
461 
463 
464  //If the dataset should be scaled using the external ranges then return the external ranges
465  if( useExternalRanges ) return externalRanges;
466 
467  //Otherwise return the min and max values for each column in the dataset
468  if( totalNumSamples > 0 ){
469  for(UINT j=0; j<numDimensions; j++){
470  ranges[j].minValue = data[0][0];
471  ranges[j].maxValue = data[0][0];
472  for(UINT i=0; i<totalNumSamples; i++){
473  if( data[i][j] < ranges[j].minValue ){ ranges[j].minValue = data[i][j]; } //Search for the min value
474  else if( data[i][j] > ranges[j].maxValue ){ ranges[j].maxValue = data[i][j]; } //Search for the max value
475  }
476  }
477  }
478  return ranges;
479 }
480 
481 bool ClassificationDataStream::save(const std::string &filename){
482 
483  //Check if the file should be saved as a csv file
484  if( Util::stringEndsWith( filename, ".csv" ) ){
485  return saveDatasetToCSVFile( filename );
486  }
487 
488  //Otherwise save it as a custom GRT file
489  return saveDatasetToFile( filename );
490 }
491 
492 bool ClassificationDataStream::load(const std::string &filename){
493 
494  //Check if the file should be loaded as a csv file
495  if( Util::stringEndsWith( filename, ".csv" ) ){
496  return loadDatasetFromCSVFile( filename );
497  }
498 
499  //Otherwise save it as a custom GRT file
500  return loadDatasetFromFile( filename );
501 }
502 
503 bool ClassificationDataStream::saveDatasetToFile(const std::string &filename) {
504 
505  std::fstream file;
506  file.open(filename.c_str(), std::ios::out);
507 
508  if( !file.is_open() ){
509  errorLog << "saveDatasetToFile(const std::string &filename) - Failed to open file!" << std::endl;
510  return false;
511  }
512 
513  if( trackingClass ){
514  //The class tracker was not stopped so assume the last sample is the end
515  trackingClass = false;
516  timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( totalNumSamples-1 );
517  }
518 
519  file << "GRT_LABELLED_CONTINUOUS_TIME_SERIES_CLASSIFICATION_FILE_V1.0\n";
520  file << "DatasetName: " << datasetName << std::endl;
521  file << "InfoText: " << infoText << std::endl;
522  file << "NumDimensions: " << numDimensions << std::endl;
523  file << "TotalNumSamples: " << totalNumSamples << std::endl;
524  file << "NumberOfClasses: " << classTracker.size() << std::endl;
525  file << "ClassIDsAndCounters: " << std::endl;
526  for(UINT i=0; i<classTracker.size(); i++){
527  file << classTracker[i].classLabel << "\t" << classTracker[i].counter << std::endl;
528  }
529 
530  file << "NumberOfPositionTrackers: " << timeSeriesPositionTracker.size() << std::endl;
531  file << "TimeSeriesPositionTrackers: " << std::endl;
532  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
533  file << timeSeriesPositionTracker[i].getClassLabel() << "\t" << timeSeriesPositionTracker[i].getStartIndex() << "\t" << timeSeriesPositionTracker[i].getEndIndex() << std::endl;
534  }
535 
536  file << "UseExternalRanges: " << useExternalRanges << std::endl;
537 
538  if( useExternalRanges ){
539  for(UINT i=0; i<externalRanges.size(); i++){
540  file << externalRanges[i].minValue << "\t" << externalRanges[i].maxValue << std::endl;
541  }
542  }
543 
544  file << "LabelledContinuousTimeSeriesClassificationData:\n";
545  for(UINT i=0; i<totalNumSamples; i++){
546  file << data[i].getClassLabel();
547  for(UINT j=0; j<numDimensions; j++){
548  file << "\t" << data[i][j];
549  }
550  file << std::endl;
551  }
552 
553  file.close();
554  return true;
555 }
556 
557 bool ClassificationDataStream::loadDatasetFromFile(const std::string &filename){
558 
559  std::fstream file;
560  file.open(filename.c_str(), std::ios::in);
561  UINT numClasses = 0;
562  UINT numTrackingPoints = 0;
563  clear();
564 
565  if( !file.is_open() ){
566  errorLog<< "loadDatasetFromFile(string fileName) - Failed to open file!" << std::endl;
567  return false;
568  }
569 
570  std::string word;
571 
572  //Check to make sure this is a file with the Training File Format
573  file >> word;
574  if(word != "GRT_LABELLED_CONTINUOUS_TIME_SERIES_CLASSIFICATION_FILE_V1.0"){
575  file.close();
576  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find file header!" << std::endl;
577  return false;
578  }
579 
580  //Get the name of the dataset
581  file >> word;
582  if(word != "DatasetName:"){
583  errorLog << "loadDatasetFromFile(string filename) - failed to find DatasetName!" << std::endl;
584  file.close();
585  return false;
586  }
587  file >> datasetName;
588 
589  file >> word;
590  if(word != "InfoText:"){
591  errorLog << "loadDatasetFromFile(string filename) - failed to find InfoText!" << std::endl;
592  file.close();
593  return false;
594  }
595 
596  //Load the info text
597  file >> word;
598  infoText = "";
599  while( word != "NumDimensions:" ){
600  infoText += word + " ";
601  file >> word;
602  }
603 
604  //Get the number of dimensions in the training data
605  if(word != "NumDimensions:"){
606  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find NumDimensions!" << std::endl;
607  file.close();
608  return false;
609  }
610  file >> numDimensions;
611 
612  //Get the total number of training examples in the training data
613  file >> word;
614  if(word != "TotalNumSamples:"){
615  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find TotalNumSamples!" << std::endl;
616  file.close();
617  return false;
618  }
619  file >> totalNumSamples;
620 
621  //Get the total number of classes in the training data
622  file >> word;
623  if(word != "NumberOfClasses:"){
624  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find NumberOfClasses!" << std::endl;
625  file.close();
626  return false;
627  }
628  file >> numClasses;
629 
630  //Resize the class counter buffer and load the counters
631  classTracker.resize(numClasses);
632 
633  //Get the total number of classes in the training data
634  file >> word;
635  if(word != "ClassIDsAndCounters:"){
636  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find ClassIDsAndCounters!" << std::endl;
637  file.close();
638  return false;
639  }
640 
641  for(UINT i=0; i<classTracker.size(); i++){
642  file >> classTracker[i].classLabel;
643  file >> classTracker[i].counter;
644  }
645 
646  //Get the NumberOfPositionTrackers
647  file >> word;
648  if(word != "NumberOfPositionTrackers:"){
649  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find NumberOfPositionTrackers!" << std::endl;
650  file.close();
651  return false;
652  }
653  file >> numTrackingPoints;
654  timeSeriesPositionTracker.resize( numTrackingPoints );
655 
656  //Get the TimeSeriesPositionTrackers
657  file >> word;
658  if(word != "TimeSeriesPositionTrackers:"){
659  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find TimeSeriesPositionTrackers!" << std::endl;
660  file.close();
661  return false;
662  }
663 
664  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
665  UINT classLabel;
666  UINT startIndex;
667  UINT endIndex;
668  file >> classLabel;
669  file >> startIndex;
670  file >> endIndex;
671  timeSeriesPositionTracker[i].setTracker(startIndex,endIndex,classLabel);
672  }
673 
674  //Check if the dataset should be scaled using external ranges
675  file >> word;
676  if(word != "UseExternalRanges:"){
677  errorLog << "loadDatasetFromFile(string filename) - failed to find DatasetName!" << std::endl;
678  file.close();
679  return false;
680  }
681  file >> useExternalRanges;
682 
683  //If we are using external ranges then load them
684  if( useExternalRanges ){
685  externalRanges.resize(numDimensions);
686  for(UINT i=0; i<externalRanges.size(); i++){
687  file >> externalRanges[i].minValue;
688  file >> externalRanges[i].maxValue;
689  }
690  }
691 
692  //Get the main time series data
693  file >> word;
694  if(word != "LabelledContinuousTimeSeriesClassificationData:"){
695  errorLog<< "loadDatasetFromFile(string fileName) - Failed to find LabelledContinuousTimeSeriesClassificationData!" << std::endl;
696  file.close();
697  return false;
698  }
699 
700  //Reset the memory
701  data.resize( totalNumSamples, ClassificationSample() );
702 
703  //Load each sample
704  UINT classLabel = 0;
705  VectorFloat sample(numDimensions);
706  for(UINT i=0; i<totalNumSamples; i++){
707 
708  file >> classLabel;
709  for(UINT j=0; j<numDimensions; j++){
710  file >> sample[j];
711  }
712 
713  data[i].set(classLabel,sample);
714  }
715 
716  file.close();
717  return true;
718 }
719 
720 bool ClassificationDataStream::saveDatasetToCSVFile(const std::string &filename) {
721  std::fstream file;
722  file.open(filename.c_str(), std::ios::out );
723 
724  if( !file.is_open() ){
725  return false;
726  }
727 
728  //Write the data to the CSV file
729 
730  for(UINT i=0; i<data.size(); i++){
731  file << data[i].getClassLabel();
732  for(UINT j=0; j<numDimensions; j++){
733  file << "," << data[i][j];
734  }
735  file << std::endl;
736  }
737 
738  file.close();
739 
740  return true;
741 }
742 
743 bool ClassificationDataStream::loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex){
744 
745  datasetName = "NOT_SET";
746  infoText = "";
747 
748  //Clear any previous data
749  clear();
750 
751  //Parse the CSV file
752  FileParser parser;
753 
754  if( !parser.parseCSVFile(filename,true) ){
755  errorLog << "loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - Failed to parse CSV file!" << std::endl;
756  return false;
757  }
758 
759  if( !parser.getConsistentColumnSize() ){
760  errorLog << "loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - The CSV file does not have a consistent number of columns!" << std::endl;
761  return false;
762  }
763 
764  if( parser.getColumnSize() <= 1 ){
765  errorLog << "loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - The CSV file does not have enough columns! It should contain at least two columns!" << std::endl;
766  return false;
767  }
768 
769  //Set the number of dimensions
770  numDimensions = parser.getColumnSize()-1;
771  UINT classLabel = 0;
772  UINT j = 0;
773  UINT n = 0;
774  VectorFloat sample(numDimensions);
775  for(UINT i=0; i<parser.getRowSize(); i++){
776  //Get the class label
777  classLabel = Util::stringToInt( parser[i][classLabelColumnIndex] );
778 
779  //Get the sample data
780  j=0;
781  n=0;
782  while( j != numDimensions ){
783  if( n != classLabelColumnIndex ){
784  sample[j++] = Util::stringToFloat( parser[i][n] );
785  }
786  n++;
787  }
788 
789  //Add the labelled sample to the dataset
790  if( !addSample(classLabel, sample) ){
791  warningLog << "loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - Could not add sample " << i << " to the dataset!" << std::endl;
792  }
793  }
794 
795  return true;
796 }
797 
799 
800  std::cout << "DatasetName:\t" << datasetName << std::endl;
801  std::cout << "DatasetInfo:\t" << infoText << std::endl;
802  std::cout << "Number of Dimensions:\t" << numDimensions << std::endl;
803  std::cout << "Number of Samples:\t" << totalNumSamples << std::endl;
804  std::cout << "Number of Classes:\t" << getNumClasses() << std::endl;
805  std::cout << "ClassStats:\n";
806 
807  for(UINT k=0; k<getNumClasses(); k++){
808  std::cout << "ClassLabel:\t" << classTracker[k].classLabel;
809  std::cout << "\tNumber of Samples:\t" << classTracker[k].counter;
810  std::cout << "\tClassName:\t" << classTracker[k].className << std::endl;
811  }
812 
813  std::cout << "TimeSeriesMarkerStats:\n";
814  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
815  std::cout << "ClassLabel: " << timeSeriesPositionTracker[i].getClassLabel();
816  std::cout << "\tStartIndex: " << timeSeriesPositionTracker[i].getStartIndex();
817  std::cout << "\tEndIndex: " << timeSeriesPositionTracker[i].getEndIndex();
818  std::cout << "\tLength: " << timeSeriesPositionTracker[i].getLength() << std::endl;
819  }
820 
821  Vector< MinMax > ranges = getRanges();
822 
823  std::cout << "Dataset Ranges:\n";
824  for(UINT j=0; j<ranges.size(); j++){
825  std::cout << "[" << j+1 << "] Min:\t" << ranges[j].minValue << "\tMax: " << ranges[j].maxValue << std::endl;
826  }
827 
828  return true;
829 }
830 
831 ClassificationDataStream ClassificationDataStream::getSubset(const UINT startIndex,const UINT endIndex) const {
832 
834 
835  if( endIndex >= totalNumSamples ){
836  warningLog << "getSubset(const UINT startIndex,const UINT endIndex) - The endIndex is greater than or equal to the number of samples in the current dataset!" << std::endl;
837  return subset;
838  }
839 
840  if( startIndex >= endIndex ){
841  warningLog << "getSubset(const UINT startIndex,const UINT endIndex) - The startIndex is greater than or equal to the endIndex!" << std::endl;
842  return subset;
843  }
844 
845  //Set the header info
847  subset.setDatasetName( getDatasetName() );
848  subset.setInfoText( getInfoText() );
849 
850  //Add the data
851  for(UINT i=startIndex; i<=endIndex; i++){
852  subset.addSample(data[i].getClassLabel(), data[i].getSample());
853  }
854 
855  return subset;
856 }
857 
859 
861 
863  tsData.setAllowNullGestureClass( includeNullGestures );
864 
865  bool addSample = false;
866  const UINT numTimeseries = (UINT)timeSeriesPositionTracker.size();
867  for(UINT i=0; i<numTimeseries; i++){
868  addSample = includeNullGestures ? true : timeSeriesPositionTracker[i].getClassLabel() != GRT_DEFAULT_NULL_CLASS_LABEL;
869  if( addSample ){
870  tsData.addSample(timeSeriesPositionTracker[i].getClassLabel(), getTimeSeriesData( timeSeriesPositionTracker[i] ) );
871  }
872  }
873 
874  return tsData;
875 }
876 
878 
879  ClassificationData classificationData;
880 
881  classificationData.setNumDimensions( getNumDimensions() );
882  classificationData.setAllowNullGestureClass( includeNullGestures );
883 
884  bool addSample = false;
885  for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
886  addSample = includeNullGestures ? true : timeSeriesPositionTracker[i].getClassLabel() != GRT_DEFAULT_NULL_CLASS_LABEL;
887  if( addSample ){
888  MatrixFloat dataSegment = getTimeSeriesData( timeSeriesPositionTracker[i] );
889  for(UINT j=0; j<dataSegment.getNumRows(); j++){
890  classificationData.addSample(timeSeriesPositionTracker[i].getClassLabel(), dataSegment.getRow(j) );
891  }
892  }
893  }
894 
895  return classificationData;
896 }
897 
899 
900  if( trackerInfo.getStartIndex() >= totalNumSamples || trackerInfo.getEndIndex() > totalNumSamples ){
901  warningLog << "getTimeSeriesData(TimeSeriesPositionTracker trackerInfo) - Invalid tracker indexs!" << std::endl;
902  return MatrixFloat();
903  }
904 
905  UINT startIndex = trackerInfo.getStartIndex();
906  UINT endIndex = trackerInfo.getEndIndex();
907  UINT M = endIndex > 0 ? trackerInfo.getLength() : totalNumSamples - startIndex;
908  UINT N = getNumDimensions();
909 
910  MatrixFloat tsData(M,N);
911  for(UINT i=0; i<M; i++){
912  for(UINT j=0; j<N; j++){
913  tsData[i][j] = data[ i+startIndex ][j];
914  }
915  }
916  return tsData;
917 }
918 
920  UINT M = getNumSamples();
921  UINT N = getNumDimensions();
922  MatrixFloat matrixData(M,N);
923  for(UINT i=0; i<M; i++){
924  for(UINT j=0; j<N; j++){
925  matrixData[i][j] = data[i][j];
926  }
927  }
928  return matrixData;
929 }
930 
932  const UINT K = (UINT)classTracker.size();
933  Vector< UINT > classLabels( K );
934 
935  for(UINT i=0; i<K; i++){
936  classLabels[i] = classTracker[i].classLabel;
937  }
938 
939  return classLabels;
940 }
941 
942 GRT_END_NAMESPACE
943 
bool loadDatasetFromCSVFile(const std::string &filename, const UINT classLabelColumnIndex=0)
TimeSeriesClassificationData getTimeSeriesClassificationData(const bool includeNullGestures=false) const
UINT eraseAllSamplesWithClassLabel(const UINT classLabel)
bool enableExternalRangeScaling(const bool useExternalRanges)
bool addSample(UINT classLabel, const VectorFloat &sample)
bool save(const std::string &filename)
static Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)
Definition: Util.cpp:52
bool setDatasetName(const std::string datasetName)
static Float stringToFloat(const std::string &s)
Definition: Util.cpp:132
ClassificationDataStream(const UINT numDimensions=0, const std::string datasetName="NOT_SET", const std::string infoText="")
bool setNumDimensions(const UINT numDimensions)
ClassificationDataStream & operator=(const ClassificationDataStream &rhs)
bool resetPlaybackIndex(const UINT playbackIndex)
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
bool setNumDimensions(UINT numDimensions)
bool setExternalRanges(const Vector< MinMax > &externalRanges, const bool useExternalRanges=false)
std::string getClassNameForCorrespondingClassLabel(const UINT classLabel)
bool saveDatasetToCSVFile(const std::string &filename)
bool scale(const Float minTarget, const Float maxTarget)
bool load(const std::string &filename)
DebugLog debugLog
Default debugging log.
bool setAllowNullGestureClass(bool allowNullGestureClass)
ClassificationData getClassificationData(const bool includeNullGestures=false) const
Vector< UINT > getClassLabels() const
ErrorLog errorLog
Default error log.
bool useExternalRanges
A flag to show if the dataset should be scaled using the externalRanges values.
ClassificationDataStream getSubset(const UINT startIndex, const UINT endIndex) const
bool loadDatasetFromFile(const std::string &filename)
Vector< MinMax > externalRanges
A Vector containing a set of externalRanges set by the user.
ClassificationSample getNextSample()
std::string infoText
Some infoText about the dataset.
bool setInfoText(const std::string infoText)
UINT numDimensions
The number of dimensions in the dataset.
std::string datasetName
The name of the dataset.
The ClassificationDataStream is the main data structure for recording, labeling, managing, saving, and loading datasets that can be used to test the continuous classification abilities of the GRT supervised learning algorithms.
bool relabelAllSamplesWithClassLabel(const UINT oldClassLabel, const UINT newClassLabel)
unsigned int getNumRows() const
Definition: Matrix.h:542
MatrixFloat getDataAsMatrixFloat() const
unsigned int getNumCols() const
Definition: Matrix.h:549
bool addSample(const UINT classLabel, const VectorFloat &sample)
bool saveDatasetToFile(const std::string &filename)
bool addSample(const UINT classLabel, const MatrixFloat &trainingSample)
VectorFloat getRow(const unsigned int r) const
Definition: MatrixFloat.h:100
bool setNumDimensions(const UINT numDimensions)
static bool stringEndsWith(const std::string &str, const std::string &ending)
Definition: Util.cpp:156
MatrixFloat getTimeSeriesData(const TimeSeriesPositionTracker &trackerInfo) const
Vector< MinMax > getRanges() const
Definition: Vector.h:41
static int stringToInt(const std::string &s)
Definition: Util.cpp:125
WarningLog warningLog
Default warning log.
TimeSeriesClassificationData getAllTrainingExamplesWithClassLabel(const UINT classLabel) const
bool push_back(const Vector< T > &sample)
Definition: Matrix.h:401
UINT getClassLabelIndexValue(const UINT classLabel) const
bool setClassNameForCorrespondingClassLabel(const std::string className, const UINT classLabel)
bool setAllowNullGestureClass(const bool allowNullGestureClass)
std::string getDatasetName() const