28 #ifndef GRT_GRID_SEARCH_HEADER 29 #define GRT_GRID_SEARCH_HEADER 31 #include "../../CoreModules/MLBase.h" 32 #include "../../CoreModules/GestureRecognitionPipeline.h" 41 GridSearchRange(
const T _min = T(),
const T _max = T(),
const T _inc = T() ):min(_min),max(_max),inc(_inc){ value = min; expired =
false; }
44 this->value = rhs.value;
48 this->expired = rhs.expired;
52 if( expired )
return value;
53 if( value + inc < max ) value += inc;
64 bool getExpired()
const {
return expired; }
66 T
get() {
if( value >= max ) expired =
true;
return value; }
85 this->func = rhs.func;
86 this->range = rhs.range;
94 if( !func )
return false;
95 return func( range.get() );
99 if( !func )
return false;
100 return func( range.next() );
103 bool getExpired()
const {
return range.getExpired(); }
105 T
get(){
return range.get(); }
107 std::function< bool(T) > func;
114 enum SearchType {MaxValueSearch=0,MinValueSearch};
128 if( params.getSize() == 0 ){
129 warningLog << __GRT_LOG__ <<
" No parameters to search! Add some parameters!" << std::endl;
141 errorLog << __GRT_LOG__ <<
" Unknown eval type!" << std::endl;
146 if( params.getSize() == 0 )
return false;
147 unsigned int paramIndex = 0;
148 return recursive_search( paramIndex );
151 Float getBestResult()
const {
return bestResult; }
153 T getBestModel()
const {
return bestModel; }
155 bool setModel(
const T &model ){
160 bool setEvaluationFunction( std::function< Float () > f, SearchType type = MaxValueSearch ){
168 bool recursive_search(
unsigned int paramIndex ){
170 const unsigned int numParams = params.getSize();
171 if( paramIndex >= numParams ){
172 errorLog << __GRT_LOG__ <<
" Param Index out of bounds!" << std::endl;
176 recursive_reset( paramIndex );
178 bool stopAfterNextIter =
false;
183 params[ paramIndex ].set();
185 if( paramIndex+1 < numParams )
186 recursive_search( paramIndex + 1 );
188 if( paramIndex == numParams-1 ){
195 if( result > bestResult ){
201 if( result < bestResult ){
207 errorLog << __GRT_LOG__ <<
" Unknown eval type!" << std::endl;
213 if( stopAfterNextIter )
break;
216 params[ paramIndex ].update();
218 if( params[ paramIndex ].getExpired() ) stopAfterNextIter =
true;
224 bool recursive_reset(
unsigned int paramIndex ){
225 const unsigned int numParams = params.getSize();
226 if( paramIndex >= numParams ){
227 errorLog << __GRT_LOG__ <<
" Param Index out of bounds!" << std::endl;
231 if( paramIndex+1 < numParams )
232 recursive_reset( paramIndex + 1 );
234 return params[ paramIndex ].reset();
238 std::function< Float () > evalFunc;
247 #endif // header guard
This class implements a basic grid search algorithm.
This is the main base class that all GRT machine learning algorithms should inherit from...