Last active
January 19, 2026 16:34
-
-
Save mizchi/8a997f35bf3983657bc94afdf5afac21 to your computer and use it in GitHub Desktop.
WASM Jit compiler C/Zig by Claude Code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Minimal Baseline JIT for WebAssembly (ARM64) | |
| // Zig implementation | |
| const std = @import("std"); | |
| const fs = std.fs; | |
| const mem = std.mem; | |
| const Allocator = std.mem.Allocator; | |
| // ============================================================================ | |
| // Code buffer for JIT compilation | |
| // ============================================================================ | |
| const CodeBuffer = struct { | |
| code: []align(4096) u8, | |
| size: usize, | |
| capacity: usize, | |
| const Self = @This(); | |
| fn init(capacity: usize) !Self { | |
| const c = @cImport({ | |
| @cInclude("sys/mman.h"); | |
| }); | |
| // PROT_READ | PROT_WRITE | PROT_EXEC | |
| const prot: u32 = c.PROT_READ | c.PROT_WRITE | c.PROT_EXEC; | |
| // MAP_PRIVATE | MAP_ANONYMOUS (MAP_ANON on macOS) | MAP_JIT | |
| var flags: i32 = c.MAP_PRIVATE | c.MAP_ANON; | |
| if (comptime @import("builtin").os.tag == .macos) { | |
| flags |= c.MAP_JIT; | |
| } | |
| const result = c.mmap(null, capacity, @intCast(prot), flags, -1, 0); | |
| if (result == c.MAP_FAILED) { | |
| return error.MmapFailed; | |
| } | |
| const mapped: [*]align(4096) u8 = @ptrCast(@alignCast(result)); | |
| return Self{ | |
| .code = mapped[0..capacity], | |
| .size = 0, | |
| .capacity = capacity, | |
| }; | |
| } | |
| fn deinit(self: *Self) void { | |
| const c = @cImport(@cInclude("sys/mman.h")); | |
| _ = c.munmap(self.code.ptr, self.code.len); | |
| } | |
| fn beginWrite(self: *Self) void { | |
| _ = self; | |
| if (comptime @import("builtin").os.tag == .macos) { | |
| // pthread_jit_write_protect_np(0) | |
| const c = @cImport(@cInclude("pthread.h")); | |
| c.pthread_jit_write_protect_np(0); | |
| } | |
| } | |
| fn endWrite(self: *Self) void { | |
| if (comptime @import("builtin").os.tag == .macos) { | |
| const c = @cImport(@cInclude("pthread.h")); | |
| c.pthread_jit_write_protect_np(1); | |
| // sys_icache_invalidate | |
| const cache = @cImport(@cInclude("libkern/OSCacheControl.h")); | |
| cache.sys_icache_invalidate(self.code.ptr, self.size * 4); | |
| } | |
| } | |
| fn emit(self: *Self, inst: u32) void { | |
| if (self.size >= self.capacity / 4) { | |
| @panic("Code buffer overflow"); | |
| } | |
| const ptr: *u32 = @ptrCast(@alignCast(self.code.ptr + self.size * 4)); | |
| ptr.* = inst; | |
| self.size += 1; | |
| } | |
| fn pos(self: *const Self) usize { | |
| return self.size; | |
| } | |
| fn patch(self: *Self, position: usize, inst: u32) void { | |
| const ptr: *u32 = @ptrCast(@alignCast(self.code.ptr + position * 4)); | |
| ptr.* = inst; | |
| } | |
| fn getFunction(self: *const Self, comptime T: type) T { | |
| return @ptrCast(self.code.ptr); | |
| } | |
| }; | |
| // ============================================================================ | |
| // ARM64 instruction encoding | |
| // ============================================================================ | |
| const Arm64 = struct { | |
| // MOV immediate (MOVZ) | |
| fn movz(rd: u5, imm: u16, shift: u2) u32 { | |
| return 0xD2800000 | (@as(u32, shift) << 21) | (@as(u32, imm) << 5) | rd; | |
| } | |
| // MOV immediate (MOVK) | |
| fn movk(rd: u5, imm: u16, shift: u2) u32 { | |
| return 0xF2800000 | (@as(u32, shift) << 21) | (@as(u32, imm) << 5) | rd; | |
| } | |
| // MOV register (64-bit) - ORR rd, xzr, rm | |
| fn mov(rd: u5, rm: u5) u32 { | |
| return 0xAA0003E0 | (@as(u32, rm) << 16) | rd; | |
| } | |
| // MOV register (32-bit) | |
| fn movW(rd: u5, rm: u5) u32 { | |
| return 0x2A0003E0 | (@as(u32, rm) << 16) | rd; | |
| } | |
| // ADD (64-bit register) | |
| fn add(rd: u5, rn: u5, rm: u5) u32 { | |
| return 0x8B000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd; | |
| } | |
| // ADD (32-bit register) | |
| fn addW(rd: u5, rn: u5, rm: u5) u32 { | |
| return 0x0B000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd; | |
| } | |
| // ADD immediate (64-bit) | |
| fn addImm(rd: u5, rn: u5, imm: u12) u32 { | |
| return 0x91000000 | (@as(u32, imm) << 10) | (@as(u32, rn) << 5) | rd; | |
| } | |
| // SUB (32-bit register) | |
| fn subW(rd: u5, rn: u5, rm: u5) u32 { | |
| return 0x4B000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd; | |
| } | |
| // SUB immediate (64-bit) | |
| fn subImm(rd: u5, rn: u5, imm: u12) u32 { | |
| return 0xD1000000 | (@as(u32, imm) << 10) | (@as(u32, rn) << 5) | rd; | |
| } | |
| // MUL (32-bit) | |
| fn mulW(rd: u5, rn: u5, rm: u5) u32 { | |
| return 0x1B007C00 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd; | |
| } | |
| // LDR (64-bit, unsigned offset) | |
| fn ldr(rt: u5, rn: u5, offset: u12) u32 { | |
| return 0xF9400000 | (@as(u32, offset / 8) << 10) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // LDR (32-bit, unsigned offset) | |
| fn ldrW(rt: u5, rn: u5, offset: u12) u32 { | |
| return 0xB9400000 | (@as(u32, offset / 4) << 10) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // LDR (32-bit, register offset) | |
| fn ldrWReg(rt: u5, rn: u5, rm: u5) u32 { | |
| return 0xB8606800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // STR (64-bit, unsigned offset) | |
| fn str(rt: u5, rn: u5, offset: u12) u32 { | |
| return 0xF9000000 | (@as(u32, offset / 8) << 10) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // STR (32-bit, unsigned offset) | |
| fn strW(rt: u5, rn: u5, offset: u12) u32 { | |
| return 0xB9000000 | (@as(u32, offset / 4) << 10) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // STR (32-bit, register offset) | |
| fn strWReg(rt: u5, rn: u5, rm: u5) u32 { | |
| return 0xB8206800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // STP (64-bit) | |
| fn stp(rt1: u5, rt2: u5, rn: u5, offset: i7) u32 { | |
| const imm: u7 = @bitCast(@as(i7, @divTrunc(offset, 8))); | |
| return 0xA9000000 | (@as(u32, imm) << 15) | (@as(u32, rt2) << 10) | (@as(u32, rn) << 5) | rt1; | |
| } | |
| // LDP (64-bit) | |
| fn ldp(rt1: u5, rt2: u5, rn: u5, offset: i7) u32 { | |
| const imm: u7 = @bitCast(@as(i7, @divTrunc(offset, 8))); | |
| return 0xA9400000 | (@as(u32, imm) << 15) | (@as(u32, rt2) << 10) | (@as(u32, rn) << 5) | rt1; | |
| } | |
| // LDR float single (unsigned offset) | |
| fn ldrS(rt: u5, rn: u5, offset: u12) u32 { | |
| return 0xBD400000 | (@as(u32, offset / 4) << 10) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // STR float single (unsigned offset) | |
| fn strS(rt: u5, rn: u5, offset: u12) u32 { | |
| return 0xBD000000 | (@as(u32, offset / 4) << 10) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // LDR float single (register offset) | |
| fn ldrSReg(rt: u5, rn: u5, rm: u5) u32 { | |
| return 0xBC606800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // STR float single (register offset) | |
| fn strSReg(rt: u5, rn: u5, rm: u5) u32 { | |
| return 0xBC206800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rt; | |
| } | |
| // FMOV (GP to float) | |
| fn fmovSFromW(sd: u5, wn: u5) u32 { | |
| return 0x1E270000 | (@as(u32, wn) << 5) | sd; | |
| } | |
| // FADD single | |
| fn faddS(rd: u5, rn: u5, rm: u5) u32 { | |
| return 0x1E202800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd; | |
| } | |
| // FMUL single | |
| fn fmulS(rd: u5, rn: u5, rm: u5) u32 { | |
| return 0x1E200800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd; | |
| } | |
| // CMP (32-bit register) | |
| fn cmpW(rn: u5, rm: u5) u32 { | |
| return 0x6B00001F | (@as(u32, rm) << 16) | (@as(u32, rn) << 5); | |
| } | |
| // CSET (32-bit) | |
| fn csetW(rd: u5, cond: u4) u32 { | |
| const inv_cond = cond ^ 1; | |
| return 0x1A9F07E0 | (@as(u32, inv_cond) << 12) | rd; | |
| } | |
| // B (unconditional) | |
| fn b(offset: i26) u32 { | |
| const imm: u26 = @bitCast(offset); | |
| return 0x14000000 | @as(u32, imm); | |
| } | |
| // CBNZ (32-bit) | |
| fn cbnzW(rt: u5, offset: i19) u32 { | |
| const imm: u19 = @bitCast(offset); | |
| return 0x35000000 | (@as(u32, imm) << 5) | rt; | |
| } | |
| // RET | |
| fn ret() u32 { | |
| return 0xD65F03C0; | |
| } | |
| // Condition codes | |
| const COND_HS: u4 = 2; // unsigned >= | |
| const COND_LS: u4 = 9; // unsigned <= | |
| }; | |
| // ============================================================================ | |
| // Wasm parser | |
| // ============================================================================ | |
| const WasmReader = struct { | |
| data: []const u8, | |
| pos: usize, | |
| const Self = @This(); | |
| fn init(data: []const u8) Self { | |
| return Self{ .data = data, .pos = 0 }; | |
| } | |
| fn readU8(self: *Self) u8 { | |
| const b = self.data[self.pos]; | |
| self.pos += 1; | |
| return b; | |
| } | |
| fn readU32Leb128(self: *Self) u32 { | |
| var result: u32 = 0; | |
| var shift: u5 = 0; | |
| while (true) { | |
| const byte = self.readU8(); | |
| result |= @as(u32, byte & 0x7F) << shift; | |
| if (byte & 0x80 == 0) break; | |
| shift += 7; | |
| } | |
| return result; | |
| } | |
| fn readI32Leb128(self: *Self) i32 { | |
| var result: i32 = 0; | |
| var shift: u5 = 0; | |
| var byte: u8 = 0; | |
| while (true) { | |
| byte = self.readU8(); | |
| result |= @as(i32, @intCast(byte & 0x7F)) << shift; | |
| shift += 7; | |
| if (byte & 0x80 == 0) break; | |
| } | |
| if (shift < 32 and (byte & 0x40) != 0) { | |
| result |= @as(i32, -1) << shift; | |
| } | |
| return result; | |
| } | |
| fn readI64Leb128(self: *Self) i64 { | |
| var result: i64 = 0; | |
| var shift: u6 = 0; | |
| var byte: u8 = 0; | |
| while (true) { | |
| byte = self.readU8(); | |
| result |= @as(i64, @intCast(byte & 0x7F)) << shift; | |
| shift += 7; | |
| if (byte & 0x80 == 0) break; | |
| } | |
| if (shift < 64 and (byte & 0x40) != 0) { | |
| result |= @as(i64, -1) << shift; | |
| } | |
| return result; | |
| } | |
| fn readF32(self: *Self) f32 { | |
| const bytes = self.data[self.pos..][0..4]; | |
| self.pos += 4; | |
| return @bitCast(bytes.*); | |
| } | |
| fn skip(self: *Self, n: usize) void { | |
| self.pos += n; | |
| } | |
| }; | |
| // ============================================================================ | |
| // Wasm module structures | |
| // ============================================================================ | |
| const WasmFuncType = struct { | |
| param_count: u32, | |
| result_count: u32, | |
| }; | |
| const WasmFunc = struct { | |
| type_idx: u32, | |
| local_count: u32, | |
| code: []const u8, | |
| }; | |
| const WasmGlobal = struct { | |
| init_val: i64, | |
| }; | |
| const WasmModule = struct { | |
| types: []WasmFuncType, | |
| funcs: []WasmFunc, | |
| globals: []WasmGlobal, | |
| run_func_idx: ?u32, | |
| min_memory_pages: u32, | |
| allocator: Allocator, | |
| const Self = @This(); | |
| fn parse(allocator: Allocator, data: []const u8) !Self { | |
| var r = WasmReader.init(data); | |
| // Check magic and version | |
| const magic = r.readU8() | (@as(u32, r.readU8()) << 8) | (@as(u32, r.readU8()) << 16) | (@as(u32, r.readU8()) << 24); | |
| const version = r.readU8() | (@as(u32, r.readU8()) << 8) | (@as(u32, r.readU8()) << 16) | (@as(u32, r.readU8()) << 24); | |
| if (magic != 0x6D736100) return error.InvalidMagic; | |
| if (version != 1) return error.UnsupportedVersion; | |
| var types: std.ArrayListUnmanaged(WasmFuncType) = .{}; | |
| defer types.deinit(allocator); | |
| var func_type_indices: std.ArrayListUnmanaged(u32) = .{}; | |
| defer func_type_indices.deinit(allocator); | |
| var funcs: std.ArrayListUnmanaged(WasmFunc) = .{}; | |
| defer funcs.deinit(allocator); | |
| var globals: std.ArrayListUnmanaged(WasmGlobal) = .{}; | |
| defer globals.deinit(allocator); | |
| var run_func_idx: ?u32 = null; | |
| var min_memory_pages: u32 = 0; | |
| // Parse sections | |
| while (r.pos < r.data.len) { | |
| const section_id = r.readU8(); | |
| const section_size = r.readU32Leb128(); | |
| const section_end = r.pos + section_size; | |
| switch (section_id) { | |
| 1 => { // Type section | |
| const count = r.readU32Leb128(); | |
| try types.ensureTotalCapacity(allocator, count); | |
| for (0..count) |_| { | |
| _ = r.readU8(); // 0x60 | |
| const param_count = r.readU32Leb128(); | |
| r.skip(param_count); | |
| const result_count = r.readU32Leb128(); | |
| r.skip(result_count); | |
| try types.append(allocator, .{ | |
| .param_count = param_count, | |
| .result_count = result_count, | |
| }); | |
| } | |
| }, | |
| 3 => { // Function section | |
| const count = r.readU32Leb128(); | |
| try func_type_indices.ensureTotalCapacity(allocator, count); | |
| for (0..count) |_| { | |
| try func_type_indices.append(allocator, r.readU32Leb128()); | |
| } | |
| }, | |
| 5 => { // Memory section | |
| const count = r.readU32Leb128(); | |
| if (count > 0) { | |
| const flags = r.readU8(); | |
| min_memory_pages = r.readU32Leb128(); | |
| if (flags & 1 != 0) _ = r.readU32Leb128(); // max | |
| } | |
| }, | |
| 6 => { // Global section | |
| const count = r.readU32Leb128(); | |
| try globals.ensureTotalCapacity(allocator, count); | |
| for (0..count) |_| { | |
| _ = r.readU8(); // type | |
| _ = r.readU8(); // mutable | |
| const opcode = r.readU8(); | |
| const init_val: i64 = switch (opcode) { | |
| 0x41 => r.readI32Leb128(), | |
| 0x42 => r.readI64Leb128(), | |
| else => return error.UnsupportedGlobalInit, | |
| }; | |
| _ = r.readU8(); // end | |
| try globals.append(allocator, .{ .init_val = init_val }); | |
| } | |
| }, | |
| 7 => { // Export section | |
| const count = r.readU32Leb128(); | |
| for (0..count) |_| { | |
| const name_len = r.readU32Leb128(); | |
| const name = r.data[r.pos..][0..name_len]; | |
| r.skip(name_len); | |
| const kind = r.readU8(); | |
| const idx = r.readU32Leb128(); | |
| if (kind == 0 and mem.eql(u8, name, "run")) { | |
| run_func_idx = idx; | |
| } | |
| } | |
| }, | |
| 10 => { // Code section | |
| const count = r.readU32Leb128(); | |
| try funcs.ensureTotalCapacity(allocator, count); | |
| for (0..count) |i| { | |
| const func_size = r.readU32Leb128(); | |
| const func_end = r.pos + func_size; | |
| // Parse locals | |
| const local_group_count = r.readU32Leb128(); | |
| var total_locals: u32 = 0; | |
| for (0..local_group_count) |_| { | |
| total_locals += r.readU32Leb128(); | |
| _ = r.readU8(); // type | |
| } | |
| const code_start = r.pos; | |
| const code = r.data[code_start..func_end]; | |
| try funcs.append(allocator, .{ | |
| .type_idx = func_type_indices.items[i], | |
| .local_count = total_locals, | |
| .code = code, | |
| }); | |
| r.pos = func_end; | |
| } | |
| }, | |
| else => r.skip(section_size), | |
| } | |
| r.pos = section_end; | |
| } | |
| // Transfer ownership - don't deinit these | |
| const types_slice = try allocator.dupe(WasmFuncType, types.items); | |
| const funcs_slice = try allocator.dupe(WasmFunc, funcs.items); | |
| const globals_slice = try allocator.dupe(WasmGlobal, globals.items); | |
| return Self{ | |
| .types = types_slice, | |
| .funcs = funcs_slice, | |
| .globals = globals_slice, | |
| .run_func_idx = run_func_idx, | |
| .min_memory_pages = min_memory_pages, | |
| .allocator = allocator, | |
| }; | |
| } | |
| fn deinit(self: *Self) void { | |
| self.allocator.free(self.types); | |
| self.allocator.free(self.funcs); | |
| self.allocator.free(self.globals); | |
| } | |
| }; | |
| // ============================================================================ | |
| // JIT Compiler | |
| // ============================================================================ | |
| const MAX_BLOCKS = 64; | |
| const MAX_STACK = 64; | |
| const BlockInfo = struct { | |
| start_pos: usize, | |
| patch_sites: [256]usize, | |
| patch_count: usize, | |
| is_loop: bool, | |
| }; | |
| const JitCompiler = struct { | |
| code: *CodeBuffer, | |
| mod: *const WasmModule, | |
| stack_depth: usize, | |
| blocks: [MAX_BLOCKS]BlockInfo, | |
| block_depth: usize, | |
| locals_offset: u12, | |
| param_count: u32, | |
| local_count: u32, | |
| const Self = @This(); | |
| fn init(code: *CodeBuffer, mod: *const WasmModule) Self { | |
| return Self{ | |
| .code = code, | |
| .mod = mod, | |
| .stack_depth = 0, | |
| .blocks = undefined, | |
| .block_depth = 0, | |
| .locals_offset = 16, | |
| .param_count = 0, | |
| .local_count = 0, | |
| }; | |
| } | |
| fn getIntReg(stack_pos: usize) u5 { | |
| return @intCast(9 + (stack_pos % 7)); | |
| } | |
| fn getFloatReg(stack_pos: usize) u5 { | |
| return @intCast(stack_pos % 8); | |
| } | |
| fn emitMovImm32(self: *Self, rd: u5, val: i32) void { | |
| const uval: u32 = @bitCast(val); | |
| self.code.emit(Arm64.movz(rd, @truncate(uval), 0)); | |
| if (uval > 0xFFFF) { | |
| self.code.emit(Arm64.movk(rd, @truncate(uval >> 16), 1)); | |
| } | |
| } | |
| fn emitMovImm64(self: *Self, rd: u5, val: i64) void { | |
| const uval: u64 = @bitCast(val); | |
| self.code.emit(Arm64.movz(rd, @truncate(uval), 0)); | |
| if ((uval >> 16) & 0xFFFF != 0) { | |
| self.code.emit(Arm64.movk(rd, @truncate(uval >> 16), 1)); | |
| } | |
| if ((uval >> 32) & 0xFFFF != 0) { | |
| self.code.emit(Arm64.movk(rd, @truncate(uval >> 32), 2)); | |
| } | |
| if ((uval >> 48) & 0xFFFF != 0) { | |
| self.code.emit(Arm64.movk(rd, @truncate(uval >> 48), 3)); | |
| } | |
| } | |
| fn compileFunction(self: *Self, func: *const WasmFunc) void { | |
| const ftype = &self.mod.types[func.type_idx]; | |
| self.param_count = ftype.param_count; | |
| self.local_count = func.local_count; | |
| const total_locals = self.param_count + self.local_count; | |
| const frame_size: u12 = @intCast(((16 + total_locals * 8) + 15) & ~@as(u32, 15)); | |
| // Prologue | |
| self.code.emit(0xA9BF7BFD); // stp x29, x30, [sp, #-16]! | |
| self.code.emit(Arm64.addImm(29, 31, 0)); // mov x29, sp | |
| self.code.emit(Arm64.subImm(31, 31, frame_size)); | |
| self.code.emit(Arm64.stp(19, 20, 31, 0)); | |
| // Setup registers | |
| self.code.emit(Arm64.mov(19, 1)); // x19 = memory base | |
| self.code.emit(Arm64.mov(20, 2)); // x20 = globals base | |
| // Store first parameter | |
| if (self.param_count > 0) { | |
| self.code.emit(Arm64.str(0, 31, self.locals_offset)); | |
| } | |
| // Initialize locals to 0 | |
| for (self.param_count..total_locals) |i| { | |
| self.code.emit(Arm64.movz(8, 0, 0)); | |
| self.code.emit(Arm64.str(8, 31, self.locals_offset + @as(u12, @intCast(i * 8)))); | |
| } | |
| // Parse and compile | |
| var r = WasmReader.init(func.code); | |
| self.stack_depth = 0; | |
| self.block_depth = 0; | |
| // Push implicit function block | |
| self.blocks[0] = .{ | |
| .start_pos = self.code.pos(), | |
| .patch_sites = undefined, | |
| .patch_count = 0, | |
| .is_loop = false, | |
| }; | |
| self.block_depth = 1; | |
| while (r.pos < r.data.len) { | |
| const opcode = r.readU8(); | |
| switch (opcode) { | |
| 0x00, 0x01 => {}, // unreachable, nop | |
| 0x02 => { // block | |
| _ = r.readI32Leb128(); | |
| self.blocks[self.block_depth] = .{ | |
| .start_pos = self.code.pos(), | |
| .patch_sites = undefined, | |
| .patch_count = 0, | |
| .is_loop = false, | |
| }; | |
| self.block_depth += 1; | |
| }, | |
| 0x03 => { // loop | |
| _ = r.readI32Leb128(); | |
| self.blocks[self.block_depth] = .{ | |
| .start_pos = self.code.pos(), | |
| .patch_sites = undefined, | |
| .patch_count = 0, | |
| .is_loop = true, | |
| }; | |
| self.block_depth += 1; | |
| }, | |
| 0x0B => { // end | |
| self.block_depth -= 1; | |
| if (self.block_depth == 0) break; // Function end | |
| const block = &self.blocks[self.block_depth]; | |
| const end_pos = self.code.pos(); | |
| // Patch forward branches | |
| for (0..block.patch_count) |i| { | |
| const patch_pos = block.patch_sites[i]; | |
| const offset: i32 = @intCast(@as(isize, @intCast(end_pos)) - @as(isize, @intCast(patch_pos))); | |
| const old_inst = self.code.code[patch_pos * 4 ..][0..4]; | |
| const old: u32 = @bitCast(old_inst.*); | |
| if ((old & 0xFC000000) == 0x14000000) { | |
| self.code.patch(patch_pos, Arm64.b(@truncate(offset))); | |
| } else if ((old & 0xFF000000) == 0x35000000) { | |
| const rt: u5 = @truncate(old); | |
| self.code.patch(patch_pos, Arm64.cbnzW(rt, @truncate(offset))); | |
| } | |
| } | |
| }, | |
| 0x0C => { // br | |
| const depth = r.readU32Leb128(); | |
| const target_block = self.block_depth - 1 - depth; | |
| const block = &self.blocks[target_block]; | |
| if (block.is_loop) { | |
| const offset: i26 = @intCast(@as(isize, @intCast(block.start_pos)) - @as(isize, @intCast(self.code.pos()))); | |
| self.code.emit(Arm64.b(offset)); | |
| } else { | |
| block.patch_sites[block.patch_count] = self.code.pos(); | |
| block.patch_count += 1; | |
| self.code.emit(Arm64.b(0)); | |
| } | |
| }, | |
| 0x0D => { // br_if | |
| const depth = r.readU32Leb128(); | |
| self.stack_depth -= 1; | |
| const cond_reg = getIntReg(self.stack_depth); | |
| const target_block = self.block_depth - 1 - depth; | |
| const block = &self.blocks[target_block]; | |
| if (block.is_loop) { | |
| const offset: i19 = @intCast(@as(isize, @intCast(block.start_pos)) - @as(isize, @intCast(self.code.pos())) - 1); | |
| self.code.emit(Arm64.cbnzW(cond_reg, offset)); | |
| } else { | |
| block.patch_sites[block.patch_count] = self.code.pos(); | |
| block.patch_count += 1; | |
| self.code.emit(Arm64.cbnzW(cond_reg, 0)); | |
| } | |
| }, | |
| 0x20 => { // local.get | |
| const idx = r.readU32Leb128(); | |
| const target = getIntReg(self.stack_depth); | |
| self.code.emit(Arm64.ldr(target, 31, self.locals_offset + @as(u12, @intCast(idx * 8)))); | |
| self.stack_depth += 1; | |
| }, | |
| 0x21 => { // local.set | |
| const idx = r.readU32Leb128(); | |
| self.stack_depth -= 1; | |
| const src = getIntReg(self.stack_depth); | |
| self.code.emit(Arm64.str(src, 31, self.locals_offset + @as(u12, @intCast(idx * 8)))); | |
| }, | |
| 0x22 => { // local.tee | |
| const idx = r.readU32Leb128(); | |
| const src = getIntReg(self.stack_depth - 1); | |
| self.code.emit(Arm64.str(src, 31, self.locals_offset + @as(u12, @intCast(idx * 8)))); | |
| }, | |
| 0x23 => { // global.get | |
| const idx = r.readU32Leb128(); | |
| const target = getIntReg(self.stack_depth); | |
| self.code.emit(Arm64.ldrW(target, 20, @intCast(idx * 4))); | |
| self.stack_depth += 1; | |
| }, | |
| 0x28 => { // i32.load | |
| _ = r.readU32Leb128(); // align | |
| const offset = r.readU32Leb128(); | |
| self.stack_depth -= 1; | |
| const addr = getIntReg(self.stack_depth); | |
| const target = getIntReg(self.stack_depth); | |
| if (offset > 0) { | |
| self.emitMovImm32(8, @intCast(offset)); | |
| self.code.emit(Arm64.addW(addr, addr, 8)); | |
| } | |
| self.code.emit(Arm64.ldrWReg(target, 19, addr)); | |
| self.stack_depth += 1; | |
| }, | |
| 0x36 => { // i32.store | |
| _ = r.readU32Leb128(); | |
| const offset = r.readU32Leb128(); | |
| self.stack_depth -= 1; | |
| const value = getIntReg(self.stack_depth); | |
| self.stack_depth -= 1; | |
| const addr = getIntReg(self.stack_depth); | |
| if (offset > 0) { | |
| self.emitMovImm32(8, @intCast(offset)); | |
| self.code.emit(Arm64.addW(addr, addr, 8)); | |
| } | |
| self.code.emit(Arm64.strWReg(value, 19, addr)); | |
| }, | |
| 0x2A => { // f32.load | |
| _ = r.readU32Leb128(); | |
| const offset = r.readU32Leb128(); | |
| self.stack_depth -= 1; | |
| const addr = getIntReg(self.stack_depth); | |
| const target = getFloatReg(self.stack_depth); | |
| if (offset > 0) { | |
| self.emitMovImm32(8, @intCast(offset)); | |
| self.code.emit(Arm64.addW(addr, addr, 8)); | |
| } | |
| self.code.emit(Arm64.ldrSReg(target, 19, addr)); | |
| self.stack_depth += 1; | |
| }, | |
| 0x38 => { // f32.store | |
| _ = r.readU32Leb128(); | |
| const offset = r.readU32Leb128(); | |
| self.stack_depth -= 1; | |
| const value = getFloatReg(self.stack_depth); | |
| self.stack_depth -= 1; | |
| const addr = getIntReg(self.stack_depth); | |
| if (offset > 0) { | |
| self.emitMovImm32(8, @intCast(offset)); | |
| self.code.emit(Arm64.addW(addr, addr, 8)); | |
| } | |
| self.code.emit(Arm64.strSReg(value, 19, addr)); | |
| }, | |
| 0x41 => { // i32.const | |
| const val = r.readI32Leb128(); | |
| const target = getIntReg(self.stack_depth); | |
| self.emitMovImm32(target, val); | |
| self.stack_depth += 1; | |
| }, | |
| 0x42 => { // i64.const | |
| const val = r.readI64Leb128(); | |
| const target = getIntReg(self.stack_depth); | |
| self.emitMovImm64(target, val); | |
| self.stack_depth += 1; | |
| }, | |
| 0x43 => { // f32.const | |
| const val = r.readF32(); | |
| const target = getFloatReg(self.stack_depth); | |
| const bits: u32 = @bitCast(val); | |
| self.emitMovImm32(8, @bitCast(bits)); | |
| self.code.emit(Arm64.fmovSFromW(target, 8)); | |
| self.stack_depth += 1; | |
| }, | |
| 0x6A => { // i32.add | |
| self.stack_depth -= 1; | |
| const b_reg = getIntReg(self.stack_depth); | |
| self.stack_depth -= 1; | |
| const a_reg = getIntReg(self.stack_depth); | |
| const target = getIntReg(self.stack_depth); | |
| self.code.emit(Arm64.addW(target, a_reg, b_reg)); | |
| self.stack_depth += 1; | |
| }, | |
| 0x6B => { // i32.sub | |
| self.stack_depth -= 1; | |
| const b_reg = getIntReg(self.stack_depth); | |
| self.stack_depth -= 1; | |
| const a_reg = getIntReg(self.stack_depth); | |
| const target = getIntReg(self.stack_depth); | |
| self.code.emit(Arm64.subW(target, a_reg, b_reg)); | |
| self.stack_depth += 1; | |
| }, | |
| 0x6C => { // i32.mul | |
| self.stack_depth -= 1; | |
| const b_reg = getIntReg(self.stack_depth); | |
| self.stack_depth -= 1; | |
| const a_reg = getIntReg(self.stack_depth); | |
| const target = getIntReg(self.stack_depth); | |
| self.code.emit(Arm64.mulW(target, a_reg, b_reg)); | |
| self.stack_depth += 1; | |
| }, | |
| 0x4D => { // i32.le_u | |
| self.stack_depth -= 1; | |
| const b_reg = getIntReg(self.stack_depth); | |
| self.stack_depth -= 1; | |
| const a_reg = getIntReg(self.stack_depth); | |
| const target = getIntReg(self.stack_depth); | |
| self.code.emit(Arm64.cmpW(a_reg, b_reg)); | |
| self.code.emit(Arm64.csetW(target, Arm64.COND_LS)); | |
| self.stack_depth += 1; | |
| }, | |
| 0x4F => { // i32.ge_u | |
| self.stack_depth -= 1; | |
| const b_reg = getIntReg(self.stack_depth); | |
| self.stack_depth -= 1; | |
| const a_reg = getIntReg(self.stack_depth); | |
| const target = getIntReg(self.stack_depth); | |
| self.code.emit(Arm64.cmpW(a_reg, b_reg)); | |
| self.code.emit(Arm64.csetW(target, Arm64.COND_HS)); | |
| self.stack_depth += 1; | |
| }, | |
| 0xA7 => { // i32.wrap_i64 | |
| self.stack_depth -= 1; | |
| const src = getIntReg(self.stack_depth); | |
| const target = getIntReg(self.stack_depth); | |
| if (target != src) { | |
| self.code.emit(Arm64.movW(target, src)); | |
| } | |
| self.stack_depth += 1; | |
| }, | |
| 0x92 => { // f32.add | |
| self.stack_depth -= 1; | |
| const b_reg = getFloatReg(self.stack_depth); | |
| self.stack_depth -= 1; | |
| const a_reg = getFloatReg(self.stack_depth); | |
| const target = getFloatReg(self.stack_depth); | |
| self.code.emit(Arm64.faddS(target, a_reg, b_reg)); | |
| self.stack_depth += 1; | |
| }, | |
| 0x94 => { // f32.mul | |
| self.stack_depth -= 1; | |
| const b_reg = getFloatReg(self.stack_depth); | |
| self.stack_depth -= 1; | |
| const a_reg = getFloatReg(self.stack_depth); | |
| const target = getFloatReg(self.stack_depth); | |
| self.code.emit(Arm64.fmulS(target, a_reg, b_reg)); | |
| self.stack_depth += 1; | |
| }, | |
| else => { | |
| std.debug.print("Unsupported opcode: 0x{X:0>2}\n", .{opcode}); | |
| @panic("Unsupported opcode"); | |
| }, | |
| } | |
| } | |
| // Epilogue | |
| if (self.stack_depth > 0) { | |
| const result_reg = getIntReg(self.stack_depth - 1); | |
| if (result_reg != 0) { | |
| self.code.emit(Arm64.mov(0, result_reg)); | |
| } | |
| } else { | |
| self.code.emit(Arm64.movz(0, 0, 0)); | |
| } | |
| self.code.emit(Arm64.ldp(19, 20, 31, 0)); | |
| self.code.emit(Arm64.addImm(31, 31, frame_size)); | |
| self.code.emit(0xA8C17BFD); // ldp x29, x30, [sp], #16 | |
| self.code.emit(Arm64.ret()); | |
| } | |
| }; | |
| // ============================================================================ | |
| // Main | |
| // ============================================================================ | |
| pub fn main() !void { | |
| var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
| defer _ = gpa.deinit(); | |
| const allocator = gpa.allocator(); | |
| const args = try std.process.argsAlloc(allocator); | |
| defer std.process.argsFree(allocator, args); | |
| if (args.len < 2) { | |
| std.debug.print("Usage: {s} <wasm-file> [N]\n", .{args[0]}); | |
| return; | |
| } | |
| // Read wasm file | |
| const file = try fs.cwd().openFile(args[1], .{}); | |
| defer file.close(); | |
| const data = try file.readToEndAlloc(allocator, 1024 * 1024); | |
| defer allocator.free(data); | |
| std.debug.print("Loaded {d} bytes of wasm\n", .{data.len}); | |
| // Parse module | |
| var mod = try WasmModule.parse(allocator, data); | |
| defer mod.deinit(); | |
| std.debug.print("Parsed: {d} types, {d} funcs, {d} globals\n", .{ mod.types.len, mod.funcs.len, mod.globals.len }); | |
| const run_idx = mod.run_func_idx orelse { | |
| std.debug.print("No 'run' export found\n", .{}); | |
| return; | |
| }; | |
| std.debug.print("Found export 'run' -> func[{d}]\n", .{run_idx}); | |
| // Allocate memory | |
| const mem_size = if (mod.min_memory_pages > 0) mod.min_memory_pages * 65536 else 65536; | |
| const memory = try allocator.alloc(u8, mem_size); | |
| defer allocator.free(memory); | |
| @memset(memory, 0); | |
| // Allocate globals | |
| const globals = try allocator.alloc(i32, mod.globals.len + 1); | |
| defer allocator.free(globals); | |
| for (mod.globals, 0..) |g, i| { | |
| globals[i] = @truncate(g.init_val); | |
| std.debug.print("Global[{d}] = {d}\n", .{ i, globals[i] }); | |
| } | |
| // Create JIT | |
| var code = try CodeBuffer.init(65536); | |
| defer code.deinit(); | |
| code.beginWrite(); | |
| var jit = JitCompiler.init(&code, &mod); | |
| jit.compileFunction(&mod.funcs[run_idx]); | |
| code.endWrite(); | |
| std.debug.print("Generated {d} instructions\n", .{code.size}); | |
| // Get argument | |
| const arg: i64 = if (args.len > 2) try std.fmt.parseInt(i64, args[2], 10) else 10; | |
| std.debug.print("Running with N={d}\n", .{arg}); | |
| // Call JIT function | |
| const JitFunc = *const fn (i64, [*]u8, [*]i32) callconv(.c) i64; | |
| const func = code.getFunction(JitFunc); | |
| const result = func(arg, memory.ptr, globals.ptr); | |
| std.debug.print("JIT result: {d}\n", .{result}); | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ;; benchmark target | |
| (module | |
| ;; 8 pages satisfy the memory requirement for roughly N <= 200. | |
| (memory 8 8) | |
| (global $size_of_f32 i32 (i32.const 4)) | |
| ;; Matrix multiplication of 2 NxN matrices. | |
| ;; The function returns 0 to keep it aligned to other benchmark functions. | |
| ;; | |
| ;; This function uses the linear memory in the following way: | |
| ;; | |
| ;; - mem[0..N*N): `lhs` matrix | |
| ;; - mem[N*N..N*N*2): `rhs` matrix | |
| ;; - mem[N*N*2..N*N*3): `result` matrix | |
| ;; | |
| ;; - There need to be enough linear memory pages to provide | |
| ;; at least N*N*3 elements to operate on to run this function | |
| ;; - Each element is a `f32` and occupies 4 bytes. | |
| ;; - There is no padding or alignment for the matrices to keep it simple. | |
| ;; | |
| ;; Implements the following pseudo-code: | |
| ;; | |
| ;; fn matmul(n: i64) -> i64 { | |
| ;; offset_lhs = 0 | |
| ;; offset_rhs = n*n | |
| ;; offset_res = rhs*2; | |
| ;; for i in 0..n { | |
| ;; for j in 0..n { | |
| ;; mem[offset_res + (i * n) + j] = 0 | |
| ;; for k in 0..n { | |
| ;; mem[offset_res + (i * n) + j] += mem[offset_lhs + (i * n) + k] * mem[offset_rhs + (k * n) + j] | |
| ;; } | |
| ;; } | |
| ;; } | |
| ;; } | |
| (func (export "run") (param $N i64) (result i64) | |
| (local $offset_lhs i32) ;; offset in bytes to `lhs` matrix | |
| (local $offset_rhs i32) ;; offset in bytes to `rhs` matrix | |
| (local $offset_res i32) ;; offset in bytes to result matrix | |
| (local $n i32) | |
| (local $i i32) | |
| (local $j i32) | |
| (local $k i32) | |
| (local $tmp i32) | |
| ;; n = N as i32 | |
| (local.set $n (i32.wrap_i64 (local.get $N))) | |
| ;; offset_lhs = 0 | |
| (local.set $offset_lhs (i32.const 0)) | |
| ;; offset_rhs = N * N | |
| (local.set $offset_rhs (i32.mul (local.get $n) (local.get $n))) | |
| ;; offset_res = offset_rhs * 2 | |
| (local.set $offset_res (i32.mul (local.get $offset_rhs) (i32.const 2))) | |
| (block $break_i | |
| ;; i = 0 | |
| (local.set $i (i32.const 0)) | |
| (loop $continue_i | |
| ;; if i >= n: break | |
| (br_if $break_i (i32.ge_u (local.get $i) (local.get $n))) | |
| (block $break_j | |
| ;; j = 0 | |
| (local.set $j (i32.const 0)) | |
| (loop $continue_j | |
| ;; if j >= n: break | |
| (br_if $break_j (i32.ge_u (local.get $j) (local.get $n))) | |
| ;; tmp = offset_res + (i * n) + j | |
| (local.set $tmp | |
| (i32.mul | |
| (i32.add | |
| (local.get $offset_res) | |
| (i32.add | |
| (i32.mul (local.get $i) (local.get $n)) | |
| (local.get $j) | |
| ) | |
| ) | |
| (global.get $size_of_f32) | |
| ) | |
| ) | |
| ;; mem[tmp] = 0 | |
| (f32.store (local.get $tmp) (f32.const 0.0)) | |
| (block $break_k | |
| ;; k = 0 | |
| (local.set $k (i32.const 0)) | |
| (loop $continue_k | |
| ;; if k >= n: break | |
| (br_if $break_k (i32.ge_u (local.get $k) (local.get $n))) | |
| ;; mem[tmp] += mem[offset_lhs + (i * n) + k] * mem[offset_rhs + (k * n) + j] | |
| (f32.store | |
| (local.get $tmp) | |
| (f32.add | |
| (f32.load (local.get $tmp)) | |
| (f32.mul | |
| (f32.load | |
| (i32.mul | |
| ;; offset_lhs + (i * n) + k | |
| (i32.add | |
| (local.get $offset_lhs) | |
| (i32.add | |
| (i32.mul (local.get $i) (local.get $n)) | |
| (local.get $k) | |
| ) | |
| ) | |
| (global.get $size_of_f32) | |
| ) | |
| ) | |
| (f32.load | |
| (i32.mul | |
| ;; offset_rhs + (k * n) + j | |
| (i32.add | |
| (local.get $offset_rhs) | |
| (i32.add | |
| (i32.mul (local.get $k) (local.get $n)) | |
| (local.get $j) | |
| ) | |
| ) | |
| (global.get $size_of_f32) | |
| ) | |
| ) | |
| ) | |
| ) | |
| ) | |
| ;; k += 1 | |
| (local.set $k (i32.add (local.get $k) (i32.const 1))) | |
| (br $continue_k) | |
| ) | |
| ) | |
| ;; j += 1 | |
| (local.set $j (i32.add (local.get $j) (i32.const 1))) | |
| (br $continue_j) | |
| ) | |
| ) | |
| ;; i += 1 | |
| (local.set $i (i32.add (local.get $i) (i32.const 1))) | |
| (br $continue_i) | |
| ) | |
| ) | |
| (i64.const 0) | |
| ) | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Minimal Baseline JIT for WebAssembly (ARM64) | |
| // Supports matmul benchmark subset of instructions | |
| #include <stdio.h> | |
| #include <stdlib.h> | |
| #include <stdint.h> | |
| #include <string.h> | |
| #include <sys/mman.h> | |
| #ifdef __APPLE__ | |
| #include <pthread.h> | |
| #include <libkern/OSCacheControl.h> | |
| #endif | |
| // ============================================================================ | |
| // Debug macros | |
| // ============================================================================ | |
| #define DEBUG 0 | |
| #if DEBUG | |
| #define DBG(...) fprintf(stderr, __VA_ARGS__) | |
| #else | |
| #define DBG(...) | |
| #endif | |
| // ============================================================================ | |
| // Code buffer for JIT compilation | |
| // ============================================================================ | |
| typedef struct { | |
| uint32_t *code; | |
| size_t capacity; | |
| size_t size; | |
| } CodeBuffer; | |
| CodeBuffer *codebuf_new(size_t capacity) { | |
| CodeBuffer *buf = malloc(sizeof(CodeBuffer)); | |
| #ifdef __APPLE__ | |
| buf->code = mmap(NULL, capacity * sizeof(uint32_t), | |
| PROT_READ | PROT_WRITE | PROT_EXEC, | |
| MAP_PRIVATE | MAP_ANONYMOUS | MAP_JIT, -1, 0); | |
| #else | |
| buf->code = mmap(NULL, capacity * sizeof(uint32_t), | |
| PROT_READ | PROT_WRITE | PROT_EXEC, | |
| MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); | |
| #endif | |
| if (buf->code == MAP_FAILED) { | |
| perror("mmap"); | |
| exit(1); | |
| } | |
| buf->capacity = capacity; | |
| buf->size = 0; | |
| return buf; | |
| } | |
| void codebuf_begin_write(CodeBuffer *buf) { | |
| #ifdef __APPLE__ | |
| pthread_jit_write_protect_np(0); | |
| #endif | |
| (void)buf; | |
| } | |
| void codebuf_end_write(CodeBuffer *buf) { | |
| #ifdef __APPLE__ | |
| pthread_jit_write_protect_np(1); | |
| sys_icache_invalidate(buf->code, buf->size * sizeof(uint32_t)); | |
| #else | |
| __builtin___clear_cache((char*)buf->code, (char*)(buf->code + buf->size)); | |
| #endif | |
| } | |
| void codebuf_emit(CodeBuffer *buf, uint32_t inst) { | |
| if (buf->size >= buf->capacity) { | |
| fprintf(stderr, "Code buffer overflow\n"); | |
| exit(1); | |
| } | |
| buf->code[buf->size++] = inst; | |
| } | |
| size_t codebuf_pos(CodeBuffer *buf) { | |
| return buf->size; | |
| } | |
| void codebuf_patch(CodeBuffer *buf, size_t pos, uint32_t inst) { | |
| buf->code[pos] = inst; | |
| } | |
| void codebuf_free(CodeBuffer *buf) { | |
| munmap(buf->code, buf->capacity * sizeof(uint32_t)); | |
| free(buf); | |
| } | |
| // ============================================================================ | |
| // ARM64 instruction encoding helpers | |
| // ============================================================================ | |
| // Registers: | |
| // x0-x7: argument/return registers | |
| // x8: indirect result | |
| // x9-x15: caller-saved temps | |
| // x16-x17: intra-procedure-call | |
| // x18: platform register | |
| // x19-x28: callee-saved | |
| // x29: frame pointer | |
| // x30: link register | |
| // sp: stack pointer | |
| // For our JIT: | |
| // x0: first argument (N for matmul) | |
| // x19: memory base | |
| // x20: globals base | |
| // x21: locals base (on native stack) | |
| // x9-x15: wasm value stack (mapped dynamically) | |
| // s0-s15: float registers for wasm stack | |
| // MOV immediate (MOVZ - zero and move) | |
| uint32_t arm64_movz(int rd, uint16_t imm, int shift) { | |
| // shift: 0=bits[15:0], 1=bits[31:16], 2=bits[47:32], 3=bits[63:48] | |
| return 0xD2800000 | (shift << 21) | ((uint32_t)imm << 5) | rd; | |
| } | |
| // MOV immediate (MOVK - keep other bits) | |
| uint32_t arm64_movk(int rd, uint16_t imm, int shift) { | |
| return 0xF2800000 | (shift << 21) | ((uint32_t)imm << 5) | rd; | |
| } | |
| // MOV register (64-bit) | |
| uint32_t arm64_mov(int rd, int rm) { | |
| // ORR rd, xzr, rm | |
| return 0xAA0003E0 | (rm << 16) | rd; | |
| } | |
| // MOV register (32-bit) | |
| uint32_t arm64_mov_w(int rd, int rm) { | |
| // ORR wd, wzr, wm | |
| return 0x2A0003E0 | (rm << 16) | rd; | |
| } | |
| // ADD (64-bit register) | |
| uint32_t arm64_add(int rd, int rn, int rm) { | |
| return 0x8B000000 | (rm << 16) | (rn << 5) | rd; | |
| } | |
| // ADD (32-bit register) | |
| uint32_t arm64_add_w(int rd, int rn, int rm) { | |
| return 0x0B000000 | (rm << 16) | (rn << 5) | rd; | |
| } | |
| // ADD immediate (64-bit) | |
| uint32_t arm64_add_imm(int rd, int rn, uint16_t imm) { | |
| return 0x91000000 | ((uint32_t)imm << 10) | (rn << 5) | rd; | |
| } | |
| // SUB (64-bit register) | |
| uint32_t arm64_sub(int rd, int rn, int rm) { | |
| return 0xCB000000 | (rm << 16) | (rn << 5) | rd; | |
| } | |
| // SUB (32-bit register) | |
| uint32_t arm64_sub_w(int rd, int rn, int rm) { | |
| return 0x4B000000 | (rm << 16) | (rn << 5) | rd; | |
| } | |
| // SUB immediate (64-bit) | |
| uint32_t arm64_sub_imm(int rd, int rn, uint16_t imm) { | |
| return 0xD1000000 | ((uint32_t)imm << 10) | (rn << 5) | rd; | |
| } | |
| // MUL (64-bit) | |
| uint32_t arm64_mul(int rd, int rn, int rm) { | |
| return 0x9B007C00 | (rm << 16) | (rn << 5) | rd; | |
| } | |
| // MUL (32-bit) | |
| uint32_t arm64_mul_w(int rd, int rn, int rm) { | |
| return 0x1B007C00 | (rm << 16) | (rn << 5) | rd; | |
| } | |
| // LDR (64-bit from [base + unsigned offset]) | |
| uint32_t arm64_ldr(int rt, int rn, int offset) { | |
| // offset must be 8-byte aligned | |
| return 0xF9400000 | ((offset / 8) << 10) | (rn << 5) | rt; | |
| } | |
| // LDR (32-bit from [base + unsigned offset]) | |
| uint32_t arm64_ldr_w(int rt, int rn, int offset) { | |
| // offset must be 4-byte aligned | |
| return 0xB9400000 | ((offset / 4) << 10) | (rn << 5) | rt; | |
| } | |
| // LDR (32-bit from [base + register]) | |
| uint32_t arm64_ldr_w_reg(int rt, int rn, int rm) { | |
| // LSL #0 (no scale) | |
| return 0xB8606800 | (rm << 16) | (rn << 5) | rt; | |
| } | |
| // STR (64-bit to [base + unsigned offset]) | |
| uint32_t arm64_str(int rt, int rn, int offset) { | |
| return 0xF9000000 | ((offset / 8) << 10) | (rn << 5) | rt; | |
| } | |
| // STR (32-bit to [base + unsigned offset]) | |
| uint32_t arm64_str_w(int rt, int rn, int offset) { | |
| return 0xB9000000 | ((offset / 4) << 10) | (rn << 5) | rt; | |
| } | |
| // STR (32-bit to [base + register]) | |
| uint32_t arm64_str_w_reg(int rt, int rn, int rm) { | |
| return 0xB8206800 | (rm << 16) | (rn << 5) | rt; | |
| } | |
| // STP (store pair, 64-bit) | |
| uint32_t arm64_stp(int rt1, int rt2, int rn, int offset) { | |
| // offset in units of 8 bytes, signed 7-bit | |
| int imm7 = (offset / 8) & 0x7F; | |
| return 0xA9000000 | (imm7 << 15) | (rt2 << 10) | (rn << 5) | rt1; | |
| } | |
| // LDP (load pair, 64-bit) | |
| uint32_t arm64_ldp(int rt1, int rt2, int rn, int offset) { | |
| int imm7 = (offset / 8) & 0x7F; | |
| return 0xA9400000 | (imm7 << 15) | (rt2 << 10) | (rn << 5) | rt1; | |
| } | |
| // LDR single float (s-register from [base + unsigned offset]) | |
| uint32_t arm64_ldr_s(int rt, int rn, int offset) { | |
| return 0xBD400000 | ((offset / 4) << 10) | (rn << 5) | rt; | |
| } | |
| // STR single float (s-register to [base + unsigned offset]) | |
| uint32_t arm64_str_s(int rt, int rn, int offset) { | |
| return 0xBD000000 | ((offset / 4) << 10) | (rn << 5) | rt; | |
| } | |
| // LDR single float (s-register from [base + register]) | |
| uint32_t arm64_ldr_s_reg(int rt, int rn, int rm) { | |
| return 0xBC606800 | (rm << 16) | (rn << 5) | rt; | |
| } | |
| // STR single float (s-register to [base + register]) | |
| uint32_t arm64_str_s_reg(int rt, int rn, int rm) { | |
| return 0xBC206800 | (rm << 16) | (rn << 5) | rt; | |
| } | |
| // FMOV (move GP to float register) | |
| uint32_t arm64_fmov_s_from_w(int sd, int wn) { | |
| return 0x1E270000 | (wn << 5) | sd; | |
| } | |
| // FMOV (move float to GP register) | |
| uint32_t arm64_fmov_w_from_s(int wd, int sn) { | |
| return 0x1E260000 | (sn << 5) | wd; | |
| } | |
| // FADD (single) | |
| uint32_t arm64_fadd_s(int rd, int rn, int rm) { | |
| return 0x1E202800 | (rm << 16) | (rn << 5) | rd; | |
| } | |
| // FMUL (single) | |
| uint32_t arm64_fmul_s(int rd, int rn, int rm) { | |
| return 0x1E200800 | (rm << 16) | (rn << 5) | rd; | |
| } | |
| // CMP (64-bit register) - actually SUBS xzr, rn, rm | |
| uint32_t arm64_cmp(int rn, int rm) { | |
| return 0xEB00001F | (rm << 16) | (rn << 5); | |
| } | |
| // CMP (32-bit register) | |
| uint32_t arm64_cmp_w(int rn, int rm) { | |
| return 0x6B00001F | (rm << 16) | (rn << 5); | |
| } | |
| // CSET (set register to 1 if condition true) | |
| // CSET rd, cond == CSINC rd, xzr, xzr, invert(cond) | |
| uint32_t arm64_cset_w(int rd, int cond) { | |
| int inv_cond = cond ^ 1; // invert condition | |
| return 0x1A9F07E0 | (inv_cond << 12) | rd; | |
| } | |
| // B (unconditional branch, PC-relative) | |
| uint32_t arm64_b(int offset) { | |
| // offset is in instructions (4 bytes each) | |
| return 0x14000000 | (offset & 0x3FFFFFF); | |
| } | |
| // B.cond (conditional branch) | |
| uint32_t arm64_bcond(int cond, int offset) { | |
| return 0x54000000 | ((offset & 0x7FFFF) << 5) | cond; | |
| } | |
| // Condition codes | |
| #define COND_EQ 0 | |
| #define COND_NE 1 | |
| #define COND_HS 2 // unsigned >= | |
| #define COND_LO 3 // unsigned < | |
| #define COND_HI 8 // unsigned > | |
| #define COND_LS 9 // unsigned <= | |
| #define COND_GE 10 // signed >= | |
| #define COND_LT 11 // signed < | |
| #define COND_GT 12 // signed > | |
| #define COND_LE 13 // signed <= | |
| // CBZ (compare and branch if zero) | |
| uint32_t arm64_cbz_w(int rt, int offset) { | |
| return 0x34000000 | ((offset & 0x7FFFF) << 5) | rt; | |
| } | |
| // CBNZ (compare and branch if non-zero) | |
| uint32_t arm64_cbnz_w(int rt, int offset) { | |
| return 0x35000000 | ((offset & 0x7FFFF) << 5) | rt; | |
| } | |
| // RET | |
| uint32_t arm64_ret(void) { | |
| return 0xD65F03C0; | |
| } | |
| // BL (branch with link - function call) | |
| uint32_t arm64_bl(int offset) { | |
| return 0x94000000 | (offset & 0x3FFFFFF); | |
| } | |
| // NOP | |
| uint32_t arm64_nop(void) { | |
| return 0xD503201F; | |
| } | |
| // ============================================================================ | |
| // Wasm parser | |
| // ============================================================================ | |
| typedef struct { | |
| uint8_t *data; | |
| size_t size; | |
| size_t pos; | |
| } WasmReader; | |
| uint8_t read_u8(WasmReader *r) { | |
| if (r->pos >= r->size) { | |
| fprintf(stderr, "Unexpected end of wasm\n"); | |
| exit(1); | |
| } | |
| return r->data[r->pos++]; | |
| } | |
| uint32_t read_u32_leb128(WasmReader *r) { | |
| uint32_t result = 0; | |
| int shift = 0; | |
| while (1) { | |
| uint8_t byte = read_u8(r); | |
| result |= (uint32_t)(byte & 0x7F) << shift; | |
| if ((byte & 0x80) == 0) break; | |
| shift += 7; | |
| } | |
| return result; | |
| } | |
| int32_t read_i32_leb128(WasmReader *r) { | |
| int32_t result = 0; | |
| int shift = 0; | |
| uint8_t byte; | |
| do { | |
| byte = read_u8(r); | |
| result |= (int32_t)(byte & 0x7F) << shift; | |
| shift += 7; | |
| } while (byte & 0x80); | |
| if (shift < 32 && (byte & 0x40)) { | |
| result |= (~0 << shift); | |
| } | |
| return result; | |
| } | |
| int64_t read_i64_leb128(WasmReader *r) { | |
| int64_t result = 0; | |
| int shift = 0; | |
| uint8_t byte; | |
| do { | |
| byte = read_u8(r); | |
| result |= ((int64_t)(byte & 0x7F)) << shift; | |
| shift += 7; | |
| } while (byte & 0x80); | |
| if (shift < 64 && (byte & 0x40)) { | |
| result |= (~0LL << shift); | |
| } | |
| return result; | |
| } | |
| float read_f32(WasmReader *r) { | |
| float f; | |
| memcpy(&f, &r->data[r->pos], 4); | |
| r->pos += 4; | |
| return f; | |
| } | |
| void skip_bytes(WasmReader *r, size_t n) { | |
| r->pos += n; | |
| } | |
| // ============================================================================ | |
| // Wasm module structures | |
| // ============================================================================ | |
| typedef struct { | |
| int param_count; | |
| int result_count; | |
| uint8_t *param_types; // array of value types | |
| uint8_t *result_types; | |
| } WasmFuncType; | |
| typedef struct { | |
| uint32_t type_idx; | |
| uint32_t local_count; | |
| uint8_t *local_types; // expanded - one per local | |
| uint8_t *code; | |
| size_t code_size; | |
| } WasmFunc; | |
| typedef struct { | |
| uint8_t type; // 0x7F = i32, 0x7E = i64, 0x7D = f32, 0x7C = f64 | |
| uint8_t mutable; | |
| int64_t init_val; // initial value (simplified - assume i32.const or i64.const) | |
| } WasmGlobal; | |
| typedef struct { | |
| uint32_t min_pages; | |
| uint32_t max_pages; | |
| int has_max; | |
| } WasmMemory; | |
| typedef struct { | |
| WasmFuncType *types; | |
| uint32_t type_count; | |
| WasmFunc *funcs; | |
| uint32_t func_count; | |
| uint32_t import_func_count; | |
| WasmGlobal *globals; | |
| uint32_t global_count; | |
| WasmMemory memory; | |
| int has_memory; | |
| // Export info | |
| char *run_export_name; | |
| uint32_t run_func_idx; | |
| } WasmModule; | |
| // Parse wasm module | |
| WasmModule *parse_wasm(uint8_t *data, size_t size) { | |
| WasmReader r = { data, size, 0 }; | |
| // Check magic and version | |
| uint32_t magic = read_u8(&r) | (read_u8(&r) << 8) | (read_u8(&r) << 16) | (read_u8(&r) << 24); | |
| uint32_t version = read_u8(&r) | (read_u8(&r) << 8) | (read_u8(&r) << 16) | (read_u8(&r) << 24); | |
| if (magic != 0x6D736100) { | |
| fprintf(stderr, "Invalid wasm magic\n"); | |
| exit(1); | |
| } | |
| if (version != 1) { | |
| fprintf(stderr, "Unsupported wasm version: %u\n", version); | |
| exit(1); | |
| } | |
| WasmModule *mod = calloc(1, sizeof(WasmModule)); | |
| // Temporary array to hold function type indices | |
| uint32_t *func_type_indices = NULL; | |
| uint32_t func_type_count = 0; | |
| // Parse sections | |
| while (r.pos < r.size) { | |
| uint8_t section_id = read_u8(&r); | |
| uint32_t section_size = read_u32_leb128(&r); | |
| size_t section_end = r.pos + section_size; | |
| DBG("Section %d, size %u\n", section_id, section_size); | |
| switch (section_id) { | |
| case 1: { // Type section | |
| mod->type_count = read_u32_leb128(&r); | |
| mod->types = calloc(mod->type_count, sizeof(WasmFuncType)); | |
| for (uint32_t i = 0; i < mod->type_count; i++) { | |
| uint8_t form = read_u8(&r); | |
| if (form != 0x60) { | |
| fprintf(stderr, "Expected function type\n"); | |
| exit(1); | |
| } | |
| mod->types[i].param_count = read_u32_leb128(&r); | |
| mod->types[i].param_types = malloc(mod->types[i].param_count); | |
| for (int j = 0; j < mod->types[i].param_count; j++) { | |
| mod->types[i].param_types[j] = read_u8(&r); | |
| } | |
| mod->types[i].result_count = read_u32_leb128(&r); | |
| mod->types[i].result_types = malloc(mod->types[i].result_count); | |
| for (int j = 0; j < mod->types[i].result_count; j++) { | |
| mod->types[i].result_types[j] = read_u8(&r); | |
| } | |
| } | |
| break; | |
| } | |
| case 2: { // Import section | |
| uint32_t import_count = read_u32_leb128(&r); | |
| for (uint32_t i = 0; i < import_count; i++) { | |
| uint32_t mod_len = read_u32_leb128(&r); | |
| skip_bytes(&r, mod_len); | |
| uint32_t name_len = read_u32_leb128(&r); | |
| skip_bytes(&r, name_len); | |
| uint8_t kind = read_u8(&r); | |
| if (kind == 0) { // function import | |
| read_u32_leb128(&r); // type index | |
| mod->import_func_count++; | |
| } else if (kind == 1) { // table | |
| read_u8(&r); // reftype | |
| uint8_t flags = read_u8(&r); | |
| read_u32_leb128(&r); // min | |
| if (flags & 1) read_u32_leb128(&r); // max | |
| } else if (kind == 2) { // memory | |
| uint8_t flags = read_u8(&r); | |
| read_u32_leb128(&r); // min | |
| if (flags & 1) read_u32_leb128(&r); // max | |
| } else if (kind == 3) { // global | |
| read_u8(&r); // type | |
| read_u8(&r); // mutable | |
| } | |
| } | |
| break; | |
| } | |
| case 3: { // Function section | |
| func_type_count = read_u32_leb128(&r); | |
| func_type_indices = malloc(func_type_count * sizeof(uint32_t)); | |
| for (uint32_t i = 0; i < func_type_count; i++) { | |
| func_type_indices[i] = read_u32_leb128(&r); | |
| } | |
| break; | |
| } | |
| case 5: { // Memory section | |
| uint32_t count = read_u32_leb128(&r); | |
| if (count > 0) { | |
| mod->has_memory = 1; | |
| uint8_t flags = read_u8(&r); | |
| mod->memory.min_pages = read_u32_leb128(&r); | |
| if (flags & 1) { | |
| mod->memory.has_max = 1; | |
| mod->memory.max_pages = read_u32_leb128(&r); | |
| } | |
| } | |
| break; | |
| } | |
| case 6: { // Global section | |
| mod->global_count = read_u32_leb128(&r); | |
| mod->globals = calloc(mod->global_count, sizeof(WasmGlobal)); | |
| for (uint32_t i = 0; i < mod->global_count; i++) { | |
| mod->globals[i].type = read_u8(&r); | |
| mod->globals[i].mutable = read_u8(&r); | |
| // Parse init expression (simplified) | |
| uint8_t opcode = read_u8(&r); | |
| if (opcode == 0x41) { // i32.const | |
| mod->globals[i].init_val = read_i32_leb128(&r); | |
| } else if (opcode == 0x42) { // i64.const | |
| mod->globals[i].init_val = read_i64_leb128(&r); | |
| } else { | |
| fprintf(stderr, "Unsupported global init: 0x%02X\n", opcode); | |
| exit(1); | |
| } | |
| read_u8(&r); // end opcode (0x0B) | |
| } | |
| break; | |
| } | |
| case 7: { // Export section | |
| uint32_t export_count = read_u32_leb128(&r); | |
| for (uint32_t i = 0; i < export_count; i++) { | |
| uint32_t name_len = read_u32_leb128(&r); | |
| char *name = malloc(name_len + 1); | |
| for (uint32_t j = 0; j < name_len; j++) { | |
| name[j] = read_u8(&r); | |
| } | |
| name[name_len] = '\0'; | |
| uint8_t kind = read_u8(&r); | |
| uint32_t idx = read_u32_leb128(&r); | |
| if (kind == 0 && strcmp(name, "run") == 0) { | |
| mod->run_export_name = name; | |
| mod->run_func_idx = idx; | |
| } else { | |
| free(name); | |
| } | |
| } | |
| break; | |
| } | |
| case 10: { // Code section | |
| mod->func_count = read_u32_leb128(&r); | |
| mod->funcs = calloc(mod->func_count, sizeof(WasmFunc)); | |
| for (uint32_t i = 0; i < mod->func_count; i++) { | |
| uint32_t func_size = read_u32_leb128(&r); | |
| size_t func_end = r.pos + func_size; | |
| mod->funcs[i].type_idx = func_type_indices[i]; | |
| // Parse locals | |
| uint32_t local_group_count = read_u32_leb128(&r); | |
| uint32_t total_locals = 0; | |
| // First pass: count total locals | |
| size_t locals_start = r.pos; | |
| for (uint32_t j = 0; j < local_group_count; j++) { | |
| uint32_t count = read_u32_leb128(&r); | |
| read_u8(&r); // type | |
| total_locals += count; | |
| } | |
| mod->funcs[i].local_count = total_locals; | |
| mod->funcs[i].local_types = malloc(total_locals); | |
| // Second pass: expand types | |
| r.pos = locals_start; | |
| uint32_t local_idx = 0; | |
| for (uint32_t j = 0; j < local_group_count; j++) { | |
| uint32_t count = read_u32_leb128(&r); | |
| uint8_t type = read_u8(&r); | |
| for (uint32_t k = 0; k < count; k++) { | |
| mod->funcs[i].local_types[local_idx++] = type; | |
| } | |
| } | |
| // Store code | |
| mod->funcs[i].code = &r.data[r.pos]; | |
| mod->funcs[i].code_size = func_end - r.pos; | |
| r.pos = func_end; | |
| } | |
| break; | |
| } | |
| default: | |
| skip_bytes(&r, section_size); | |
| break; | |
| } | |
| r.pos = section_end; | |
| } | |
| if (func_type_indices) free(func_type_indices); | |
| return mod; | |
| } | |
| void free_wasm_module(WasmModule *mod) { | |
| for (uint32_t i = 0; i < mod->type_count; i++) { | |
| free(mod->types[i].param_types); | |
| free(mod->types[i].result_types); | |
| } | |
| free(mod->types); | |
| for (uint32_t i = 0; i < mod->func_count; i++) { | |
| free(mod->funcs[i].local_types); | |
| } | |
| free(mod->funcs); | |
| free(mod->globals); | |
| if (mod->run_export_name) free(mod->run_export_name); | |
| free(mod); | |
| } | |
| // ============================================================================ | |
| // JIT Compiler | |
| // ============================================================================ | |
| #define MAX_BLOCKS 64 | |
| #define MAX_STACK 64 | |
| typedef struct { | |
| size_t start_pos; // instruction position at block start | |
| size_t *patch_sites; // positions needing patching for br | |
| int patch_count; | |
| int is_loop; | |
| } BlockInfo; | |
| typedef struct { | |
| CodeBuffer *code; | |
| WasmModule *mod; | |
| WasmFunc *func; | |
| // Runtime pointers (passed as arguments) | |
| // x19 = memory base | |
| // x20 = globals base | |
| // stack based locals (frame pointer relative) | |
| // Virtual stack tracking | |
| int stack_depth; // current wasm stack depth | |
| // Stack values are stored in x9-x15 (for i32) or s0-s7 (for f32) | |
| // We'll use a simple scheme: even positions = int regs, track types | |
| uint8_t stack_types[MAX_STACK]; // 0x7F=i32, 0x7D=f32 | |
| // Block management | |
| BlockInfo blocks[MAX_BLOCKS]; | |
| int block_depth; | |
| // Locals frame offset | |
| int locals_offset; // offset from sp for first local | |
| int param_count; | |
| int local_count; | |
| } JitCompiler; | |
| // Get integer register for stack position | |
| int jit_get_int_reg(int stack_pos) { | |
| // Use x9-x15 for stack positions 0-6 | |
| // For deeper stacks, we'd need to spill - simplified here | |
| return 9 + (stack_pos % 7); | |
| } | |
| // Get float register for stack position | |
| int jit_get_float_reg(int stack_pos) { | |
| return stack_pos % 8; // s0-s7 | |
| } | |
| // Emit code to load a 32-bit immediate | |
| void jit_emit_mov_imm32(CodeBuffer *c, int rd, int32_t val) { | |
| uint32_t uval = (uint32_t)val; | |
| codebuf_emit(c, arm64_movz(rd, uval & 0xFFFF, 0)); | |
| if (uval > 0xFFFF) { | |
| codebuf_emit(c, arm64_movk(rd, (uval >> 16) & 0xFFFF, 1)); | |
| } | |
| } | |
| // Emit code to load a 64-bit immediate | |
| void jit_emit_mov_imm64(CodeBuffer *c, int rd, int64_t val) { | |
| uint64_t uval = (uint64_t)val; | |
| codebuf_emit(c, arm64_movz(rd, uval & 0xFFFF, 0)); | |
| if ((uval >> 16) & 0xFFFF) { | |
| codebuf_emit(c, arm64_movk(rd, (uval >> 16) & 0xFFFF, 1)); | |
| } | |
| if ((uval >> 32) & 0xFFFF) { | |
| codebuf_emit(c, arm64_movk(rd, (uval >> 32) & 0xFFFF, 2)); | |
| } | |
| if ((uval >> 48) & 0xFFFF) { | |
| codebuf_emit(c, arm64_movk(rd, (uval >> 48) & 0xFFFF, 3)); | |
| } | |
| } | |
| void jit_push_i32(JitCompiler *jit, int reg) { | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| if (target != reg) { | |
| codebuf_emit(jit->code, arm64_mov_w(target, reg)); | |
| } | |
| jit->stack_depth++; | |
| } | |
| void jit_push_f32(JitCompiler *jit, int sreg) { | |
| jit->stack_types[jit->stack_depth] = 0x7D; | |
| // Float is already in the right s-register if using stack position | |
| jit->stack_depth++; | |
| } | |
| int jit_pop_i32(JitCompiler *jit) { | |
| jit->stack_depth--; | |
| return jit_get_int_reg(jit->stack_depth); | |
| } | |
| int jit_pop_f32(JitCompiler *jit) { | |
| jit->stack_depth--; | |
| return jit_get_float_reg(jit->stack_depth); | |
| } | |
| int jit_peek_i32(JitCompiler *jit, int offset) { | |
| return jit_get_int_reg(jit->stack_depth - 1 - offset); | |
| } | |
| void jit_compile_function(JitCompiler *jit, WasmFunc *func) { | |
| CodeBuffer *c = jit->code; | |
| jit->func = func; | |
| WasmFuncType *ftype = &jit->mod->types[func->type_idx]; | |
| jit->param_count = ftype->param_count; | |
| jit->local_count = func->local_count; | |
| int total_locals = jit->param_count + jit->local_count; | |
| // Frame layout: [x19,x20 (16 bytes)] [locals (total_locals * 8 bytes)] | |
| int frame_size = (16 + total_locals * 8 + 15) & ~15; // 16-byte aligned | |
| // Prologue | |
| // stp x29, x30, [sp, #-16]! (push frame pointer and link register) | |
| codebuf_emit(c, 0xA9BF7BFD); // stp x29, x30, [sp, #-16]! | |
| // mov x29, sp (must use ADD since ORR x31 is XZR, not SP) | |
| codebuf_emit(c, arm64_add_imm(29, 31, 0)); // add x29, sp, #0 | |
| // Allocate locals on stack | |
| if (frame_size > 0) { | |
| codebuf_emit(c, arm64_sub_imm(31, 31, frame_size)); | |
| } | |
| // Save callee-saved registers we use (x19, x20) | |
| codebuf_emit(c, arm64_stp(19, 20, 31, 0)); | |
| // x0 = first parameter (N for matmul) | |
| // x1 = memory base (passed by caller) | |
| // x2 = globals base (passed by caller) | |
| // Setup our register conventions | |
| codebuf_emit(c, arm64_mov(19, 1)); // x19 = memory base | |
| codebuf_emit(c, arm64_mov(20, 2)); // x20 = globals base | |
| // Store first parameter as local[0] | |
| // Parameters go to stack-based locals | |
| // For matmul: param is i64, local[0] needs it | |
| // We store params at [sp + 16 + idx*8] | |
| jit->locals_offset = 16; // after saved x19, x20 | |
| if (jit->param_count > 0) { | |
| // x0 has the i64 parameter | |
| codebuf_emit(c, arm64_str(0, 31, jit->locals_offset)); | |
| } | |
| // Initialize other locals to 0 | |
| for (int i = jit->param_count; i < total_locals; i++) { | |
| codebuf_emit(c, arm64_movz(8, 0, 0)); // x8 = 0 | |
| codebuf_emit(c, arm64_str(8, 31, jit->locals_offset + i * 8)); | |
| } | |
| // Parse and compile function body | |
| WasmReader r = { func->code, func->code_size, 0 }; | |
| jit->stack_depth = 0; | |
| jit->block_depth = 0; | |
| // Push implicit function block | |
| jit->blocks[jit->block_depth].start_pos = codebuf_pos(c); | |
| jit->blocks[jit->block_depth].patch_sites = calloc(256, sizeof(size_t)); | |
| jit->blocks[jit->block_depth].patch_count = 0; | |
| jit->blocks[jit->block_depth].is_loop = 0; | |
| jit->block_depth++; | |
| while (r.pos < r.size) { | |
| uint8_t opcode = read_u8(&r); | |
| DBG(" opcode: 0x%02X at pos %zu, stack=%d, blocks=%d\n", | |
| opcode, r.pos - 1, jit->stack_depth, jit->block_depth); | |
| switch (opcode) { | |
| case 0x00: // unreachable | |
| // Could emit a trap instruction | |
| break; | |
| case 0x01: // nop | |
| break; | |
| case 0x02: { // block | |
| read_i32_leb128(&r); // block type (ignore) | |
| jit->blocks[jit->block_depth].start_pos = codebuf_pos(c); | |
| jit->blocks[jit->block_depth].patch_sites = calloc(256, sizeof(size_t)); | |
| jit->blocks[jit->block_depth].patch_count = 0; | |
| jit->blocks[jit->block_depth].is_loop = 0; | |
| jit->block_depth++; | |
| break; | |
| } | |
| case 0x03: { // loop | |
| read_i32_leb128(&r); // block type | |
| jit->blocks[jit->block_depth].start_pos = codebuf_pos(c); | |
| jit->blocks[jit->block_depth].patch_sites = calloc(256, sizeof(size_t)); | |
| jit->blocks[jit->block_depth].patch_count = 0; | |
| jit->blocks[jit->block_depth].is_loop = 1; | |
| jit->block_depth++; | |
| break; | |
| } | |
| case 0x0B: { // end | |
| jit->block_depth--; | |
| BlockInfo *block = &jit->blocks[jit->block_depth]; | |
| if (jit->block_depth == 0) { | |
| // Function end - return | |
| goto done_compiling; | |
| } | |
| // Patch forward branches | |
| size_t end_pos = codebuf_pos(c); | |
| for (int i = 0; i < block->patch_count; i++) { | |
| size_t patch_pos = block->patch_sites[i]; | |
| int offset = (int)(end_pos - patch_pos); | |
| uint32_t old_inst = c->code[patch_pos]; | |
| // Check instruction type by opcode pattern | |
| if ((old_inst & 0xFC000000) == 0x14000000) { | |
| // B (unconditional branch) | |
| codebuf_patch(c, patch_pos, arm64_b(offset)); | |
| } else if ((old_inst & 0xFF000000) == 0x35000000) { | |
| // CBNZ (Wn) | |
| int rt = old_inst & 0x1F; | |
| codebuf_patch(c, patch_pos, arm64_cbnz_w(rt, offset)); | |
| } else if ((old_inst & 0xFF000000) == 0x34000000) { | |
| // CBZ (Wn) | |
| int rt = old_inst & 0x1F; | |
| codebuf_patch(c, patch_pos, arm64_cbz_w(rt, offset)); | |
| } else if ((old_inst & 0xFF000000) == 0x54000000) { | |
| // B.cond | |
| int cond = old_inst & 0xF; | |
| codebuf_patch(c, patch_pos, arm64_bcond(cond, offset)); | |
| } | |
| } | |
| free(block->patch_sites); | |
| break; | |
| } | |
| case 0x0C: { // br | |
| uint32_t depth = read_u32_leb128(&r); | |
| int target_block = jit->block_depth - 1 - depth; | |
| BlockInfo *block = &jit->blocks[target_block]; | |
| if (block->is_loop) { | |
| // Branch back to loop start | |
| int offset = (int)block->start_pos - (int)codebuf_pos(c); | |
| codebuf_emit(c, arm64_b(offset)); | |
| } else { | |
| // Forward branch - patch later | |
| block->patch_sites[block->patch_count++] = codebuf_pos(c); | |
| codebuf_emit(c, arm64_b(0)); // placeholder | |
| } | |
| break; | |
| } | |
| case 0x0D: { // br_if | |
| uint32_t depth = read_u32_leb128(&r); | |
| int cond_reg = jit_pop_i32(jit); | |
| int target_block = jit->block_depth - 1 - depth; | |
| BlockInfo *block = &jit->blocks[target_block]; | |
| if (block->is_loop) { | |
| // Branch back if non-zero | |
| int offset = (int)block->start_pos - (int)codebuf_pos(c) - 1; | |
| codebuf_emit(c, arm64_cbnz_w(cond_reg, offset)); | |
| } else { | |
| // Forward branch - patch later | |
| block->patch_sites[block->patch_count++] = codebuf_pos(c); | |
| codebuf_emit(c, arm64_cbnz_w(cond_reg, 0)); // placeholder | |
| } | |
| break; | |
| } | |
| case 0x20: { // local.get | |
| uint32_t idx = read_u32_leb128(&r); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| // Load from stack | |
| codebuf_emit(c, arm64_ldr(target, 31, jit->locals_offset + idx * 8)); | |
| jit->stack_types[jit->stack_depth] = 0x7F; // assume i32 for now | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x21: { // local.set | |
| uint32_t idx = read_u32_leb128(&r); | |
| int src = jit_pop_i32(jit); | |
| codebuf_emit(c, arm64_str(src, 31, jit->locals_offset + idx * 8)); | |
| break; | |
| } | |
| case 0x22: { // local.tee | |
| uint32_t idx = read_u32_leb128(&r); | |
| int src = jit_get_int_reg(jit->stack_depth - 1); | |
| codebuf_emit(c, arm64_str(src, 31, jit->locals_offset + idx * 8)); | |
| break; | |
| } | |
| case 0x23: { // global.get | |
| uint32_t idx = read_u32_leb128(&r); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| codebuf_emit(c, arm64_ldr_w(target, 20, idx * 4)); | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x24: { // global.set | |
| uint32_t idx = read_u32_leb128(&r); | |
| int src = jit_pop_i32(jit); | |
| codebuf_emit(c, arm64_str_w(src, 20, idx * 4)); | |
| break; | |
| } | |
| case 0x28: { // i32.load | |
| uint32_t align = read_u32_leb128(&r); | |
| uint32_t offset = read_u32_leb128(&r); | |
| (void)align; | |
| int addr = jit_pop_i32(jit); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| // Compute effective address: memory_base + addr + offset | |
| if (offset > 0) { | |
| jit_emit_mov_imm32(c, 8, offset); | |
| codebuf_emit(c, arm64_add_w(addr, addr, 8)); | |
| } | |
| // target = *(memory_base + addr) | |
| codebuf_emit(c, arm64_ldr_w_reg(target, 19, addr)); | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x36: { // i32.store | |
| uint32_t align = read_u32_leb128(&r); | |
| uint32_t offset = read_u32_leb128(&r); | |
| (void)align; | |
| int value = jit_pop_i32(jit); | |
| int addr = jit_pop_i32(jit); | |
| if (offset > 0) { | |
| jit_emit_mov_imm32(c, 8, offset); | |
| codebuf_emit(c, arm64_add_w(addr, addr, 8)); | |
| } | |
| codebuf_emit(c, arm64_str_w_reg(value, 19, addr)); | |
| break; | |
| } | |
| case 0x2A: { // f32.load | |
| uint32_t align = read_u32_leb128(&r); | |
| uint32_t offset = read_u32_leb128(&r); | |
| (void)align; | |
| int addr = jit_pop_i32(jit); | |
| int target = jit_get_float_reg(jit->stack_depth); | |
| if (offset > 0) { | |
| jit_emit_mov_imm32(c, 8, offset); | |
| codebuf_emit(c, arm64_add_w(addr, addr, 8)); | |
| } | |
| codebuf_emit(c, arm64_ldr_s_reg(target, 19, addr)); | |
| jit->stack_types[jit->stack_depth] = 0x7D; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x38: { // f32.store | |
| uint32_t align = read_u32_leb128(&r); | |
| uint32_t offset = read_u32_leb128(&r); | |
| (void)align; | |
| int value = jit_pop_f32(jit); | |
| int addr = jit_pop_i32(jit); | |
| if (offset > 0) { | |
| jit_emit_mov_imm32(c, 8, offset); | |
| codebuf_emit(c, arm64_add_w(addr, addr, 8)); | |
| } | |
| codebuf_emit(c, arm64_str_s_reg(value, 19, addr)); | |
| break; | |
| } | |
| case 0x41: { // i32.const | |
| int32_t val = read_i32_leb128(&r); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| jit_emit_mov_imm32(c, target, val); | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x42: { // i64.const | |
| int64_t val = read_i64_leb128(&r); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| jit_emit_mov_imm64(c, target, val); | |
| jit->stack_types[jit->stack_depth] = 0x7E; // i64 | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x43: { // f32.const | |
| float val = read_f32(&r); | |
| int target = jit_get_float_reg(jit->stack_depth); | |
| // Move float constant via integer register | |
| uint32_t bits; | |
| memcpy(&bits, &val, 4); | |
| jit_emit_mov_imm32(c, 8, bits); | |
| codebuf_emit(c, arm64_fmov_s_from_w(target, 8)); | |
| jit->stack_types[jit->stack_depth] = 0x7D; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x6A: { // i32.add | |
| int b = jit_pop_i32(jit); | |
| int a = jit_pop_i32(jit); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| codebuf_emit(c, arm64_add_w(target, a, b)); | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x6B: { // i32.sub | |
| int b = jit_pop_i32(jit); | |
| int a = jit_pop_i32(jit); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| codebuf_emit(c, arm64_sub_w(target, a, b)); | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x6C: { // i32.mul | |
| int b = jit_pop_i32(jit); | |
| int a = jit_pop_i32(jit); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| codebuf_emit(c, arm64_mul_w(target, a, b)); | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x4D: { // i32.le_u | |
| int b = jit_pop_i32(jit); | |
| int a = jit_pop_i32(jit); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| codebuf_emit(c, arm64_cmp_w(a, b)); | |
| codebuf_emit(c, arm64_cset_w(target, COND_LS)); // unsigned <= | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x4F: { // i32.ge_u | |
| int b = jit_pop_i32(jit); | |
| int a = jit_pop_i32(jit); | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| codebuf_emit(c, arm64_cmp_w(a, b)); | |
| codebuf_emit(c, arm64_cset_w(target, COND_HS)); // unsigned >= | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0xA7: { // i32.wrap_i64 | |
| // Just truncate - value is already in 32-bit form in our regs | |
| // Pop and push with same register (effectively no-op) | |
| int src = jit_pop_i32(jit); // Actually i64 but we only use lower 32 | |
| int target = jit_get_int_reg(jit->stack_depth); | |
| if (target != src) { | |
| codebuf_emit(c, arm64_mov_w(target, src)); | |
| } | |
| jit->stack_types[jit->stack_depth] = 0x7F; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x92: { // f32.add | |
| int b = jit_pop_f32(jit); | |
| int a = jit_pop_f32(jit); | |
| int target = jit_get_float_reg(jit->stack_depth); | |
| codebuf_emit(c, arm64_fadd_s(target, a, b)); | |
| jit->stack_types[jit->stack_depth] = 0x7D; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| case 0x94: { // f32.mul | |
| int b = jit_pop_f32(jit); | |
| int a = jit_pop_f32(jit); | |
| int target = jit_get_float_reg(jit->stack_depth); | |
| codebuf_emit(c, arm64_fmul_s(target, a, b)); | |
| jit->stack_types[jit->stack_depth] = 0x7D; | |
| jit->stack_depth++; | |
| break; | |
| } | |
| default: | |
| fprintf(stderr, "Unsupported opcode: 0x%02X\n", opcode); | |
| exit(1); | |
| } | |
| } | |
| done_compiling: | |
| // Epilogue | |
| // Return value should be in x0 (or s0 for float) | |
| // If stack has a value, move it to x0 for return | |
| if (jit->stack_depth > 0) { | |
| int result_reg = jit_get_int_reg(jit->stack_depth - 1); | |
| if (result_reg != 0) { | |
| codebuf_emit(c, arm64_mov(0, result_reg)); | |
| } | |
| } else { | |
| // No return value - return 0 | |
| codebuf_emit(c, arm64_movz(0, 0, 0)); | |
| } | |
| // Restore callee-saved | |
| codebuf_emit(c, arm64_ldp(19, 20, 31, 0)); | |
| // Deallocate frame | |
| if (frame_size > 0) { | |
| codebuf_emit(c, arm64_add_imm(31, 31, frame_size)); | |
| } | |
| // ldp x29, x30, [sp], #16 | |
| codebuf_emit(c, 0xA8C17BFD); | |
| codebuf_emit(c, arm64_ret()); | |
| } | |
| // ============================================================================ | |
| // Main | |
| // ============================================================================ | |
| int main(int argc, char **argv) { | |
| if (argc < 2) { | |
| fprintf(stderr, "Usage: %s <wasm-file> [args...]\n", argv[0]); | |
| return 1; | |
| } | |
| // Read wasm file | |
| FILE *f = fopen(argv[1], "rb"); | |
| if (!f) { | |
| perror("fopen"); | |
| return 1; | |
| } | |
| fseek(f, 0, SEEK_END); | |
| size_t size = ftell(f); | |
| fseek(f, 0, SEEK_SET); | |
| uint8_t *data = malloc(size); | |
| fread(data, 1, size, f); | |
| fclose(f); | |
| printf("Loaded %zu bytes of wasm\n", size); | |
| // Parse wasm module | |
| WasmModule *mod = parse_wasm(data, size); | |
| printf("Parsed: %u types, %u funcs, %u globals, memory=%d\n", | |
| mod->type_count, mod->func_count, mod->global_count, mod->has_memory); | |
| if (!mod->run_export_name) { | |
| fprintf(stderr, "No 'run' export found\n"); | |
| return 1; | |
| } | |
| printf("Found export 'run' -> func[%u]\n", mod->run_func_idx); | |
| // Allocate memory | |
| size_t mem_size = mod->has_memory ? mod->memory.min_pages * 65536 : 65536; | |
| uint8_t *memory = calloc(1, mem_size); | |
| // Allocate globals | |
| int32_t *globals = calloc(mod->global_count + 1, sizeof(int32_t)); | |
| for (uint32_t i = 0; i < mod->global_count; i++) { | |
| globals[i] = (int32_t)mod->globals[i].init_val; | |
| printf("Global[%u] = %d\n", i, globals[i]); | |
| } | |
| // Create JIT compiler | |
| CodeBuffer *code = codebuf_new(16384); | |
| JitCompiler jit = { | |
| .code = code, | |
| .mod = mod, | |
| }; | |
| codebuf_begin_write(code); | |
| // Compile the run function | |
| uint32_t func_idx = mod->run_func_idx - mod->import_func_count; | |
| jit_compile_function(&jit, &mod->funcs[func_idx]); | |
| codebuf_end_write(code); | |
| printf("Generated %zu instructions\n", code->size); | |
| // Get argument | |
| int64_t arg = 10; // default | |
| if (argc > 2) { | |
| arg = atoll(argv[2]); | |
| } | |
| printf("Running with N=%lld\n", arg); | |
| // Call JIT function | |
| // Function signature: i64 run(i64 N, void* memory, void* globals) | |
| typedef int64_t (*JitFunc)(int64_t, void*, void*); | |
| JitFunc fn = (JitFunc)code->code; | |
| int64_t result = fn(arg, memory, globals); | |
| printf("JIT result: %lld\n", result); | |
| // Cleanup | |
| codebuf_free(code); | |
| free(memory); | |
| free(globals); | |
| free_wasm_module(mod); | |
| free(data); | |
| return 0; | |
| } |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Result