21 #define GRT_DLL_EXPORTS
27 UINT Node::numNodeInstances = 0;
31 StringNodeMap::iterator iter = getMap()->find( nodeType );
32 if( iter == getMap()->end() ){
36 return iter->second();
48 debugLog.setProceedingText(
"[DEBUG Node]");
49 errorLog.setProceedingText(
"[ERROR Node]");
50 trainingLog.setProceedingText(
"[TRAINING Node]");
51 testingLog.setProceedingText(
"[TESTING Node]");
52 warningLog.setProceedingText(
"[WARNING Node]");
61 warningLog <<
"predict(const VectorFloat &x) - Base class not overwritten!" << std::endl;
66 warningLog <<
"predict(const VectorFloat &x) - Base class not overwritten!" << std::endl;
75 if( leftChild != NULL ){
84 if( rightChild != NULL ){
111 std::ostringstream stream;
113 std::cout << stream.str();
122 std::string tab =
"";
123 for(UINT i=0; i<depth; i++) tab +=
"\t";
125 stream << tab <<
"depth: " << depth <<
" isLeafNode: " << isLeafNode <<
" nodeID: " << nodeID << std::endl;
127 if( leftChild != NULL ){
128 stream << tab <<
"LeftChild: " << std::endl;
132 if( rightChild != NULL ){
133 stream << tab <<
"RightChild: " << std::endl;
144 errorLog <<
"save(fstream &file) - File is not open!" << std::endl;
148 file <<
"NodeType: " << nodeType << std::endl;
149 file <<
"Depth: " << depth << std::endl;
150 file <<
"NodeID: " << nodeID << std::endl;
151 file <<
"IsLeafNode: " << isLeafNode << std::endl;
157 file <<
"LeftChild\n";
158 if( !leftChild->
save( file ) ){
159 errorLog <<
"save(fstream &file) - Failed to save left child at depth: " << depth << std::endl;
166 file <<
"RightChild\n";
167 if( !rightChild->
save( file ) ){
168 errorLog <<
"save(fstream &file) - Failed to save right child at depth: " << depth << std::endl;
175 errorLog <<
"save(fstream &file) - Failed to save parameters to file at depth: " << depth << std::endl;
189 errorLog <<
"load(fstream &file) - File is not open!" << std::endl;
194 bool hasLeftChild =
false;
195 bool hasRightChild =
false;
198 if( word !=
"NodeType:" ){
199 errorLog <<
"load(fstream &file) - Failed to find Node header!" << std::endl;
205 if( word !=
"Depth:" ){
206 errorLog <<
"load(fstream &file) - Failed to find Depth header!" << std::endl;
212 if( word !=
"NodeID:" ){
213 errorLog <<
"load(fstream &file) - Failed to find NodeID header!" << std::endl;
219 if( word !=
"IsLeafNode:" ){
220 errorLog <<
"load(fstream &file) - Failed to find IsLeafNode header!" << std::endl;
226 if( word !=
"HasLeftChild:" ){
227 errorLog <<
"load(fstream &file) - Failed to find HasLeftChild header!" << std::endl;
230 file >> hasLeftChild;
233 if( word !=
"HasRightChild:" ){
234 errorLog <<
"load(fstream &file) - Failed to find HasRightChild header!" << std::endl;
237 file >> hasRightChild;
241 if( word !=
"LeftChild" ){
242 errorLog <<
"load(fstream &file) - Failed to find LeftChild header!" << std::endl;
246 leftChild->setParent(
this );
247 if( !leftChild->
load(file) ){
248 errorLog <<
"load(fstream &file) - Failed to load left child at depth: " << depth << std::endl;
255 if( word !=
"RightChild" ){
256 errorLog <<
"load(fstream &file) - Failed to find RightChild header!" << std::endl;
260 rightChild->setParent(
this );
261 if( !rightChild->
load( file ) ){
262 errorLog <<
"load(fstream &file) - Failed to load right child at depth: " << depth << std::endl;
269 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to load parameters from file at depth: " << depth << std::endl;
285 node->setNodeID( nodeID );
286 node->setDepth( depth );
287 node->setIsLeafNode( isLeafNode );
292 node->leftChild->setParent( node );
298 node->rightChild->setParent( node );
317 return predictedNodeID;
320 UINT Node::getMaxDepth()
const {
322 UINT maxDepth = depth;
326 UINT maxLeftDepth = leftChild->getMaxDepth();
327 if( maxLeftDepth > maxDepth ){
328 maxDepth = maxLeftDepth;
334 UINT maxRightDepth = rightChild->getMaxDepth();
335 if( maxRightDepth > maxDepth ){
336 maxDepth = maxRightDepth;
348 return (parent != NULL);
352 return (leftChild != NULL);
356 return (rightChild != NULL);
359 bool Node::initNode(
Node *parent,
const UINT depth,
const UINT nodeID,
const bool isLeafNode){
360 this->parent = parent;
362 this->nodeID = nodeID;
363 this->isLeafNode = isLeafNode;
367 bool Node::setParent(
Node *parent){
368 this->parent = parent;
372 bool Node::setLeftChild(
Node *leftChild){
373 this->leftChild = leftChild;
377 bool Node::setRightChild(
Node *rightChild){
378 this->rightChild = rightChild;
382 bool Node::setDepth(
const UINT depth){
387 bool Node::setNodeID(
const UINT nodeID){
388 this->nodeID = nodeID;
392 bool Node::setIsLeafNode(
const bool isLeafNode){
393 this->isLeafNode = isLeafNode;
virtual bool print() const
std::string getNodeType() const
virtual bool getModel(std::ostream &stream) const
virtual bool save(std::fstream &file) const
virtual Node * deepCopyNode() const
bool getIsLeafNode() const
virtual bool loadParametersFromFile(std::fstream &file)
virtual bool saveParametersToFile(std::fstream &file) const
This class contains the main Node base class.
bool getHasParent() const
virtual bool computeFeatureWeights(VectorFloat &weights) const
Node * createNewInstance() const
UINT getPredictedNodeID() const
static Node * createInstanceFromString(std::string const &nodeType)
virtual bool computeLeafNodeWeights(MatrixFloat &weights) const
virtual bool load(std::fstream &file)
std::map< std::string, Node *(*)() > StringNodeMap
virtual bool predict(const VectorFloat &x)
bool getHasRightChild() const
bool getHasLeftChild() const