GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClassificationData.h
Go to the documentation of this file.
1 
31 #ifndef GRT_CLASSIFICATION_DATA_HEADER
32 #define GRT_CLASSIFICATION_DATA_HEADER
33 
34 #include "VectorFloat.h"
35 #include "../Util/GRTCommon.h"
36 #include "../CoreModules/GRTBase.h"
37 #include "ClassificationSample.h"
38 #include "RegressionData.h"
39 #include "UnlabelledData.h"
40 
41 GRT_BEGIN_NAMESPACE
42 
43 class GRT_API ClassificationData : public GRTBase{
44 public:
45 
54  ClassificationData(UINT numDimensions = 0,std::string datasetName = "NOT_SET",std::string infoText = "");
55 
61 
65  virtual ~ClassificationData();
66 
73  ClassificationData& operator=(const ClassificationData &rhs);
74 
82  inline ClassificationSample& operator[] (const UINT &i){
83  return data[i];
84  }
85 
93  inline const ClassificationSample& operator[] (const UINT &i) const{
94  return data[i];
95  }
96 
100  void clear();
101 
112  bool setNumDimensions(UINT numDimensions);
113 
121  bool setDatasetName(std::string datasetName);
122 
130  bool setInfoText(std::string infoText);
131 
139  bool setClassNameForCorrespondingClassLabel(std::string className,UINT classLabel);
140 
149  bool setAllowNullGestureClass(bool allowNullGestureClass);
150 
160  bool addSample(UINT classLabel,const VectorFloat &sample);
161 
167  bool removeSample( const UINT index );
168 
174  bool removeLastSample();
175 
184  bool reserve(const UINT N);
185 
194  bool addClass(const UINT classLabel,const std::string className = "NOT_SET");
195 
202  UINT removeClass(const UINT classLabel);
203 
213  UINT eraseAllSamplesWithClassLabel(const UINT classLabel);
214 
222  bool relabelAllSamplesWithClassLabel(const UINT oldClassLabel,const UINT newClassLabel);
223 
232  bool setExternalRanges(const Vector< MinMax > &externalRanges,const bool useExternalRanges = false);
233 
241  bool enableExternalRangeScaling(const bool useExternalRanges);
242 
248  bool scale(const Float minTarget,const Float maxTarget);
249 
255  bool scale(const Vector<MinMax> &ranges,const Float minTarget,const Float maxTarget);
256 
265  bool save(const std::string &filename) const;
266 
275  bool load(const std::string &filename);
276 
283  bool saveDatasetToFile(const std::string &filename) const;
284 
291  bool loadDatasetFromFile(const std::string &filename);
292 
300  bool saveDatasetToCSVFile(const std::string &filename) const;
301 
312  bool loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex = 0);
313 
320  bool printStats() const;
321 
327  bool sortClassLabels();
328 
337  bool merge(const ClassificationData &data);
338 
345  GRT_DEPRECATED_MSG( "partition(...) is deprecated, use split(...) instead", ClassificationData partition(const UINT partitionPercentage,const bool useStratifiedSampling = false) );
346 
355  ClassificationData split(const UINT splitPercentage,const bool useStratifiedSampling = false);
356 
364  bool spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling = false);
365 
373  ClassificationData getTrainingFoldData(const UINT foldIndex) const;
374 
382  ClassificationData getTestFoldData(const UINT foldIndex) const;
383 
391  ClassificationData getClassData(const UINT classLabel) const;
392 
402  ClassificationData getBootstrappedDataset(UINT numSamples=0, bool balanceDataset=false ) const;
403 
412  RegressionData reformatAsRegressionData() const;
413 
419  UnlabelledData reformatAsUnlabelledData() const;
420 
426  std::string getDatasetName() const{ return datasetName; }
427 
433  std::string getInfoText() const{ return infoText; }
434 
440  std::string getStatsAsString() const;
441 
447  UINT inline getNumDimensions() const{ return numDimensions; }
448 
454  UINT inline getNumSamples() const{ return totalNumSamples; }
455 
461  UINT inline getNumClasses() const{ return classTracker.getSize(); }
462 
468  UINT getMinimumClassLabel() const;
469 
475  UINT getMaximumClassLabel() const;
476 
482  UINT getClassLabelIndexValue(const UINT classLabel) const;
483 
489  std::string getClassNameForCorrespondingClassLabel(const UINT classLabel) const;
490 
496  Vector<MinMax> getRanges() const;
497 
503  Vector< UINT > getClassLabels() const;
504 
510  Vector< UINT > getNumSamplesPerClass() const;
511 
517  Vector< ClassTracker > getClassTracker() const{ return classTracker; }
518 
526  MatrixFloat getClassHistogramData(const UINT classLabel,const UINT numBins) const;
527 
535  Vector< MatrixFloat > getHistogramData(const UINT numBins) const;
536 
543 
544  VectorFloat getClassProbabilities() const;
545 
546  VectorFloat getClassProbabilities( const Vector< UINT > &classLabels ) const;
547 
553  VectorFloat getMean() const;
554 
560  VectorFloat getStdDev() const;
561 
568  MatrixFloat getClassMean() const;
569 
576  MatrixFloat getClassStdDev() const;
577 
583  MatrixFloat getCovarianceMatrix() const;
584 
591  Vector< UINT > getClassDataIndexes(const UINT classLabel) const;
592 
599  MatrixDouble getDataAsMatrixDouble() const;
600 
607  MatrixFloat getDataAsMatrixFloat() const;
608 
626  static bool generateGaussDataset( const std::string filename, const UINT numSamples = 10000, const UINT numClasses = 10, const UINT numDimensions = 3, const Float range = 10, const Float sigma = 1 );
627 
628 private:
629 
630  std::string datasetName;
631  std::string infoText;
632  UINT numDimensions;
633  UINT totalNumSamples;
634  UINT kFoldValue;
635  bool crossValidationSetup;
636  bool useExternalRanges;
637  bool allowNullGestureClass;
638  Vector< MinMax > externalRanges;
639  Vector< ClassTracker > classTracker;
641  Vector< Vector< UINT > > crossValidationIndexs;
642 };
643 
644 GRT_END_NAMESPACE
645 
646 #endif //GRT_CLASSIFICATION_DATA_HEADER
Vector< ClassTracker > getClassTracker() const
Vector< ClassificationSample > getClassificationData() const
std::string getDatasetName() const
The UnlabelledData class is the main data container for supporting unsupervised learning.
UINT getNumSamples() const
UINT getNumDimensions() const
UINT getNumClasses() const
The RegressionData is the main data structure for recording, labeling, managing, saving, and loading datasets that can be used to train and test the GRT supervised regression algorithms.
std::string getInfoText() const
This class stores the class label and raw data for a single labelled classification sample...