Nabla  1.0
Nabla - a DSL for Automatic differentiation
Operators.h
Go to the documentation of this file.
1 #ifndef OP_H
2 #define OP_H
3 #include "Node.h"
4 #include <vector>
5 
6 namespace nb{
7 
8 class Operator: public Node{
9  public:
10  virtual void backward(){};
12  inputs.resize(0);
13  }
14 };
15 
16 class Add : public Operator{
17  public:
18  void backward();
19  Add(Node* a , Node* b, int count);
20  Node* forward(const Node* a, const Node* b);
21 };
22 
23 class Sub : public Operator{
24  public:
25  void backward();
26  Sub(Node* a , Node* b, int count);
27  Node* forward(const Node* a, const Node* b);
28 };
29 
30 class Multiply : public Operator{
31  public:
32  void backward();
33  Multiply(Node* a , Node* b , int count);
34  Node* forward(const Node* a, const Node* b);
35  // std::pair<Tensor, Tensor> backward(Tensor dout);
36 };
37 
38 class Mul : public Operator{
39  public:
40  void backward();
41  Mul(Node* a , Node* b , int count);
42  Node* forward(const Node* a, const Node* b);
43  // std::pair<Tensor, Tensor> backward(Tensor dout);
44 };
45 
46 class Transpose : public Operator{
47  public:
48  void backward();
49  Transpose(Node* a, int count);
50  Node* forward(const Node* a);
51 };
52 
53 class Negative : public Operator{
54  public:
55  void backward();
56  Negative(Node* a, int count);
57  Node* forward(const Node* a);
58 };
59 
60 class Exponential : public Operator{
61  public:
62  void backward();
63  Exponential(Node* a, int count);
64  Node* forward(const Node* a);
65 };
66 
67 class Division : public Operator{
68  public:
69  void backward();
70  Division(Node* a, Node* b, int count);
71  Node* forward(const Node* a, const Node* b);
72 };
73 
74 class Sin : public Operator{
75  public:
76  void backward();
77  Sin(Node* a, int count);
78  Node* forward(const Node* a);
79 };
80 
81 class Cos : public Operator{
82  public:
83  void backward();
84  Cos(Node* a, int count);
85  Node* forward(const Node* a);
86 };
87 
88 class Tan : public Operator{
89  public:
90  void backward();
91  Tan(Node* a, int count);
92  Node* forward(const Node* a);
93 };
94 
95 };
96 
97 #endif
nb::Sub::forward
Node * forward(const Node *a, const Node *b)
Definition: Operators.cpp:45
nb::Transpose
Definition: Operators.h:46
nb::Negative::Negative
Negative(Node *a, int count)
Definition: Operators.cpp:21
nb
Definition: Dtypes.h:8
nb::Negative::forward
Node * forward(const Node *a)
Definition: Operators.cpp:28
nb::Cos::Cos
Cos(Node *a, int count)
Definition: Operators.cpp:320
nb::Tan::backward
void backward()
Definition: Operators.cpp:383
nb::Mul
Definition: Operators.h:38
nb::Cos::backward
void backward()
Definition: Operators.cpp:344
nb::Exponential::forward
Node * forward(const Node *a)
Definition: Operators.cpp:256
nb::Tan::forward
Node * forward(const Node *a)
Definition: Operators.cpp:366
nb::Sin::forward
Node * forward(const Node *a)
Definition: Operators.cpp:288
nb::Mul::Mul
Mul(Node *a, Node *b, int count)
Definition: Operators.cpp:147
nb::Node::inputs
std::vector< Node * > inputs
Definition: Node.h:17
nb::Sin::Sin
Sin(Node *a, int count)
Definition: Operators.cpp:281
nb::Add::backward
void backward()
Definition: Operators.cpp:102
nb::Add::forward
Node * forward(const Node *a, const Node *b)
Definition: Operators.cpp:83
nb::Exponential::backward
void backward()
Definition: Operators.cpp:273
nb::Operator::backward
virtual void backward()
Definition: Operators.h:10
nb::Sub
Definition: Operators.h:23
nb::Multiply
Definition: Operators.h:30
nb::Exponential::Exponential
Exponential(Node *a, int count)
Definition: Operators.cpp:249
nb::Add
Definition: Operators.h:16
nb::Mul::forward
Node * forward(const Node *a, const Node *b)
Definition: Operators.cpp:155
nb::Sin::backward
void backward()
Definition: Operators.cpp:305
nb::Division::Division
Division(Node *a, Node *b, int count)
Definition: Operators.cpp:219
nb::Add::Add
Add(Node *a, Node *b, int count)
Definition: Operators.cpp:74
nb::Sub::Sub
Sub(Node *a, Node *b, int count)
Definition: Operators.cpp:37
nb::Sin
Definition: Operators.h:74
nb::Node
Definition: Node.h:9
Node.h
nb::Division::forward
Node * forward(const Node *a, const Node *b)
Definition: Operators.cpp:227
nb::Sub::backward
void backward()
Definition: Operators.cpp:63
nb::Mul::backward
void backward()
Definition: Operators.cpp:195
nb::Negative
Definition: Operators.h:53
nb::Division::backward
void backward()
Definition: Operators.cpp:239
nb::Transpose::forward
Node * forward(const Node *a)
Definition: Operators.cpp:12
nb::Division
Definition: Operators.h:67
nb::Multiply::forward
Node * forward(const Node *a, const Node *b)
Definition: Operators.cpp:124
nb::Multiply::Multiply
Multiply(Node *a, Node *b, int count)
Definition: Operators.cpp:116
nb::Operator
Definition: Operators.h:8
nb::Node::count
int count
Definition: Node.h:13
nb::Cos
Definition: Operators.h:81
nb::Tan::Tan
Tan(Node *a, int count)
Definition: Operators.cpp:359
nb::Operator::Operator
Operator()
Definition: Operators.h:11
nb::Negative::backward
void backward()
Definition: Operators.cpp:33
nb::Tan
Definition: Operators.h:88
nb::Exponential
Definition: Operators.h:60
nb::Transpose::Transpose
Transpose(Node *a, int count)
Definition: Operators.cpp:5
nb::Transpose::backward
void backward()
Definition: Operators.cpp:17
nb::Cos::forward
Node * forward(const Node *a)
Definition: Operators.cpp:327
nb::Multiply::backward
void backward()
Definition: Operators.cpp:141