#include #include #include #include "com.h" #include "mem.h" #include "chunk.h" #include "ast.h" #include "read.h" #include "util.h" #define BYTECODE(C) (C->ch->bc) Chunk chunk_new(State *S) { return (Chunk){ 0 }; } Compiler compiler_new(Compiler *outer, Chunk *ch) { return (Compiler){ .S = outer->S, .ch = ch, .stack_cur = 0, }; } static void compile_byte(Compiler *C, uint8_t byte); static void compile_opcode(Compiler *C, Op opcode); static size_t compile_constant(Compiler *C, Val v); static int stack_effect_of(Op opcode) { switch (opcode) { case OP_LOADK: case OP_GETGLOBAL: case OP_GETLOCAL: case OP_NIL: case OP_TRUE: case OP_FALSE: return 1; case OP_PUTS: case OP_PRINT: case OP_SKIP: case OP_REDO: case OP_SETGLOBAL: case OP_RET: case OP_HALT: return 0; case OP_DROP: case OP_0BRANCH: case OP_ADD: case OP_SUB: case OP_MUL: case OP_DIV: case OP_CMP: case OP_EQU: case OP_MOD: return -1; // these ones depend on their argument. handle them specifically case OP_CALL: case OP_ENDSCOPE: return 0; default: ERROR("unknown stack effect of opcode %d",opcode); } } // compile that byte directly static void compile_byte(Compiler *C, uint8_t byte) { Chunk *ch = C->ch; if (ch->bc.len == ch->bc.cap) { size_t newsz = (ch->bc.cap == 0 ? 8 : ch->bc.cap * 2); ch->bc.d = RENEW_ARR(C->S, ch->bc.d, uint8_t, ch->bc.cap, newsz); ch->bc.cap = newsz; } size_t ix = ch->bc.len; ch->bc.d[ix] = byte; ch->bc.len ++; } // compile an opcode, keeping track of its stack effect static void compile_opcode(Compiler *C, Op opcode) { int stack_effect = stack_effect_of(opcode); C->stack_cur += stack_effect; compile_byte(C, opcode); } // add a new constant to the constant table, and return its index // (but return the index of any existing identical constant instead of // inserting a duplicate) static size_t compile_constant(Compiler *C, Val v) { Chunk *ch = C->ch; for (int i = 0; i < ch->consts.len; i ++) if (val_equal(v, ch->consts.d[i])) return i; if (ch->consts.len == ch->consts.cap) { size_t newsz = (ch->consts.cap == 0 ? 8 : ch->consts.cap *2); ch->consts.d = RENEW_ARR(C->S, ch->consts.d, Val, ch->consts.cap, newsz); ch->consts.cap = newsz; } size_t ix = ch->consts.len; ch->consts.d[ix] = v; ch->consts.len ++; return ix; } // len is 1 + number of args static void compile_call_instr(Compiler *C, uint8_t len) { compile_opcode(C, OP_CALL); compile_byte(C, len); C->stack_cur -= len; } static void compile_endscope_instr(Compiler *C, uint8_t nlocals) { compile_opcode(C, OP_ENDSCOPE); compile_byte(C, nlocals); C->stack_cur -= nlocals; } static size_t placeholder(Compiler *C) { size_t old_ix = BYTECODE(C).len; compile_byte(C, 0x00); compile_byte(C, 0x00); return old_ix; } static void patch(Compiler *C, size_t addr, uint16_t val) { BYTECODE(C).d[addr] = val & 0xff; BYTECODE(C).d[addr+1] = (val & 0xff00) >> 8; } // ---- static void compile_node(Compiler *C, AstNode a); typedef void (*form_compiler)(Compiler *C, AstVec l, Op op); void single_form(Compiler *C, AstVec l, Op op) { compile_node(C, l.vals[1]); compile_opcode(C, op); } void set_form(Compiler *C, AstVec l, Op _) { AstNode ident = l.vals[1]; CHECK(ident.ty == AST_IDENT, "set's first argument must be identifier"); ObjString *o = objstring_copy_cstr(C->S, ident.as.str); compile_node(C, l.vals[2]); compile_opcode(C, OP_SETGLOBAL); compile_byte(C, compile_constant(C, VAL_OBJ(o))); } void do_form(Compiler *C, AstVec l, Op _) { for (int i = 1; i < l.len - 1; i++) { compile_node(C, l.vals[i]); compile_opcode(C, OP_DROP); } compile_node(C, l.vals[l.len - 1]); } void if_form(Compiler *C, AstVec l, Op _) { // (if cond if-true if-false) // cond // 0branch ->A // if-true // skip ->B // A: if-false // B: compile_node(C, l.vals[1]); compile_opcode(C, OP_0BRANCH); size_t ph_a = placeholder(C); compile_node(C, l.vals[2]); compile_opcode(C, OP_SKIP); size_t ph_b = placeholder(C); size_t dest_a = BYTECODE(C).len; compile_node(C, l.vals[3]); size_t dest_b = BYTECODE(C).len; patch(C, ph_a, dest_a - ph_a - 2); patch(C, ph_b, dest_b - ph_b - 2); } void while_form(Compiler *C, AstVec l, Op _) { // (while cond body ...) // A: // cond // 0branch ->B // body .... // redo ->A // B: // nil (while loop always returns nil) size_t dest_a = BYTECODE(C).len; compile_node(C, l.vals[1]); compile_opcode(C, OP_0BRANCH); size_t ph_b = placeholder(C); for (int i = 2; i < l.len; i++) { compile_node(C, l.vals[i]); compile_opcode(C, OP_DROP); } compile_opcode(C, OP_REDO); size_t ph_a = placeholder(C); size_t dest_b = BYTECODE(C).len; compile_opcode(C, OP_NIL); patch(C, ph_a, ph_a - dest_a + 2); patch(C, ph_b, dest_b - ph_b - 2); } void arith_form(Compiler *C, AstVec l, Op op) { compile_node(C, l.vals[1]); compile_node(C, l.vals[2]); compile_opcode(C, op); } void fn_form(Compiler *C, AstVec l, Op _) { // (fn (arg arg arg) body ...) CHECK(l.vals[1].ty == AST_LIST, "fn's first argument must be list"); AstVec arglist = l.vals[1].as.list; CHECK(arglist.len <= 255, "maximum 255 args for function"); uint8_t arity = arglist.len; ObjFunc *func = objfunc_new(C->S, arity); Compiler subcompiler = compiler_new(C, &func->ch); for (int i = 2; i < l.len - 1; i++) { compile_node(&subcompiler, l.vals[i]); compile_opcode(&subcompiler, OP_DROP); } compile_node(&subcompiler, l.vals[l.len-1]); compile_opcode(&subcompiler, OP_RET); compile_opcode(C, OP_LOADK); compile_byte(C, compile_constant(C, VAL_OBJ(func))); } static void begin_scope(Compiler *C) { Scope *sc = malloc(sizeof(Scope)); CHECK(sc != NULL, "memory fail"); memset(sc, 0, sizeof(Scope)); sc->outer = C->scope; C->scope = sc; } static void end_scope(Compiler *C) { Scope *sc = C->scope; CHECK(sc != NULL, "attempt to end nonexistent scope"); C->scope = sc->outer; // printf("ending scope with %d locals, named: \n", sc->nlocals); // for (int i = 0; i < sc->nlocals; i++) { // Local loc = sc->locals[i]; // printf("\t%3d %s\n",loc.slot, loc.name); // } compile_endscope_instr(C, sc->nlocals); free(sc); } static void declare_local(Compiler *C, char *name) { Scope *sc = C->scope; CHECK(sc != NULL, "can't declare local outside of scope"); Local *l = &sc->locals[sc->nlocals++]; l->name = name; // -1 because local is expected to be already on the stack // ie sitting just below where stack_cur points l->slot = C->stack_cur - 1; } void let_form(Compiler *C, AstVec l, Op _) { CHECK(l.vals[1].ty == AST_LIST, "let's first argument must be list"); AstVec bindlist = l.vals[1].as.list; CHECK(bindlist.len % 2 == 0, "unmatched binding in let"); int nbinds = bindlist.len / 2; begin_scope(C); for (int i = 0; i < nbinds; i++) { int ix = i * 2; AstNode name = bindlist.vals[ix]; AstNode expr = bindlist.vals[ix+1]; CHECK(name.ty == AST_IDENT, "binding name must be identifier"); compile_node(C, expr); declare_local(C, name.as.str); } for (int i = 2; i < l.len - 1; i++) { compile_node(C, l.vals[i]); compile_opcode(C, OP_DROP); } compile_node(C, l.vals[l.len-1]); end_scope(C); } typedef struct { char *name; int min_params; bool ellipsis; form_compiler action; Op op; } BuiltinForm; static BuiltinForm builtin_forms[] = { { "puts", 1, false, single_form, OP_PUTS }, { "print", 1, false, single_form, OP_PRINT }, { "set", 2, false, set_form, 0 }, { "do", 1, true, do_form, 0 }, { "if", 3, false, if_form, 0 }, { "while", 2, true, while_form, 0 }, { "fn", 2, true, fn_form, 0 }, { "let", 2, true, let_form, 0 }, #define ARITH_OP(str, op) \ { str, 2, false, arith_form, op }, ARITH_OP("+", OP_ADD) ARITH_OP("-", OP_SUB) ARITH_OP("*", OP_MUL) ARITH_OP("/", OP_DIV) ARITH_OP("=", OP_EQU) ARITH_OP("<", OP_CMP) ARITH_OP("%", OP_MOD) #undef ARITH_OP { 0 }, }; typedef struct { char *name; Op op; } BuiltinIdent; static BuiltinIdent builtin_idents[] = { { "true", OP_TRUE }, { "false", OP_FALSE }, { "nil", OP_NIL }, { 0 }, }; static void compile_node(Compiler *C, AstNode a) { switch (a.ty) { case AST_IDENT:; char *ident = a.as.str; bool found_builtin = false; for (BuiltinIdent *b = builtin_idents; b->name != NULL; b++) { if (0 == strcmp(b->name, ident)) { compile_opcode(C, b->op); found_builtin = true; break; } } if (!found_builtin) { // read local or global variable bool found_local = false; if (C->scope != NULL) { Scope *sc = C->scope; for (int i = 0; i < sc->nlocals; i++) { Local loc = sc->locals[i]; if (0 == strcmp(ident, loc.name)) { compile_opcode(C, OP_GETLOCAL); compile_byte(C, loc.slot); found_local = true; break; } } } if (!found_local) { // read global ObjString *o = objstring_copy_cstr(C->S, a.as.str); compile_opcode(C, OP_GETGLOBAL); compile_byte(C, compile_constant(C, VAL_OBJ(o))); } } break; case AST_NUM: compile_opcode(C, OP_LOADK); compile_byte(C, compile_constant(C, VAL_NUM(a.as.num))); break; case AST_STRING: { ObjString *o = objstring_copy_cstr(C->S, a.as.str); compile_opcode(C, OP_LOADK); compile_byte(C, compile_constant(C, VAL_OBJ(o))); break; } case AST_LIST: { AstVec l = a.as.list; CHECK(l.len > 0, "can't handle empty list"); BuiltinForm *form = NULL; if (l.vals[0].ty == AST_IDENT) { char *head = l.vals[0].as.str; for (BuiltinForm *b = builtin_forms; b->name != NULL; b++) { if (0 == strcmp(b->name, head)) { form = b; break; } } } if (form != NULL) { size_t nparams = l.len - 1; if (form->ellipsis) CHECK(nparams >= form->min_params, "%s requires at least %d parameters", form->name, form->min_params); else CHECK(nparams == form->min_params, "%s requires exactly %d parameters", form->name, form->min_params); form->action(C, l, form->op); } else { // function call // (f a b c ) if (l.len > 255) { fprintf(stderr, "can't have more than 255 args in a function call\n"); exit(1); } for (int i = 0; i < l.len; i++) { compile_node(C, l.vals[i]); } compile_call_instr(C, l.len); } break; } } } int main(int argc, char **argv) { State st = state_new(); State *S = &st; Chunk ch = chunk_new(S); S->do_disasm = false; S->do_trace = false; for (int i = 1; i < argc; i++) { if (argv[i][0] != '-') break; switch (argv[i][1]) { case 'D': CHECK(strlen(argv[i]) > 1, "unknown option a"); switch (argv[i][2]) { case 'l': S->do_disasm = true; break; case 't': S->do_trace = true; break; default: ERROR("unknown option b"); } break; default: ERROR("unknown option c"); } } AstNode an = read(); Compiler com = (Compiler){ .S = S, .ch = &ch, }; Compiler *C = &com; compile_node(C, an); compile_opcode(C, OP_PUTS); compile_opcode(C, OP_HALT); Thread th = thread_new(S); th.ch = &ch; S->th = &th; return runvm(S); } #undef CHECK #undef ER