GestureRecognitionToolkit  Version: 0.2.0
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
GridSearch.h
1 
28 #ifndef GRT_GRID_SEARCH_HEADER
29 #define GRT_GRID_SEARCH_HEADER
30 
31 #include "../../CoreModules/MLBase.h"
32 #include "../../CoreModules/GestureRecognitionPipeline.h"
33 
34 #include <functional>
35 
36 GRT_BEGIN_NAMESPACE
37 
38 template< class T >
40 public:
41  GridSearchRange( const T _min = T(), const T _max = T(), const T _inc = T() ):min(_min),max(_max),inc(_inc){ value = min; expired = false; }
42 
43  GridSearchRange( const GridSearchRange &rhs ){
44  this->value = rhs.value;
45  this->min = rhs.min;
46  this->max = rhs.max;
47  this->inc = rhs.inc;
48  this->expired = rhs.expired;
49  }
50 
51  T next(){
52  if( expired ) return value;
53  if( value + inc < max ) value += inc;
54  else{ value = max; }
55  return value;
56  }
57 
58  bool reset(){
59  value = min;
60  expired = false;
61  return true;
62  }
63 
64  bool getExpired() const { return expired; }
65 
66  T get() { if( value >= max ) expired = true; return value; }
67 
68  T value;
69  T min;
70  T max;
71  T inc;
72  bool expired;
73 };
74 
75 template < class T >
77 public:
78 
79  GridSearchParam( std::function< bool(T) > func = nullptr, GridSearchRange<T> range = GridSearchRange<T>() ){
80  this->func = func;
81  this->range = range;
82  }
83 
84  GridSearchParam( const GridSearchParam &rhs ){
85  this->func = rhs.func;
86  this->range = rhs.range;
87  }
88 
89  bool reset(){
90  return range.reset();
91  }
92 
93  bool set(){
94  if( !func ) return false;
95  return func( range.get() );
96  }
97 
98  bool update(){
99  if( !func ) return false;
100  return func( range.next() );
101  }
102 
103  bool getExpired() const { return range.getExpired(); }
104 
105  T get(){ return range.get(); }
106 
107  std::function< bool(T) > func;
108  GridSearchRange<T> range;
109 };
110 
111 template< class T >
112 class GridSearch : public MLBase{
113 public:
114  enum SearchType {MaxValueSearch=0,MinValueSearch};
115  GridSearch() {
116  classType = "GridSearch";
117  infoLog.setProceedingText("[GridSearch]");
118  debugLog.setProceedingText("[DEBUG GridSearch]");
119  errorLog.setProceedingText("[ERROR GridSearch");
120  trainingLog.setProceedingText("[TRAINING GridSearch]");
121  warningLog.setProceedingText("[WARNING GridSearch]");
122  }
123 
124  virtual ~GridSearch(){
125 
126  }
127 
128  bool addParameter( std::function< bool(unsigned int) > f , GridSearchRange< unsigned int > range ){
129  params.push_back( GridSearchParam<unsigned int>( f, range ) );
130  return true;
131  }
132 
133  bool search( ){
134 
135  if( params.getSize() == 0 ){
136  warningLog << "No parameters to search! Add some parameters!" << std::endl;
137  return false;
138  }
139 
140  switch( evalType ){
141  case MaxValueSearch:
142  bestResult = 0;
143  break;
144  case MinValueSearch:
145  bestResult = grt_numeric_limits< Float >::max();
146  break;
147  default:
148  errorLog << "recursive_search( unsigned int paramIndex ) - Unknown eval type!" << std::endl;
149  return false;
150  break;
151  }
152 
153  if( params.getSize() == 0 ) return false;
154  unsigned int paramIndex = 0;
155  return recursive_search( paramIndex );
156  }
157 
158  Float getBestResult() const { return bestResult; }
159 
160  T getBestModel() const { return bestModel; }
161 
162  bool setModel( const T &model ){
163  this->model = model;
164  return true;
165  }
166 
167  bool setEvaluationFunction( std::function< Float () > f, SearchType type = MaxValueSearch ){
168  evalFunc = f;
169  evalType = type;
170  return true;
171  }
172 
173 protected:
174 
175  bool recursive_search( unsigned int paramIndex ){
176 
177  const unsigned int numParams = params.getSize();
178  if( paramIndex >= numParams ){
179  errorLog << "recursive_search( unsigned int paramIndex ) - Param Index out of bounds!" << std::endl;
180  return false;
181  }
182 
183  recursive_reset( paramIndex );
184 
185  bool stopAfterNextIter = false;
186  Float result = 0.0;
187  while( true ){
188 
189  //Make sure the parameter is set
190  params[ paramIndex ].set();
191 
192  if( paramIndex+1 < numParams )
193  recursive_search( paramIndex + 1 );
194 
195  if( paramIndex == numParams-1 ){ //If we are at the final parameter, run the evaluation
196 
197  //Evaluate the function using the current parameters
198  result = evalFunc();
199 
200  switch( evalType ){
201  case MaxValueSearch:
202  if( result > bestResult ){
203  bestResult = result;
204  bestModel = model;
205  }
206  break;
207  case MinValueSearch:
208  if( result < bestResult ){
209  bestResult = result;
210  bestModel = model;
211  }
212  break;
213  default:
214  errorLog << "recursive_search( unsigned int paramIndex ) - Unknown eval type!" << std::endl;
215  return false;
216  break;
217  }
218  }
219 
220  if( stopAfterNextIter ) break;
221 
222  //Update the parameter
223  params[ paramIndex ].update();
224 
225  if( params[ paramIndex ].getExpired() ) stopAfterNextIter = true;
226  }
227 
228  return true;
229  }
230 
231  bool recursive_reset( unsigned int paramIndex ){
232  const unsigned int numParams = params.getSize();
233  if( paramIndex >= numParams ){
234  errorLog << "recursive_reset( unsigned int paramIndex ) - Param Index out of bounds!" << std::endl;
235  return false;
236  }
237 
238  if( paramIndex+1 < numParams )
239  recursive_reset( paramIndex + 1 );
240 
241  return params[ paramIndex ].reset();
242  }
243 
245  std::function< Float () > evalFunc;
246  SearchType evalType;
247  Float bestResult;
248  T model;
249  T bestModel;
250 };
251 
252 GRT_END_NAMESPACE
253 
254 #endif // header guard
UINT getSize() const
Definition: Vector.h:191
This class implements a basic grid search algorithm.
Definition: GridSearch.h:39
Definition: Vector.h:41
Definition: MLBase.h:70