GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClassificationData.h
Go to the documentation of this file.
1 
26 #ifndef GRT_CLASSIFICATION_DATA_HEADER
27 #define GRT_CLASSIFICATION_DATA_HEADER
28 
29 #include "VectorFloat.h"
30 #include "../Util/GRTCommon.h"
31 #include "../CoreModules/GRTBase.h"
32 #include "ClassificationSample.h"
33 #include "RegressionData.h"
34 #include "UnlabelledData.h"
35 
36 GRT_BEGIN_NAMESPACE
37 
43 class GRT_API ClassificationData : public GRTBase{
44 public:
45 
54  ClassificationData(UINT numDimensions = 0,std::string datasetName = "NOT_SET",std::string infoText = "");
55 
61 
65  virtual ~ClassificationData();
66 
73  ClassificationData& operator=(const ClassificationData &rhs);
74 
82  inline ClassificationSample& operator[] (const UINT &i){
83  return data[i];
84  }
85 
93  inline const ClassificationSample& operator[] (const UINT &i) const{
94  return data[i];
95  }
96 
100  void clear();
101 
112  bool setNumDimensions(UINT numDimensions);
113 
121  bool setDatasetName(std::string datasetName);
122 
130  bool setInfoText(std::string infoText);
131 
141  bool setClassNameForCorrespondingClassLabel(const std::string className,const UINT classLabel);
142 
152  bool setAllowNullGestureClass(const bool allowNullGestureClass);
153 
163  bool addSample(const UINT classLabel,const VectorFloat &sample);
164 
171  bool removeSample( const UINT index );
172 
178  bool removeLastSample();
179 
188  bool reserve(const UINT M);
189 
198  bool addClass(const UINT classLabel,const std::string className = "NOT_SET");
199 
206  UINT removeClass(const UINT classLabel);
207 
217  UINT eraseAllSamplesWithClassLabel(const UINT classLabel);
218 
226  bool relabelAllSamplesWithClassLabel(const UINT oldClassLabel,const UINT newClassLabel);
227 
236  bool setExternalRanges(const Vector< MinMax > &externalRanges,const bool useExternalRanges = false);
237 
245  bool enableExternalRangeScaling(const bool useExternalRanges);
246 
254  bool scale(const Float minTarget,const Float maxTarget);
255 
264  bool scale(const Vector<MinMax> &ranges,const Float minTarget,const Float maxTarget);
265 
274  bool save(const std::string &filename) const;
275 
284  bool load(const std::string &filename);
285 
292  bool saveDatasetToFile(const std::string &filename) const;
293 
300  bool loadDatasetFromFile(const std::string &filename);
301 
309  bool saveDatasetToCSVFile(const std::string &filename) const;
310 
321  bool loadDatasetFromCSVFile(const std::string &filename,const UINT classLabelColumnIndex = 0);
322 
329  bool printStats() const;
330 
336  bool sortClassLabels();
337 
346  bool merge(const ClassificationData &data);
347 
354  GRT_DEPRECATED_MSG( "partition(...) is deprecated, use split(...) instead", ClassificationData partition(const UINT partitionPercentage,const bool useStratifiedSampling = false) );
355 
364  ClassificationData split(const UINT splitPercentage,const bool useStratifiedSampling = false);
365 
373  bool spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling = false);
374 
382  ClassificationData getTrainingFoldData(const UINT foldIndex) const;
383 
391  ClassificationData getTestFoldData(const UINT foldIndex) const;
392 
400  ClassificationData getClassData(const UINT classLabel) const;
401 
411  ClassificationData getBootstrappedDataset(const UINT numSamples=0, const bool balanceDataset=false ) const;
412 
421  RegressionData reformatAsRegressionData() const;
422 
428  UnlabelledData reformatAsUnlabelledData() const;
429 
435  std::string getDatasetName() const{ return datasetName; }
436 
442  std::string getInfoText() const{ return infoText; }
443 
449  std::string getStatsAsString() const;
450 
456  UINT inline getNumDimensions() const{ return numDimensions; }
457 
463  UINT inline getNumSamples() const{ return totalNumSamples; }
464 
470  UINT inline getNumClasses() const{ return classTracker.getSize(); }
471 
477  UINT getMinimumClassLabel() const;
478 
484  UINT getMaximumClassLabel() const;
485 
492  UINT getClassLabelIndexValue(const UINT classLabel) const;
493 
500  std::string getClassNameForCorrespondingClassLabel(const UINT classLabel) const;
501 
507  Vector<MinMax> getRanges() const;
508 
514  Vector< UINT > getClassLabels() const;
515 
521  Vector< UINT > getNumSamplesPerClass() const;
522 
528  Vector< ClassTracker > getClassTracker() const{ return classTracker; }
529 
537  MatrixFloat getClassHistogramData(const UINT classLabel,const UINT numBins) const;
538 
546  Vector< MatrixFloat > getHistogramData(const UINT numBins) const;
547 
554 
555  VectorFloat getClassProbabilities() const;
556 
557  VectorFloat getClassProbabilities( const Vector< UINT > &classLabels ) const;
558 
564  VectorFloat getMean() const;
565 
571  VectorFloat getStdDev() const;
572 
579  MatrixFloat getClassMean() const;
580 
587  MatrixFloat getClassStdDev() const;
588 
594  MatrixFloat getCovarianceMatrix() const;
595 
602  Vector< UINT > getClassDataIndexes(const UINT classLabel) const;
603 
610  MatrixDouble getDataAsMatrixDouble() const;
611 
618  MatrixFloat getDataAsMatrixFloat() const;
619 
637  static bool generateGaussDataset( const std::string filename, const UINT numSamples = 10000, const UINT numClasses = 10, const UINT numDimensions = 3, const Float range = 10, const Float sigma = 1 );
638 
655  static ClassificationData generateGaussDataset( const UINT numSamples = 10000, const UINT numClasses = 10, const UINT numDimensions = 3, const Float range = 10, const Float sigma = 1 );
656 
674  static ClassificationData generateGaussLinearDataset( const UINT numSamples = 10000, const UINT numClasses = 10, const UINT numDimensions = 3, const Float range = 10, const Float sigma = 1 );
675 
676 private:
677 
678  std::string datasetName;
679  std::string infoText;
680  UINT numDimensions;
681  UINT totalNumSamples;
682  UINT kFoldValue;
683  bool crossValidationSetup;
684  bool useExternalRanges;
685  bool allowNullGestureClass;
686  Vector< MinMax > externalRanges;
687  Vector< ClassTracker > classTracker;
689  Vector< Vector< UINT > > crossValidationIndexs;
690 };
691 
692 GRT_END_NAMESPACE
693 
694 #endif //GRT_CLASSIFICATION_DATA_HEADER
Vector< ClassTracker > getClassTracker() const
Vector< ClassificationSample > getClassificationData() const
std::string getDatasetName() const
GRT_DEPRECATED_MSG("getClassType is deprecated, use getId() instead!", std::string getClassType() const )
The UnlabelledData class is the main data container for supporting unsupervised learning.
UINT getNumSamples() const
UINT getNumDimensions() const
UINT getNumClasses() const
The RegressionData is the main data structure for recording, labeling, managing, saving, and loading datasets that can be used to train and test the GRT supervised regression algorithms.
std::string getInfoText() const
This class stores the class label and raw data for a single labelled classification sample...
Float scale(const Float &x, const Float &minSource, const Float &maxSource, const Float &minTarget, const Float &maxTarget, const bool constrain=false)
Definition: GRTBase.h:184