28 #define GRT_DLL_EXPORTS 37 this->numSteps = numSteps;
38 this->positiveClassificationThreshold = positiveClassificationThreshold;
39 this->minAlphaSearchRange = minAlphaSearchRange;
40 this->maxAlphaSearchRange = maxAlphaSearchRange;
46 trainingLog.
setKey(
"[DEBUG RadialBasisFunction]");
47 warningLog.setKey(
"[WARNING RadialBasisFunction]");
48 errorLog.
setKey(
"[ERROR RadialBasisFunction]");
61 this->numSteps = rhs.numSteps;
62 this->alpha = rhs.alpha;
63 this->gamma = rhs.gamma;
64 this->positiveClassificationThreshold = rhs.positiveClassificationThreshold;
65 this->minAlphaSearchRange = rhs.minAlphaSearchRange;
66 this->maxAlphaSearchRange = rhs.maxAlphaSearchRange;
67 this->rbfCentre = rhs.rbfCentre;
74 if( weakClassifer == NULL )
return false;
92 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There should only be 2 classes in the training data, but there are : " << trainingData.
getNumClasses() << std::endl;
98 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There number of examples in the training data (" << trainingData.
getNumSamples() <<
") does not match the lenght of the weights vector (" << weights.size() <<
")" << std::endl;
109 for(UINT i=0; i<M; i++){
111 if( weights[i] > maxWeight ){
112 maxWeight = weights[i];
114 bestWeights.push_back(i);
115 }
else if( weights[i] == maxWeight ){
116 bestWeights.push_back( i );
122 const UINT N = (UINT)bestWeights.size();
125 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There are no positive class weigts!" << std::endl;
129 for(UINT i=0; i<N; i++){
131 rbfCentre[j] += trainingData[ bestWeights[i] ][j];
137 rbfCentre[j] /= Float(N);
141 Float step = (maxAlphaSearchRange-minAlphaSearchRange)/numSteps;
145 alpha = minAlphaSearchRange;
146 while( alpha <= maxAlphaSearchRange ){
149 gamma = -1.0/(2.0*grt_sqr(alpha));
153 for(UINT i=0; i<M; i++){
155 Float v = rbf(trainingData[ i ].getSample(),rbfCentre);
157 if( (v >= positiveClassificationThreshold && !positiveSample) || (v<positiveClassificationThreshold && positiveSample) ){
163 if( error < minError ){
177 gamma = -1.0/(2.0*grt_sqr(alpha));
180 std::cout <<
"BestAlpha: " << bestAlpha <<
" Error: " << minError << std::endl;
186 if( rbf(x,rbfCentre) >= positiveClassificationThreshold )
return 1;
191 const UINT N = (UINT)a.size();
194 for(UINT i=0; i<N; i++){
197 return exp( gamma * r );
204 errorLog <<
"saveModelToFile(fstream &file) - The file is not open!" << std::endl;
210 file <<
"Trained: "<<
trained << std::endl;
214 file <<
"NumSteps: " << numSteps << std::endl;
215 file <<
"PositiveClassificationThreshold: " << positiveClassificationThreshold << std::endl;
216 file <<
"Alpha: " << alpha << std::endl;
217 file <<
"MinAlphaSearchRange: " << minAlphaSearchRange << std::endl;
218 file <<
"MaxAlphaSearchRange: " << maxAlphaSearchRange << std::endl;
222 file << rbfCentre[i] <<
"\t";
223 }
else file << 0 <<
"\t";
235 errorLog <<
"loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
242 if( word !=
"WeakClassifierType:" ){
243 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read WeakClassifierType header!" << std::endl;
249 errorLog <<
"loadModelFromFile(fstream &file) - The weakClassifierType:" << word <<
" does not match: " <<
weakClassifierType << std::endl;
254 if( word !=
"Trained:" ){
255 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
261 if( word !=
"NumInputDimensions:" ){
262 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
269 if( word !=
"NumSteps:" ){
270 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumSteps header!" << std::endl;
276 if( word !=
"PositiveClassificationThreshold:" ){
277 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read PositiveClassificationThreshold header!" << std::endl;
280 file >> positiveClassificationThreshold;
283 if( word !=
"Alpha:" ){
284 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Alpha header!" << std::endl;
290 if( word !=
"MinAlphaSearchRange:" ){
291 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read MinAlphaSearchRange header!" << std::endl;
294 file >> minAlphaSearchRange;
297 if( word !=
"MaxAlphaSearchRange:" ){
298 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read MaxAlphaSearchRange header!" << std::endl;
301 file >> maxAlphaSearchRange;
304 if( word !=
"RBF:" ){
305 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read RBF header!" << std::endl;
308 rbfCentre.
resize(numInputDimensions);
311 file >> rbfCentre[i];
315 gamma = -1.0/(2.0*SQR(alpha));
333 return positiveClassificationThreshold;
341 return minAlphaSearchRange;
345 return maxAlphaSearchRange;
virtual bool train(ClassificationData &trainingData, VectorFloat &weights)
virtual Float predict(const VectorFloat &x)
virtual bool loadModelFromFile(std::fstream &file)
std::string weakClassifierType
A string that represents the weak classifier type, e.g. DecisionStump.
virtual bool deepCopyFrom(const WeakClassifier *weakClassifer)
UINT numInputDimensions
The number of input dimensions to the weak classifier.
virtual bool resize(const unsigned int size)
virtual bool setKey(const std::string &key)
sets the key that gets written at the start of each message, this will be written in the format 'key ...
std::string getWeakClassifierType() const
virtual ~RadialBasisFunction()
RadialBasisFunction & operator=(const RadialBasisFunction &rhs)
virtual void print() const
Float getMinAlphaSearchRange() const
static RegisterWeakClassifierModule< RadialBasisFunction > registerModule
This is used to register the DecisionStump with the WeakClassifier base class.
UINT getNumSamples() const
#define WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL
virtual bool saveModelToFile(std::fstream &file) const
bool copyBaseVariables(const WeakClassifier *weakClassifer)
UINT getNumDimensions() const
UINT getNumClasses() const
VectorFloat getRBFCentre() const
bool trained
A flag to show if the weak classifier model has been trained.
Float getMaxAlphaSearchRange() const
RadialBasisFunction(UINT numSteps=100, Float positiveClassificationThreshold=0.9, Float minAlphaSearchRange=0.001, Float maxAlphaSearchRange=1.0)
Float getPositiveClassificationThreshold() const
This class implements a Radial Basis Function Weak Classifier. The Radial Basis Function (RBF) class ...