summaryrefslogtreecommitdiff
path: root/com.c
diff options
context:
space:
mode:
authorubq323 <ubq323@ubq323.website>2024-06-29 13:12:34 +0100
committerubq323 <ubq323@ubq323.website>2024-06-29 13:12:34 +0100
commit83382cb1b46eb17f94f16fbbf05b5e471284d797 (patch)
tree73e8124725984a8c883c158ae50e61693924df9f /com.c
parentb51136defc2898c868e4a1b60025d5bb57347662 (diff)
add proper tail calls
Diffstat (limited to 'com.c')
-rw-r--r--com.c90
1 files changed, 55 insertions, 35 deletions
diff --git a/com.c b/com.c
index 23abeaa..ff13815 100644
--- a/com.c
+++ b/com.c
@@ -1,6 +1,7 @@
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
+#include <stdbool.h>
#include "com.h"
#include "mem.h"
@@ -61,6 +62,7 @@ static int stack_effect_of(Op opcode) {
// these ones depend on their argument. handle them specifically
case OP_CALL:
+ case OP_TAILCALL:
case OP_ENDSCOPE:
return 0;
default:
@@ -114,6 +116,12 @@ static void compile_call_instr(Compiler *C, uint8_t len) {
C->stack_cur -= len - 1;
}
+static void compile_tailcall_instr(Compiler *C, uint8_t len) {
+ compile_opcode(C, OP_TAILCALL);
+ compile_byte(C, len);
+ C->stack_cur -= len - 1;
+}
+
static void compile_endscope_instr(Compiler *C, uint8_t nlocals) {
if (nlocals > 0) {
compile_opcode(C, OP_ENDSCOPE);
@@ -135,53 +143,58 @@ static void patch(Compiler *C, size_t addr, uint16_t val) {
// ----
-static void compile_node(Compiler *C, AstNode a, bool toplevel);
+enum flags {
+ F_toplevel = 1,
+ F_tail = 2,
+};
+
+static void compile_node(Compiler *C, AstNode a, int flags);
static void begin_scope(Compiler *C);
static void end_scope(Compiler *C);
-typedef void (*form_compiler)(Compiler *C, AstVec l, Op op, bool toplevel);
+typedef void (*form_compiler)(Compiler *C, AstVec l, Op op, int flags);
-static void compile_body(Compiler *C, AstVec l, int startat) {
+static void compile_body(Compiler *C, AstVec l, int startat, int flags) {
begin_scope(C);
for (int i = startat; i < l.len - 1; i++) {
- compile_node(C, l.vals[i], true);
+ compile_node(C, l.vals[i], F_toplevel);
compile_opcode(C, OP_DROP);
}
- compile_node(C, l.vals[l.len - 1], true);
+ compile_node(C, l.vals[l.len - 1], F_toplevel | (flags & F_tail));
end_scope(C);
}
void single_form(Compiler *C, AstVec l, Op op) {
- compile_node(C, l.vals[1], false);
+ compile_node(C, l.vals[1], 0);
compile_opcode(C, op);
}
static Local *locate_local(Compiler *C, char *name);
-void set_form(Compiler *C, AstVec l, Op _, bool __) {
+void set_form(Compiler *C, AstVec l, Op _, int flags) {
AstNode ident = l.vals[1];
CHECK(ident.ty == AST_IDENT, "set's first argument must be identifier");
char *name = ident.as.str;
Local *loc = locate_local(C, name);
if (loc != NULL) {
- compile_node(C, l.vals[2], false);
+ compile_node(C, l.vals[2], 0);
compile_opcode(C, OP_SETLOCAL);
compile_byte(C, loc->slot);
} else {
// write global
ObjString *o = objstring_copy_cstr(C->S, name);
- compile_node(C, l.vals[2], false);
+ compile_node(C, l.vals[2], 0);
compile_opcode(C, OP_SETGLOBAL);
compile_byte(C, compile_constant(C, VAL_OBJ(o)));
}
}
-void do_form(Compiler *C, AstVec l, Op _, bool __) {
- compile_body(C, l, 1);
+void do_form(Compiler *C, AstVec l, Op _, int flags) {
+ compile_body(C, l, 1, flags);
}
-void if_form(Compiler *C, AstVec l, Op _, bool __) {
+void if_form(Compiler *C, AstVec l, Op _, int flags) {
// (if cond if-true if-false)
// cond
// 0branch ->A
@@ -190,12 +203,15 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) {
// A: if-false
// B:
+ // never toplevel
+ int downflags = flags & ~F_toplevel;
+
int orig_stack_cur = C->stack_cur;
- compile_node(C, l.vals[1], false);
+ compile_node(C, l.vals[1], 0);
compile_opcode(C, OP_0BRANCH);
size_t ph_a = placeholder(C);
- compile_node(C, l.vals[2], false);
+ compile_node(C, l.vals[2], downflags);
compile_opcode(C, OP_SKIP);
size_t ph_b = placeholder(C);
@@ -203,7 +219,7 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) {
int stack_cur_a = C->stack_cur;
C->stack_cur = orig_stack_cur;
- compile_node(C, l.vals[3], false);
+ compile_node(C, l.vals[3], downflags);
size_t dest_b = BYTECODE(C).len;
int stack_cur_b = C->stack_cur;
@@ -215,7 +231,7 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) {
CHECK(stack_cur_a == orig_stack_cur + 1, "this should never happen");
}
-void while_form(Compiler *C, AstVec l, Op _, bool __) {
+void while_form(Compiler *C, AstVec l, Op _, int flags) {
// (while cond body ...)
// A:
// cond
@@ -225,10 +241,10 @@ void while_form(Compiler *C, AstVec l, Op _, bool __) {
// B:
// nil (while loop always returns nil)
size_t dest_a = BYTECODE(C).len;
- compile_node(C, l.vals[1], false);
+ compile_node(C, l.vals[1], 0);
compile_opcode(C, OP_0BRANCH);
size_t ph_b = placeholder(C);
- compile_body(C, l, 2);
+ compile_body(C, l, 2, flags & ~F_tail);
compile_opcode(C, OP_DROP);
compile_opcode(C, OP_REDO);
size_t ph_a = placeholder(C);
@@ -239,9 +255,9 @@ void while_form(Compiler *C, AstVec l, Op _, bool __) {
patch(C, ph_b, dest_b - ph_b - 2);
}
-void arith_form(Compiler *C, AstVec l, Op op, bool __) {
- compile_node(C, l.vals[1], false);
- compile_node(C, l.vals[2], false);
+void arith_form(Compiler *C, AstVec l, Op op, int flags) {
+ compile_node(C, l.vals[1], 0);
+ compile_node(C, l.vals[2], 0);
compile_opcode(C, op);
}
@@ -288,7 +304,7 @@ static Local *locate_local(Compiler *C, char *name) {
return NULL;
}
-void let_form(Compiler *C, AstVec l, Op _, bool __) {
+void let_form(Compiler *C, AstVec l, Op _, int flags) {
CHECK(l.vals[1].ty == AST_LIST, "let's first argument must be list");
AstVec bindlist = l.vals[1].as.list;
CHECK(bindlist.len % 2 == 0, "unmatched binding in let");
@@ -301,25 +317,25 @@ void let_form(Compiler *C, AstVec l, Op _, bool __) {
AstNode name = bindlist.vals[ix];
AstNode expr = bindlist.vals[ix+1];
CHECK(name.ty == AST_IDENT, "binding name must be identifier");
- compile_node(C, expr, false);
+ compile_node(C, expr, 0);
declare_local(C, name.as.str);
}
- compile_body(C, l, 2);
+ compile_body(C, l, 2, flags);
end_scope(C);
}
-void def_form(Compiler *C, AstVec l, Op _, bool toplevel) {
+void def_form(Compiler *C, AstVec l, Op _, int flags) {
CHECK(l.vals[1].ty == AST_IDENT, "def's first argument must be ident");
- CHECK(toplevel, "def only allowed at top level");
- compile_node(C, l.vals[2], false);
+ CHECK(flags & F_toplevel, "def only allowed at top level");
+ compile_node(C, l.vals[2], 0);
declare_local(C, l.vals[1].as.str);
// whatever is calling us will compile an OP_DROP next
compile_opcode(C, OP_NIL);
}
-void fn_form(Compiler *C, AstVec l, Op _, bool __) {
+void fn_form(Compiler *C, AstVec l, Op _, int flags) {
// (fn (arg arg arg) body ...)
CHECK(l.vals[1].ty == AST_LIST, "fn's first argument must be list");
AstVec arglist = l.vals[1].as.list;
@@ -342,7 +358,7 @@ void fn_form(Compiler *C, AstVec l, Op _, bool __) {
declare_local(SC, argname.as.str);
}
- compile_body(SC, l, 2);
+ compile_body(SC, l, 2, F_tail);
end_scope(SC);
compile_opcode(SC, OP_RET);
@@ -393,7 +409,7 @@ static BuiltinIdent builtin_idents[] = {
-static void compile_node(Compiler *C, AstNode a, bool toplevel) {
+static void compile_node(Compiler *C, AstNode a, int flags) {
switch (a.ty) {
case AST_IDENT:;
char *ident = a.as.str;
@@ -432,7 +448,7 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) {
compile_opcode(C, OP_ARRNEW);
AstVec v = a.as.list;
for (int i = 0; i < v.len; i++) {
- compile_node(C, v.vals[i], false);
+ compile_node(C, v.vals[i], 0);
compile_opcode(C, OP_ARRAPPEND);
}
break;
@@ -464,7 +480,7 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) {
CHECK(nparams == form->min_params, "%s requires exactly %d parameters",
form->name, form->min_params);
- form->action(C, l, form->op, toplevel);
+ form->action(C, l, form->op, flags);
} else {
// function call
// (f a b c )
@@ -473,9 +489,13 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) {
exit(1);
}
for (int i = 0; i < l.len; i++) {
- compile_node(C, l.vals[i], false);
+ compile_node(C, l.vals[i], 0);
+ }
+ if (flags & F_tail) {
+ compile_tailcall_instr(C, l.len);
+ } else {
+ compile_call_instr(C, l.len);
}
- compile_call_instr(C, l.len);
}
break;
@@ -538,7 +558,7 @@ int main(int argc, char **argv) {
do {
astnode_free(&an);
rv = pcc_parse(parser, &an);
- compile_node(&com, an, false);
+ compile_node(&com, an, 0);
} while (rv != 0);
pcc_destroy(parser);
}