GestureRecognitionToolkit  Version: 0.1.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClassLabelFilter.cpp
1 /*
2  GRT MIT License
3  Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5  Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6  and associated documentation files (the "Software"), to deal in the Software without restriction,
7  including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9  subject to the following conditions:
10 
11  The above copyright notice and this permission notice shall be included in all copies or substantial
12  portions of the Software.
13 
14  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15  LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16  IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17  WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18  SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19  */
20 
21 #include "ClassLabelFilter.h"
22 
23 GRT_BEGIN_NAMESPACE
24 
25 //Register the ClassLabelFilter module with the PostProcessing base class
26 RegisterPostProcessingModule< ClassLabelFilter > ClassLabelFilter::registerModule("ClassLabelFilter");
27 
28 ClassLabelFilter::ClassLabelFilter(UINT minimumCount,UINT bufferSize){
29  classType = "ClassLabelFilter";
30  postProcessingType = classType;
31  postProcessingInputMode = INPUT_MODE_PREDICTED_CLASS_LABEL;
32  postProcessingOutputMode = OUTPUT_MODE_PREDICTED_CLASS_LABEL;
33  debugLog.setProceedingText("[DEBUG ClassLabelFilter]");
34  errorLog.setProceedingText("[ERROR ClassLabelFilter]");
35  warningLog.setProceedingText("[WARNING ClassLabelFilter]");
36  init(minimumCount,bufferSize);
37 }
38 
40 
41  classType = "ClassLabelFilter";
42  postProcessingType = classType;
43  postProcessingInputMode = INPUT_MODE_PREDICTED_CLASS_LABEL;
44  postProcessingOutputMode = OUTPUT_MODE_PREDICTED_CLASS_LABEL;
45  debugLog.setProceedingText("[DEBUG ClassLabelFilter]");
46  errorLog.setProceedingText("[ERROR ClassLabelFilter]");
47  warningLog.setProceedingText("[WARNING ClassLabelFilter]");
48 
49  //Copy the ClassLabelFilter values
51  this->minimumCount = rhs.minimumCount;
52  this->bufferSize = rhs.bufferSize;
53  this->buffer = rhs.buffer;
54 
55  //Clone the post processing base variables
57 }
58 
60 
61 }
62 
64 
65  if( this != &rhs ){
66  //Copy the ClassLabelFilter values
68  this->minimumCount = rhs.minimumCount;
69  this->bufferSize = rhs.bufferSize;
70  this->buffer = rhs.buffer;
71 
72  //Clone the post processing base variables
74  }
75 
76  return *this;
77 }
78 
79 bool ClassLabelFilter::deepCopyFrom(const PostProcessing *postProcessing){
80 
81  if( postProcessing == NULL ) return false;
82 
83  if( this->getPostProcessingType() == postProcessing->getPostProcessingType() ){
84 
85  ClassLabelFilter *ptr = (ClassLabelFilter*)postProcessing;
86 
87  //Clone the ClassLabelFilter values
89  this->minimumCount = ptr->minimumCount;
90  this->bufferSize = ptr->bufferSize;
91  this->buffer = ptr->buffer;
92 
93  //Clone the post processing base variables
94  copyBaseVariables( postProcessing );
95  return true;
96  }
97  return false;
98 }
99 
100 bool ClassLabelFilter::process(const VectorDouble &inputVector){
101 
102  if( !initialized ){
103  errorLog << "process(const VectorDouble &inputVector) - Not initialized!" << std::endl;
104  return false;
105  }
106 
107  if( inputVector.getSize() != numInputDimensions ){
108  errorLog << "process(const VectorDouble &inputVector) - The size of the inputVector (" << inputVector.getSize() << ") does not match that of the filter (" << numInputDimensions << ")!" << std::endl;
109  return false;
110  }
111 
112  //Use only the first value (as that is the predicted class label)
113  processedData[0] = filter( (UINT)inputVector[0] );
114  return true;
115 }
116 
118  filteredClassLabel = 0;
119  processedData.clear();
120  processedData.resize(1,0);
121  buffer.clear();
122  if( bufferSize > 0 ){
123  initialized = buffer.resize(bufferSize,0);
124  }else initialized = false;
125  return true;
126 }
127 
128 bool ClassLabelFilter::init(UINT minimumCount,UINT bufferSize){
129 
130  initialized = false;
131 
132  if( minimumCount < 1 ){
133  errorLog << "init(UINT minimumCount,UINT bufferSize) - MinimumCount must be greater than or equal to 1!" << std::endl;
134  return false;
135  }
136 
137  if( bufferSize < 1 ){
138  errorLog << "init(UINT minimumCount,UINT bufferSize) - BufferSize must be greater than or equal to 1!" << std::endl;
139  return false;
140  }
141 
142  if( bufferSize < minimumCount ){
143  errorLog << "init(UINT minimumCount,UINT bufferSize) - The buffer size must be greater than or equal to the minimum count!" << std::endl;
144  return false;
145  }
146 
147  this->minimumCount = minimumCount;
148  this->bufferSize = bufferSize;
149  numInputDimensions = 1;
150  numOutputDimensions = 1;
151  initialized = reset();
152  return true;
153 }
154 
155 UINT ClassLabelFilter::filter(UINT predictedClassLabel){
156 
157  if( !initialized ){
158  errorLog << "filter(UINT predictedClassLabel) - The filter has not been initialized!" << std::endl;
159  filteredClassLabel = 0;
160  return 0;
161  }
162 
163  //Add the current predictedClassLabel to the buffer
164  buffer.push_back( predictedClassLabel );
165 
166  //Count the class values in the buffer, automatically start with the first value in the buffer
167  Vector< ClassTracker > classTracker( 1, ClassTracker( buffer[0], 1 ) );
168 
169  UINT maxCount = classTracker[0].counter;
170  UINT maxClass = classTracker[0].classLabel;
171  bool classLabelFound = false;
172 
173  for(UINT i=1; i<bufferSize; i++){
174  classLabelFound = false;
175  UINT currentCount = 0;
176  UINT currentClassLabel = buffer[i];
177  for(UINT k=0; k<classTracker.size(); k++){
178  if( currentClassLabel == classTracker[k].classLabel ){
179  classTracker[k].counter++;
180  classLabelFound = true;
181  currentCount = classTracker[k].counter;
182  break;
183  }
184  }
185 
186  //If we have not found the class label then we need to add this class to the classTracker
187  if( !classLabelFound ){
188  classTracker.push_back( ClassTracker(currentClassLabel,1) );
189  currentCount = 1;
190  }
191 
192  //Check to see if we should update the max count and maxClass (ignoring class label 0)
193  if( currentCount > maxCount && currentClassLabel != 0 ){
194  maxCount = currentCount;
195  maxClass = currentClassLabel;
196  }
197  }
198 
199  //printf("minimumCount: %i maxCount: %i maxClass: %i\n",minimumCount,maxCount,maxClass);
200 
201  if( maxCount >= minimumCount ){
202  filteredClassLabel = maxClass;
203  }else filteredClassLabel = 0;
204 
205  return filteredClassLabel;
206 }
207 
208 bool ClassLabelFilter::saveModelToFile( std::string filename ) const{
209 
210  if( !initialized ){
211  errorLog << "saveModelToFile(string filename) - The ClassLabelFilter has not been initialized" << std::endl;
212  return false;
213  }
214 
215  std::fstream file;
216  file.open(filename.c_str(), std::ios::out);
217 
218  if( !saveModelToFile( file ) ){
219  file.close();
220  return false;
221  }
222 
223  file.close();
224 
225  return true;
226 }
227 
228 bool ClassLabelFilter::saveModelToFile( std::fstream &file ) const{
229 
230  if( !file.is_open() ){
231  errorLog << "saveModelToFile(fstream &file) - The file is not open!" << std::endl;
232  return false;
233  }
234 
235  file << "GRT_CLASS_LABEL_FILTER_FILE_V1.0" << std::endl;
236  file << "NumInputDimensions: " << numInputDimensions << std::endl;
237  file << "NumOutputDimensions: " << numOutputDimensions << std::endl;
238  file << "MinimumCount: " << minimumCount << std::endl;
239  file << "BufferSize: " << bufferSize << std::endl;
240 
241  return true;
242 }
243 
244 bool ClassLabelFilter::loadModelFromFile( std::string filename ){
245 
246  std::fstream file;
247  file.open(filename.c_str(), std::ios::in);
248 
249  if( !loadModelFromFile( file ) ){
250  file.close();
251  initialized = false;
252  return false;
253  }
254 
255  file.close();
256 
257  return true;
258 }
259 
260 bool ClassLabelFilter::loadModelFromFile( std::fstream &file ){
261 
262  if( !file.is_open() ){
263  errorLog << "loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
264  return false;
265  }
266 
267  std::string word;
268 
269  //Load the header
270  file >> word;
271 
272  if( word != "GRT_CLASS_LABEL_FILTER_FILE_V1.0" ){
273  errorLog << "loadModelFromFile(fstream &file) - Invalid file format!" << std::endl;
274  return false;
275  }
276 
277  file >> word;
278  if( word != "NumInputDimensions:" ){
279  errorLog << "loadModelFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
280  return false;
281  }
282  file >> numInputDimensions;
283 
284  //Load the number of output dimensions
285  file >> word;
286  if( word != "NumOutputDimensions:" ){
287  errorLog << "loadModelFromFile(fstream &file) - Failed to read NumOutputDimensions header!" << std::endl;
288  return false;
289  }
290  file >> numOutputDimensions;
291 
292  //Load the minimumCount
293  file >> word;
294  if( word != "MinimumCount:" ){
295  errorLog << "loadModelFromFile(fstream &file) - Failed to read MinimumCount header!" << std::endl;
296  return false;
297  }
298  file >> minimumCount;
299 
300  file >> word;
301  if( word != "BufferSize:" ){
302  errorLog << "loadModelFromFile(fstream &file) - Failed to read BufferSize header!" << std::endl;
303  return false;
304  }
305  file >> bufferSize;
306 
307  //Init the classLabelFilter module to ensure everything is initialized correctly
308  return init(minimumCount,bufferSize);
309 }
310 
311 bool ClassLabelFilter::setMinimumCount(UINT minimumCount){
312  this->minimumCount = minimumCount;
313  if( initialized ){
314  return reset();
315  }
316  return true;
317 }
318 
319 bool ClassLabelFilter::setBufferSize(UINT bufferSize){
320  this->bufferSize = bufferSize;
321  if( initialized ){
322  return reset();
323  }
324  return true;
325 }
326 
327 GRT_END_NAMESPACE
bool push_back(const T &value)
std::string getPostProcessingType() const
virtual bool reset()
ClassLabelFilter(UINT minimumCount=1, UINT bufferSize=1)
virtual bool resize(const unsigned int size)
Definition: Vector.h:133
unsigned int getSize() const
Definition: Vector.h:193
virtual bool process(const VectorDouble &inputVector)
virtual bool saveModelToFile(std::string filename) const
bool setMinimumCount(UINT minimumCount)
UINT filter(UINT predictedClassLabel)
bool copyBaseVariables(const PostProcessing *postProcessingModule)
UINT minimumCount
The minimum count sets the minimum number of class label values that must be present in the class lab...
CircularBuffer< UINT > buffer
The class label filter buffer.
virtual bool deepCopyFrom(const PostProcessing *postProcessing)
UINT bufferSize
The size of the Class Label Filter buffer.
bool init()
UINT filteredClassLabel
The most recent filtered class label value.
virtual bool loadModelFromFile(std::string filename)
bool setBufferSize(UINT bufferSize)
The Class Label Filter is a useful post-processing module which can remove erroneous or sporadic pred...
ClassLabelFilter & operator=(const ClassLabelFilter &rhs)
bool resize(const unsigned int newBufferSize)
virtual ~ClassLabelFilter()