21 #define GRT_DLL_EXPORTS 34 trackingClass =
false;
38 warningLog.setKey(
"[WARNING ClassificationDataStream]");
40 if( numDimensions > 0 ){
56 this->totalNumSamples = rhs.totalNumSamples;
57 this->lastClassID = rhs.lastClassID;
58 this->playbackIndex = rhs.playbackIndex;
59 this->trackingClass = rhs.trackingClass;
62 this->data = rhs.data;
63 this->classTracker = rhs.classTracker;
64 this->timeSeriesPositionTracker = rhs.timeSeriesPositionTracker;
76 trackingClass =
false;
79 timeSeriesPositionTracker.clear();
83 if( numDimensions > 0 ){
93 errorLog <<
"setNumDimensions(const UINT numDimensions) - The number of dimensions of the dataset must be greater than zero!" << std::endl;
101 if( datasetName.find(
" ") == std::string::npos ){
106 errorLog <<
"setDatasetName(const std::string datasetName) - The dataset name cannot contain any spaces!" << std::endl;
117 for(UINT i=0; i<classTracker.size(); i++){
118 if( classTracker[i].classLabel == classLabel ){
119 classTracker[i].className = className;
124 errorLog <<
"setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel) - Failed to find class with label: " << classLabel << std::endl;
131 errorLog <<
"addSample(const UINT classLabel, VectorFloat sample) - the size of the new sample (" << sample.size() <<
") does not match the number of dimensions of the dataset (" <<
numDimensions <<
")" << std::endl;
135 bool searchForNewClass =
true;
137 if( classLabel != lastClassID ){
139 timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( totalNumSamples-1 );
140 }
else searchForNewClass =
false;
143 if( searchForNewClass ){
144 bool newClass =
true;
146 for(UINT k=0; k<classTracker.size(); k++){
147 if( classTracker[k].classLabel == classLabel ){
149 classTracker[k].counter++;
154 classTracker.push_back( newCounter );
158 trackingClass =
true;
159 lastClassID = classLabel;
161 timeSeriesPositionTracker.push_back( newTracker );
165 data.push_back( labelledSample );
173 errorLog <<
"addSample(const UINT classLabel, const MatrixFloat &sample) - the number of columns in the sample (" << sample.
getNumCols() <<
") does not match the number of dimensions of the dataset (" <<
numDimensions <<
")" << std::endl;
177 bool searchForNewClass =
true;
179 if( classLabel != lastClassID ){
181 timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( totalNumSamples-1 );
182 }
else searchForNewClass =
false;
185 if( searchForNewClass ){
186 bool newClass =
true;
188 for(UINT k=0; k<classTracker.size(); k++){
189 if( classTracker[k].classLabel == classLabel ){
191 classTracker[k].counter += sample.
getNumRows();
196 classTracker.push_back( newCounter );
200 trackingClass =
true;
201 lastClassID = classLabel;
203 timeSeriesPositionTracker.push_back( newTracker );
208 data.push_back( labelledSample );
209 data.back().setClassLabel( classLabel );
211 data.back()[j] = sample[i][j];
221 if( totalNumSamples > 0 ){
224 UINT classLabel = data[ totalNumSamples-1 ].getClassLabel();
227 data.erase( data.end()-1 );
229 totalNumSamples = (UINT)data.size();
232 for(UINT i=0; i<classTracker.size(); i++){
233 if( classTracker[i].classLabel == classLabel ){
234 classTracker[i].counter--;
240 if( !trackingClass ){
241 UINT endIndex = timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].getEndIndex();
242 timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( endIndex-1 );
252 UINT numExamplesRemoved = 0;
253 UINT numExamplesToRemove = 0;
256 for(UINT i=0; i<classTracker.size(); i++){
257 if( classTracker[i].classLabel == classLabel ){
258 numExamplesToRemove = classTracker[i].counter;
259 classTracker.erase(classTracker.begin()+i);
265 if( numExamplesToRemove > 0 ){
267 while( numExamplesRemoved < numExamplesToRemove ){
268 if( data[i].getClassLabel() == classLabel ){
269 data.erase(data.begin()+i);
270 numExamplesRemoved++;
271 }
else if( ++i == data.size() )
break;
278 while( iter != timeSeriesPositionTracker.end() ){
279 if( iter->getClassLabel() == classLabel ){
280 UINT length = iter->getLength();
284 while( updateIter != timeSeriesPositionTracker.end() ){
285 updateIter->setStartIndex( updateIter->getStartIndex() - length );
286 updateIter->setEndIndex( updateIter->getEndIndex() - length );
291 iter = timeSeriesPositionTracker.erase( iter );
295 totalNumSamples = (UINT)data.size();
297 return numExamplesRemoved;
301 bool oldClassLabelFound =
false;
302 bool newClassLabelAllReadyExists =
false;
303 UINT indexOfOldClassLabel = 0;
304 UINT indexOfNewClassLabel = 0;
307 for(UINT i=0; i<classTracker.size(); i++){
308 if( classTracker[i].classLabel == oldClassLabel ){
309 indexOfOldClassLabel = i;
310 oldClassLabelFound =
true;
312 if( classTracker[i].classLabel == newClassLabel ){
313 indexOfNewClassLabel = i;
314 newClassLabelAllReadyExists =
true;
319 if( !oldClassLabelFound ){
324 for(UINT i=0; i<totalNumSamples; i++){
325 if( data[i].getClassLabel() == oldClassLabel ){
326 data[i].set(newClassLabel, data[i].getSample());
331 if( newClassLabelAllReadyExists ){
333 classTracker[ indexOfNewClassLabel ].counter += classTracker[ indexOfOldClassLabel ].counter;
336 classTracker.erase( classTracker.begin() + indexOfOldClassLabel );
339 classTracker.push_back(
ClassTracker(newClassLabel,classTracker[ indexOfOldClassLabel ].counter,classTracker[ indexOfOldClassLabel ].className) );
343 for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
344 if( timeSeriesPositionTracker[i].getClassLabel() == oldClassLabel ){
345 timeSeriesPositionTracker[i].setClassLabel( newClassLabel );
372 return scale(ranges,minTarget,maxTarget);
379 for(UINT i=0; i<totalNumSamples; i++){
381 data[i][j] =
Util::scale(data[i][j],ranges[j].minValue,ranges[j].maxValue,minTarget,maxTarget);
388 if( playbackIndex < totalNumSamples ){
389 this->playbackIndex = playbackIndex;
398 UINT index = playbackIndex++ % totalNumSamples;
399 return data[ index ];
404 for(UINT x=0; x<timeSeriesPositionTracker.size(); x++){
405 if( timeSeriesPositionTracker[x].getClassLabel() == classLabel && timeSeriesPositionTracker[x].getEndIndex() > 0){
407 for(UINT i=timeSeriesPositionTracker[x].getStartIndex(); i<timeSeriesPositionTracker[x].getEndIndex(); i++){
408 timeSeries.
push_back( data[ i ].getSample() );
410 classData.
addSample(classLabel,timeSeries);
417 UINT minClassLabel = 99999;
419 for(UINT i=0; i<classTracker.size(); i++){
420 if( classTracker[i].classLabel < minClassLabel ){
421 minClassLabel = classTracker[i].classLabel;
425 return minClassLabel;
430 UINT maxClassLabel = 0;
432 for(UINT i=0; i<classTracker.size(); i++){
433 if( classTracker[i].classLabel > maxClassLabel ){
434 maxClassLabel = classTracker[i].classLabel;
438 return maxClassLabel;
442 for(UINT k=0; k<classTracker.size(); k++){
443 if( classTracker[k].classLabel == classLabel ){
447 warningLog <<
"getClassLabelIndexValue(const UINT classLabel) - Failed to find class label: " << classLabel <<
" in class tracker!" << std::endl;
453 for(UINT i=0; i<classTracker.size(); i++){
454 if( classTracker[i].classLabel == classLabel ){
455 return classTracker[i].className;
458 return "CLASS_LABEL_NOT_FOUND";
469 if( totalNumSamples > 0 ){
471 ranges[j].minValue = data[0][0];
472 ranges[j].maxValue = data[0][0];
473 for(UINT i=0; i<totalNumSamples; i++){
474 if( data[i][j] < ranges[j].minValue ){ ranges[j].minValue = data[i][j]; }
475 else if( data[i][j] > ranges[j].maxValue ){ ranges[j].maxValue = data[i][j]; }
507 file.open(filename.c_str(), std::ios::out);
509 if( !file.is_open() ){
510 errorLog <<
"saveDatasetToFile(const std::string &filename) - Failed to open file!" << std::endl;
516 trackingClass =
false;
517 timeSeriesPositionTracker[ timeSeriesPositionTracker.size()-1 ].setEndIndex( totalNumSamples-1 );
520 file <<
"GRT_LABELLED_CONTINUOUS_TIME_SERIES_CLASSIFICATION_FILE_V1.0\n";
521 file <<
"DatasetName: " <<
datasetName << std::endl;
522 file <<
"InfoText: " <<
infoText << std::endl;
524 file <<
"TotalNumSamples: " << totalNumSamples << std::endl;
525 file <<
"NumberOfClasses: " << classTracker.size() << std::endl;
526 file <<
"ClassIDsAndCounters: " << std::endl;
527 for(UINT i=0; i<classTracker.size(); i++){
528 file << classTracker[i].classLabel <<
"\t" << classTracker[i].counter << std::endl;
531 file <<
"NumberOfPositionTrackers: " << timeSeriesPositionTracker.size() << std::endl;
532 file <<
"TimeSeriesPositionTrackers: " << std::endl;
533 for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
534 file << timeSeriesPositionTracker[i].getClassLabel() <<
"\t" << timeSeriesPositionTracker[i].getStartIndex() <<
"\t" << timeSeriesPositionTracker[i].getEndIndex() << std::endl;
545 file <<
"LabelledContinuousTimeSeriesClassificationData:\n";
546 for(UINT i=0; i<totalNumSamples; i++){
547 file << data[i].getClassLabel();
549 file <<
"\t" << data[i][j];
561 file.open(filename.c_str(), std::ios::in);
563 UINT numTrackingPoints = 0;
566 if( !file.is_open() ){
567 errorLog<<
"loadDatasetFromFile(string fileName) - Failed to open file!" << std::endl;
575 if(word !=
"GRT_LABELLED_CONTINUOUS_TIME_SERIES_CLASSIFICATION_FILE_V1.0"){
577 errorLog<<
"loadDatasetFromFile(string fileName) - Failed to find file header!" << std::endl;
583 if(word !=
"DatasetName:"){
584 errorLog <<
"loadDatasetFromFile(string filename) - failed to find DatasetName!" << std::endl;
591 if(word !=
"InfoText:"){
592 errorLog <<
"loadDatasetFromFile(string filename) - failed to find InfoText!" << std::endl;
600 while( word !=
"NumDimensions:" ){
606 if(word !=
"NumDimensions:"){
607 errorLog<<
"loadDatasetFromFile(string fileName) - Failed to find NumDimensions!" << std::endl;
615 if(word !=
"TotalNumSamples:"){
616 errorLog<<
"loadDatasetFromFile(string fileName) - Failed to find TotalNumSamples!" << std::endl;
620 file >> totalNumSamples;
624 if(word !=
"NumberOfClasses:"){
625 errorLog<<
"loadDatasetFromFile(string fileName) - Failed to find NumberOfClasses!" << std::endl;
632 classTracker.
resize(numClasses);
636 if(word !=
"ClassIDsAndCounters:"){
637 errorLog<<
"loadDatasetFromFile(string fileName) - Failed to find ClassIDsAndCounters!" << std::endl;
642 for(UINT i=0; i<classTracker.size(); i++){
643 file >> classTracker[i].classLabel;
644 file >> classTracker[i].counter;
649 if(word !=
"NumberOfPositionTrackers:"){
650 errorLog<<
"loadDatasetFromFile(string fileName) - Failed to find NumberOfPositionTrackers!" << std::endl;
654 file >> numTrackingPoints;
655 timeSeriesPositionTracker.
resize( numTrackingPoints );
659 if(word !=
"TimeSeriesPositionTrackers:"){
660 errorLog<<
"loadDatasetFromFile(string fileName) - Failed to find TimeSeriesPositionTrackers!" << std::endl;
665 for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
672 timeSeriesPositionTracker[i].setTracker(startIndex,endIndex,classLabel);
677 if(word !=
"UseExternalRanges:"){
678 errorLog <<
"loadDatasetFromFile(string filename) - failed to find DatasetName!" << std::endl;
685 if( useExternalRanges ){
695 if(word !=
"LabelledContinuousTimeSeriesClassificationData:"){
696 errorLog<<
"loadDatasetFromFile(string fileName) - Failed to find LabelledContinuousTimeSeriesClassificationData!" << std::endl;
707 for(UINT i=0; i<totalNumSamples; i++){
714 data[i].set(classLabel,sample);
723 file.open(filename.c_str(), std::ios::out );
725 if( !file.is_open() ){
731 for(UINT i=0; i<data.size(); i++){
732 file << data[i].getClassLabel();
734 file <<
"," << data[i][j];
755 if( !parser.parseCSVFile(filename,
true) ){
756 errorLog <<
"loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - Failed to parse CSV file!" << std::endl;
760 if( !parser.getConsistentColumnSize() ){
761 errorLog <<
"loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - The CSV file does not have a consistent number of columns!" << std::endl;
765 if( parser.getColumnSize() <= 1 ){
766 errorLog <<
"loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - The CSV file does not have enough columns! It should contain at least two columns!" << std::endl;
776 for(UINT i=0; i<parser.getRowSize(); i++){
784 if( n != classLabelColumnIndex ){
792 warningLog <<
"loadDatasetFromCSVFile(const std::string filename,const UINT classLabelColumnIndex) - Could not add sample " << i <<
" to the dataset!" << std::endl;
801 std::cout <<
"DatasetName:\t" <<
datasetName << std::endl;
802 std::cout <<
"DatasetInfo:\t" <<
infoText << std::endl;
803 std::cout <<
"Number of Dimensions:\t" <<
numDimensions << std::endl;
804 std::cout <<
"Number of Samples:\t" << totalNumSamples << std::endl;
805 std::cout <<
"Number of Classes:\t" <<
getNumClasses() << std::endl;
806 std::cout <<
"ClassStats:\n";
809 std::cout <<
"ClassLabel:\t" << classTracker[k].classLabel;
810 std::cout <<
"\tNumber of Samples:\t" << classTracker[k].counter;
811 std::cout <<
"\tClassName:\t" << classTracker[k].className << std::endl;
814 std::cout <<
"TimeSeriesMarkerStats:\n";
815 for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
816 std::cout <<
"ClassLabel: " << timeSeriesPositionTracker[i].getClassLabel();
817 std::cout <<
"\tStartIndex: " << timeSeriesPositionTracker[i].getStartIndex();
818 std::cout <<
"\tEndIndex: " << timeSeriesPositionTracker[i].getEndIndex();
819 std::cout <<
"\tLength: " << timeSeriesPositionTracker[i].getLength() << std::endl;
824 std::cout <<
"Dataset Ranges:\n";
825 for(UINT j=0; j<ranges.size(); j++){
826 std::cout <<
"[" << j+1 <<
"] Min:\t" << ranges[j].minValue <<
"\tMax: " << ranges[j].maxValue << std::endl;
836 if( endIndex >= totalNumSamples ){
837 warningLog <<
"getSubset(const UINT startIndex,const UINT endIndex) - The endIndex is greater than or equal to the number of samples in the current dataset!" << std::endl;
841 if( startIndex >= endIndex ){
842 warningLog <<
"getSubset(const UINT startIndex,const UINT endIndex) - The startIndex is greater than or equal to the endIndex!" << std::endl;
852 for(UINT i=startIndex; i<=endIndex; i++){
853 subset.
addSample(data[i].getClassLabel(), data[i].getSample());
867 const UINT numTimeseries = (UINT)timeSeriesPositionTracker.size();
868 for(UINT i=0; i<numTimeseries; i++){
869 addSample = includeNullGestures ?
true : timeSeriesPositionTracker[i].getClassLabel() != GRT_DEFAULT_NULL_CLASS_LABEL;
886 for(UINT i=0; i<timeSeriesPositionTracker.size(); i++){
887 addSample = includeNullGestures ?
true : timeSeriesPositionTracker[i].getClassLabel() != GRT_DEFAULT_NULL_CLASS_LABEL;
890 for(UINT j=0; j<dataSegment.
getNumRows(); j++){
891 classificationData.
addSample(timeSeriesPositionTracker[i].getClassLabel(), dataSegment.
getRow(j) );
896 return classificationData;
902 warningLog <<
"getTimeSeriesData(TimeSeriesPositionTracker trackerInfo) - Invalid tracker indexs!" << std::endl;
908 UINT M = endIndex > 0 ? trackerInfo.
getLength() : totalNumSamples - startIndex;
912 for(UINT i=0; i<M; i++){
913 for(UINT j=0; j<N; j++){
914 tsData[i][j] = data[ i+startIndex ][j];
924 for(UINT i=0; i<M; i++){
925 for(UINT j=0; j<N; j++){
926 matrixData[i][j] = data[i][j];
933 const UINT K = (UINT)classTracker.size();
936 for(UINT i=0; i<K; i++){
937 classLabels[i] = classTracker[i].classLabel;
bool loadDatasetFromCSVFile(const std::string &filename, const UINT classLabelColumnIndex=0)
TimeSeriesClassificationData getTimeSeriesClassificationData(const bool includeNullGestures=false) const
UINT eraseAllSamplesWithClassLabel(const UINT classLabel)
UINT getNumSamples() const
bool setAllowNullGestureClass(const bool allowNullGestureClass)
bool enableExternalRangeScaling(const bool useExternalRanges)
bool save(const std::string &filename)
bool addSample(const UINT classLabel, const VectorFloat &sample)
static Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)
bool setDatasetName(const std::string datasetName)
static Float stringToFloat(const std::string &s)
ClassificationDataStream(const UINT numDimensions=0, const std::string datasetName="NOT_SET", const std::string infoText="")
bool setNumDimensions(const UINT numDimensions)
ClassificationDataStream & operator=(const ClassificationDataStream &rhs)
bool resetPlaybackIndex(const UINT playbackIndex)
UINT getStartIndex() const
virtual bool resize(const unsigned int size)
bool setNumDimensions(UINT numDimensions)
bool setExternalRanges(const Vector< MinMax > &externalRanges, const bool useExternalRanges=false)
virtual bool setKey(const std::string &key)
sets the key that gets written at the start of each message, this will be written in the format 'key ...
std::string getClassNameForCorrespondingClassLabel(const UINT classLabel)
bool saveDatasetToCSVFile(const std::string &filename)
bool scale(const Float minTarget, const Float maxTarget)
virtual ~ClassificationDataStream()
bool load(const std::string &filename)
DebugLog debugLog
Default debugging log.
ClassificationData getClassificationData(const bool includeNullGestures=false) const
Vector< UINT > getClassLabels() const
ErrorLog errorLog
Default error log.
bool useExternalRanges
A flag to show if the dataset should be scaled using the externalRanges values.
ClassificationDataStream getSubset(const UINT startIndex, const UINT endIndex) const
UINT getNumDimensions() const
bool loadDatasetFromFile(const std::string &filename)
Vector< MinMax > externalRanges
A Vector containing a set of externalRanges set by the user.
ClassificationSample getNextSample()
std::string infoText
Some infoText about the dataset.
UINT getMaximumClassLabel() const
bool setInfoText(const std::string infoText)
UINT numDimensions
The number of dimensions in the dataset.
std::string datasetName
The name of the dataset.
The ClassificationDataStream is the main data structure for recording, labeling, managing, saving, and loading datasets that can be used to test the continuous classification abilities of the GRT supervised learning algorithms.
bool relabelAllSamplesWithClassLabel(const UINT oldClassLabel, const UINT newClassLabel)
UINT getMinimumClassLabel() const
unsigned int getNumRows() const
MatrixFloat getDataAsMatrixFloat() const
unsigned int getNumCols() const
bool addSample(const UINT classLabel, const VectorFloat &sample)
bool saveDatasetToFile(const std::string &filename)
bool addSample(const UINT classLabel, const MatrixFloat &trainingSample)
VectorFloat getRow(const unsigned int r) const
UINT getNumClasses() const
bool setNumDimensions(const UINT numDimensions)
std::string getInfoText() const
static bool stringEndsWith(const std::string &str, const std::string &ending)
MatrixFloat getTimeSeriesData(const TimeSeriesPositionTracker &trackerInfo) const
Vector< MinMax > getRanges() const
This class stores the class label and raw data for a single labelled classification sample...
static int stringToInt(const std::string &s)
WarningLog warningLog
Default warning log.
TimeSeriesClassificationData getAllTrainingExamplesWithClassLabel(const UINT classLabel) const
bool push_back(const Vector< T > &sample)
UINT getClassLabelIndexValue(const UINT classLabel) const
bool setClassNameForCorrespondingClassLabel(const std::string className, const UINT classLabel)
bool setAllowNullGestureClass(const bool allowNullGestureClass)
std::string getDatasetName() const