Nabla
1.0
Nabla - a DSL for Automatic differentiation
Transpiler
include
Node.h
Go to the documentation of this file.
1
#ifndef NODE_H
2
#define NODE_H
3
4
#include<string>
5
#include"
Tensor.h
"
6
7
namespace
nb
{
8
9
class
Node
{
10
public
:
11
std::string
name
;
12
void
print
();
13
int
count
;
14
bool
is_visited
=
false
;
15
bool
is_printed
=
false
;
16
bool
is_scalar
=
false
;
17
std::vector<Node*>
inputs
;
18
Tensor
data
;
19
double
ddata
;
//this stores the value of a scalar
20
Tensor
gradient
;
21
double
scalar_gradient
= 0;
//this stores the gradient with respect to a scalar
22
Node
();
23
Node
(
Tensor
&
data
);
24
Node
&
forward
(
const
Node
& a,
const
Node
& b);
25
virtual
void
backward
(){};
26
};
27
28
};
29
30
#endif
nb::Node::is_visited
bool is_visited
Definition:
Node.h:14
nb
Definition:
Dtypes.h:8
nb::Node::scalar_gradient
double scalar_gradient
Definition:
Node.h:21
nb::Node::Node
Node()
Definition:
Node.cpp:13
nb::Node::is_printed
bool is_printed
Definition:
Node.h:15
nb::Node::inputs
std::vector< Node * > inputs
Definition:
Node.h:17
nb::Node::print
void print()
Definition:
Node.cpp:4
nb::Node::data
Tensor data
Definition:
Node.h:18
nb::Node::is_scalar
bool is_scalar
Definition:
Node.h:16
nb::Tensor
Definition:
Tensor.h:12
nb::Node
Definition:
Node.h:9
nb::Node::ddata
double ddata
Definition:
Node.h:19
nb::Node::backward
virtual void backward()
Definition:
Node.h:25
Tensor.h
nb::Node::count
int count
Definition:
Node.h:13
nb::Node::name
std::string name
Definition:
Node.h:11
nb::Node::gradient
Tensor gradient
Definition:
Node.h:20
nb::Node::forward
Node & forward(const Node &a, const Node &b)
Generated by
1.8.17