summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorubq323 <ubq323@ubq323.website>2024-06-29 13:12:34 +0100
committerubq323 <ubq323@ubq323.website>2024-06-29 13:12:34 +0100
commit83382cb1b46eb17f94f16fbbf05b5e471284d797 (patch)
tree73e8124725984a8c883c158ae50e61693924df9f
parentb51136defc2898c868e4a1b60025d5bb57347662 (diff)
add proper tail calls
-rw-r--r--com.c90
-rw-r--r--dis.c6
-rw-r--r--tests/vars8.bth1
-rw-r--r--tests/vars8.out1
-rw-r--r--vm.c33
-rw-r--r--vm.h1
6 files changed, 88 insertions, 44 deletions
diff --git a/com.c b/com.c
index 23abeaa..ff13815 100644
--- a/com.c
+++ b/com.c
@@ -1,6 +1,7 @@
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
+#include <stdbool.h>
#include "com.h"
#include "mem.h"
@@ -61,6 +62,7 @@ static int stack_effect_of(Op opcode) {
// these ones depend on their argument. handle them specifically
case OP_CALL:
+ case OP_TAILCALL:
case OP_ENDSCOPE:
return 0;
default:
@@ -114,6 +116,12 @@ static void compile_call_instr(Compiler *C, uint8_t len) {
C->stack_cur -= len - 1;
}
+static void compile_tailcall_instr(Compiler *C, uint8_t len) {
+ compile_opcode(C, OP_TAILCALL);
+ compile_byte(C, len);
+ C->stack_cur -= len - 1;
+}
+
static void compile_endscope_instr(Compiler *C, uint8_t nlocals) {
if (nlocals > 0) {
compile_opcode(C, OP_ENDSCOPE);
@@ -135,53 +143,58 @@ static void patch(Compiler *C, size_t addr, uint16_t val) {
// ----
-static void compile_node(Compiler *C, AstNode a, bool toplevel);
+enum flags {
+ F_toplevel = 1,
+ F_tail = 2,
+};
+
+static void compile_node(Compiler *C, AstNode a, int flags);
static void begin_scope(Compiler *C);
static void end_scope(Compiler *C);
-typedef void (*form_compiler)(Compiler *C, AstVec l, Op op, bool toplevel);
+typedef void (*form_compiler)(Compiler *C, AstVec l, Op op, int flags);
-static void compile_body(Compiler *C, AstVec l, int startat) {
+static void compile_body(Compiler *C, AstVec l, int startat, int flags) {
begin_scope(C);
for (int i = startat; i < l.len - 1; i++) {
- compile_node(C, l.vals[i], true);
+ compile_node(C, l.vals[i], F_toplevel);
compile_opcode(C, OP_DROP);
}
- compile_node(C, l.vals[l.len - 1], true);
+ compile_node(C, l.vals[l.len - 1], F_toplevel | (flags & F_tail));
end_scope(C);
}
void single_form(Compiler *C, AstVec l, Op op) {
- compile_node(C, l.vals[1], false);
+ compile_node(C, l.vals[1], 0);
compile_opcode(C, op);
}
static Local *locate_local(Compiler *C, char *name);
-void set_form(Compiler *C, AstVec l, Op _, bool __) {
+void set_form(Compiler *C, AstVec l, Op _, int flags) {
AstNode ident = l.vals[1];
CHECK(ident.ty == AST_IDENT, "set's first argument must be identifier");
char *name = ident.as.str;
Local *loc = locate_local(C, name);
if (loc != NULL) {
- compile_node(C, l.vals[2], false);
+ compile_node(C, l.vals[2], 0);
compile_opcode(C, OP_SETLOCAL);
compile_byte(C, loc->slot);
} else {
// write global
ObjString *o = objstring_copy_cstr(C->S, name);
- compile_node(C, l.vals[2], false);
+ compile_node(C, l.vals[2], 0);
compile_opcode(C, OP_SETGLOBAL);
compile_byte(C, compile_constant(C, VAL_OBJ(o)));
}
}
-void do_form(Compiler *C, AstVec l, Op _, bool __) {
- compile_body(C, l, 1);
+void do_form(Compiler *C, AstVec l, Op _, int flags) {
+ compile_body(C, l, 1, flags);
}
-void if_form(Compiler *C, AstVec l, Op _, bool __) {
+void if_form(Compiler *C, AstVec l, Op _, int flags) {
// (if cond if-true if-false)
// cond
// 0branch ->A
@@ -190,12 +203,15 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) {
// A: if-false
// B:
+ // never toplevel
+ int downflags = flags & ~F_toplevel;
+
int orig_stack_cur = C->stack_cur;
- compile_node(C, l.vals[1], false);
+ compile_node(C, l.vals[1], 0);
compile_opcode(C, OP_0BRANCH);
size_t ph_a = placeholder(C);
- compile_node(C, l.vals[2], false);
+ compile_node(C, l.vals[2], downflags);
compile_opcode(C, OP_SKIP);
size_t ph_b = placeholder(C);
@@ -203,7 +219,7 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) {
int stack_cur_a = C->stack_cur;
C->stack_cur = orig_stack_cur;
- compile_node(C, l.vals[3], false);
+ compile_node(C, l.vals[3], downflags);
size_t dest_b = BYTECODE(C).len;
int stack_cur_b = C->stack_cur;
@@ -215,7 +231,7 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) {
CHECK(stack_cur_a == orig_stack_cur + 1, "this should never happen");
}
-void while_form(Compiler *C, AstVec l, Op _, bool __) {
+void while_form(Compiler *C, AstVec l, Op _, int flags) {
// (while cond body ...)
// A:
// cond
@@ -225,10 +241,10 @@ void while_form(Compiler *C, AstVec l, Op _, bool __) {
// B:
// nil (while loop always returns nil)
size_t dest_a = BYTECODE(C).len;
- compile_node(C, l.vals[1], false);
+ compile_node(C, l.vals[1], 0);
compile_opcode(C, OP_0BRANCH);
size_t ph_b = placeholder(C);
- compile_body(C, l, 2);
+ compile_body(C, l, 2, flags & ~F_tail);
compile_opcode(C, OP_DROP);
compile_opcode(C, OP_REDO);
size_t ph_a = placeholder(C);
@@ -239,9 +255,9 @@ void while_form(Compiler *C, AstVec l, Op _, bool __) {
patch(C, ph_b, dest_b - ph_b - 2);
}
-void arith_form(Compiler *C, AstVec l, Op op, bool __) {
- compile_node(C, l.vals[1], false);
- compile_node(C, l.vals[2], false);
+void arith_form(Compiler *C, AstVec l, Op op, int flags) {
+ compile_node(C, l.vals[1], 0);
+ compile_node(C, l.vals[2], 0);
compile_opcode(C, op);
}
@@ -288,7 +304,7 @@ static Local *locate_local(Compiler *C, char *name) {
return NULL;
}
-void let_form(Compiler *C, AstVec l, Op _, bool __) {
+void let_form(Compiler *C, AstVec l, Op _, int flags) {
CHECK(l.vals[1].ty == AST_LIST, "let's first argument must be list");
AstVec bindlist = l.vals[1].as.list;
CHECK(bindlist.len % 2 == 0, "unmatched binding in let");
@@ -301,25 +317,25 @@ void let_form(Compiler *C, AstVec l, Op _, bool __) {
AstNode name = bindlist.vals[ix];
AstNode expr = bindlist.vals[ix+1];
CHECK(name.ty == AST_IDENT, "binding name must be identifier");
- compile_node(C, expr, false);
+ compile_node(C, expr, 0);
declare_local(C, name.as.str);
}
- compile_body(C, l, 2);
+ compile_body(C, l, 2, flags);
end_scope(C);
}
-void def_form(Compiler *C, AstVec l, Op _, bool toplevel) {
+void def_form(Compiler *C, AstVec l, Op _, int flags) {
CHECK(l.vals[1].ty == AST_IDENT, "def's first argument must be ident");
- CHECK(toplevel, "def only allowed at top level");
- compile_node(C, l.vals[2], false);
+ CHECK(flags & F_toplevel, "def only allowed at top level");
+ compile_node(C, l.vals[2], 0);
declare_local(C, l.vals[1].as.str);
// whatever is calling us will compile an OP_DROP next
compile_opcode(C, OP_NIL);
}
-void fn_form(Compiler *C, AstVec l, Op _, bool __) {
+void fn_form(Compiler *C, AstVec l, Op _, int flags) {
// (fn (arg arg arg) body ...)
CHECK(l.vals[1].ty == AST_LIST, "fn's first argument must be list");
AstVec arglist = l.vals[1].as.list;
@@ -342,7 +358,7 @@ void fn_form(Compiler *C, AstVec l, Op _, bool __) {
declare_local(SC, argname.as.str);
}
- compile_body(SC, l, 2);
+ compile_body(SC, l, 2, F_tail);
end_scope(SC);
compile_opcode(SC, OP_RET);
@@ -393,7 +409,7 @@ static BuiltinIdent builtin_idents[] = {
-static void compile_node(Compiler *C, AstNode a, bool toplevel) {
+static void compile_node(Compiler *C, AstNode a, int flags) {
switch (a.ty) {
case AST_IDENT:;
char *ident = a.as.str;
@@ -432,7 +448,7 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) {
compile_opcode(C, OP_ARRNEW);
AstVec v = a.as.list;
for (int i = 0; i < v.len; i++) {
- compile_node(C, v.vals[i], false);
+ compile_node(C, v.vals[i], 0);
compile_opcode(C, OP_ARRAPPEND);
}
break;
@@ -464,7 +480,7 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) {
CHECK(nparams == form->min_params, "%s requires exactly %d parameters",
form->name, form->min_params);
- form->action(C, l, form->op, toplevel);
+ form->action(C, l, form->op, flags);
} else {
// function call
// (f a b c )
@@ -473,9 +489,13 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) {
exit(1);
}
for (int i = 0; i < l.len; i++) {
- compile_node(C, l.vals[i], false);
+ compile_node(C, l.vals[i], 0);
+ }
+ if (flags & F_tail) {
+ compile_tailcall_instr(C, l.len);
+ } else {
+ compile_call_instr(C, l.len);
}
- compile_call_instr(C, l.len);
}
break;
@@ -538,7 +558,7 @@ int main(int argc, char **argv) {
do {
astnode_free(&an);
rv = pcc_parse(parser, &an);
- compile_node(&com, an, false);
+ compile_node(&com, an, 0);
} while (rv != 0);
pcc_destroy(parser);
}
diff --git a/dis.c b/dis.c
index 2939384..0ca77e7 100644
--- a/dis.c
+++ b/dis.c
@@ -60,6 +60,11 @@ static size_t disasm_instr_h(Chunk *ch, size_t ip, int depth) {
printf("call #%hhu\n",nargs);
break;
}
+ case OP_TAILCALL: {
+ uint8_t nargs = ch->bc.d[ip++];
+ printf("\033[31mtailcall\033[0m #%hhu\n",nargs);
+ break;
+ }
case OP_ENDSCOPE: {
uint8_t nlocals = ch->bc.d[ip++];
printf("endscope #%hhu\n",nlocals);
@@ -107,7 +112,6 @@ static size_t disasm_instr_h(Chunk *ch, size_t ip, int depth) {
default:
printf("unknown opcode %d\n", instr);
- exit(2);
}
diff --git a/tests/vars8.bth b/tests/vars8.bth
new file mode 100644
index 0000000..90f16d4
--- /dev/null
+++ b/tests/vars8.bth
@@ -0,0 +1 @@
+(if (< 2 3) (def x 100) 20)
diff --git a/tests/vars8.out b/tests/vars8.out
new file mode 100644
index 0000000..72995f4
--- /dev/null
+++ b/tests/vars8.out
@@ -0,0 +1 @@
+def only allowed at top level
diff --git a/vm.c b/vm.c
index 500b4bf..4a65832 100644
--- a/vm.c
+++ b/vm.c
@@ -4,6 +4,7 @@
#include <stdbool.h>
#include <stdio.h>
#include <math.h>
+#include <string.h>
#include "val.h"
#include "vm.h"
@@ -174,6 +175,7 @@ int runvm(State *S) {
case OP_TRUE: PUSH(VAL_TRUE); break;
case OP_FALSE: PUSH(VAL_FALSE); break;
+ case OP_TAILCALL:
case OP_CALL: {
// nargs + 1 = function and args
uint8_t len = RBYTE();
@@ -185,14 +187,27 @@ int runvm(State *S) {
CHECK(nargs == func->arity, "func needs exactly %d args, but got %d",func->arity,nargs);
CHECK(th->rsp < MAXDEPTH, "rstack overflow");
- StackFrame *sf = &th->rstack[th->rsp++];
- sf->ip = th->ip;
- sf->ch = th->ch;
- sf->fp = th->fp;
-
- th->ip = 0;
- th->ch = &func->ch;
- th->fp = th->sp - len;
+ if (instr != OP_TAILCALL) {
+ StackFrame *sf = &th->rstack[th->rsp++];
+ sf->ip = th->ip;
+ sf->ch = th->ch;
+ sf->fp = th->fp;
+
+ th->ip = 0;
+ th->ch = &func->ch;
+ th->fp = th->sp - len;
+ } else {
+ // xxx might invalidate open upvalues
+ memmove(&th->stack[th->fp], &th->stack[th->sp - len], len*sizeof(Val));
+ th->sp = th->fp + len;
+ th->ip = 0;
+ th->ch = &func->ch;
+
+ StackFrame *cur_sf = &th->rstack[th->rsp];
+ cur_sf->ip = th->ip;
+ cur_sf->ch = th->ch;
+ cur_sf->fp = th->fp;
+ }
} else if (IS_CFUNC(callee)) {
Val *firstarg = &th->stack[th->sp - nargs];
Val res = AS_CFUNC(callee)(S, nargs, firstarg);
@@ -247,6 +262,8 @@ int runvm(State *S) {
objarr_append(S, arr, v);
break;
}
+ default:
+ ERROR("unknown opcode");
}
diff --git a/vm.h b/vm.h
index 8479623..1ca2c48 100644
--- a/vm.h
+++ b/vm.h
@@ -61,6 +61,7 @@ typedef enum {
OP_REDO,
OP_CALL,
+ OP_TAILCALL,
OP_ENDSCOPE,
OP_ARRNEW,