10 nodeType =
"DecisionTreeThresholdNode";
23 if( x[ featureIndex ] >= threshold )
return true;
41 std::ostringstream stream;
44 std::cout << stream.str();
54 for(UINT i=0; i<depth; i++) tab +=
"\t";
56 stream << tab <<
"depth: " << depth <<
" nodeSize: " << nodeSize <<
" featureIndex: " << featureIndex <<
" threshold " << threshold <<
" isLeafNode: " << isLeafNode << std::endl;
57 stream << tab <<
"ClassProbabilities: ";
58 for(UINT i=0; i<classProbabilities.size(); i++){
59 stream << classProbabilities[i] <<
"\t";
63 if( leftChild != NULL ){
64 stream << tab <<
"LeftChild: " << std::endl;
68 if( rightChild != NULL ){
69 stream << tab <<
"RightChild: " << std::endl;
86 node->isLeafNode = isLeafNode;
87 node->nodeID = nodeID;
88 node->predictedNodeID = predictedNodeID;
89 node->nodeSize = nodeSize;
90 node->featureIndex = featureIndex;
91 node->threshold = threshold;
92 node->classProbabilities = classProbabilities;
97 node->leftChild->setParent( node );
103 node->rightChild->setParent( node );
106 return dynamic_cast< Node*
>( node );
122 this->nodeSize = nodeSize;
123 this->featureIndex = featureIndex;
124 this->threshold = threshold;
125 this->classProbabilities = classProbabilities;
129 bool DecisionTreeThresholdNode::computeBestSpiltBestIterativeSpilt(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const Vector< UINT > &features,
const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
132 const UINT N = features.
getSize();
133 const UINT K = classLabels.
getSize();
135 if( N == 0 )
return false;
138 UINT bestFeatureIndex = 0;
139 Float bestThreshold = 0;
144 Float giniIndexL = 0;
145 Float giniIndexR = 0;
155 for(UINT n=0; n<N; n++){
156 minRange = ranges[n].minValue;
157 maxRange = ranges[n].maxValue;
158 step = (maxRange-minRange)/Float(numSplittingSteps);
159 threshold = minRange;
160 featureIndex = features[n];
161 while( threshold <= maxRange ){
164 groupCounter[0] = groupCounter[1] = 0;
165 classProbabilities.setAllValues(0);
166 for(UINT i=0; i<M; i++){
167 groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
168 groupCounter[ groupIndex[i] ]++;
169 classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
173 for(UINT k=0; k<K; k++){
174 classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
175 classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
179 giniIndexL = giniIndexR = 0;
180 for(UINT k=0; k<K; k++){
181 giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
182 giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
184 weightL = groupCounter[0]/M;
185 weightR = groupCounter[1]/M;
186 error = (giniIndexL*weightL) + (giniIndexR*weightR);
189 if( error < minError ){
191 bestThreshold = threshold;
192 bestFeatureIndex = featureIndex;
201 featureIndex = bestFeatureIndex;
204 set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
209 bool DecisionTreeThresholdNode::computeBestSpiltBestRandomSpilt(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const Vector< UINT > &features,
const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
212 const UINT N = (UINT)features.size();
213 const UINT K = (UINT)classLabels.size();
215 if( N == 0 )
return false;
218 UINT bestFeatureIndex = 0;
219 Float bestThreshold = 0;
221 Float giniIndexL = 0;
222 Float giniIndexR = 0;
233 for(UINT n=0; n<N; n++){
234 featureIndex = features[n];
235 for(UINT m=0; m<numSplittingSteps; m++){
240 groupCounter[0] = groupCounter[1] = 0;
241 classProbabilities.setAllValues(0);
242 for(UINT i=0; i<M; i++){
243 groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
244 groupCounter[ groupIndex[i] ]++;
245 classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
249 for(UINT k=0; k<K; k++){
250 classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
251 classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
255 giniIndexL = giniIndexR = 0;
256 for(UINT k=0; k<K; k++){
257 giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
258 giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
260 weightL = groupCounter[0]/M;
261 weightR = groupCounter[1]/M;
262 error = (giniIndexL*weightL) + (giniIndexR*weightR);
265 if( error < minError ){
267 bestThreshold = threshold;
268 bestFeatureIndex = featureIndex;
274 featureIndex = bestFeatureIndex;
277 set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
286 errorLog <<
"saveParametersToFile(fstream &file) - File is not open!" << std::endl;
292 errorLog <<
"saveParametersToFile(fstream &file) - Failed to save DecisionTreeNode parameters to file!" << std::endl;
297 file <<
"FeatureIndex: " << featureIndex << std::endl;
298 file <<
"Threshold: " << threshold << std::endl;
307 errorLog <<
"loadParametersFromFile(fstream &file) - File is not open!" << std::endl;
313 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to load DecisionTreeNode parameters from file!" << std::endl;
321 if( word !=
"FeatureIndex:" ){
322 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find FeatureIndex header!" << std::endl;
325 file >> featureIndex;
328 if( word !=
"Threshold:" ){
329 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find Threshold header!" << std::endl;
UINT getFeatureIndex() const
virtual bool print() const
DecisionTreeThresholdNode * deepCopy() const
DecisionTreeThresholdNode()
virtual bool getModel(std::ostream &stream) const
unsigned int getSize() const
virtual Node * deepCopyNode() const
virtual bool saveParametersToFile(std::fstream &file) const
UINT getNumSamples() const
virtual Node * deepCopyNode() const
virtual bool loadParametersFromFile(std::fstream &file)
virtual bool getModel(std::ostream &stream) const
virtual bool saveParametersToFile(std::fstream &file) const
virtual bool predict(const VectorFloat &x)
Float getThreshold() const
virtual ~DecisionTreeThresholdNode()
Vector< MinMax > getRanges() const
Float getRandomNumberUniform(Float minRange=0.0, Float maxRange=1.0)
bool set(const UINT nodeSize, const UINT featureIndex, const Float threshold, const VectorFloat &classProbabilities)
virtual bool loadParametersFromFile(std::fstream &file)
This file implements a DecisionTreeThresholdNode, which is a specific type of node used for a Decisio...