#include #include #include #include #include "com.h" #include "mem.h" #include "chunk.h" #include "util.h" #include "lib.h" #include "read.h" #include "dis.h" #define BYTECODE(C) (C->ch->bc) enum flags { F_toplevel = 1, F_tail = 2, }; Chunk chunk_new(State *S) { return (Chunk){ 0 }; } Compiler compiler_new(State *S, Compiler *outer, Chunk *ch) { return (Compiler){ .S = S, .ch = ch, .stack_cur = 0, .scope = NULL, }; } static void cpl_expr(Compiler *C, Val v, int flags); static int stack_effect_of(Op opcode) { switch (opcode) { case OP_LOADK: case OP_GETGLOBAL: case OP_GETLOCAL: case OP_NIL: case OP_TRUE: case OP_FALSE: case OP_ARRNEW: return 1; case OP_SKIP: case OP_REDO: case OP_SETGLOBAL: case OP_SETLOCAL: case OP_RET: case OP_HALT: case OP_ARRLEN: return 0; case OP_DROP: case OP_0BRANCH: case OP_ADD: case OP_SUB: case OP_MUL: case OP_DIV: case OP_CMP: case OP_EQU: case OP_MOD: case OP_ARRAPPEND: return -1; case OP_SETIDX: return -2; // these ones depend on their argument. handle them specifically case OP_CALL: case OP_TAILCALL: case OP_ENDSCOPE: ERROR("stack effect of opcode %d not constant",opcode); break; default: ERROR("unknown opcode %d",opcode); break; } } // compilation of things static void cpl_byte(Compiler *C, uint8_t byte) { Chunk *ch = C->ch; ENSURE_CAP(C->S, ch->bc, uint8_t, ch->bc.len+1); ch->bc.d[ch->bc.len++] = byte; } static void cpl_op(Compiler *C, Op op) { int stack_effect = stack_effect_of(op); C->stack_cur += stack_effect; cpl_byte(C, op); } // particular instructions static uint8_t cpl_const(Compiler *C, Val v) { Chunk *ch = C->ch; for (int i = 0; i < ch->consts.len; i++) { if (val_equal(v, ch->consts.d[i])) { return i; } } CHECK(ch->consts.len < 256, "maximum number of constants per function reached"); ENSURE_CAP(C->S, ch->consts, Val, ch->consts.len+1); uint8_t ix = ch->consts.len; ch->consts.d[ch->consts.len++] = v; return ix; } static void cpl_constop(Compiler *C, Op op, Val v) { cpl_op(C, op); cpl_byte(C, cpl_const(C, v)); } static void cpl_call(Compiler *C, uint8_t len) { cpl_byte(C, OP_CALL); cpl_byte(C, len); C->stack_cur -= len - 1; } static void cpl_tailcall(Compiler *C, uint8_t len) { cpl_byte(C, OP_TAILCALL); cpl_byte(C, len); C->stack_cur -= len - 1; } static void cpl_endscope(Compiler *C, uint8_t nlocals) { if (nlocals > 0) { cpl_byte(C, OP_ENDSCOPE); cpl_byte(C, nlocals); C->stack_cur -= nlocals; } } // jump offsets and things static size_t placeholder(Compiler *C) { size_t old_ix = BYTECODE(C).len; cpl_byte(C, 0x00); cpl_byte(C, 0x00); return old_ix; } static void patch(Compiler *C, size_t addr, uint16_t val) { BYTECODE(C).d[addr] = val & 0xff; BYTECODE(C).d[addr+1] = (val & 0xff00) >> 8; } // scopes and locals static void begin_scope(Compiler *C) { Scope *sc = malloc(sizeof(Scope)); CHECK(sc != NULL, "memory fail"); memset(sc, 0, sizeof(Scope)); sc->outer = C->scope; C->scope = sc; } static void end_scope(Compiler *C) { Scope *sc = C->scope; CHECK(sc != NULL, "attempt to end nonexistent scope"); C->scope = sc->outer; // printf("ending scope with %d locals, named: \n", sc->nlocals); // for (int i = 0; i < sc->nlocals; i++) { // Local loc = sc->locals[i]; // printf("\t%3d %s\n",loc.slot, loc.name); // } cpl_endscope(C, sc->nlocals); free(sc); } // returns slot of declared local static uint8_t declare_local(Compiler *C, char *name) { Scope *sc = C->scope; CHECK(sc != NULL, "can't declare local outside of scope"); CHECK(sc->nlocals < MAX_LOCALS, "maximum number of locals per function exceeded"); Local *l = &sc->locals[sc->nlocals++]; l->name = name; // -1 because local is expected to be already on the stack // ie sitting just below where stack_cur points uint8_t slot = C->stack_cur - 1; l->slot = slot; // printf("declaring local %s at %d, stack_cur is %d\n",l->name, l->slot, C->stack_cur); return slot; } static Local *locate_local(Compiler *C, char *name) { for (Scope *sc = C->scope; sc != NULL; sc = sc->outer) { for (int i = 0; i < sc->nlocals; i++) { Local *loc = &sc->locals[i]; if (0 == strcmp(loc->name, name)) return loc; } } return NULL; } // compiles a body, ie sequence of "toplevel" expressions in which // declarations are allowed static void cpl_body(Compiler *C, ObjArr *a, int startat, int flags) { CHECK(a->len > 0, "tried to compile empty body"); CHECK(startat < a->len, "tried to startat past end of body"); begin_scope(C); for (int i = startat; i < a->len - 1; i++) { cpl_expr(C, a->d[i], F_toplevel); cpl_op(C, OP_DROP); } cpl_expr(C, a->d[a->len - 1], F_toplevel | (flags & F_tail)); end_scope(C); } // the forms! static void set_form(Compiler *C, ObjArr *a, Op _, int flags) { Val target = a->d[1]; if (IS_STRING(target)) { // set variable: local or global char *name = AS_CSTRING(target); Local *loc = locate_local(C, name); if (loc) { // write local cpl_expr(C, a->d[2], 0); cpl_op(C, OP_SETLOCAL); cpl_byte(C, loc->slot); } else { // write global cpl_expr(C, a->d[2], 0); cpl_constop(C, OP_SETGLOBAL, target); } } else if (IS_ARR(target)) { ObjArr *pair = AS_ARR(target); CHECK(pair->len == 2, "can only set to (arr, ix) 2-pair"); // (value arr ix) <- TOS cpl_expr(C, a->d[2], 0); cpl_expr(C, pair->d[0], 0); cpl_expr(C, pair->d[1], 0); cpl_op(C, OP_SETIDX); } } static void do_form(Compiler *C, ObjArr *a, Op _, int flags) { cpl_body(C, a, 1, flags); } static void if_form(Compiler *C, ObjArr *a, Op _, int flags) { // (if cond if-true if-false) // cond // 0branch ->A // if-true // skip ->B // A: if-false // B: // never toplevel int downflags = flags & ~F_toplevel; int orig_stack_cur = C->stack_cur; cpl_expr(C, a->d[1], 0); cpl_op(C, OP_0BRANCH); size_t ph_a = placeholder(C); cpl_expr(C, a->d[2], downflags); cpl_op(C, OP_SKIP); size_t ph_b = placeholder(C); size_t dest_a = BYTECODE(C).len; int stack_cur_a = C->stack_cur; C->stack_cur = orig_stack_cur; cpl_expr(C, a->d[3], downflags); size_t dest_b = BYTECODE(C).len; int stack_cur_b = C->stack_cur; patch(C, ph_a, dest_a - ph_a - 2); patch(C, ph_b, dest_b - ph_b - 2); CHECK(stack_cur_a == stack_cur_b, "this should never happen"); CHECK(stack_cur_a == orig_stack_cur + 1, "this should never happen"); } static void when_form(Compiler *C, ObjArr *a, Op _, int flags) { // (when cond body...) // cond // 0branch -> A // body... // drop // A: // nil cpl_expr(C, a->d[1], 0); cpl_op(C, OP_0BRANCH); size_t ph = placeholder(C); cpl_body(C, a, 2, 0); cpl_op(C, OP_DROP); size_t dest = BYTECODE(C).len; patch(C, ph, dest - ph - 2); cpl_op(C, OP_NIL); } static void while_form(Compiler *C, ObjArr *a, Op _, int flags) { // (while cond body ...) // A: // cond // 0branch ->B // body .... // redo ->A // B: // nil (while loop always returns nil) size_t dest_a = BYTECODE(C).len; cpl_expr(C, a->d[1], 0); cpl_op(C, OP_0BRANCH); size_t ph_b = placeholder(C); cpl_body(C, a, 2, flags & ~F_tail); cpl_op(C, OP_DROP); cpl_op(C, OP_REDO); size_t ph_a = placeholder(C); size_t dest_b = BYTECODE(C).len; cpl_op(C, OP_NIL); patch(C, ph_a, ph_a - dest_a + 2); patch(C, ph_b, dest_b - ph_b - 2); } static void arith_form(Compiler *C, ObjArr *a, Op op, int flags) { cpl_expr(C, a->d[1], 0); cpl_expr(C, a->d[2], 0); cpl_op(C, op); } static void let_form(Compiler *C, ObjArr *a, Op _, int flags) { CHECK(IS_ARR(a->d[1]), "let's first argument must be list"); ObjArr *bindlist = AS_ARR(a->d[1]); CHECK(bindlist->len % 2 == 0, "unmatched binding in let"); int nbinds = bindlist->len / 2; begin_scope(C); for (int i = 0; i < nbinds; i++) { int ix = i * 2; Val name = bindlist->d[ix]; Val expr = bindlist->d[ix+1]; CHECK(IS_STRING(name), "binding name must be identifier"); cpl_expr(C, expr, 0); declare_local(C, AS_CSTRING(name)); } cpl_body(C, a, 2, flags); end_scope(C); } static void for_form(Compiler *C, ObjArr *a, Op _, int flags) { // (for (x n) ...) CHECK(IS_ARR(a->d[1]), "for needs binding list"); ObjArr *blist = AS_ARR(a->d[1]); CHECK(blist->len == 2, "for binding list must have length 2"); CHECK(IS_STRING(blist->d[0]), "can only bind to ident"); char *ivar = AS_CSTRING(blist->d[0]); begin_scope(C); cpl_constop(C, OP_LOADK, VAL_NUM(0)); uint8_t islot = declare_local(C, ivar); cpl_expr(C, blist->d[1], 0); uint8_t mslot = declare_local(C, "__max__"); // A // getlocal ivar // getlocal max // cmp // 0branch -> B // body ... // incr ivar // redo -> A // B: // nil size_t dest_A = BYTECODE(C).len; cpl_op(C, OP_GETLOCAL); cpl_byte(C, islot); cpl_op(C, OP_GETLOCAL); cpl_byte(C, mslot); cpl_op(C, OP_CMP); cpl_op(C, OP_0BRANCH); size_t ph_B = placeholder(C); cpl_body(C, a, 2, flags & ~F_tail); cpl_op(C, OP_DROP); cpl_op(C, OP_GETLOCAL); cpl_byte(C, islot); cpl_constop(C, OP_LOADK, VAL_NUM(1)); cpl_op(C, OP_ADD); cpl_op(C, OP_SETLOCAL); cpl_byte(C, islot); cpl_op(C, OP_DROP); cpl_op(C, OP_REDO); size_t ph_A = placeholder(C); size_t dest_B = BYTECODE(C).len; cpl_op(C, OP_NIL); patch(C, ph_A , ph_A - dest_A + 2); patch(C, ph_B ,dest_B - ph_B - 2); end_scope(C); } static void each_form(Compiler *C, ObjArr *a, Op _, int flags) { // (each (x a) ...) // returns nil, for now CHECK(IS_ARR(a->d[1]), "each needs binding list"); ObjArr *blist = AS_ARR(a->d[1]); CHECK(blist->len == 2, "each binding list must have length 2"); CHECK(IS_STRING(blist->d[0]), "can only bind to ident"); char *ivar = AS_CSTRING(blist->d[0]); begin_scope(C); cpl_constop(C, OP_LOADK, VAL_NUM(0)); uint8_t islot = declare_local(C, "__idx__"); cpl_expr(C, blist->d[1], 0); uint8_t aslot = declare_local(C, "__arr__"); cpl_op(C, OP_GETLOCAL); cpl_byte(C, aslot); cpl_op(C, OP_ARRLEN); uint8_t mslot = declare_local(C, "__max__"); cpl_op(C, OP_ARRNEW); uint8_t oslot = declare_local(C, "__out__"); cpl_op(C, OP_NIL); uint8_t vslot = declare_local(C, ivar); // A // getlocal idx // getlocal max // cmp // 0branch -> B // getlocal arr // getlocal idx // call 2 // setlocal ivar // getlocal out // body ... // arrappend // drop // incr idx // redo -> A // B: // nil size_t dest_A = BYTECODE(C).len; cpl_op(C, OP_GETLOCAL); cpl_byte(C, islot); cpl_op(C, OP_GETLOCAL); cpl_byte(C, mslot); cpl_op(C, OP_CMP); cpl_op(C, OP_0BRANCH); size_t ph_B = placeholder(C); cpl_op(C, OP_GETLOCAL); cpl_byte(C, aslot); cpl_op(C, OP_GETLOCAL); cpl_byte(C, islot); cpl_call(C, 2); cpl_op(C, OP_SETLOCAL); cpl_byte(C, vslot); cpl_op(C, OP_DROP); cpl_op(C, OP_GETLOCAL); cpl_byte(C, oslot); cpl_body(C, a, 2, flags & ~F_tail); cpl_op(C, OP_ARRAPPEND); cpl_op(C, OP_DROP); cpl_op(C, OP_GETLOCAL); cpl_byte(C, islot); cpl_constop(C, OP_LOADK, VAL_NUM(1)); cpl_op(C, OP_ADD); cpl_op(C, OP_SETLOCAL); cpl_byte(C, islot); cpl_op(C, OP_DROP); cpl_op(C, OP_REDO); size_t ph_A = placeholder(C); size_t dest_B = BYTECODE(C).len; cpl_op(C, OP_GETLOCAL); cpl_byte(C, oslot); patch(C, ph_A, ph_A - dest_A + 2); patch(C, ph_B, dest_B - ph_B - 2); end_scope(C); } static void def_form(Compiler *C, ObjArr *a, Op _, int flags) { CHECK(IS_STRING(a->d[1]), "def's first argument must be ident"); CHECK(flags & F_toplevel, "def only allowed at top level"); cpl_expr(C, a->d[2], 0); declare_local(C, AS_CSTRING(a->d[1])); // whatever is calling us will compile an OP_DROP next // or, well. not if we're in tail position. but i can't see // any circumstance where you'd want that anyway cpl_op(C, OP_NIL); } static void fn_form(Compiler *C, ObjArr *a, Op _, int flags) { // (fn (arg arg arg) body ...) CHECK(IS_ARR(a->d[1]), "fn's first argument must be list"); ObjArr *arglist = AS_ARR(a->d[1]); CHECK(arglist->len <= 255, "maximum 255 args for function"); uint8_t arity = arglist->len; ObjFunc *func = objfunc_new(C->S, arity); Compiler subcompiler = compiler_new(C->S, C, &func->ch); Compiler *SC = &subcompiler; begin_scope(SC); // when called, stack slot 0 contains the function itself, // stack slots 1..n contain passed arguments SC->stack_cur ++; declare_local(SC, "__func__"); for (int i = 0; i < arity; i++) { Val argname = arglist->d[i]; CHECK(IS_STRING(argname), "argument name must be identifier"); SC->stack_cur ++; declare_local(SC, AS_CSTRING(argname)); } cpl_body(SC, a, 2, F_tail); end_scope(SC); cpl_op(SC, OP_RET); cpl_constop(C, OP_LOADK, VAL_OBJ(func)); } static void defn_form(Compiler *C, ObjArr *a, Op _, int flags) { // todo: reduce redundancy CHECK(IS_ARR(a->d[1]), "defns first arg must be list"); ObjArr *blist = AS_ARR(a->d[1]); CHECK(blist->len > 0, "defn needs at least a function name"); CHECK(blist->len <= 256, "maximum 255 args for function"); CHECK(flags & F_toplevel, "defn only allowed at toplevel"); uint8_t arity = blist->len - 1; CHECK(IS_STRING(blist->d[0]), "func name must be ident"); char *fname = AS_CSTRING(blist->d[0]); ObjFunc *func = objfunc_new(C->S, arity); Compiler subcompiler = compiler_new(C->S, C, &func->ch); Compiler *SC = &subcompiler; begin_scope(SC); SC->stack_cur ++; declare_local(SC, fname); for (int i = 0; i < arity; i++) { Val argname = blist->d[i+1]; CHECK(IS_STRING(argname), "arg name must be identifier"); SC->stack_cur ++; declare_local(SC, AS_CSTRING(argname)); } cpl_body(SC, a, 2, F_tail); end_scope(SC); cpl_op(SC, OP_RET); cpl_constop(C, OP_LOADK, VAL_OBJ(func)); declare_local(C, fname); cpl_op(C, OP_NIL); } static void quote_form(Compiler *C, ObjArr *a, Op _, int flags) { CHECK(IS_STRING(a->d[1]), "can only quote strings for now"); cpl_constop(C, OP_LOADK, a->d[1]); } static void arrlit_form(Compiler *C, ObjArr *a, Op _, int flags) { CHECK(IS_ARR(a->d[1]), "can only arrlit a list"); ObjArr *exprs = AS_ARR(a->d[1]); cpl_op(C, OP_ARRNEW); for (int i = 0; i < exprs->len; i++) { cpl_expr(C, exprs->d[i], 0); cpl_op(C, OP_ARRAPPEND); } } typedef void (*form_compiler)(Compiler *C, ObjArr *a, Op op, int flags); typedef struct { char *name; int min_params; bool ellipsis; form_compiler action; Op op; } BuiltinForm; static BuiltinForm builtin_forms[] = { { "arrlit", 1, false, arrlit_form, 0 }, { "def", 2, false, def_form, 0 }, { "defn", 2, true, defn_form, 0 }, { "do", 1, true, do_form, 0 }, { "each", 2, true, each_form, 0 }, { "fn", 2, true, fn_form, 0 }, { "for", 2, true, for_form, 0 }, { "if", 3, false, if_form, 0 }, { "let", 2, true, let_form, 0 }, { "quote", 1, false, quote_form, 0 }, { "set!", 2, false, set_form, 0 }, { "when", 2, true, when_form, 0 }, { "while", 2, true, while_form, 0 }, #define ARITH_OP(str, op) \ { str, 2, false, arith_form, op }, ARITH_OP("+", OP_ADD) ARITH_OP("-", OP_SUB) ARITH_OP("*", OP_MUL) ARITH_OP("/", OP_DIV) ARITH_OP("=", OP_EQU) ARITH_OP("<", OP_CMP) ARITH_OP("%", OP_MOD) #undef ARITH_OP { 0 }, }; static BuiltinForm *find_builtinform(char *name) { for (BuiltinForm *b = builtin_forms; b->name != NULL; b++) if (0 == strcmp(b->name, name)) return b; return NULL; } static void cpl_expr(Compiler *C, Val v, int flags) { int stack_cur_a = C->stack_cur; int nlocals_a = 0; if (C->scope) { nlocals_a = C->scope->nlocals; } switch (val_type(v)) { case TY_NUM: case TY_NIL: case TY_BOOL: cpl_constop(C, OP_LOADK, v); break; case OTY_STRING:; Local *loc = locate_local(C, AS_CSTRING(v)); if (loc) { cpl_op(C, OP_GETLOCAL); cpl_byte(C, loc->slot); } else { cpl_constop(C, OP_GETGLOBAL, v); } break; case OTY_ARR:; ObjArr *a = AS_ARR(v); size_t len = a->len; CHECK(len > 0, "can't handle empty array"); Val first = a->d[0]; BuiltinForm *form = NULL; if (IS_STRING(first)) form = find_builtinform(AS_CSTRING(first)); if (form) { size_t nargs = len - 1; if (form->ellipsis) CHECK(nargs >= form->min_params, "%s requires at least %d parameters", form->name, form->min_params); else CHECK(nargs == form->min_params, "%s requires exactly %d parameters", form->name, form->min_params); form->action(C, a, form->op, flags); } else { // function call CHECK(len < 256, "max 255 args in a function call"); for (int i = 0; i < a->len; i++) cpl_expr(C, a->d[i], 0); if (flags & F_tail) cpl_tailcall(C, len); else cpl_call(C, len); } break; } int stack_cur_b = C->stack_cur; int nlocals_b = C->stack_cur; if (C->scope) { nlocals_b = C->scope->nlocals; } // every badthing expression returns exactly one value, // and might declare some locals as well, which also live on the stack // so (returned values) = (stack change) - (new locals) = 1 CHECK( (stack_cur_b - stack_cur_a) - (nlocals_b - nlocals_a) == 1, "stack corruption (compiler bug)"); CHECK( (flags & F_toplevel) || (nlocals_b == nlocals_a), "local declared not at top level (compiler bug)"); } static char buf[8193]; int main(int argc, char **argv) { State st = state_new(); State *S = &st; Chunk ch = chunk_new(S); S->do_disasm = false; S->do_trace = false; S->do_dumpsexpr = false; char *infile_names[16]; int infile_name_n = 0; for (int i = 1; i < argc; i++) { if (argv[i][0] != '-') { infile_names[infile_name_n++] = argv[i]; CHECK(infile_name_n < 16, "input file count exceeded"); continue; } switch (argv[i][1]) { case 'D': CHECK(strlen(argv[i]) > 1, "unknown option a"); switch (argv[i][2]) { case 'l': S->do_disasm = true; break; case 't': S->do_trace = true; break; case 's': S->do_dumpsexpr = true; break; default: ERROR("unknown option b"); } break; default: ERROR("unknown option c"); } } Compiler com = (Compiler){ 0 }; com.S = S; com.ch = &ch; if (infile_name_n == 0) infile_names[infile_name_n++] = "-"; for (int i = 0; i < infile_name_n; i++) { char *fname = infile_names[i]; FILE *infile = NULL; if (0 == strcmp(fname, "-")) infile = stdin; else infile = fopen(fname, "r"); if (infile == NULL) { perror("fopen"); exit(1); } fread(buf, 1, 8192, infile); buf[8192] = '\0'; ObjArr *top = read_exprs(S, buf); if (S->do_dumpsexpr) println_val(VAL_OBJ(top)); cpl_body(&com, top, 0, 0); } cpl_op(&com, OP_HALT); Thread th = thread_new(S); th.ch = &ch; S->th = &th; load_stdlib(S); return runvm(S); }