2 #define GRT_DLL_EXPORTS
11 nodeType =
"DecisionTreeThresholdNode";
24 if( x[ featureIndex ] >= threshold )
return true;
42 std::ostringstream stream;
45 std::cout << stream.str();
55 for(UINT i=0; i<depth; i++) tab +=
"\t";
57 stream << tab <<
"depth: " << depth <<
" nodeSize: " << nodeSize <<
" featureIndex: " << featureIndex <<
" threshold " << threshold <<
" isLeafNode: " << isLeafNode << std::endl;
58 stream << tab <<
"ClassProbabilities: ";
59 for(UINT i=0; i<classProbabilities.size(); i++){
60 stream << classProbabilities[i] <<
"\t";
64 if( leftChild != NULL ){
65 stream << tab <<
"LeftChild: " << std::endl;
69 if( rightChild != NULL ){
70 stream << tab <<
"RightChild: " << std::endl;
87 node->isLeafNode = isLeafNode;
88 node->nodeID = nodeID;
89 node->predictedNodeID = predictedNodeID;
90 node->nodeSize = nodeSize;
91 node->featureIndex = featureIndex;
92 node->threshold = threshold;
93 node->classProbabilities = classProbabilities;
98 node->leftChild->setParent( node );
104 node->rightChild->setParent( node );
107 return dynamic_cast< Node*
>( node );
123 this->nodeSize = nodeSize;
124 this->featureIndex = featureIndex;
125 this->threshold = threshold;
126 this->classProbabilities = classProbabilities;
130 bool DecisionTreeThresholdNode::computeBestSpiltBestIterativeSpilt(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const Vector< UINT > &features,
const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
133 const UINT N = features.
getSize();
134 const UINT K = classLabels.
getSize();
136 if( N == 0 )
return false;
139 UINT bestFeatureIndex = 0;
140 Float bestThreshold = 0;
145 Float giniIndexL = 0;
146 Float giniIndexR = 0;
156 for(UINT n=0; n<N; n++){
157 minRange = ranges[n].minValue;
158 maxRange = ranges[n].maxValue;
159 step = (maxRange-minRange)/Float(numSplittingSteps);
160 threshold = minRange;
161 featureIndex = features[n];
162 while( threshold <= maxRange ){
165 groupCounter[0] = groupCounter[1] = 0;
166 classProbabilities.setAllValues(0);
167 for(UINT i=0; i<M; i++){
168 groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
169 groupCounter[ groupIndex[i] ]++;
170 classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
174 for(UINT k=0; k<K; k++){
175 classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
176 classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
180 giniIndexL = giniIndexR = 0;
181 for(UINT k=0; k<K; k++){
182 giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
183 giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
185 weightL = groupCounter[0]/M;
186 weightR = groupCounter[1]/M;
187 error = (giniIndexL*weightL) + (giniIndexR*weightR);
190 if( error < minError ){
192 bestThreshold = threshold;
193 bestFeatureIndex = featureIndex;
202 featureIndex = bestFeatureIndex;
205 set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
210 bool DecisionTreeThresholdNode::computeBestSpiltBestRandomSpilt(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const Vector< UINT > &features,
const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
213 const UINT N = (UINT)features.size();
214 const UINT K = (UINT)classLabels.size();
216 if( N == 0 )
return false;
219 UINT bestFeatureIndex = 0;
220 Float bestThreshold = 0;
222 Float giniIndexL = 0;
223 Float giniIndexR = 0;
234 for(UINT n=0; n<N; n++){
235 featureIndex = features[n];
236 for(UINT m=0; m<numSplittingSteps; m++){
241 groupCounter[0] = groupCounter[1] = 0;
242 classProbabilities.setAllValues(0);
243 for(UINT i=0; i<M; i++){
244 groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
245 groupCounter[ groupIndex[i] ]++;
246 classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
250 for(UINT k=0; k<K; k++){
251 classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
252 classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
256 giniIndexL = giniIndexR = 0;
257 for(UINT k=0; k<K; k++){
258 giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
259 giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
261 weightL = groupCounter[0]/M;
262 weightR = groupCounter[1]/M;
263 error = (giniIndexL*weightL) + (giniIndexR*weightR);
266 if( error < minError ){
268 bestThreshold = threshold;
269 bestFeatureIndex = featureIndex;
275 featureIndex = bestFeatureIndex;
278 set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
287 errorLog <<
"saveParametersToFile(fstream &file) - File is not open!" << std::endl;
293 errorLog <<
"saveParametersToFile(fstream &file) - Failed to save DecisionTreeNode parameters to file!" << std::endl;
298 file <<
"FeatureIndex: " << featureIndex << std::endl;
299 file <<
"Threshold: " << threshold << std::endl;
308 errorLog <<
"loadParametersFromFile(fstream &file) - File is not open!" << std::endl;
314 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to load DecisionTreeNode parameters from file!" << std::endl;
322 if( word !=
"FeatureIndex:" ){
323 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find FeatureIndex header!" << std::endl;
326 file >> featureIndex;
329 if( word !=
"Threshold:" ){
330 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find Threshold header!" << std::endl;
UINT getFeatureIndex() const
virtual bool print() const
DecisionTreeThresholdNode * deepCopy() const
DecisionTreeThresholdNode()
virtual bool getModel(std::ostream &stream) const
virtual Node * deepCopyNode() const
virtual bool saveParametersToFile(std::fstream &file) const
UINT getNumSamples() const
virtual Node * deepCopyNode() const
virtual bool loadParametersFromFile(std::fstream &file)
virtual bool getModel(std::ostream &stream) const
virtual bool saveParametersToFile(std::fstream &file) const
virtual bool predict(const VectorFloat &x)
Float getThreshold() const
virtual ~DecisionTreeThresholdNode()
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
bool set(const UINT nodeSize, const UINT featureIndex, const Float threshold, const VectorFloat &classProbabilities)
virtual bool loadParametersFromFile(std::fstream &file)
This file implements a DecisionTreeThresholdNode, which is a specific type of node used for a Decisio...