Nabla  1.0
Nabla - a DSL for Automatic differentiation
Dtypes.h
Go to the documentation of this file.
1 #ifndef DTYPE_H
2 #define DTYPE_H
3 
4 #include "Node.h"
5 #include <string>
6 #include <vector>
7 
8 namespace nb{
9 class Variable: public Node{
10  public:
11  Variable(int m , int n, std::vector<std::vector<double>> vals, int count, std::string name){
12  data = Tensor(m, n, vals);
13  gradient = Tensor(m, n);
14  this->count = count;
15  this->name = name;
16  }
17 };
18 
19 class Constant: public Node{
20  public:
21  Constant(int m, int n, std::vector<std::vector<double>> vals, int count, std::string name){
22  data = Tensor(m, n, vals);
23  this->count = count;
24  this->name = name;
25  }
26 
27 };
28 
29 class Scalar: public Node{
30  public:
31  Scalar(double d_data, int count){
32  // this->ddata = d_data;
33  // scl_count = count;
34  // name = "Scalar:" + std::to_string(scl_count);
35  }
36 };
37 
38 class Scalar_Variable : public Scalar{
39  public:
40  Scalar_Variable(std::string name, double d_data, int count): Scalar(d_data, count){
41  this->ddata = d_data;
42  this->count = count;
43  is_scalar = true;
44  this->name = name;
45  }
46 };
47 
48 class Scalar_Constant : public Scalar{
49  public:
50  Scalar_Constant(std::string name, double d_data, int count): Scalar(d_data, count){
51  this->ddata = d_data; //this innitializes a scalar constant
52  this->count = count;
53  is_scalar = true;
54  this->name = name;
55  }
56 };
57 
58 };
59 #endif
nb
Definition: Dtypes.h:8
nb::Scalar::Scalar
Scalar(double d_data, int count)
Definition: Dtypes.h:31
nb::Constant::Constant
Constant(int m, int n, std::vector< std::vector< double >> vals, int count, std::string name)
Definition: Dtypes.h:21
nb::Scalar_Variable
Definition: Dtypes.h:38
nb::Node::data
Tensor data
Definition: Node.h:18
nb::Node::is_scalar
bool is_scalar
Definition: Node.h:16
nb::Scalar_Constant
Definition: Dtypes.h:48
nb::Tensor
Definition: Tensor.h:12
nb::Node
Definition: Node.h:9
Node.h
nb::Node::ddata
double ddata
Definition: Node.h:19
nb::Scalar
Definition: Dtypes.h:29
nb::Constant
Definition: Dtypes.h:19
nb::Scalar_Variable::Scalar_Variable
Scalar_Variable(std::string name, double d_data, int count)
Definition: Dtypes.h:40
nb::Variable
Definition: Dtypes.h:9
nb::Node::count
int count
Definition: Node.h:13
nb::Variable::Variable
Variable(int m, int n, std::vector< std::vector< double >> vals, int count, std::string name)
Definition: Dtypes.h:11
nb::Node::name
std::string name
Definition: Node.h:11
nb::Node::gradient
Tensor gradient
Definition: Node.h:20
nb::Scalar_Constant::Scalar_Constant
Scalar_Constant(std::string name, double d_data, int count)
Definition: Dtypes.h:50