From 83382cb1b46eb17f94f16fbbf05b5e471284d797 Mon Sep 17 00:00:00 2001 From: ubq323 Date: Sat, 29 Jun 2024 13:12:34 +0100 Subject: add proper tail calls --- com.c | 90 +++++++++++++++++++++++++++++++++++---------------------- dis.c | 6 +++- tests/vars8.bth | 1 + tests/vars8.out | 1 + vm.c | 33 ++++++++++++++++----- vm.h | 1 + 6 files changed, 88 insertions(+), 44 deletions(-) create mode 100644 tests/vars8.bth create mode 100644 tests/vars8.out diff --git a/com.c b/com.c index 23abeaa..ff13815 100644 --- a/com.c +++ b/com.c @@ -1,6 +1,7 @@ #include #include #include +#include #include "com.h" #include "mem.h" @@ -61,6 +62,7 @@ static int stack_effect_of(Op opcode) { // these ones depend on their argument. handle them specifically case OP_CALL: + case OP_TAILCALL: case OP_ENDSCOPE: return 0; default: @@ -114,6 +116,12 @@ static void compile_call_instr(Compiler *C, uint8_t len) { C->stack_cur -= len - 1; } +static void compile_tailcall_instr(Compiler *C, uint8_t len) { + compile_opcode(C, OP_TAILCALL); + compile_byte(C, len); + C->stack_cur -= len - 1; +} + static void compile_endscope_instr(Compiler *C, uint8_t nlocals) { if (nlocals > 0) { compile_opcode(C, OP_ENDSCOPE); @@ -135,53 +143,58 @@ static void patch(Compiler *C, size_t addr, uint16_t val) { // ---- -static void compile_node(Compiler *C, AstNode a, bool toplevel); +enum flags { + F_toplevel = 1, + F_tail = 2, +}; + +static void compile_node(Compiler *C, AstNode a, int flags); static void begin_scope(Compiler *C); static void end_scope(Compiler *C); -typedef void (*form_compiler)(Compiler *C, AstVec l, Op op, bool toplevel); +typedef void (*form_compiler)(Compiler *C, AstVec l, Op op, int flags); -static void compile_body(Compiler *C, AstVec l, int startat) { +static void compile_body(Compiler *C, AstVec l, int startat, int flags) { begin_scope(C); for (int i = startat; i < l.len - 1; i++) { - compile_node(C, l.vals[i], true); + compile_node(C, l.vals[i], F_toplevel); compile_opcode(C, OP_DROP); } - compile_node(C, l.vals[l.len - 1], true); + compile_node(C, l.vals[l.len - 1], F_toplevel | (flags & F_tail)); end_scope(C); } void single_form(Compiler *C, AstVec l, Op op) { - compile_node(C, l.vals[1], false); + compile_node(C, l.vals[1], 0); compile_opcode(C, op); } static Local *locate_local(Compiler *C, char *name); -void set_form(Compiler *C, AstVec l, Op _, bool __) { +void set_form(Compiler *C, AstVec l, Op _, int flags) { AstNode ident = l.vals[1]; CHECK(ident.ty == AST_IDENT, "set's first argument must be identifier"); char *name = ident.as.str; Local *loc = locate_local(C, name); if (loc != NULL) { - compile_node(C, l.vals[2], false); + compile_node(C, l.vals[2], 0); compile_opcode(C, OP_SETLOCAL); compile_byte(C, loc->slot); } else { // write global ObjString *o = objstring_copy_cstr(C->S, name); - compile_node(C, l.vals[2], false); + compile_node(C, l.vals[2], 0); compile_opcode(C, OP_SETGLOBAL); compile_byte(C, compile_constant(C, VAL_OBJ(o))); } } -void do_form(Compiler *C, AstVec l, Op _, bool __) { - compile_body(C, l, 1); +void do_form(Compiler *C, AstVec l, Op _, int flags) { + compile_body(C, l, 1, flags); } -void if_form(Compiler *C, AstVec l, Op _, bool __) { +void if_form(Compiler *C, AstVec l, Op _, int flags) { // (if cond if-true if-false) // cond // 0branch ->A @@ -190,12 +203,15 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) { // A: if-false // B: + // never toplevel + int downflags = flags & ~F_toplevel; + int orig_stack_cur = C->stack_cur; - compile_node(C, l.vals[1], false); + compile_node(C, l.vals[1], 0); compile_opcode(C, OP_0BRANCH); size_t ph_a = placeholder(C); - compile_node(C, l.vals[2], false); + compile_node(C, l.vals[2], downflags); compile_opcode(C, OP_SKIP); size_t ph_b = placeholder(C); @@ -203,7 +219,7 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) { int stack_cur_a = C->stack_cur; C->stack_cur = orig_stack_cur; - compile_node(C, l.vals[3], false); + compile_node(C, l.vals[3], downflags); size_t dest_b = BYTECODE(C).len; int stack_cur_b = C->stack_cur; @@ -215,7 +231,7 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) { CHECK(stack_cur_a == orig_stack_cur + 1, "this should never happen"); } -void while_form(Compiler *C, AstVec l, Op _, bool __) { +void while_form(Compiler *C, AstVec l, Op _, int flags) { // (while cond body ...) // A: // cond @@ -225,10 +241,10 @@ void while_form(Compiler *C, AstVec l, Op _, bool __) { // B: // nil (while loop always returns nil) size_t dest_a = BYTECODE(C).len; - compile_node(C, l.vals[1], false); + compile_node(C, l.vals[1], 0); compile_opcode(C, OP_0BRANCH); size_t ph_b = placeholder(C); - compile_body(C, l, 2); + compile_body(C, l, 2, flags & ~F_tail); compile_opcode(C, OP_DROP); compile_opcode(C, OP_REDO); size_t ph_a = placeholder(C); @@ -239,9 +255,9 @@ void while_form(Compiler *C, AstVec l, Op _, bool __) { patch(C, ph_b, dest_b - ph_b - 2); } -void arith_form(Compiler *C, AstVec l, Op op, bool __) { - compile_node(C, l.vals[1], false); - compile_node(C, l.vals[2], false); +void arith_form(Compiler *C, AstVec l, Op op, int flags) { + compile_node(C, l.vals[1], 0); + compile_node(C, l.vals[2], 0); compile_opcode(C, op); } @@ -288,7 +304,7 @@ static Local *locate_local(Compiler *C, char *name) { return NULL; } -void let_form(Compiler *C, AstVec l, Op _, bool __) { +void let_form(Compiler *C, AstVec l, Op _, int flags) { CHECK(l.vals[1].ty == AST_LIST, "let's first argument must be list"); AstVec bindlist = l.vals[1].as.list; CHECK(bindlist.len % 2 == 0, "unmatched binding in let"); @@ -301,25 +317,25 @@ void let_form(Compiler *C, AstVec l, Op _, bool __) { AstNode name = bindlist.vals[ix]; AstNode expr = bindlist.vals[ix+1]; CHECK(name.ty == AST_IDENT, "binding name must be identifier"); - compile_node(C, expr, false); + compile_node(C, expr, 0); declare_local(C, name.as.str); } - compile_body(C, l, 2); + compile_body(C, l, 2, flags); end_scope(C); } -void def_form(Compiler *C, AstVec l, Op _, bool toplevel) { +void def_form(Compiler *C, AstVec l, Op _, int flags) { CHECK(l.vals[1].ty == AST_IDENT, "def's first argument must be ident"); - CHECK(toplevel, "def only allowed at top level"); - compile_node(C, l.vals[2], false); + CHECK(flags & F_toplevel, "def only allowed at top level"); + compile_node(C, l.vals[2], 0); declare_local(C, l.vals[1].as.str); // whatever is calling us will compile an OP_DROP next compile_opcode(C, OP_NIL); } -void fn_form(Compiler *C, AstVec l, Op _, bool __) { +void fn_form(Compiler *C, AstVec l, Op _, int flags) { // (fn (arg arg arg) body ...) CHECK(l.vals[1].ty == AST_LIST, "fn's first argument must be list"); AstVec arglist = l.vals[1].as.list; @@ -342,7 +358,7 @@ void fn_form(Compiler *C, AstVec l, Op _, bool __) { declare_local(SC, argname.as.str); } - compile_body(SC, l, 2); + compile_body(SC, l, 2, F_tail); end_scope(SC); compile_opcode(SC, OP_RET); @@ -393,7 +409,7 @@ static BuiltinIdent builtin_idents[] = { -static void compile_node(Compiler *C, AstNode a, bool toplevel) { +static void compile_node(Compiler *C, AstNode a, int flags) { switch (a.ty) { case AST_IDENT:; char *ident = a.as.str; @@ -432,7 +448,7 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) { compile_opcode(C, OP_ARRNEW); AstVec v = a.as.list; for (int i = 0; i < v.len; i++) { - compile_node(C, v.vals[i], false); + compile_node(C, v.vals[i], 0); compile_opcode(C, OP_ARRAPPEND); } break; @@ -464,7 +480,7 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) { CHECK(nparams == form->min_params, "%s requires exactly %d parameters", form->name, form->min_params); - form->action(C, l, form->op, toplevel); + form->action(C, l, form->op, flags); } else { // function call // (f a b c ) @@ -473,9 +489,13 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) { exit(1); } for (int i = 0; i < l.len; i++) { - compile_node(C, l.vals[i], false); + compile_node(C, l.vals[i], 0); + } + if (flags & F_tail) { + compile_tailcall_instr(C, l.len); + } else { + compile_call_instr(C, l.len); } - compile_call_instr(C, l.len); } break; @@ -538,7 +558,7 @@ int main(int argc, char **argv) { do { astnode_free(&an); rv = pcc_parse(parser, &an); - compile_node(&com, an, false); + compile_node(&com, an, 0); } while (rv != 0); pcc_destroy(parser); } diff --git a/dis.c b/dis.c index 2939384..0ca77e7 100644 --- a/dis.c +++ b/dis.c @@ -60,6 +60,11 @@ static size_t disasm_instr_h(Chunk *ch, size_t ip, int depth) { printf("call #%hhu\n",nargs); break; } + case OP_TAILCALL: { + uint8_t nargs = ch->bc.d[ip++]; + printf("\033[31mtailcall\033[0m #%hhu\n",nargs); + break; + } case OP_ENDSCOPE: { uint8_t nlocals = ch->bc.d[ip++]; printf("endscope #%hhu\n",nlocals); @@ -107,7 +112,6 @@ static size_t disasm_instr_h(Chunk *ch, size_t ip, int depth) { default: printf("unknown opcode %d\n", instr); - exit(2); } diff --git a/tests/vars8.bth b/tests/vars8.bth new file mode 100644 index 0000000..90f16d4 --- /dev/null +++ b/tests/vars8.bth @@ -0,0 +1 @@ +(if (< 2 3) (def x 100) 20) diff --git a/tests/vars8.out b/tests/vars8.out new file mode 100644 index 0000000..72995f4 --- /dev/null +++ b/tests/vars8.out @@ -0,0 +1 @@ +def only allowed at top level diff --git a/vm.c b/vm.c index 500b4bf..4a65832 100644 --- a/vm.c +++ b/vm.c @@ -4,6 +4,7 @@ #include #include #include +#include #include "val.h" #include "vm.h" @@ -174,6 +175,7 @@ int runvm(State *S) { case OP_TRUE: PUSH(VAL_TRUE); break; case OP_FALSE: PUSH(VAL_FALSE); break; + case OP_TAILCALL: case OP_CALL: { // nargs + 1 = function and args uint8_t len = RBYTE(); @@ -185,14 +187,27 @@ int runvm(State *S) { CHECK(nargs == func->arity, "func needs exactly %d args, but got %d",func->arity,nargs); CHECK(th->rsp < MAXDEPTH, "rstack overflow"); - StackFrame *sf = &th->rstack[th->rsp++]; - sf->ip = th->ip; - sf->ch = th->ch; - sf->fp = th->fp; - - th->ip = 0; - th->ch = &func->ch; - th->fp = th->sp - len; + if (instr != OP_TAILCALL) { + StackFrame *sf = &th->rstack[th->rsp++]; + sf->ip = th->ip; + sf->ch = th->ch; + sf->fp = th->fp; + + th->ip = 0; + th->ch = &func->ch; + th->fp = th->sp - len; + } else { + // xxx might invalidate open upvalues + memmove(&th->stack[th->fp], &th->stack[th->sp - len], len*sizeof(Val)); + th->sp = th->fp + len; + th->ip = 0; + th->ch = &func->ch; + + StackFrame *cur_sf = &th->rstack[th->rsp]; + cur_sf->ip = th->ip; + cur_sf->ch = th->ch; + cur_sf->fp = th->fp; + } } else if (IS_CFUNC(callee)) { Val *firstarg = &th->stack[th->sp - nargs]; Val res = AS_CFUNC(callee)(S, nargs, firstarg); @@ -247,6 +262,8 @@ int runvm(State *S) { objarr_append(S, arr, v); break; } + default: + ERROR("unknown opcode"); } diff --git a/vm.h b/vm.h index 8479623..1ca2c48 100644 --- a/vm.h +++ b/vm.h @@ -61,6 +61,7 @@ typedef enum { OP_REDO, OP_CALL, + OP_TAILCALL, OP_ENDSCOPE, OP_ARRNEW, -- cgit v1.2.3