GestureRecognitionToolkit  Version: 0.2.5
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
GridSearch.h
1 
28 #ifndef GRT_GRID_SEARCH_HEADER
29 #define GRT_GRID_SEARCH_HEADER
30 
31 #include "../../CoreModules/MLBase.h"
32 #include "../../CoreModules/GestureRecognitionPipeline.h"
33 
34 #include <functional>
35 
36 GRT_BEGIN_NAMESPACE
37 
38 template< class T >
40 public:
41  GridSearchRange( const T _min = T(), const T _max = T(), const T _inc = T() ):min(_min),max(_max),inc(_inc){ value = min; expired = false; }
42 
43  GridSearchRange( const GridSearchRange &rhs ){
44  this->value = rhs.value;
45  this->min = rhs.min;
46  this->max = rhs.max;
47  this->inc = rhs.inc;
48  this->expired = rhs.expired;
49  }
50 
51  T next(){
52  if( expired ) return value;
53  if( value + inc < max ) value += inc;
54  else{ value = max; }
55  return value;
56  }
57 
58  bool reset(){
59  value = min;
60  expired = false;
61  return true;
62  }
63 
64  bool getExpired() const { return expired; }
65 
66  T get() { if( value >= max ) expired = true; return value; }
67 
68  T value;
69  T min;
70  T max;
71  T inc;
72  bool expired;
73 };
74 
75 template < class T >
77 public:
78 
79  GridSearchParam( std::function< bool(T) > func = nullptr, GridSearchRange<T> range = GridSearchRange<T>() ){
80  this->func = func;
81  this->range = range;
82  }
83 
84  GridSearchParam( const GridSearchParam &rhs ){
85  this->func = rhs.func;
86  this->range = rhs.range;
87  }
88 
89  bool reset(){
90  return range.reset();
91  }
92 
93  bool set(){
94  if( !func ) return false;
95  return func( range.get() );
96  }
97 
98  bool update(){
99  if( !func ) return false;
100  return func( range.next() );
101  }
102 
103  bool getExpired() const { return range.getExpired(); }
104 
105  T get(){ return range.get(); }
106 
107  std::function< bool(T) > func;
108  GridSearchRange<T> range;
109 };
110 
111 template< class T >
112 class GridSearch : public MLBase{
113 public:
114  enum SearchType {MaxValueSearch=0,MinValueSearch};
115  GridSearch() : MLBase("GridSearch") {}
116 
117  virtual ~GridSearch(){
118 
119  }
120 
121  bool addParameter( std::function< bool(unsigned int) > f , GridSearchRange< unsigned int > range ){
122  params.push_back( GridSearchParam<unsigned int>( f, range ) );
123  return true;
124  }
125 
126  bool search( ){
127 
128  if( params.getSize() == 0 ){
129  warningLog << __GRT_LOG__ << " No parameters to search! Add some parameters!" << std::endl;
130  return false;
131  }
132 
133  switch( evalType ){
134  case MaxValueSearch:
135  bestResult = 0;
136  break;
137  case MinValueSearch:
138  bestResult = grt_numeric_limits< Float >::max();
139  break;
140  default:
141  errorLog << __GRT_LOG__ << " Unknown eval type!" << std::endl;
142  return false;
143  break;
144  }
145 
146  if( params.getSize() == 0 ) return false;
147  unsigned int paramIndex = 0;
148  return recursive_search( paramIndex );
149  }
150 
151  Float getBestResult() const { return bestResult; }
152 
153  T getBestModel() const { return bestModel; }
154 
155  bool setModel( const T &model ){
156  this->model = model;
157  return true;
158  }
159 
160  bool setEvaluationFunction( std::function< Float () > f, SearchType type = MaxValueSearch ){
161  evalFunc = f;
162  evalType = type;
163  return true;
164  }
165 
166 protected:
167 
168  bool recursive_search( unsigned int paramIndex ){
169 
170  const unsigned int numParams = params.getSize();
171  if( paramIndex >= numParams ){
172  errorLog << __GRT_LOG__ << " Param Index out of bounds!" << std::endl;
173  return false;
174  }
175 
176  recursive_reset( paramIndex );
177 
178  bool stopAfterNextIter = false;
179  Float result = 0.0;
180  while( true ){
181 
182  //Make sure the parameter is set
183  params[ paramIndex ].set();
184 
185  if( paramIndex+1 < numParams )
186  recursive_search( paramIndex + 1 );
187 
188  if( paramIndex == numParams-1 ){ //If we are at the final parameter, run the evaluation
189 
190  //Evaluate the function using the current parameters
191  result = evalFunc();
192 
193  switch( evalType ){
194  case MaxValueSearch:
195  if( result > bestResult ){
196  bestResult = result;
197  bestModel = model;
198  }
199  break;
200  case MinValueSearch:
201  if( result < bestResult ){
202  bestResult = result;
203  bestModel = model;
204  }
205  break;
206  default:
207  errorLog << __GRT_LOG__ << " Unknown eval type!" << std::endl;
208  return false;
209  break;
210  }
211  }
212 
213  if( stopAfterNextIter ) break;
214 
215  //Update the parameter
216  params[ paramIndex ].update();
217 
218  if( params[ paramIndex ].getExpired() ) stopAfterNextIter = true;
219  }
220 
221  return true;
222  }
223 
224  bool recursive_reset( unsigned int paramIndex ){
225  const unsigned int numParams = params.getSize();
226  if( paramIndex >= numParams ){
227  errorLog << __GRT_LOG__ << " Param Index out of bounds!" << std::endl;
228  return false;
229  }
230 
231  if( paramIndex+1 < numParams )
232  recursive_reset( paramIndex + 1 );
233 
234  return params[ paramIndex ].reset();
235  }
236 
238  std::function< Float () > evalFunc;
239  SearchType evalType;
240  Float bestResult;
241  T model;
242  T bestModel;
243 };
244 
245 GRT_END_NAMESPACE
246 
247 #endif // header guard
This class implements a basic grid search algorithm.
Definition: GridSearch.h:39
Definition: Vector.h:41
This is the main base class that all GRT machine learning algorithms should inherit from...
Definition: MLBase.h:72