2 #define GRT_DLL_EXPORTS 20 if( x[ featureIndex ] >= threshold )
return true;
38 std::ostringstream stream;
41 std::cout << stream.str();
51 for(UINT i=0; i<depth; i++) tab +=
"\t";
53 stream << tab <<
"depth: " << depth <<
" nodeSize: " << nodeSize <<
" featureIndex: " << featureIndex <<
" threshold " << threshold <<
" isLeafNode: " << isLeafNode << std::endl;
54 stream << tab <<
"ClassProbabilities: ";
55 for(UINT i=0; i<classProbabilities.size(); i++){
56 stream << classProbabilities[i] <<
"\t";
60 if( leftChild != NULL ){
61 stream << tab <<
"LeftChild: " << std::endl;
65 if( rightChild != NULL ){
66 stream << tab <<
"RightChild: " << std::endl;
83 node->isLeafNode = isLeafNode;
84 node->nodeID = nodeID;
85 node->predictedNodeID = predictedNodeID;
86 node->nodeSize = nodeSize;
87 node->featureIndex = featureIndex;
88 node->threshold = threshold;
89 node->classProbabilities = classProbabilities;
93 node->leftChild = leftChild->
deepCopy();
94 node->leftChild->setParent( node );
99 node->rightChild = rightChild->
deepCopy();
100 node->rightChild->setParent( node );
103 return dynamic_cast< Node*
>( node );
115 this->nodeSize = nodeSize;
116 this->featureIndex = featureIndex;
117 this->threshold = threshold;
118 this->classProbabilities = classProbabilities;
122 bool DecisionTreeThresholdNode::computeBestSplitBestIterativeSplit(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const Vector< UINT > &features,
const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
125 const UINT N = features.
getSize();
126 const UINT K = classLabels.
getSize();
128 if( N == 0 )
return false;
131 UINT bestFeatureIndex = 0;
132 Float bestThreshold = 0;
137 Float giniIndexL = 0;
138 Float giniIndexR = 0;
148 for(UINT n=0; n<N; n++){
149 minRange = ranges[n].minValue;
150 maxRange = ranges[n].maxValue;
151 step = (maxRange-minRange)/Float(numSplittingSteps);
152 threshold = minRange;
153 featureIndex = features[n];
154 while( threshold <= maxRange ){
157 groupCounter[0] = groupCounter[1] = 0;
159 for(UINT i=0; i<M; i++){
160 groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
161 groupCounter[ groupIndex[i] ]++;
162 classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
166 for(UINT k=0; k<K; k++){
167 classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
168 classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
172 giniIndexL = giniIndexR = 0;
173 for(UINT k=0; k<K; k++){
174 giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
175 giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
177 weightL = groupCounter[0]/M;
178 weightR = groupCounter[1]/M;
179 error = (giniIndexL*weightL) + (giniIndexR*weightR);
182 if( error < minError ){
184 bestThreshold = threshold;
185 bestFeatureIndex = featureIndex;
194 featureIndex = bestFeatureIndex;
197 set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
202 bool DecisionTreeThresholdNode::computeBestSplitBestRandomSplit(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const Vector< UINT > &features,
const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
205 const UINT N = (UINT)features.size();
206 const UINT K = (UINT)classLabels.size();
208 if( N == 0 )
return false;
211 UINT bestFeatureIndex = 0;
212 Float bestThreshold = 0;
214 Float giniIndexL = 0;
215 Float giniIndexR = 0;
226 const UINT numFeatures = features.
getSize();
227 for(m=0; m<numSplittingSteps; m++){
230 featureIndex = features[n];
236 groupCounter[0] = groupCounter[1] = 0;
238 for(UINT i=0; i<M; i++){
239 groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
240 groupCounter[ groupIndex[i] ]++;
241 classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
245 for(UINT k=0; k<K; k++){
246 classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
247 classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
251 giniIndexL = giniIndexR = 0;
252 for(UINT k=0; k<K; k++){
253 giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
254 giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
256 weightL = groupCounter[0]/M;
257 weightR = groupCounter[1]/M;
258 error = (giniIndexL*weightL) + (giniIndexR*weightR);
261 if( error < minError ){
263 bestThreshold = threshold;
264 bestFeatureIndex = featureIndex;
269 featureIndex = bestFeatureIndex;
272 set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
281 errorLog <<
"saveParametersToFile(fstream &file) - File is not open!" << std::endl;
287 errorLog <<
"saveParametersToFile(fstream &file) - Failed to save DecisionTreeNode parameters to file!" << std::endl;
292 file <<
"FeatureIndex: " << featureIndex << std::endl;
293 file <<
"Threshold: " << threshold << std::endl;
302 errorLog <<
"loadParametersFromFile(fstream &file) - File is not open!" << std::endl;
308 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to load DecisionTreeNode parameters from file!" << std::endl;
316 if( word !=
"FeatureIndex:" ){
317 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find FeatureIndex header!" << std::endl;
320 file >> featureIndex;
323 if( word !=
"Threshold:" ){
324 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find Threshold header!" << std::endl;
UINT getFeatureIndex() const
virtual bool clear() override
virtual bool getModel(std::ostream &stream) const override
This file contains the Random class, a useful wrapper for generating cross platform random functions...
DecisionTreeThresholdNode()
virtual bool getModel(std::ostream &stream) const override
virtual bool print() const override
virtual bool predict_(VectorFloat &x) override
virtual bool loadParametersFromFile(std::fstream &file) override
virtual bool saveParametersToFile(std::fstream &file) const override
bool setAllValues(const T &value)
UINT getNumSamples() const
virtual Node * deepCopy() const override
Float getThreshold() const
virtual ~DecisionTreeThresholdNode()
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
bool set(const UINT nodeSize, const UINT featureIndex, const Float threshold, const VectorFloat &classProbabilities)
virtual bool clear() override
int getRandomNumberInt(int minRange, int maxRange)
virtual bool loadParametersFromFile(std::fstream &file) override
virtual bool saveParametersToFile(std::fstream &file) const override
virtual Node * deepCopy() const