Nabla  1.0
Nabla - a DSL for Automatic differentiation
Node.h
Go to the documentation of this file.
1 #ifndef NODE_H
2 #define NODE_H
3 
4 #include<string>
5 #include"Tensor.h"
6 
7 namespace nb{
8 
9 class Node{
10  public:
11  std::string name;
12  void print();
13  int count;
14  bool is_visited = false;
15  bool is_printed = false;
16  bool is_scalar = false;
17  std::vector<Node*> inputs;
19  double ddata; //this stores the value of a scalar
21  double scalar_gradient = 0; //this stores the gradient with respect to a scalar
22  Node();
23  Node(Tensor& data);
24  Node& forward(const Node& a, const Node& b);
25  virtual void backward(){};
26 };
27 
28 };
29 
30 #endif
nb::Node::is_visited
bool is_visited
Definition: Node.h:14
nb
Definition: Dtypes.h:8
nb::Node::scalar_gradient
double scalar_gradient
Definition: Node.h:21
nb::Node::Node
Node()
Definition: Node.cpp:13
nb::Node::is_printed
bool is_printed
Definition: Node.h:15
nb::Node::inputs
std::vector< Node * > inputs
Definition: Node.h:17
nb::Node::print
void print()
Definition: Node.cpp:4
nb::Node::data
Tensor data
Definition: Node.h:18
nb::Node::is_scalar
bool is_scalar
Definition: Node.h:16
nb::Tensor
Definition: Tensor.h:12
nb::Node
Definition: Node.h:9
nb::Node::ddata
double ddata
Definition: Node.h:19
nb::Node::backward
virtual void backward()
Definition: Node.h:25
Tensor.h
nb::Node::count
int count
Definition: Node.h:13
nb::Node::name
std::string name
Definition: Node.h:11
nb::Node::gradient
Tensor gradient
Definition: Node.h:20
nb::Node::forward
Node & forward(const Node &a, const Node &b)