Nabla  1.0
Nabla - a DSL for Automatic differentiation
ast.h
Go to the documentation of this file.
1 #ifndef AST_H
2 #define AST_H
3 
4 #include <iostream>
5 #include <cstring>
6 #include <vector>
7 #include <memory>
8 #include <optional>
9 #include <map>
10 #include <unordered_map>
11 #include "sym.h"
12 
13 
14 extern int yylineno, yycolumn;
15 
16 // Abstract Node class
17 class Node;
18 
19 // Start Class
20 class Start;
21 extern Start *root;
22 
23 // Classes that are a part of the Declare Section
24 class Decl;
25 enum class GradSpecifier
26 {
27  CNS,
28  VAR
29 };
30 
31 enum class TypeSpecifier
32 {
33  CHAR,
34  INT,
35  FLOAT,
36  BOOL,
37  TENSOR
38 };
39 
40 class ConstValue;
41 class InitDeclarator;
42 class Declarator;
43 class Initializer;
44 
45 // Operations class
46 class AssgnStmt;
47 enum class LibFuncs
48 {
49  SIN,
50  COS,
51  LOG,
52  EXP,
53  TRANSPOSE
54 };
56 {
57  AST_ASSIGN, // only =
63 };
64 
65 class Expr;
66 class BinaryExpr;
67 class UnaryExpr;
68 
69 // Gradient class
70 class GradStmt;
71 enum class GradType
72 {
73  GRAD,
74  BACKWARD,
75  PRINT
76 };
77 
78 class Node
79 {
80 
81 public:
82  Node();
83  virtual ~Node() = default;
84 
85  // virtual void print() = 0;
86  // int row_num, col_num;
87  // add codegen() function from llvm for IR gen
88 };
89 
90 // Start Class stores the pointers to all the declarations, expressions and gradients
91 class Start : public Node
92 {
93 public:
94  std::vector<class Decl *> *DeclList;
95  std::vector<class AssgnStmt *> *AssgnStmtList;
96  std::vector<class GradStmt *> *GradStmtList;
97  std::unordered_map<std::string, SymTabItem> *symbolTable;
98  Start(std::vector<class Decl *> *DeclList, std::vector<class AssgnStmt *> *AssgnStmtList, std::vector<class GradStmt *> *GradStmtList, std::unordered_map<std::string, SymTabItem> *symbolTable);
99  virtual ~Start() = default;
100  void transpile(std::ostream &out, int tab = 0) const;
101 };
102 
103 // Decl Class stores the declarations of a variable such as their gradient specifier, data type and the pointer to the initializer
104 // It also stores the initial value of the variable if it is initialized
105 class Decl : public Node
106 {
107 public:
112  virtual ~Decl() = default;
113  void transpile(std::ostream &out, int tab = 0) const;
114 };
115 
116 // InitDeclarator Class stores the pointer to the declarator and the pointer to the initializer
117 class InitDeclarator : public Node
118 {
119 public:
123  virtual ~InitDeclarator() = default;
124  void transpile(std::ostream &out, int tab = 0) const;
125 };
126 // Declarator Class stores the name of the variable and the dimensions of the variable
127 class Declarator : public Node
128 {
129 public:
130  std::string name;
131  std::vector<int> Dimensions;
132  Declarator(std::string);
133  virtual ~Declarator() = default;
134  void transpile(std::ostream &out, int tab = 0) const;
135 };
136 
137 class ConstValue : public Node
138 {
139 public:
141  {
142  int int_val;
143  float float_val;
144  };
146  bool isInt;
147  // ConstValue(int value, bool isInt = false);
148  ConstValue(int value);
149  ConstValue(float value);
150  virtual ~ConstValue() = default;
151 };
152 
153 // Initializer Class stores the value of the variable using a union structure
154 // It also stores the pointers to the initializers of the elements of the variable if it is an array
155 class Initializer : public Node
156 {
157 public:
159  {
161  std::vector<Initializer *> *InitializerList;
162  constexpr type_value() : cvalue(nullptr) {}
164  };
166  bool isScalar;
167  Initializer(ConstValue *value);
168  Initializer(std::vector<Initializer *> *InitializerList);
169  // Initializer(ConstValue*, std::vector<Initializer*>);
170 
171  void printInitializerList();
172  virtual ~Initializer() = default;
173  void transpile(std::ostream &out, int tab = 0) const;
174 };
175 
176 // Operations
177 class AssgnStmt : public Node
178 {
179 public:
181  // Declarator *declarator;
182  // Initializer *initializer;
183  // AssgnStmt(Declarator*, Initializer*);
184  std::string name;
185  std::optional<AssignmentOperator> op;
187  AssgnStmt(std::string, std::optional<AssignmentOperator>, Expr *);
188  virtual ~AssgnStmt() = default;
189  void transpile(std::ostream &out, int tab = 0) const;
190 };
191 
192 class Expr : public Node
193 {
194 public:
195  // below two to be initialized after ast generation and before semantic analysis of operations part
196  std::vector<int> dimensions;
198 
199  Expr();
200  virtual void printExpression();
201  virtual void initialize_expression_node_info(std::unordered_map<std::string, SymTabItem> *symbolTable);
202  virtual ~Expr() = default;
203  virtual void transpile(std::ostream &out, int tab = 0) const;
204 };
205 
206 class BinaryExpr : public Expr
207 {
208 public:
210  char op;
211  BinaryExpr(Expr *lhs, Expr *rhs, char op);
212  virtual ~BinaryExpr() = default;
213  virtual void printExpression() override;
214  virtual void initialize_expression_node_info(std::unordered_map<std::string, SymTabItem> *symbolTable) override;
215  void transpile(std::ostream &out, int tab = 0) const override;
216 };
217 
218 class UnaryExpr : public Expr
219 {
220 public:
221  Expr *expr; // expr for libfunc if present
222  std::string identifier;
224  std::optional<LibFuncs> libfunc;
225 
226  UnaryExpr(Expr *expr, std::optional<LibFuncs> libfunc, std::string identifier, ConstValue *cvalue);
227  virtual ~UnaryExpr() = default;
228  virtual void printExpression() override;
229  virtual void initialize_expression_node_info(std::unordered_map<std::string, SymTabItem> *symbolTable) override;
230  void transpile(std::ostream &out, int tab = 0) const override;
231 };
232 
233 // Gradient
234 class GradStmt : public Node
235 {
236 public:
237  // Declarator *declarator;
238  // Initializer *initializer;
239  // GradStmt(Declarator*, Initializer*);
241  std::string name;
242  GradStmt(GradType, std::string);
243  virtual ~GradStmt() = default;
244  void transpile(std::ostream &out, int tab = 0) const;
245 };
246 
247 #endif
TypeSpecifier::BOOL
@ BOOL
Decl::transpile
void transpile(std::ostream &out, int tab=0) const
Definition: ast.cpp:544
ConstValue::isInt
bool isInt
Definition: ast.h:146
AssignmentOperator
AssignmentOperator
Definition: ast.h:55
GradType::GRAD
@ GRAD
AssignmentOperator::AST_AT_ASSIGN
@ AST_AT_ASSIGN
Initializer::type_value
Definition: ast.h:158
UnaryExpr::libfunc
std::optional< LibFuncs > libfunc
Definition: ast.h:224
InitDeclarator::transpile
void transpile(std::ostream &out, int tab=0) const
Definition: ast.cpp:576
GradType::PRINT
@ PRINT
Initializer::type_value::~type_value
~type_value()
Definition: ast.h:163
AssgnStmt::col_num
int col_num
Definition: ast.h:180
UnaryExpr::UnaryExpr
UnaryExpr(Expr *expr, std::optional< LibFuncs > libfunc, std::string identifier, ConstValue *cvalue)
Definition: ast.cpp:309
GradStmt::transpile
void transpile(std::ostream &out, int tab=0) const
Definition: ast.cpp:683
AssgnStmt::op
std::optional< AssignmentOperator > op
Definition: ast.h:185
Expr::~Expr
virtual ~Expr()=default
Decl::~Decl
virtual ~Decl()=default
AssgnStmt::name
std::string name
Definition: ast.h:184
BinaryExpr::rhs
Expr * rhs
Definition: ast.h:209
sym.h
InitDeclarator::~InitDeclarator
virtual ~InitDeclarator()=default
Declarator
Definition: ast.h:127
ConstValue::value
inbuilt_type value
Definition: ast.h:145
AssgnStmt::AssgnStmt
AssgnStmt(std::string, std::optional< AssignmentOperator >, Expr *)
Definition: ast.cpp:103
Initializer
Definition: ast.h:155
AssgnStmt::transpile
void transpile(std::ostream &out, int tab=0) const
Definition: ast.cpp:658
ConstValue::inbuilt_type::int_val
int int_val
Definition: ast.h:142
Initializer::Initializer
Initializer(ConstValue *value)
Definition: ast.cpp:62
Declarator::transpile
void transpile(std::ostream &out, int tab=0) const
Definition: ast.cpp:597
Start::DeclList
std::vector< class Decl * > * DeclList
Definition: ast.h:94
InitDeclarator::initializer
Initializer * initializer
Definition: ast.h:121
GradStmt::name
std::string name
Definition: ast.h:241
Expr
Definition: ast.h:192
UnaryExpr::expr
Expr * expr
Definition: ast.h:221
root
Start * root
GradStmt::GradStmt
GradStmt(GradType, std::string)
Definition: ast.cpp:472
BinaryExpr::BinaryExpr
BinaryExpr(Expr *lhs, Expr *rhs, char op)
Definition: ast.cpp:121
TypeSpecifier::FLOAT
@ FLOAT
AssignmentOperator::AST_DIV_ASSIGN
@ AST_DIV_ASSIGN
Expr::initialize_expression_node_info
virtual void initialize_expression_node_info(std::unordered_map< std::string, SymTabItem > *symbolTable)
Definition: ast.cpp:119
ConstValue::inbuilt_type
Definition: ast.h:140
Expr::printExpression
virtual void printExpression()
Definition: ast.cpp:117
GradSpecifier::CNS
@ CNS
Initializer::val
type_value val
Definition: ast.h:165
LibFuncs::TRANSPOSE
@ TRANSPOSE
Declarator::~Declarator
virtual ~Declarator()=default
AssgnStmt::expr
Expr * expr
Definition: ast.h:186
Expr::DataType
TypeSpecifier DataType
Definition: ast.h:197
ConstValue
Definition: ast.h:137
Start::~Start
virtual ~Start()=default
UnaryExpr::identifier
std::string identifier
Definition: ast.h:222
Initializer::transpile
void transpile(std::ostream &out, int tab=0) const
Definition: ast.cpp:619
yylineno
int yylineno
LibFuncs::LOG
@ LOG
ConstValue::ConstValue
ConstValue(int value)
Definition: ast.cpp:50
Initializer::~Initializer
virtual ~Initializer()=default
Initializer::type_value::cvalue
ConstValue * cvalue
Definition: ast.h:160
BinaryExpr::op
char op
Definition: ast.h:210
GradStmt::grad_type
GradType grad_type
Definition: ast.h:240
Start::AssgnStmtList
std::vector< class AssgnStmt * > * AssgnStmtList
Definition: ast.h:95
LibFuncs
LibFuncs
Definition: ast.h:47
Decl::GradType
GradSpecifier GradType
Definition: ast.h:108
ConstValue::inbuilt_type::float_val
float float_val
Definition: ast.h:143
GradStmt::~GradStmt
virtual ~GradStmt()=default
LibFuncs::EXP
@ EXP
GradType::BACKWARD
@ BACKWARD
GradStmt
Definition: ast.h:234
Expr::transpile
virtual void transpile(std::ostream &out, int tab=0) const
Definition: ast.cpp:679
Decl
Definition: ast.h:105
UnaryExpr::~UnaryExpr
virtual ~UnaryExpr()=default
GradSpecifier
GradSpecifier
Definition: ast.h:25
Decl::DataType
TypeSpecifier DataType
Definition: ast.h:109
AssignmentOperator::AST_MUL_ASSIGN
@ AST_MUL_ASSIGN
UnaryExpr
Definition: ast.h:218
BinaryExpr
Definition: ast.h:206
TypeSpecifier::TENSOR
@ TENSOR
InitDeclarator
Definition: ast.h:117
Expr::Expr
Expr()
Definition: ast.cpp:112
Initializer::type_value::type_value
constexpr type_value()
Definition: ast.h:162
Initializer::printInitializerList
void printInitializerList()
Definition: ast.cpp:74
InitDeclarator::InitDeclarator
InitDeclarator(Declarator *, Initializer *)
Definition: ast.cpp:38
TypeSpecifier::CHAR
@ CHAR
InitDeclarator::declarator
Declarator * declarator
Definition: ast.h:120
Node
Definition: ast.h:78
AssgnStmt
Definition: ast.h:177
BinaryExpr::transpile
void transpile(std::ostream &out, int tab=0) const override
Definition: ast.cpp:299
AssgnStmt::~AssgnStmt
virtual ~AssgnStmt()=default
LibFuncs::COS
@ COS
Initializer::isScalar
bool isScalar
Definition: ast.h:166
Declarator::Declarator
Declarator(std::string)
Definition: ast.cpp:44
Node::Node
Node()
Definition: ast.cpp:18
GradType
GradType
Definition: ast.h:71
BinaryExpr::initialize_expression_node_info
virtual void initialize_expression_node_info(std::unordered_map< std::string, SymTabItem > *symbolTable) override
Definition: ast.cpp:158
Initializer::type_value::InitializerList
std::vector< Initializer * > * InitializerList
Definition: ast.h:161
Start::transpile
void transpile(std::ostream &out, int tab=0) const
Definition: ast.cpp:513
UnaryExpr::printExpression
virtual void printExpression() override
Definition: ast.cpp:317
Start
Definition: ast.h:91
AssignmentOperator::AST_ADD_ASSIGN
@ AST_ADD_ASSIGN
BinaryExpr::~BinaryExpr
virtual ~BinaryExpr()=default
AssignmentOperator::AST_ASSIGN
@ AST_ASSIGN
AssignmentOperator::AST_SUB_ASSIGN
@ AST_SUB_ASSIGN
TypeSpecifier
TypeSpecifier
Definition: ast.h:31
Decl::Decl
Decl(GradSpecifier, TypeSpecifier, InitDeclarator *)
Definition: ast.cpp:31
UnaryExpr::initialize_expression_node_info
virtual void initialize_expression_node_info(std::unordered_map< std::string, SymTabItem > *symbolTable) override
Definition: ast.cpp:357
Start::symbolTable
std::unordered_map< std::string, SymTabItem > * symbolTable
Definition: ast.h:97
Declarator::Dimensions
std::vector< int > Dimensions
Definition: ast.h:131
LibFuncs::SIN
@ SIN
GradSpecifier::VAR
@ VAR
UnaryExpr::transpile
void transpile(std::ostream &out, int tab=0) const override
Definition: ast.cpp:421
Start::Start
Start(std::vector< class Decl * > *DeclList, std::vector< class AssgnStmt * > *AssgnStmtList, std::vector< class GradStmt * > *GradStmtList, std::unordered_map< std::string, SymTabItem > *symbolTable)
Definition: ast.cpp:23
Node::~Node
virtual ~Node()=default
yycolumn
int yycolumn
Definition: ast.h:14
BinaryExpr::printExpression
virtual void printExpression() override
Definition: ast.cpp:149
Declarator::name
std::string name
Definition: ast.h:130
Start::GradStmtList
std::vector< class GradStmt * > * GradStmtList
Definition: ast.h:96
ConstValue::~ConstValue
virtual ~ConstValue()=default
BinaryExpr::lhs
Expr * lhs
Definition: ast.h:209
Expr::dimensions
std::vector< int > dimensions
Definition: ast.h:196
UnaryExpr::cvalue
ConstValue * cvalue
Definition: ast.h:223
AssgnStmt::row_num
int row_num
Definition: ast.h:180
Decl::InitDeclaratorList
InitDeclarator * InitDeclaratorList
Definition: ast.h:110
TypeSpecifier::INT
@ INT