summaryrefslogtreecommitdiff
path: root/asm.py
blob: 3d6154371bdc3bcddb2aa427f9212a45ca928a3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3

from mnems import mnems

# print(mnems)
import sys

import collections

pc = 0
labels = {}
local_labels = {}
# map of labelname -> (dest, ifn)
label_wants = collections.defaultdict(set)
local_label_wants = collections.defaultdict(set)
output = bytearray()

class Wrong(Exception):
	def __init__(self, lineno, msg):
		self.lineno = lineno
		self.msg = msg
	def __str__(self):
		return f"{self.lineno}: {self.msg}"

def assemble_line(args, lineno):

	match args:
		case ["c", label] if label in labels:
			emit_call(labels[label])
		case ["c", label]:
			label_wants[label].add((pc,call_instr))
			emit(0x8000)
		case ['b', label] if label in local_labels:
			emit_obranch(local_labels[label])
		case ['b', label]:
			local_label_wants[label].add((pc, obranch_instr))
			emit(0x5000)
		case ['j', label] if label in local_labels:
			emit_jump(local_labels[label])
		case ['j', label]:
			local_label_wants[label].add((pc, jump_instr))
			emit(0x4000)

		case ["i", i1]:
			assemble_line(["i", i1, "nop"], lineno)
		case ["i", i1, i2]:
			opc1 = mnems.index(i1)
			opc2 = mnems.index(i2)
			emit_instrs(opc1, opc2)
		case ["l", val]:
			emit_lit(int(val,16))

		case [':', label]:
			if label in labels:
				raise Wrong(lineno, "label defined twice "+label)
			labels[label] = pc
			for (dest, ifn) in label_wants[label]:
				put_at(dest, ifn(pc, dest))
			del label_wants[label]

			local_labels.clear()
			for u,v in local_label_wants.items():
				raise Wrong(lineno, "unfulfilled local label "+u)
			local_label_wants.clear()

		case ['%', label]:
			if label in local_labels:
				raise Wrong(lineno, "local label defined twice "+label)
			local_labels[label] = pc
			for (dest, ifn) in local_label_wants[label]:
				put_at(dest, ifn(pc, dest))
			del local_label_wants[label]

		case []: pass


		case _:
			raise Wrong(lineno, "unknown wordtype")

def emit_call(addr):
	emit(call_instr(addr))
def emit_instrs(opc1, opc2):
	emit(instrs_instr(opc1, opc2))
def emit_lit(val):
	emit(lit_instr(val))
def emit_obranch(addr):
	emit(obranch_instr(addr, origin=pc))
def emit_jump(addr):
	emit(jump_instr(addr, origin=pc))

def emit(word):
	oldlen = len(output)
	output.extend(word.to_bytes(2))
	global pc
	pc += 2
	assert len(output) == 2 + oldlen, "beeite"

def put_at(addr, word):
	assert addr % 2 == 0, "ooeu"
	output[addr:addr+2] = word.to_bytes(2)

def lit_instr(val):
	return (val&0b1111111111111) | 0b0110000000000000
def instrs_instr(opc1, opc2):
	return (opc1 << 7) | opc2
def call_instr(addr,*_):
	return (addr>>1)|0b1000000000000000
def obranch_instr(target, origin):
	rel = target - origin + 2048
	return (rel&0b111111111111)|0b0101000000000000
def jump_instr(target, origin):
	rel = target - origin + 2048
	return (rel&0b111111111111)|0b0100000000000000

import sys
def main():
	for ix, line in enumerate(sys.stdin):
		line = line.strip()
		args = line.split()
		assemble_line(args, ix+1)

	for u,v in label_wants.items():
		print("unfulfilled global label",u, file=sys.stderr)
	sys.stdout.buffer.write(output)

if __name__ == "__main__": main()