Nabla
1.0
Nabla - a DSL for Automatic differentiation
Transpiler
include
Dtypes.h
Go to the documentation of this file.
1
#ifndef DTYPE_H
2
#define DTYPE_H
3
4
#include "
Node.h
"
5
#include <string>
6
#include <vector>
7
8
namespace
nb
{
9
class
Variable
:
public
Node
{
10
public
:
11
Variable
(
int
m ,
int
n, std::vector<std::vector<double>> vals,
int
count
, std::string
name
){
12
data
=
Tensor
(m, n, vals);
13
gradient
=
Tensor
(m, n);
14
this->count =
count
;
15
this->name =
name
;
16
}
17
};
18
19
class
Constant
:
public
Node
{
20
public
:
21
Constant
(
int
m,
int
n, std::vector<std::vector<double>> vals,
int
count
, std::string
name
){
22
data
=
Tensor
(m, n, vals);
23
this->count =
count
;
24
this->name =
name
;
25
}
26
27
};
28
29
class
Scalar
:
public
Node
{
30
public
:
31
Scalar
(
double
d_data,
int
count
){
32
// this->ddata = d_data;
33
// scl_count = count;
34
// name = "Scalar:" + std::to_string(scl_count);
35
}
36
};
37
38
class
Scalar_Variable
:
public
Scalar
{
39
public
:
40
Scalar_Variable
(std::string
name
,
double
d_data,
int
count
):
Scalar
(d_data,
count
){
41
this->
ddata
= d_data;
42
this->count =
count
;
43
is_scalar
=
true
;
44
this->name =
name
;
45
}
46
};
47
48
class
Scalar_Constant
:
public
Scalar
{
49
public
:
50
Scalar_Constant
(std::string
name
,
double
d_data,
int
count
):
Scalar
(d_data,
count
){
51
this->
ddata
= d_data;
//this innitializes a scalar constant
52
this->count =
count
;
53
is_scalar
=
true
;
54
this->name =
name
;
55
}
56
};
57
58
};
59
#endif
nb
Definition:
Dtypes.h:8
nb::Scalar::Scalar
Scalar(double d_data, int count)
Definition:
Dtypes.h:31
nb::Constant::Constant
Constant(int m, int n, std::vector< std::vector< double >> vals, int count, std::string name)
Definition:
Dtypes.h:21
nb::Scalar_Variable
Definition:
Dtypes.h:38
nb::Node::data
Tensor data
Definition:
Node.h:18
nb::Node::is_scalar
bool is_scalar
Definition:
Node.h:16
nb::Scalar_Constant
Definition:
Dtypes.h:48
nb::Tensor
Definition:
Tensor.h:12
nb::Node
Definition:
Node.h:9
Node.h
nb::Node::ddata
double ddata
Definition:
Node.h:19
nb::Scalar
Definition:
Dtypes.h:29
nb::Constant
Definition:
Dtypes.h:19
nb::Scalar_Variable::Scalar_Variable
Scalar_Variable(std::string name, double d_data, int count)
Definition:
Dtypes.h:40
nb::Variable
Definition:
Dtypes.h:9
nb::Node::count
int count
Definition:
Node.h:13
nb::Variable::Variable
Variable(int m, int n, std::vector< std::vector< double >> vals, int count, std::string name)
Definition:
Dtypes.h:11
nb::Node::name
std::string name
Definition:
Node.h:11
nb::Node::gradient
Tensor gradient
Definition:
Node.h:20
nb::Scalar_Constant::Scalar_Constant
Scalar_Constant(std::string name, double d_data, int count)
Definition:
Dtypes.h:50
Generated by
1.8.17