diff options
| author | ubq323 <ubq323@ubq323.website> | 2024-06-29 13:12:34 +0100 | 
|---|---|---|
| committer | ubq323 <ubq323@ubq323.website> | 2024-06-29 13:12:34 +0100 | 
| commit | 83382cb1b46eb17f94f16fbbf05b5e471284d797 (patch) | |
| tree | 73e8124725984a8c883c158ae50e61693924df9f | |
| parent | b51136defc2898c868e4a1b60025d5bb57347662 (diff) | |
add proper tail calls
| -rw-r--r-- | com.c | 90 | ||||
| -rw-r--r-- | dis.c | 6 | ||||
| -rw-r--r-- | tests/vars8.bth | 1 | ||||
| -rw-r--r-- | tests/vars8.out | 1 | ||||
| -rw-r--r-- | vm.c | 33 | ||||
| -rw-r--r-- | vm.h | 1 | 
6 files changed, 88 insertions, 44 deletions
| @@ -1,6 +1,7 @@  #include <stdlib.h>  #include <string.h>  #include <stdio.h> +#include <stdbool.h>  #include "com.h"  #include "mem.h" @@ -61,6 +62,7 @@ static int stack_effect_of(Op opcode) {  	// these ones depend on their argument. handle them specifically  	case OP_CALL: +	case OP_TAILCALL:  	case OP_ENDSCOPE:  		return 0;  	default: @@ -114,6 +116,12 @@ static void compile_call_instr(Compiler *C, uint8_t len) {  	C->stack_cur -= len - 1;  } +static void compile_tailcall_instr(Compiler *C, uint8_t len) { +	compile_opcode(C, OP_TAILCALL); +	compile_byte(C, len); +	C->stack_cur -= len - 1; +} +  static void compile_endscope_instr(Compiler *C, uint8_t nlocals) {  	if (nlocals > 0) {  		compile_opcode(C, OP_ENDSCOPE); @@ -135,53 +143,58 @@ static void patch(Compiler *C, size_t addr, uint16_t val) {  // ---- -static void compile_node(Compiler *C, AstNode a, bool toplevel); +enum flags { +	F_toplevel = 1, +	F_tail = 2, +}; + +static void compile_node(Compiler *C, AstNode a, int flags);  static void begin_scope(Compiler *C);  static void end_scope(Compiler *C); -typedef void (*form_compiler)(Compiler *C, AstVec l, Op op, bool toplevel); +typedef void (*form_compiler)(Compiler *C, AstVec l, Op op, int flags); -static void compile_body(Compiler *C, AstVec l, int startat) { +static void compile_body(Compiler *C, AstVec l, int startat, int flags) {  	begin_scope(C);  	for (int i = startat; i < l.len - 1; i++) { -		compile_node(C, l.vals[i], true); +		compile_node(C, l.vals[i], F_toplevel);  		compile_opcode(C, OP_DROP);  	} -	compile_node(C, l.vals[l.len - 1], true); +	compile_node(C, l.vals[l.len - 1], F_toplevel | (flags & F_tail));  	end_scope(C);  }  void single_form(Compiler *C, AstVec l, Op op) { -	compile_node(C, l.vals[1], false); +	compile_node(C, l.vals[1], 0);  	compile_opcode(C, op);  }  static Local *locate_local(Compiler *C, char *name); -void set_form(Compiler *C, AstVec l, Op _, bool __) { +void set_form(Compiler *C, AstVec l, Op _, int flags) {  	AstNode ident = l.vals[1];  	CHECK(ident.ty == AST_IDENT, "set's first argument must be identifier");  	char *name = ident.as.str;  	Local *loc = locate_local(C, name);  	if (loc != NULL) { -		compile_node(C, l.vals[2], false); +		compile_node(C, l.vals[2], 0);  		compile_opcode(C, OP_SETLOCAL);  		compile_byte(C, loc->slot);  	} else {  		// write global  		ObjString *o = objstring_copy_cstr(C->S, name); -		compile_node(C, l.vals[2], false); +		compile_node(C, l.vals[2], 0);  		compile_opcode(C, OP_SETGLOBAL);  		compile_byte(C, compile_constant(C, VAL_OBJ(o)));  	}  } -void do_form(Compiler *C, AstVec l, Op _, bool __) { -	compile_body(C, l, 1); +void do_form(Compiler *C, AstVec l, Op _, int flags) { +	compile_body(C, l, 1, flags);  } -void if_form(Compiler *C, AstVec l, Op _, bool __) { +void if_form(Compiler *C, AstVec l, Op _, int flags) {  	// (if cond if-true if-false)  	// cond  	// 0branch ->A @@ -190,12 +203,15 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) {  	// A: if-false  	// B: +	// never toplevel +	int downflags = flags & ~F_toplevel; +  	int orig_stack_cur = C->stack_cur; -	compile_node(C, l.vals[1], false); +	compile_node(C, l.vals[1], 0);  	compile_opcode(C, OP_0BRANCH);  	size_t ph_a = placeholder(C); -	compile_node(C, l.vals[2], false); +	compile_node(C, l.vals[2], downflags);  	compile_opcode(C, OP_SKIP);  	size_t ph_b = placeholder(C); @@ -203,7 +219,7 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) {  	int stack_cur_a = C->stack_cur;  	C->stack_cur = orig_stack_cur; -	compile_node(C, l.vals[3], false); +	compile_node(C, l.vals[3], downflags);  	size_t dest_b = BYTECODE(C).len;  	int stack_cur_b = C->stack_cur; @@ -215,7 +231,7 @@ void if_form(Compiler *C, AstVec l, Op _, bool __) {  	CHECK(stack_cur_a == orig_stack_cur + 1, "this should never happen");  } -void while_form(Compiler *C, AstVec l, Op _, bool __) { +void while_form(Compiler *C, AstVec l, Op _, int flags) {  	// (while cond body ...)  	// A:  	// cond @@ -225,10 +241,10 @@ void while_form(Compiler *C, AstVec l, Op _, bool __) {  	// B:  	// nil   (while loop always returns nil)  	size_t dest_a = BYTECODE(C).len; -	compile_node(C, l.vals[1], false); +	compile_node(C, l.vals[1], 0);  	compile_opcode(C, OP_0BRANCH);  	size_t ph_b = placeholder(C); -	compile_body(C, l, 2); +	compile_body(C, l, 2, flags & ~F_tail);  	compile_opcode(C, OP_DROP);  	compile_opcode(C, OP_REDO);  	size_t ph_a = placeholder(C); @@ -239,9 +255,9 @@ void while_form(Compiler *C, AstVec l, Op _, bool __) {  	patch(C, ph_b, dest_b - ph_b - 2);  } -void arith_form(Compiler *C, AstVec l, Op op, bool __) { -	compile_node(C, l.vals[1], false); -	compile_node(C, l.vals[2], false); +void arith_form(Compiler *C, AstVec l, Op op, int flags) { +	compile_node(C, l.vals[1], 0); +	compile_node(C, l.vals[2], 0);  	compile_opcode(C, op);  } @@ -288,7 +304,7 @@ static Local *locate_local(Compiler *C, char *name) {  	return NULL;  } -void let_form(Compiler *C, AstVec l, Op _, bool __) { +void let_form(Compiler *C, AstVec l, Op _, int flags) {  	CHECK(l.vals[1].ty == AST_LIST, "let's first argument must be list");  	AstVec bindlist = l.vals[1].as.list;  	CHECK(bindlist.len % 2 == 0, "unmatched binding in let"); @@ -301,25 +317,25 @@ void let_form(Compiler *C, AstVec l, Op _, bool __) {  		AstNode name = bindlist.vals[ix];  		AstNode expr = bindlist.vals[ix+1];  		CHECK(name.ty == AST_IDENT, "binding name must be identifier"); -		compile_node(C, expr, false); +		compile_node(C, expr, 0);  		declare_local(C, name.as.str);  	} -	compile_body(C, l, 2); +	compile_body(C, l, 2, flags);  	end_scope(C);  } -void def_form(Compiler *C, AstVec l, Op _, bool toplevel) { +void def_form(Compiler *C, AstVec l, Op _, int flags) {  	CHECK(l.vals[1].ty == AST_IDENT, "def's first argument must be ident"); -	CHECK(toplevel, "def only allowed at top level"); -	compile_node(C, l.vals[2], false); +	CHECK(flags & F_toplevel, "def only allowed at top level"); +	compile_node(C, l.vals[2], 0);  	declare_local(C, l.vals[1].as.str);  	// whatever is calling us will compile an OP_DROP next  	compile_opcode(C, OP_NIL);  } -void fn_form(Compiler *C, AstVec l, Op _, bool __) { +void fn_form(Compiler *C, AstVec l, Op _, int flags) {  	// (fn (arg arg arg) body ...)  	CHECK(l.vals[1].ty == AST_LIST, "fn's first argument must be list");  	AstVec arglist = l.vals[1].as.list; @@ -342,7 +358,7 @@ void fn_form(Compiler *C, AstVec l, Op _, bool __) {  		declare_local(SC, argname.as.str);  	} -	compile_body(SC, l, 2); +	compile_body(SC, l, 2, F_tail);  	end_scope(SC);  	compile_opcode(SC, OP_RET); @@ -393,7 +409,7 @@ static BuiltinIdent builtin_idents[] = { -static void compile_node(Compiler *C, AstNode a, bool toplevel) { +static void compile_node(Compiler *C, AstNode a, int flags) {  	switch (a.ty) {  	case AST_IDENT:;  		char *ident = a.as.str; @@ -432,7 +448,7 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) {  		compile_opcode(C, OP_ARRNEW);  		AstVec v = a.as.list;  		for (int i = 0; i < v.len; i++) { -			compile_node(C, v.vals[i], false); +			compile_node(C, v.vals[i], 0);  			compile_opcode(C, OP_ARRAPPEND);  		}  		break; @@ -464,7 +480,7 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) {  				CHECK(nparams == form->min_params, "%s requires exactly %d parameters",  					form->name, form->min_params); -			form->action(C, l, form->op, toplevel); +			form->action(C, l, form->op, flags);  		} else {  			// function call  			// (f a b c ) @@ -473,9 +489,13 @@ static void compile_node(Compiler *C, AstNode a, bool toplevel) {  				exit(1);  			}  			for (int i = 0; i < l.len; i++) { -				compile_node(C, l.vals[i], false); +				compile_node(C, l.vals[i], 0); +			} +			if (flags & F_tail) { +				compile_tailcall_instr(C, l.len); +			} else { +				compile_call_instr(C, l.len);  			} -			compile_call_instr(C, l.len);  		}  		break; @@ -538,7 +558,7 @@ int main(int argc, char **argv) {  		do {  			astnode_free(&an);  			rv = pcc_parse(parser, &an); -			compile_node(&com, an, false); +			compile_node(&com, an, 0);  		} while (rv != 0);  		pcc_destroy(parser);  	} @@ -60,6 +60,11 @@ static size_t disasm_instr_h(Chunk *ch, size_t ip, int depth) {  		printf("call #%hhu\n",nargs);  		break;  	} +	case OP_TAILCALL: { +		uint8_t nargs = ch->bc.d[ip++]; +		printf("\033[31mtailcall\033[0m #%hhu\n",nargs); +		break; +	}  	case OP_ENDSCOPE: {  		uint8_t nlocals = ch->bc.d[ip++];  		printf("endscope #%hhu\n",nlocals); @@ -107,7 +112,6 @@ static size_t disasm_instr_h(Chunk *ch, size_t ip, int depth) {  	default:  		printf("unknown opcode %d\n", instr); -		exit(2);  	} diff --git a/tests/vars8.bth b/tests/vars8.bth new file mode 100644 index 0000000..90f16d4 --- /dev/null +++ b/tests/vars8.bth @@ -0,0 +1 @@ +(if (< 2 3) (def x 100) 20) diff --git a/tests/vars8.out b/tests/vars8.out new file mode 100644 index 0000000..72995f4 --- /dev/null +++ b/tests/vars8.out @@ -0,0 +1 @@ +def only allowed at top level @@ -4,6 +4,7 @@  #include <stdbool.h>  #include <stdio.h>  #include <math.h> +#include <string.h>  #include "val.h"  #include "vm.h" @@ -174,6 +175,7 @@ int runvm(State *S) {  		case OP_TRUE: PUSH(VAL_TRUE); break;  		case OP_FALSE: PUSH(VAL_FALSE); break; +		case OP_TAILCALL:  		case OP_CALL: {  			// nargs + 1 = function and args  			uint8_t len = RBYTE(); @@ -185,14 +187,27 @@ int runvm(State *S) {  				CHECK(nargs == func->arity, "func needs exactly %d args, but got %d",func->arity,nargs);  				CHECK(th->rsp < MAXDEPTH, "rstack overflow"); -				StackFrame *sf = &th->rstack[th->rsp++]; -				sf->ip = th->ip; -				sf->ch = th->ch; -				sf->fp = th->fp; - -				th->ip = 0; -				th->ch = &func->ch; -				th->fp = th->sp - len; +				if (instr != OP_TAILCALL) { +					StackFrame *sf = &th->rstack[th->rsp++]; +					sf->ip = th->ip; +					sf->ch = th->ch; +					sf->fp = th->fp; + +					th->ip = 0; +					th->ch = &func->ch; +					th->fp = th->sp - len; +				} else { +					// xxx might invalidate open upvalues +					memmove(&th->stack[th->fp], &th->stack[th->sp - len], len*sizeof(Val)); +					th->sp = th->fp + len; +					th->ip = 0; +					th->ch = &func->ch; + +					StackFrame *cur_sf = &th->rstack[th->rsp]; +					cur_sf->ip = th->ip; +					cur_sf->ch = th->ch; +					cur_sf->fp = th->fp; +				}  			} else if (IS_CFUNC(callee)) {  				Val *firstarg = &th->stack[th->sp - nargs];  				Val res = AS_CFUNC(callee)(S, nargs, firstarg); @@ -247,6 +262,8 @@ int runvm(State *S) {  			objarr_append(S, arr, v);  			break;  		} +		default: +			ERROR("unknown opcode");  		} @@ -61,6 +61,7 @@ typedef enum {  	OP_REDO,  	OP_CALL, +	OP_TAILCALL,  	OP_ENDSCOPE,  	OP_ARRNEW, | 
