Nabla  1.0
Nabla - a DSL for Automatic differentiation
Graph.h
Go to the documentation of this file.
1 #ifndef GRAPH_H
2 #define GRAPH_H
3 
4 #include "Operators.h"
5 #include "Dtypes.h"
6 #include <string>
7 #include <algorithm>
8 #include <fstream>
9 #include <iostream>
10 
11 
12 
13 namespace nb{
14 
15 class Graph{
16  public:
17  int count; //making this static causes issues
18  std::vector<Node*> operators, constants, variables, scalars;
19  Graph();
20  Node* _variable(std::string name, int m, int n, std::vector<std::vector<double>> vals);
21  Node* _variable(std::string name, int m, int n);
22  Node* _constant(std::string name, int m, int n, std::vector<std::vector<double>> vals);
23  Node* _scalar_variable(std::string name, double data=0);
24  Node* _scalar_constant(std::string name, double data);
25 
26  Node* _add(Node* a, Node* b);
27  Node* _add(Node* a, double b);
28  Node* _add(double a, Node* b);
29  Node* _add(double a, double b);
30 
31  Node* _sub(Node* a, Node* b);
32  Node* _sub(Node* a, double b);
33  Node* _sub(double a, Node* b);
34  Node* _sub(double a, double b);
35 
36  Node* _matmul(Node* a, Node* b);
37  Node* _trans(Node* a);
38 
39  Node* _neg(Node* a);
40  Node* _exp(Node* a);
41  Node* _sin(Node* a);
42  Node* _cos(Node* a);
43  Node* _tan(Node* a);
44 
45  Node* _mul(Node*a , Node*b);
46  Node* _mul(double a, Node*b);
47  Node* _mul(Node*a , double b);
48  Node* _mul(double a, double b);
49 
50  Node* _div(double a, double b);
51 
52 
53  std::vector<Node*> topological_sort();
54  void backward(Node* f);
55  void generate_graph(Node* f);
56  void DFS(std::ostream& out, Node* f);
57 };
58 
59 };
60 
61 #endif
nb::Graph
Definition: Graph.h:15
nb::Graph::Graph
Graph()
Definition: Graph.cpp:5
nb
Definition: Dtypes.h:8
nb::Graph::_neg
Node * _neg(Node *a)
Definition: Graph.cpp:143
nb::Graph::backward
void backward(Node *f)
Definition: Graph.cpp:222
nb::Graph::constants
std::vector< Node * > constants
Definition: Graph.h:18
nb::Graph::generate_graph
void generate_graph(Node *f)
Definition: Graph.cpp:267
nb::Graph::_scalar_variable
Node * _scalar_variable(std::string name, double data=0)
Definition: Graph.cpp:202
nb::Graph::_sub
Node * _sub(Node *a, Node *b)
Definition: Graph.cpp:61
nb::Graph::_mul
Node * _mul(Node *a, Node *b)
Definition: Graph.cpp:101
nb::Graph::variables
std::vector< Node * > variables
Definition: Graph.h:18
nb::Graph::DFS
void DFS(std::ostream &out, Node *f)
Definition: Graph.cpp:256
nb::Graph::count
int count
Definition: Graph.h:17
nb::Graph::_matmul
Node * _matmul(Node *a, Node *b)
Definition: Graph.cpp:14
nb::Graph::_sin
Node * _sin(Node *a)
Definition: Graph.cpp:157
nb::Graph::operators
std::vector< Node * > operators
Definition: Graph.h:18
nb::Node
Definition: Node.h:9
nb::Graph::topological_sort
std::vector< Node * > topological_sort()
Definition: Graph.cpp:216
nb::Graph::_constant
Node * _constant(std::string name, int m, int n, std::vector< std::vector< double >> vals)
Definition: Graph.cpp:194
nb::Graph::_div
Node * _div(double a, double b)
Definition: Graph.cpp:129
nb::Graph::_exp
Node * _exp(Node *a)
Definition: Graph.cpp:150
nb::Graph::_add
Node * _add(Node *a, Node *b)
Definition: Graph.cpp:21
nb::Graph::_variable
Node * _variable(std::string name, int m, int n, std::vector< std::vector< double >> vals)
Definition: Graph.cpp:178
nb::Graph::_trans
Node * _trans(Node *a)
Definition: Graph.cpp:136
nb::Graph::_scalar_constant
Node * _scalar_constant(std::string name, double data)
Definition: Graph.cpp:209
nb::Graph::scalars
std::vector< Node * > scalars
Definition: Graph.h:18
nb::Graph::_cos
Node * _cos(Node *a)
Definition: Graph.cpp:164
Dtypes.h
Operators.h
nb::Graph::_tan
Node * _tan(Node *a)
Definition: Graph.cpp:171