21 #define GRT_DLL_EXPORTS 27 UINT Node::numNodeInstances = 0;
31 StringNodeMap::iterator iter = getMap()->find( nodeType );
32 if( iter == getMap()->end() ){
36 return iter->second();
57 warningLog << __GRT_LOG__ <<
" Base class not overwritten!" << std::endl;
62 warningLog << __GRT_LOG__ <<
" Base class not overwritten!" << std::endl;
71 if( leftChild != NULL ){
80 if( rightChild != NULL ){
107 std::ostringstream stream;
109 std::cout << stream.str();
118 std::string tab =
"";
119 for(UINT i=0; i<depth; i++) tab +=
"\t";
121 stream << tab <<
"depth: " << depth <<
" isLeafNode: " << isLeafNode <<
" nodeID: " << nodeID << std::endl;
123 if( leftChild != NULL ){
124 stream << tab <<
"LeftChild: " << std::endl;
128 if( rightChild != NULL ){
129 stream << tab <<
"RightChild: " << std::endl;
140 errorLog <<
"save(fstream &file) - File is not open!" << std::endl;
144 file <<
"NodeType: " << nodeType << std::endl;
145 file <<
"Depth: " << depth << std::endl;
146 file <<
"NodeID: " << nodeID << std::endl;
147 file <<
"IsLeafNode: " << isLeafNode << std::endl;
153 file <<
"LeftChild\n";
154 if( !leftChild->
save( file ) ){
155 errorLog <<
"save(fstream &file) - Failed to save left child at depth: " << depth << std::endl;
162 file <<
"RightChild\n";
163 if( !rightChild->
save( file ) ){
164 errorLog <<
"save(fstream &file) - Failed to save right child at depth: " << depth << std::endl;
171 errorLog <<
"save(fstream &file) - Failed to save parameters to file at depth: " << depth << std::endl;
185 errorLog <<
"load(fstream &file) - File is not open!" << std::endl;
190 bool hasLeftChild =
false;
191 bool hasRightChild =
false;
194 if( word !=
"NodeType:" ){
195 errorLog <<
"load(fstream &file) - Failed to find Node header!" << std::endl;
201 if( word !=
"Depth:" ){
202 errorLog <<
"load(fstream &file) - Failed to find Depth header!" << std::endl;
208 if( word !=
"NodeID:" ){
209 errorLog <<
"load(fstream &file) - Failed to find NodeID header!" << std::endl;
215 if( word !=
"IsLeafNode:" ){
216 errorLog <<
"load(fstream &file) - Failed to find IsLeafNode header!" << std::endl;
222 if( word !=
"HasLeftChild:" ){
223 errorLog <<
"load(fstream &file) - Failed to find HasLeftChild header!" << std::endl;
226 file >> hasLeftChild;
229 if( word !=
"HasRightChild:" ){
230 errorLog <<
"load(fstream &file) - Failed to find HasRightChild header!" << std::endl;
233 file >> hasRightChild;
237 if( word !=
"LeftChild" ){
238 errorLog <<
"load(fstream &file) - Failed to find LeftChild header!" << std::endl;
242 leftChild->setParent(
this );
243 if( !leftChild->
load(file) ){
244 errorLog <<
"load(fstream &file) - Failed to load left child at depth: " << depth << std::endl;
251 if( word !=
"RightChild" ){
252 errorLog <<
"load(fstream &file) - Failed to find RightChild header!" << std::endl;
256 rightChild->setParent(
this );
257 if( !rightChild->
load( file ) ){
258 errorLog <<
"load(fstream &file) - Failed to load right child at depth: " << depth << std::endl;
265 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to load parameters from file at depth: " << depth << std::endl;
281 node->setNodeID( nodeID );
282 node->setDepth( depth );
283 node->setIsLeafNode( isLeafNode );
286 if( this->leftChild ){
287 node->setLeftChild( this->leftChild->
deepCopy() );
288 node->leftChild->setParent( node );
292 if( this->rightChild ){
293 node->setRightChild( rightChild->
deepCopy() );
294 node->rightChild->setParent( node );
313 return predictedNodeID;
316 UINT Node::getMaxDepth()
const {
318 UINT maxDepth = depth;
322 UINT maxLeftDepth = leftChild->getMaxDepth();
323 if( maxLeftDepth > maxDepth ){
324 maxDepth = maxLeftDepth;
330 UINT maxRightDepth = rightChild->getMaxDepth();
331 if( maxRightDepth > maxDepth ){
332 maxDepth = maxRightDepth;
344 return (parent != NULL);
348 return (leftChild != NULL);
352 return (rightChild != NULL);
355 bool Node::initNode(
Node *parent,
const UINT depth,
const UINT nodeID,
const bool isLeafNode){
356 this->parent = parent;
358 this->nodeID = nodeID;
359 this->isLeafNode = isLeafNode;
363 bool Node::setParent(
Node *parent){
364 this->parent = parent;
368 bool Node::setLeftChild(
Node *leftChild){
369 this->leftChild = leftChild;
373 bool Node::setRightChild(
Node *rightChild){
374 this->rightChild = rightChild;
378 bool Node::setDepth(
const UINT depth){
383 bool Node::setNodeID(
const UINT nodeID){
384 this->nodeID = nodeID;
388 bool Node::setIsLeafNode(
const bool isLeafNode){
389 this->isLeafNode = isLeafNode;
virtual bool getModel(std::ostream &stream) const override
virtual bool predict_(VectorFloat &x) override
std::string getNodeType() const
Node(const std::string id="Node")
bool getIsLeafNode() const
virtual bool loadParametersFromFile(std::fstream &file)
virtual bool saveParametersToFile(std::fstream &file) const
This class contains the main Node base class.
bool getHasParent() const
virtual bool save(std::fstream &file) const override
virtual bool computeFeatureWeights(VectorFloat &weights) const
Node * createNewInstance() const
UINT getPredictedNodeID() const
virtual bool print() const override
static Node * createInstanceFromString(std::string const &nodeType)
virtual bool load(std::fstream &file) override
virtual bool computeLeafNodeWeights(MatrixFloat &weights) const
virtual bool clear() override
This is the main base class that all GRT machine learning algorithms should inherit from...
std::map< std::string, Node *(*)() > StringNodeMap
bool getHasRightChild() const
bool getHasLeftChild() const
virtual Node * deepCopy() const