36 this->numSteps = numSteps;
37 this->positiveClassificationThreshold = positiveClassificationThreshold;
38 this->minAlphaSearchRange = minAlphaSearchRange;
39 this->maxAlphaSearchRange = maxAlphaSearchRange;
45 trainingLog.setProceedingText(
"[DEBUG RadialBasisFunction]");
46 warningLog.setProceedingText(
"[WARNING RadialBasisFunction]");
47 errorLog.setProceedingText(
"[ERROR RadialBasisFunction]");
60 this->numSteps = rhs.numSteps;
61 this->alpha = rhs.alpha;
62 this->gamma = rhs.gamma;
63 this->positiveClassificationThreshold = rhs.positiveClassificationThreshold;
64 this->minAlphaSearchRange = rhs.minAlphaSearchRange;
65 this->maxAlphaSearchRange = rhs.maxAlphaSearchRange;
66 this->rbfCentre = rhs.rbfCentre;
73 if( weakClassifer == NULL )
return false;
91 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There should only be 2 classes in the training data, but there are : " << trainingData.
getNumClasses() << std::endl;
97 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There number of examples in the training data (" << trainingData.
getNumSamples() <<
") does not match the lenght of the weights vector (" << weights.size() <<
")" << std::endl;
108 for(UINT i=0; i<M; i++){
110 if( weights[i] > maxWeight ){
111 maxWeight = weights[i];
113 bestWeights.push_back(i);
114 }
else if( weights[i] == maxWeight ){
115 bestWeights.push_back( i );
121 const UINT N = (UINT)bestWeights.size();
124 errorLog <<
"train(ClassificationData &trainingData, VectorFloat &weights) - There are no positive class weigts!" << std::endl;
128 for(UINT i=0; i<N; i++){
130 rbfCentre[j] += trainingData[ bestWeights[i] ][j];
136 rbfCentre[j] /= Float(N);
140 Float step = (maxAlphaSearchRange-minAlphaSearchRange)/numSteps;
144 alpha = minAlphaSearchRange;
145 while( alpha <= maxAlphaSearchRange ){
148 gamma = -1.0/(2.0*grt_sqr(alpha));
152 for(UINT i=0; i<M; i++){
154 Float v = rbf(trainingData[ i ].getSample(),rbfCentre);
156 if( (v >= positiveClassificationThreshold && !positiveSample) || (v<positiveClassificationThreshold && positiveSample) ){
162 if( error < minError ){
176 gamma = -1.0/(2.0*grt_sqr(alpha));
179 std::cout <<
"BestAlpha: " << bestAlpha <<
" Error: " << minError << std::endl;
185 if( rbf(x,rbfCentre) >= positiveClassificationThreshold )
return 1;
190 const UINT N = (UINT)a.size();
193 for(UINT i=0; i<N; i++){
196 return exp( gamma * r );
203 errorLog <<
"saveModelToFile(fstream &file) - The file is not open!" << std::endl;
209 file <<
"Trained: "<<
trained << std::endl;
213 file <<
"NumSteps: " << numSteps << std::endl;
214 file <<
"PositiveClassificationThreshold: " << positiveClassificationThreshold << std::endl;
215 file <<
"Alpha: " << alpha << std::endl;
216 file <<
"MinAlphaSearchRange: " << minAlphaSearchRange << std::endl;
217 file <<
"MaxAlphaSearchRange: " << maxAlphaSearchRange << std::endl;
221 file << rbfCentre[i] <<
"\t";
222 }
else file << 0 <<
"\t";
234 errorLog <<
"loadModelFromFile(fstream &file) - The file is not open!" << std::endl;
241 if( word !=
"WeakClassifierType:" ){
242 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read WeakClassifierType header!" << std::endl;
248 errorLog <<
"loadModelFromFile(fstream &file) - The weakClassifierType:" << word <<
" does not match: " <<
weakClassifierType << std::endl;
253 if( word !=
"Trained:" ){
254 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Trained header!" << std::endl;
260 if( word !=
"NumInputDimensions:" ){
261 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumInputDimensions header!" << std::endl;
268 if( word !=
"NumSteps:" ){
269 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumSteps header!" << std::endl;
275 if( word !=
"PositiveClassificationThreshold:" ){
276 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read PositiveClassificationThreshold header!" << std::endl;
279 file >> positiveClassificationThreshold;
282 if( word !=
"Alpha:" ){
283 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Alpha header!" << std::endl;
289 if( word !=
"MinAlphaSearchRange:" ){
290 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read MinAlphaSearchRange header!" << std::endl;
293 file >> minAlphaSearchRange;
296 if( word !=
"MaxAlphaSearchRange:" ){
297 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read MaxAlphaSearchRange header!" << std::endl;
300 file >> maxAlphaSearchRange;
303 if( word !=
"RBF:" ){
304 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read RBF header!" << std::endl;
307 rbfCentre.
resize(numInputDimensions);
310 file >> rbfCentre[i];
314 gamma = -1.0/(2.0*SQR(alpha));
332 return positiveClassificationThreshold;
340 return minAlphaSearchRange;
344 return maxAlphaSearchRange;
virtual bool train(ClassificationData &trainingData, VectorFloat &weights)
virtual Float predict(const VectorFloat &x)
virtual bool loadModelFromFile(std::fstream &file)
std::string weakClassifierType
A string that represents the weak classifier type, e.g. DecisionStump.
virtual bool deepCopyFrom(const WeakClassifier *weakClassifer)
UINT numInputDimensions
The number of input dimensions to the weak classifier.
virtual bool resize(const unsigned int size)
std::string getWeakClassifierType() const
virtual ~RadialBasisFunction()
RadialBasisFunction & operator=(const RadialBasisFunction &rhs)
virtual void print() const
Float getMinAlphaSearchRange() const
unsigned int getSize() const
static RegisterWeakClassifierModule< RadialBasisFunction > registerModule
This is used to register the DecisionStump with the WeakClassifier base class.
UINT getNumSamples() const
#define WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL
virtual bool saveModelToFile(std::fstream &file) const
bool copyBaseVariables(const WeakClassifier *weakClassifer)
UINT getNumDimensions() const
UINT getNumClasses() const
VectorFloat getRBFCentre() const
bool trained
A flag to show if the weak classifier model has been trained.
Float getMaxAlphaSearchRange() const
RadialBasisFunction(UINT numSteps=100, Float positiveClassificationThreshold=0.9, Float minAlphaSearchRange=0.001, Float maxAlphaSearchRange=1.0)
Float getPositiveClassificationThreshold() const
This class implements a Radial Basis Function Weak Classifier. The Radial Basis Function (RBF) class ...