28 #ifndef GRT_GRID_SEARCH_HEADER
29 #define GRT_GRID_SEARCH_HEADER
31 #include "../../CoreModules/MLBase.h"
32 #include "../../CoreModules/GestureRecognitionPipeline.h"
41 GridSearchRange(
const T _min = T(),
const T _max = T(),
const T _inc = T() ):min(_min),max(_max),inc(_inc){ value = min; expired =
false; }
44 this->value = rhs.value;
48 this->expired = rhs.expired;
52 if( expired )
return value;
53 if( value + inc < max ) value += inc;
64 bool getExpired()
const {
return expired; }
66 T
get() {
if( value >= max ) expired =
true;
return value; }
85 this->func = rhs.func;
86 this->range = rhs.range;
94 if( !func )
return false;
95 return func( range.get() );
99 if( !func )
return false;
100 return func( range.next() );
103 bool getExpired()
const {
return range.getExpired(); }
105 T
get(){
return range.get(); }
107 std::function< bool(T) > func;
114 enum SearchType {MaxValueSearch=0,MinValueSearch};
116 classType =
"GridSearch";
117 infoLog.setProceedingText(
"[GridSearch]");
118 debugLog.setProceedingText(
"[DEBUG GridSearch]");
119 errorLog.setProceedingText(
"[ERROR GridSearch");
120 trainingLog.setProceedingText(
"[TRAINING GridSearch]");
121 warningLog.setProceedingText(
"[WARNING GridSearch]");
136 warningLog <<
"No parameters to search! Add some parameters!" << std::endl;
148 errorLog <<
"recursive_search( unsigned int paramIndex ) - Unknown eval type!" << std::endl;
153 if( params.
getSize() == 0 )
return false;
154 unsigned int paramIndex = 0;
155 return recursive_search( paramIndex );
158 Float getBestResult()
const {
return bestResult; }
160 T getBestModel()
const {
return bestModel; }
162 bool setModel(
const T &model ){
167 bool setEvaluationFunction( std::function< Float () > f, SearchType type = MaxValueSearch ){
175 bool recursive_search(
unsigned int paramIndex ){
177 const unsigned int numParams = params.
getSize();
178 if( paramIndex >= numParams ){
179 errorLog <<
"recursive_search( unsigned int paramIndex ) - Param Index out of bounds!" << std::endl;
183 recursive_reset( paramIndex );
185 bool stopAfterNextIter =
false;
190 params[ paramIndex ].set();
192 if( paramIndex+1 < numParams )
193 recursive_search( paramIndex + 1 );
195 if( paramIndex == numParams-1 ){
202 if( result > bestResult ){
208 if( result < bestResult ){
214 errorLog <<
"recursive_search( unsigned int paramIndex ) - Unknown eval type!" << std::endl;
220 if( stopAfterNextIter )
break;
223 params[ paramIndex ].update();
225 if( params[ paramIndex ].getExpired() ) stopAfterNextIter =
true;
231 bool recursive_reset(
unsigned int paramIndex ){
232 const unsigned int numParams = params.
getSize();
233 if( paramIndex >= numParams ){
234 errorLog <<
"recursive_reset( unsigned int paramIndex ) - Param Index out of bounds!" << std::endl;
238 if( paramIndex+1 < numParams )
239 recursive_reset( paramIndex + 1 );
241 return params[ paramIndex ].reset();
245 std::function< Float () > evalFunc;
254 #endif // header guard
unsigned int getSize() const
This class implements a basic grid search algorithm.