GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClassLabelFilter.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #define GRT_DLL_EXPORTS
22 #include "ClassLabelFilter.h"
23 
24 GRT_BEGIN_NAMESPACE
25 
26 //Define the string that will be used to identify the object
27 const std::string ClassLabelFilter::id = "ClassLabelFilter";
28 std::string ClassLabelFilter::getId() { return ClassLabelFilter::id; }
29 
30 //Register the ClassLabelFilter module with the PostProcessing base class
32 
34 {
35  postProcessingInputMode = INPUT_MODE_PREDICTED_CLASS_LABEL;
36  postProcessingOutputMode = OUTPUT_MODE_PREDICTED_CLASS_LABEL;
37  init(minimumCount,bufferSize);
38 }
39 
41 {
42 
43  postProcessingInputMode = INPUT_MODE_PREDICTED_CLASS_LABEL;
44  postProcessingOutputMode = OUTPUT_MODE_PREDICTED_CLASS_LABEL;
45 
46  //Copy the ClassLabelFilter values
48  this->minimumCount = rhs.minimumCount;
49  this->bufferSize = rhs.bufferSize;
50  this->buffer = rhs.buffer;
51 
52  //Clone the post processing base variables
54 }
55 
57 
58 }
59 
61 
62  if( this != &rhs ){
63  //Copy the ClassLabelFilter values
65  this->minimumCount = rhs.minimumCount;
66  this->bufferSize = rhs.bufferSize;
67  this->buffer = rhs.buffer;
68 
69  //Clone the post processing base variables
71  }
72 
73  return *this;
74 }
75 
76 bool ClassLabelFilter::deepCopyFrom(const PostProcessing *postProcessing){
77 
78  if( postProcessing == NULL ) return false;
79 
80  if( this->getId() == postProcessing->getId() ){
81 
82  const ClassLabelFilter *ptr = dynamic_cast<const ClassLabelFilter*>(postProcessing);
83 
84  //Clone the ClassLabelFilter values
86  this->minimumCount = ptr->minimumCount;
87  this->bufferSize = ptr->bufferSize;
88  this->buffer = ptr->buffer;
89 
90  //Clone the post processing base variables
91  copyBaseVariables( postProcessing );
92  return true;
93  }
94  return false;
95 }
96 
97 bool ClassLabelFilter::process(const VectorDouble &inputVector){
98 
99  if( !initialized ){
100  errorLog << "process(const VectorDouble &inputVector) - Not initialized!" << std::endl;
101  return false;
102  }
103 
104  if( inputVector.getSize() != numInputDimensions ){
105  errorLog << "process(const VectorDouble &inputVector) - The size of the inputVector (" << inputVector.getSize() << ") does not match that of the filter (" << numInputDimensions << ")!" << std::endl;
106  return false;
107  }
108 
109  //Use only the first value (as that is the predicted class label)
110  processedData[0] = filter( (UINT)inputVector[0] );
111  return true;
112 }
113 
115  filteredClassLabel = 0;
116  processedData.clear();
117  processedData.resize(1,0);
118  buffer.clear();
119  if( bufferSize > 0 ){
120  initialized = buffer.resize(bufferSize,0);
121  }else initialized = false;
122  return true;
123 }
124 
125 bool ClassLabelFilter::init(const UINT minimumCount,const UINT bufferSize){
126 
127  initialized = false;
128 
129  if( minimumCount < 1 ){
130  errorLog << "init(UINT minimumCount,UINT bufferSize) - MinimumCount must be greater than or equal to 1!" << std::endl;
131  return false;
132  }
133 
134  if( bufferSize < 1 ){
135  errorLog << "init(UINT minimumCount,UINT bufferSize) - BufferSize must be greater than or equal to 1!" << std::endl;
136  return false;
137  }
138 
139  if( bufferSize < minimumCount ){
140  errorLog << "init(UINT minimumCount,UINT bufferSize) - The buffer size must be greater than or equal to the minimum count!" << std::endl;
141  return false;
142  }
143 
144  this->minimumCount = minimumCount;
145  this->bufferSize = bufferSize;
146  numInputDimensions = 1;
147  numOutputDimensions = 1;
148  initialized = reset();
149  return true;
150 }
151 
152 UINT ClassLabelFilter::filter(const UINT predictedClassLabel){
153 
154  if( !initialized ){
155  errorLog << "filter(UINT predictedClassLabel) - The filter has not been initialized!" << std::endl;
156  filteredClassLabel = 0;
157  return 0;
158  }
159 
160  //Add the current predictedClassLabel to the buffer
161  buffer.push_back( predictedClassLabel );
162 
163  //Count the class values in the buffer, automatically start with the first value in the buffer
164  Vector< ClassTracker > classTracker( 1, ClassTracker( buffer[0], 1 ) );
165 
166  UINT maxCount = classTracker[0].counter;
167  UINT maxClass = classTracker[0].classLabel;
168  bool classLabelFound = false;
169 
170  for(UINT i=1; i<bufferSize; i++){
171  classLabelFound = false;
172  UINT currentCount = 0;
173  UINT currentClassLabel = buffer[i];
174  for(UINT k=0; k<classTracker.size(); k++){
175  if( currentClassLabel == classTracker[k].classLabel ){
176  classTracker[k].counter++;
177  classLabelFound = true;
178  currentCount = classTracker[k].counter;
179  break;
180  }
181  }
182 
183  //If we have not found the class label then we need to add this class to the classTracker
184  if( !classLabelFound ){
185  classTracker.push_back( ClassTracker(currentClassLabel,1) );
186  currentCount = 1;
187  }
188 
189  //Check to see if we should update the max count and maxClass (ignoring class label 0)
190  if( currentCount > maxCount && currentClassLabel != 0 ){
191  maxCount = currentCount;
192  maxClass = currentClassLabel;
193  }
194  }
195 
196  //printf("minimumCount: %i maxCount: %i maxClass: %i\n",minimumCount,maxCount,maxClass);
197 
198  if( maxCount >= minimumCount ){
199  filteredClassLabel = maxClass;
200  }else filteredClassLabel = 0;
201 
202  return filteredClassLabel;
203 }
204 
205 bool ClassLabelFilter::save( std::fstream &file ) const{
206 
207  if( !file.is_open() ){
208  errorLog << "save(fstream &file) - The file is not open!" << std::endl;
209  return false;
210  }
211 
212  file << "GRT_CLASS_LABEL_FILTER_FILE_V1.0" << std::endl;
213  file << "NumInputDimensions: " << numInputDimensions << std::endl;
214  file << "NumOutputDimensions: " << numOutputDimensions << std::endl;
215  file << "MinimumCount: " << minimumCount << std::endl;
216  file << "BufferSize: " << bufferSize << std::endl;
217 
218  return true;
219 }
220 
221 bool ClassLabelFilter::load( std::fstream &file ){
222 
223  if( !file.is_open() ){
224  errorLog << "load(fstream &file) - The file is not open!" << std::endl;
225  return false;
226  }
227 
228  std::string word;
229 
230  //Load the header
231  file >> word;
232 
233  if( word != "GRT_CLASS_LABEL_FILTER_FILE_V1.0" ){
234  errorLog << "load(fstream &file) - Invalid file format!" << std::endl;
235  return false;
236  }
237 
238  file >> word;
239  if( word != "NumInputDimensions:" ){
240  errorLog << "load(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
241  return false;
242  }
243  file >> numInputDimensions;
244 
245  //Load the number of output dimensions
246  file >> word;
247  if( word != "NumOutputDimensions:" ){
248  errorLog << "load(fstream &file) - Failed to read NumOutputDimensions header!" << std::endl;
249  return false;
250  }
251  file >> numOutputDimensions;
252 
253  //Load the minimumCount
254  file >> word;
255  if( word != "MinimumCount:" ){
256  errorLog << "load(fstream &file) - Failed to read MinimumCount header!" << std::endl;
257  return false;
258  }
259  file >> minimumCount;
260 
261  file >> word;
262  if( word != "BufferSize:" ){
263  errorLog << "load(fstream &file) - Failed to read BufferSize header!" << std::endl;
264  return false;
265  }
266  file >> bufferSize;
267 
268  //Init the classLabelFilter module to ensure everything is initialized correctly
269  return init(minimumCount,bufferSize);
270 }
271 
273  this->minimumCount = minimumCount;
274  if( initialized ){
275  return reset();
276  }
277  return true;
278 }
279 
281  this->bufferSize = bufferSize;
282  if( initialized ){
283  return reset();
284  }
285  return true;
286 }
287 
288 GRT_END_NAMESPACE
bool push_back(const T &value)
std::string getId() const
Definition: GRTBase.cpp:85
bool setBufferSize(const UINT bufferSize)
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
virtual bool deepCopyFrom(const PostProcessing *postProcessing) override
UINT getSize() const
Definition: Vector.h:201
ClassLabelFilter(const UINT minimumCount=1, const UINT bufferSize=1)
virtual bool save(std::fstream &file) const override
This is the main base class that all GRT PostProcessing algorithms should inherit from...
bool copyBaseVariables(const PostProcessing *postProcessingModule)
bool setMinimumCount(const UINT minimumCount)
UINT minimumCount
The minimum count sets the minimum number of class label values that must be present in the class lab...
virtual bool process(const VectorDouble &inputVector) override
CircularBuffer< UINT > buffer
The class label filter buffer.
UINT bufferSize
The size of the Class Label Filter buffer.
virtual bool load(std::fstream &file) override
bool init()
virtual bool reset() override
UINT filter(const UINT predictedClassLabel)
UINT filteredClassLabel
The most recent filtered class label value.
The Class Label Filter is a useful post-processing module which can remove erroneous or sporadic pred...
ClassLabelFilter & operator=(const ClassLabelFilter &rhs)
bool resize(const unsigned int newBufferSize)
virtual ~ClassLabelFilter()
static std::string getId()