Skip to content

Instantly share code, notes, and snippets.

@mizchi
Last active January 19, 2026 16:34
Show Gist options
  • Select an option

  • Save mizchi/8a997f35bf3983657bc94afdf5afac21 to your computer and use it in GitHub Desktop.

Select an option

Save mizchi/8a997f35bf3983657bc94afdf5afac21 to your computer and use it in GitHub Desktop.
WASM Jit compiler C/Zig by Claude Code
// Minimal Baseline JIT for WebAssembly (ARM64)
// Zig implementation
const std = @import("std");
const fs = std.fs;
const mem = std.mem;
const Allocator = std.mem.Allocator;
// ============================================================================
// Code buffer for JIT compilation
// ============================================================================
const CodeBuffer = struct {
code: []align(4096) u8,
size: usize,
capacity: usize,
const Self = @This();
fn init(capacity: usize) !Self {
const c = @cImport({
@cInclude("sys/mman.h");
});
// PROT_READ | PROT_WRITE | PROT_EXEC
const prot: u32 = c.PROT_READ | c.PROT_WRITE | c.PROT_EXEC;
// MAP_PRIVATE | MAP_ANONYMOUS (MAP_ANON on macOS) | MAP_JIT
var flags: i32 = c.MAP_PRIVATE | c.MAP_ANON;
if (comptime @import("builtin").os.tag == .macos) {
flags |= c.MAP_JIT;
}
const result = c.mmap(null, capacity, @intCast(prot), flags, -1, 0);
if (result == c.MAP_FAILED) {
return error.MmapFailed;
}
const mapped: [*]align(4096) u8 = @ptrCast(@alignCast(result));
return Self{
.code = mapped[0..capacity],
.size = 0,
.capacity = capacity,
};
}
fn deinit(self: *Self) void {
const c = @cImport(@cInclude("sys/mman.h"));
_ = c.munmap(self.code.ptr, self.code.len);
}
fn beginWrite(self: *Self) void {
_ = self;
if (comptime @import("builtin").os.tag == .macos) {
// pthread_jit_write_protect_np(0)
const c = @cImport(@cInclude("pthread.h"));
c.pthread_jit_write_protect_np(0);
}
}
fn endWrite(self: *Self) void {
if (comptime @import("builtin").os.tag == .macos) {
const c = @cImport(@cInclude("pthread.h"));
c.pthread_jit_write_protect_np(1);
// sys_icache_invalidate
const cache = @cImport(@cInclude("libkern/OSCacheControl.h"));
cache.sys_icache_invalidate(self.code.ptr, self.size * 4);
}
}
fn emit(self: *Self, inst: u32) void {
if (self.size >= self.capacity / 4) {
@panic("Code buffer overflow");
}
const ptr: *u32 = @ptrCast(@alignCast(self.code.ptr + self.size * 4));
ptr.* = inst;
self.size += 1;
}
fn pos(self: *const Self) usize {
return self.size;
}
fn patch(self: *Self, position: usize, inst: u32) void {
const ptr: *u32 = @ptrCast(@alignCast(self.code.ptr + position * 4));
ptr.* = inst;
}
fn getFunction(self: *const Self, comptime T: type) T {
return @ptrCast(self.code.ptr);
}
};
// ============================================================================
// ARM64 instruction encoding
// ============================================================================
const Arm64 = struct {
// MOV immediate (MOVZ)
fn movz(rd: u5, imm: u16, shift: u2) u32 {
return 0xD2800000 | (@as(u32, shift) << 21) | (@as(u32, imm) << 5) | rd;
}
// MOV immediate (MOVK)
fn movk(rd: u5, imm: u16, shift: u2) u32 {
return 0xF2800000 | (@as(u32, shift) << 21) | (@as(u32, imm) << 5) | rd;
}
// MOV register (64-bit) - ORR rd, xzr, rm
fn mov(rd: u5, rm: u5) u32 {
return 0xAA0003E0 | (@as(u32, rm) << 16) | rd;
}
// MOV register (32-bit)
fn movW(rd: u5, rm: u5) u32 {
return 0x2A0003E0 | (@as(u32, rm) << 16) | rd;
}
// ADD (64-bit register)
fn add(rd: u5, rn: u5, rm: u5) u32 {
return 0x8B000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd;
}
// ADD (32-bit register)
fn addW(rd: u5, rn: u5, rm: u5) u32 {
return 0x0B000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd;
}
// ADD immediate (64-bit)
fn addImm(rd: u5, rn: u5, imm: u12) u32 {
return 0x91000000 | (@as(u32, imm) << 10) | (@as(u32, rn) << 5) | rd;
}
// SUB (32-bit register)
fn subW(rd: u5, rn: u5, rm: u5) u32 {
return 0x4B000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd;
}
// SUB immediate (64-bit)
fn subImm(rd: u5, rn: u5, imm: u12) u32 {
return 0xD1000000 | (@as(u32, imm) << 10) | (@as(u32, rn) << 5) | rd;
}
// MUL (32-bit)
fn mulW(rd: u5, rn: u5, rm: u5) u32 {
return 0x1B007C00 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd;
}
// LDR (64-bit, unsigned offset)
fn ldr(rt: u5, rn: u5, offset: u12) u32 {
return 0xF9400000 | (@as(u32, offset / 8) << 10) | (@as(u32, rn) << 5) | rt;
}
// LDR (32-bit, unsigned offset)
fn ldrW(rt: u5, rn: u5, offset: u12) u32 {
return 0xB9400000 | (@as(u32, offset / 4) << 10) | (@as(u32, rn) << 5) | rt;
}
// LDR (32-bit, register offset)
fn ldrWReg(rt: u5, rn: u5, rm: u5) u32 {
return 0xB8606800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rt;
}
// STR (64-bit, unsigned offset)
fn str(rt: u5, rn: u5, offset: u12) u32 {
return 0xF9000000 | (@as(u32, offset / 8) << 10) | (@as(u32, rn) << 5) | rt;
}
// STR (32-bit, unsigned offset)
fn strW(rt: u5, rn: u5, offset: u12) u32 {
return 0xB9000000 | (@as(u32, offset / 4) << 10) | (@as(u32, rn) << 5) | rt;
}
// STR (32-bit, register offset)
fn strWReg(rt: u5, rn: u5, rm: u5) u32 {
return 0xB8206800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rt;
}
// STP (64-bit)
fn stp(rt1: u5, rt2: u5, rn: u5, offset: i7) u32 {
const imm: u7 = @bitCast(@as(i7, @divTrunc(offset, 8)));
return 0xA9000000 | (@as(u32, imm) << 15) | (@as(u32, rt2) << 10) | (@as(u32, rn) << 5) | rt1;
}
// LDP (64-bit)
fn ldp(rt1: u5, rt2: u5, rn: u5, offset: i7) u32 {
const imm: u7 = @bitCast(@as(i7, @divTrunc(offset, 8)));
return 0xA9400000 | (@as(u32, imm) << 15) | (@as(u32, rt2) << 10) | (@as(u32, rn) << 5) | rt1;
}
// LDR float single (unsigned offset)
fn ldrS(rt: u5, rn: u5, offset: u12) u32 {
return 0xBD400000 | (@as(u32, offset / 4) << 10) | (@as(u32, rn) << 5) | rt;
}
// STR float single (unsigned offset)
fn strS(rt: u5, rn: u5, offset: u12) u32 {
return 0xBD000000 | (@as(u32, offset / 4) << 10) | (@as(u32, rn) << 5) | rt;
}
// LDR float single (register offset)
fn ldrSReg(rt: u5, rn: u5, rm: u5) u32 {
return 0xBC606800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rt;
}
// STR float single (register offset)
fn strSReg(rt: u5, rn: u5, rm: u5) u32 {
return 0xBC206800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rt;
}
// FMOV (GP to float)
fn fmovSFromW(sd: u5, wn: u5) u32 {
return 0x1E270000 | (@as(u32, wn) << 5) | sd;
}
// FADD single
fn faddS(rd: u5, rn: u5, rm: u5) u32 {
return 0x1E202800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd;
}
// FMUL single
fn fmulS(rd: u5, rn: u5, rm: u5) u32 {
return 0x1E200800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd;
}
// CMP (32-bit register)
fn cmpW(rn: u5, rm: u5) u32 {
return 0x6B00001F | (@as(u32, rm) << 16) | (@as(u32, rn) << 5);
}
// CSET (32-bit)
fn csetW(rd: u5, cond: u4) u32 {
const inv_cond = cond ^ 1;
return 0x1A9F07E0 | (@as(u32, inv_cond) << 12) | rd;
}
// B (unconditional)
fn b(offset: i26) u32 {
const imm: u26 = @bitCast(offset);
return 0x14000000 | @as(u32, imm);
}
// CBNZ (32-bit)
fn cbnzW(rt: u5, offset: i19) u32 {
const imm: u19 = @bitCast(offset);
return 0x35000000 | (@as(u32, imm) << 5) | rt;
}
// RET
fn ret() u32 {
return 0xD65F03C0;
}
// Condition codes
const COND_HS: u4 = 2; // unsigned >=
const COND_LS: u4 = 9; // unsigned <=
};
// ============================================================================
// Wasm parser
// ============================================================================
const WasmReader = struct {
data: []const u8,
pos: usize,
const Self = @This();
fn init(data: []const u8) Self {
return Self{ .data = data, .pos = 0 };
}
fn readU8(self: *Self) u8 {
const b = self.data[self.pos];
self.pos += 1;
return b;
}
fn readU32Leb128(self: *Self) u32 {
var result: u32 = 0;
var shift: u5 = 0;
while (true) {
const byte = self.readU8();
result |= @as(u32, byte & 0x7F) << shift;
if (byte & 0x80 == 0) break;
shift += 7;
}
return result;
}
fn readI32Leb128(self: *Self) i32 {
var result: i32 = 0;
var shift: u5 = 0;
var byte: u8 = 0;
while (true) {
byte = self.readU8();
result |= @as(i32, @intCast(byte & 0x7F)) << shift;
shift += 7;
if (byte & 0x80 == 0) break;
}
if (shift < 32 and (byte & 0x40) != 0) {
result |= @as(i32, -1) << shift;
}
return result;
}
fn readI64Leb128(self: *Self) i64 {
var result: i64 = 0;
var shift: u6 = 0;
var byte: u8 = 0;
while (true) {
byte = self.readU8();
result |= @as(i64, @intCast(byte & 0x7F)) << shift;
shift += 7;
if (byte & 0x80 == 0) break;
}
if (shift < 64 and (byte & 0x40) != 0) {
result |= @as(i64, -1) << shift;
}
return result;
}
fn readF32(self: *Self) f32 {
const bytes = self.data[self.pos..][0..4];
self.pos += 4;
return @bitCast(bytes.*);
}
fn skip(self: *Self, n: usize) void {
self.pos += n;
}
};
// ============================================================================
// Wasm module structures
// ============================================================================
const WasmFuncType = struct {
param_count: u32,
result_count: u32,
};
const WasmFunc = struct {
type_idx: u32,
local_count: u32,
code: []const u8,
};
const WasmGlobal = struct {
init_val: i64,
};
const WasmModule = struct {
types: []WasmFuncType,
funcs: []WasmFunc,
globals: []WasmGlobal,
run_func_idx: ?u32,
min_memory_pages: u32,
allocator: Allocator,
const Self = @This();
fn parse(allocator: Allocator, data: []const u8) !Self {
var r = WasmReader.init(data);
// Check magic and version
const magic = r.readU8() | (@as(u32, r.readU8()) << 8) | (@as(u32, r.readU8()) << 16) | (@as(u32, r.readU8()) << 24);
const version = r.readU8() | (@as(u32, r.readU8()) << 8) | (@as(u32, r.readU8()) << 16) | (@as(u32, r.readU8()) << 24);
if (magic != 0x6D736100) return error.InvalidMagic;
if (version != 1) return error.UnsupportedVersion;
var types: std.ArrayListUnmanaged(WasmFuncType) = .{};
defer types.deinit(allocator);
var func_type_indices: std.ArrayListUnmanaged(u32) = .{};
defer func_type_indices.deinit(allocator);
var funcs: std.ArrayListUnmanaged(WasmFunc) = .{};
defer funcs.deinit(allocator);
var globals: std.ArrayListUnmanaged(WasmGlobal) = .{};
defer globals.deinit(allocator);
var run_func_idx: ?u32 = null;
var min_memory_pages: u32 = 0;
// Parse sections
while (r.pos < r.data.len) {
const section_id = r.readU8();
const section_size = r.readU32Leb128();
const section_end = r.pos + section_size;
switch (section_id) {
1 => { // Type section
const count = r.readU32Leb128();
try types.ensureTotalCapacity(allocator, count);
for (0..count) |_| {
_ = r.readU8(); // 0x60
const param_count = r.readU32Leb128();
r.skip(param_count);
const result_count = r.readU32Leb128();
r.skip(result_count);
try types.append(allocator, .{
.param_count = param_count,
.result_count = result_count,
});
}
},
3 => { // Function section
const count = r.readU32Leb128();
try func_type_indices.ensureTotalCapacity(allocator, count);
for (0..count) |_| {
try func_type_indices.append(allocator, r.readU32Leb128());
}
},
5 => { // Memory section
const count = r.readU32Leb128();
if (count > 0) {
const flags = r.readU8();
min_memory_pages = r.readU32Leb128();
if (flags & 1 != 0) _ = r.readU32Leb128(); // max
}
},
6 => { // Global section
const count = r.readU32Leb128();
try globals.ensureTotalCapacity(allocator, count);
for (0..count) |_| {
_ = r.readU8(); // type
_ = r.readU8(); // mutable
const opcode = r.readU8();
const init_val: i64 = switch (opcode) {
0x41 => r.readI32Leb128(),
0x42 => r.readI64Leb128(),
else => return error.UnsupportedGlobalInit,
};
_ = r.readU8(); // end
try globals.append(allocator, .{ .init_val = init_val });
}
},
7 => { // Export section
const count = r.readU32Leb128();
for (0..count) |_| {
const name_len = r.readU32Leb128();
const name = r.data[r.pos..][0..name_len];
r.skip(name_len);
const kind = r.readU8();
const idx = r.readU32Leb128();
if (kind == 0 and mem.eql(u8, name, "run")) {
run_func_idx = idx;
}
}
},
10 => { // Code section
const count = r.readU32Leb128();
try funcs.ensureTotalCapacity(allocator, count);
for (0..count) |i| {
const func_size = r.readU32Leb128();
const func_end = r.pos + func_size;
// Parse locals
const local_group_count = r.readU32Leb128();
var total_locals: u32 = 0;
for (0..local_group_count) |_| {
total_locals += r.readU32Leb128();
_ = r.readU8(); // type
}
const code_start = r.pos;
const code = r.data[code_start..func_end];
try funcs.append(allocator, .{
.type_idx = func_type_indices.items[i],
.local_count = total_locals,
.code = code,
});
r.pos = func_end;
}
},
else => r.skip(section_size),
}
r.pos = section_end;
}
// Transfer ownership - don't deinit these
const types_slice = try allocator.dupe(WasmFuncType, types.items);
const funcs_slice = try allocator.dupe(WasmFunc, funcs.items);
const globals_slice = try allocator.dupe(WasmGlobal, globals.items);
return Self{
.types = types_slice,
.funcs = funcs_slice,
.globals = globals_slice,
.run_func_idx = run_func_idx,
.min_memory_pages = min_memory_pages,
.allocator = allocator,
};
}
fn deinit(self: *Self) void {
self.allocator.free(self.types);
self.allocator.free(self.funcs);
self.allocator.free(self.globals);
}
};
// ============================================================================
// JIT Compiler
// ============================================================================
const MAX_BLOCKS = 64;
const MAX_STACK = 64;
const BlockInfo = struct {
start_pos: usize,
patch_sites: [256]usize,
patch_count: usize,
is_loop: bool,
};
const JitCompiler = struct {
code: *CodeBuffer,
mod: *const WasmModule,
stack_depth: usize,
blocks: [MAX_BLOCKS]BlockInfo,
block_depth: usize,
locals_offset: u12,
param_count: u32,
local_count: u32,
const Self = @This();
fn init(code: *CodeBuffer, mod: *const WasmModule) Self {
return Self{
.code = code,
.mod = mod,
.stack_depth = 0,
.blocks = undefined,
.block_depth = 0,
.locals_offset = 16,
.param_count = 0,
.local_count = 0,
};
}
fn getIntReg(stack_pos: usize) u5 {
return @intCast(9 + (stack_pos % 7));
}
fn getFloatReg(stack_pos: usize) u5 {
return @intCast(stack_pos % 8);
}
fn emitMovImm32(self: *Self, rd: u5, val: i32) void {
const uval: u32 = @bitCast(val);
self.code.emit(Arm64.movz(rd, @truncate(uval), 0));
if (uval > 0xFFFF) {
self.code.emit(Arm64.movk(rd, @truncate(uval >> 16), 1));
}
}
fn emitMovImm64(self: *Self, rd: u5, val: i64) void {
const uval: u64 = @bitCast(val);
self.code.emit(Arm64.movz(rd, @truncate(uval), 0));
if ((uval >> 16) & 0xFFFF != 0) {
self.code.emit(Arm64.movk(rd, @truncate(uval >> 16), 1));
}
if ((uval >> 32) & 0xFFFF != 0) {
self.code.emit(Arm64.movk(rd, @truncate(uval >> 32), 2));
}
if ((uval >> 48) & 0xFFFF != 0) {
self.code.emit(Arm64.movk(rd, @truncate(uval >> 48), 3));
}
}
fn compileFunction(self: *Self, func: *const WasmFunc) void {
const ftype = &self.mod.types[func.type_idx];
self.param_count = ftype.param_count;
self.local_count = func.local_count;
const total_locals = self.param_count + self.local_count;
const frame_size: u12 = @intCast(((16 + total_locals * 8) + 15) & ~@as(u32, 15));
// Prologue
self.code.emit(0xA9BF7BFD); // stp x29, x30, [sp, #-16]!
self.code.emit(Arm64.addImm(29, 31, 0)); // mov x29, sp
self.code.emit(Arm64.subImm(31, 31, frame_size));
self.code.emit(Arm64.stp(19, 20, 31, 0));
// Setup registers
self.code.emit(Arm64.mov(19, 1)); // x19 = memory base
self.code.emit(Arm64.mov(20, 2)); // x20 = globals base
// Store first parameter
if (self.param_count > 0) {
self.code.emit(Arm64.str(0, 31, self.locals_offset));
}
// Initialize locals to 0
for (self.param_count..total_locals) |i| {
self.code.emit(Arm64.movz(8, 0, 0));
self.code.emit(Arm64.str(8, 31, self.locals_offset + @as(u12, @intCast(i * 8))));
}
// Parse and compile
var r = WasmReader.init(func.code);
self.stack_depth = 0;
self.block_depth = 0;
// Push implicit function block
self.blocks[0] = .{
.start_pos = self.code.pos(),
.patch_sites = undefined,
.patch_count = 0,
.is_loop = false,
};
self.block_depth = 1;
while (r.pos < r.data.len) {
const opcode = r.readU8();
switch (opcode) {
0x00, 0x01 => {}, // unreachable, nop
0x02 => { // block
_ = r.readI32Leb128();
self.blocks[self.block_depth] = .{
.start_pos = self.code.pos(),
.patch_sites = undefined,
.patch_count = 0,
.is_loop = false,
};
self.block_depth += 1;
},
0x03 => { // loop
_ = r.readI32Leb128();
self.blocks[self.block_depth] = .{
.start_pos = self.code.pos(),
.patch_sites = undefined,
.patch_count = 0,
.is_loop = true,
};
self.block_depth += 1;
},
0x0B => { // end
self.block_depth -= 1;
if (self.block_depth == 0) break; // Function end
const block = &self.blocks[self.block_depth];
const end_pos = self.code.pos();
// Patch forward branches
for (0..block.patch_count) |i| {
const patch_pos = block.patch_sites[i];
const offset: i32 = @intCast(@as(isize, @intCast(end_pos)) - @as(isize, @intCast(patch_pos)));
const old_inst = self.code.code[patch_pos * 4 ..][0..4];
const old: u32 = @bitCast(old_inst.*);
if ((old & 0xFC000000) == 0x14000000) {
self.code.patch(patch_pos, Arm64.b(@truncate(offset)));
} else if ((old & 0xFF000000) == 0x35000000) {
const rt: u5 = @truncate(old);
self.code.patch(patch_pos, Arm64.cbnzW(rt, @truncate(offset)));
}
}
},
0x0C => { // br
const depth = r.readU32Leb128();
const target_block = self.block_depth - 1 - depth;
const block = &self.blocks[target_block];
if (block.is_loop) {
const offset: i26 = @intCast(@as(isize, @intCast(block.start_pos)) - @as(isize, @intCast(self.code.pos())));
self.code.emit(Arm64.b(offset));
} else {
block.patch_sites[block.patch_count] = self.code.pos();
block.patch_count += 1;
self.code.emit(Arm64.b(0));
}
},
0x0D => { // br_if
const depth = r.readU32Leb128();
self.stack_depth -= 1;
const cond_reg = getIntReg(self.stack_depth);
const target_block = self.block_depth - 1 - depth;
const block = &self.blocks[target_block];
if (block.is_loop) {
const offset: i19 = @intCast(@as(isize, @intCast(block.start_pos)) - @as(isize, @intCast(self.code.pos())) - 1);
self.code.emit(Arm64.cbnzW(cond_reg, offset));
} else {
block.patch_sites[block.patch_count] = self.code.pos();
block.patch_count += 1;
self.code.emit(Arm64.cbnzW(cond_reg, 0));
}
},
0x20 => { // local.get
const idx = r.readU32Leb128();
const target = getIntReg(self.stack_depth);
self.code.emit(Arm64.ldr(target, 31, self.locals_offset + @as(u12, @intCast(idx * 8))));
self.stack_depth += 1;
},
0x21 => { // local.set
const idx = r.readU32Leb128();
self.stack_depth -= 1;
const src = getIntReg(self.stack_depth);
self.code.emit(Arm64.str(src, 31, self.locals_offset + @as(u12, @intCast(idx * 8))));
},
0x22 => { // local.tee
const idx = r.readU32Leb128();
const src = getIntReg(self.stack_depth - 1);
self.code.emit(Arm64.str(src, 31, self.locals_offset + @as(u12, @intCast(idx * 8))));
},
0x23 => { // global.get
const idx = r.readU32Leb128();
const target = getIntReg(self.stack_depth);
self.code.emit(Arm64.ldrW(target, 20, @intCast(idx * 4)));
self.stack_depth += 1;
},
0x28 => { // i32.load
_ = r.readU32Leb128(); // align
const offset = r.readU32Leb128();
self.stack_depth -= 1;
const addr = getIntReg(self.stack_depth);
const target = getIntReg(self.stack_depth);
if (offset > 0) {
self.emitMovImm32(8, @intCast(offset));
self.code.emit(Arm64.addW(addr, addr, 8));
}
self.code.emit(Arm64.ldrWReg(target, 19, addr));
self.stack_depth += 1;
},
0x36 => { // i32.store
_ = r.readU32Leb128();
const offset = r.readU32Leb128();
self.stack_depth -= 1;
const value = getIntReg(self.stack_depth);
self.stack_depth -= 1;
const addr = getIntReg(self.stack_depth);
if (offset > 0) {
self.emitMovImm32(8, @intCast(offset));
self.code.emit(Arm64.addW(addr, addr, 8));
}
self.code.emit(Arm64.strWReg(value, 19, addr));
},
0x2A => { // f32.load
_ = r.readU32Leb128();
const offset = r.readU32Leb128();
self.stack_depth -= 1;
const addr = getIntReg(self.stack_depth);
const target = getFloatReg(self.stack_depth);
if (offset > 0) {
self.emitMovImm32(8, @intCast(offset));
self.code.emit(Arm64.addW(addr, addr, 8));
}
self.code.emit(Arm64.ldrSReg(target, 19, addr));
self.stack_depth += 1;
},
0x38 => { // f32.store
_ = r.readU32Leb128();
const offset = r.readU32Leb128();
self.stack_depth -= 1;
const value = getFloatReg(self.stack_depth);
self.stack_depth -= 1;
const addr = getIntReg(self.stack_depth);
if (offset > 0) {
self.emitMovImm32(8, @intCast(offset));
self.code.emit(Arm64.addW(addr, addr, 8));
}
self.code.emit(Arm64.strSReg(value, 19, addr));
},
0x41 => { // i32.const
const val = r.readI32Leb128();
const target = getIntReg(self.stack_depth);
self.emitMovImm32(target, val);
self.stack_depth += 1;
},
0x42 => { // i64.const
const val = r.readI64Leb128();
const target = getIntReg(self.stack_depth);
self.emitMovImm64(target, val);
self.stack_depth += 1;
},
0x43 => { // f32.const
const val = r.readF32();
const target = getFloatReg(self.stack_depth);
const bits: u32 = @bitCast(val);
self.emitMovImm32(8, @bitCast(bits));
self.code.emit(Arm64.fmovSFromW(target, 8));
self.stack_depth += 1;
},
0x6A => { // i32.add
self.stack_depth -= 1;
const b_reg = getIntReg(self.stack_depth);
self.stack_depth -= 1;
const a_reg = getIntReg(self.stack_depth);
const target = getIntReg(self.stack_depth);
self.code.emit(Arm64.addW(target, a_reg, b_reg));
self.stack_depth += 1;
},
0x6B => { // i32.sub
self.stack_depth -= 1;
const b_reg = getIntReg(self.stack_depth);
self.stack_depth -= 1;
const a_reg = getIntReg(self.stack_depth);
const target = getIntReg(self.stack_depth);
self.code.emit(Arm64.subW(target, a_reg, b_reg));
self.stack_depth += 1;
},
0x6C => { // i32.mul
self.stack_depth -= 1;
const b_reg = getIntReg(self.stack_depth);
self.stack_depth -= 1;
const a_reg = getIntReg(self.stack_depth);
const target = getIntReg(self.stack_depth);
self.code.emit(Arm64.mulW(target, a_reg, b_reg));
self.stack_depth += 1;
},
0x4D => { // i32.le_u
self.stack_depth -= 1;
const b_reg = getIntReg(self.stack_depth);
self.stack_depth -= 1;
const a_reg = getIntReg(self.stack_depth);
const target = getIntReg(self.stack_depth);
self.code.emit(Arm64.cmpW(a_reg, b_reg));
self.code.emit(Arm64.csetW(target, Arm64.COND_LS));
self.stack_depth += 1;
},
0x4F => { // i32.ge_u
self.stack_depth -= 1;
const b_reg = getIntReg(self.stack_depth);
self.stack_depth -= 1;
const a_reg = getIntReg(self.stack_depth);
const target = getIntReg(self.stack_depth);
self.code.emit(Arm64.cmpW(a_reg, b_reg));
self.code.emit(Arm64.csetW(target, Arm64.COND_HS));
self.stack_depth += 1;
},
0xA7 => { // i32.wrap_i64
self.stack_depth -= 1;
const src = getIntReg(self.stack_depth);
const target = getIntReg(self.stack_depth);
if (target != src) {
self.code.emit(Arm64.movW(target, src));
}
self.stack_depth += 1;
},
0x92 => { // f32.add
self.stack_depth -= 1;
const b_reg = getFloatReg(self.stack_depth);
self.stack_depth -= 1;
const a_reg = getFloatReg(self.stack_depth);
const target = getFloatReg(self.stack_depth);
self.code.emit(Arm64.faddS(target, a_reg, b_reg));
self.stack_depth += 1;
},
0x94 => { // f32.mul
self.stack_depth -= 1;
const b_reg = getFloatReg(self.stack_depth);
self.stack_depth -= 1;
const a_reg = getFloatReg(self.stack_depth);
const target = getFloatReg(self.stack_depth);
self.code.emit(Arm64.fmulS(target, a_reg, b_reg));
self.stack_depth += 1;
},
else => {
std.debug.print("Unsupported opcode: 0x{X:0>2}\n", .{opcode});
@panic("Unsupported opcode");
},
}
}
// Epilogue
if (self.stack_depth > 0) {
const result_reg = getIntReg(self.stack_depth - 1);
if (result_reg != 0) {
self.code.emit(Arm64.mov(0, result_reg));
}
} else {
self.code.emit(Arm64.movz(0, 0, 0));
}
self.code.emit(Arm64.ldp(19, 20, 31, 0));
self.code.emit(Arm64.addImm(31, 31, frame_size));
self.code.emit(0xA8C17BFD); // ldp x29, x30, [sp], #16
self.code.emit(Arm64.ret());
}
};
// ============================================================================
// Main
// ============================================================================
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
const args = try std.process.argsAlloc(allocator);
defer std.process.argsFree(allocator, args);
if (args.len < 2) {
std.debug.print("Usage: {s} <wasm-file> [N]\n", .{args[0]});
return;
}
// Read wasm file
const file = try fs.cwd().openFile(args[1], .{});
defer file.close();
const data = try file.readToEndAlloc(allocator, 1024 * 1024);
defer allocator.free(data);
std.debug.print("Loaded {d} bytes of wasm\n", .{data.len});
// Parse module
var mod = try WasmModule.parse(allocator, data);
defer mod.deinit();
std.debug.print("Parsed: {d} types, {d} funcs, {d} globals\n", .{ mod.types.len, mod.funcs.len, mod.globals.len });
const run_idx = mod.run_func_idx orelse {
std.debug.print("No 'run' export found\n", .{});
return;
};
std.debug.print("Found export 'run' -> func[{d}]\n", .{run_idx});
// Allocate memory
const mem_size = if (mod.min_memory_pages > 0) mod.min_memory_pages * 65536 else 65536;
const memory = try allocator.alloc(u8, mem_size);
defer allocator.free(memory);
@memset(memory, 0);
// Allocate globals
const globals = try allocator.alloc(i32, mod.globals.len + 1);
defer allocator.free(globals);
for (mod.globals, 0..) |g, i| {
globals[i] = @truncate(g.init_val);
std.debug.print("Global[{d}] = {d}\n", .{ i, globals[i] });
}
// Create JIT
var code = try CodeBuffer.init(65536);
defer code.deinit();
code.beginWrite();
var jit = JitCompiler.init(&code, &mod);
jit.compileFunction(&mod.funcs[run_idx]);
code.endWrite();
std.debug.print("Generated {d} instructions\n", .{code.size});
// Get argument
const arg: i64 = if (args.len > 2) try std.fmt.parseInt(i64, args[2], 10) else 10;
std.debug.print("Running with N={d}\n", .{arg});
// Call JIT function
const JitFunc = *const fn (i64, [*]u8, [*]i32) callconv(.c) i64;
const func = code.getFunction(JitFunc);
const result = func(arg, memory.ptr, globals.ptr);
std.debug.print("JIT result: {d}\n", .{result});
}
;; benchmark target
(module
;; 8 pages satisfy the memory requirement for roughly N <= 200.
(memory 8 8)
(global $size_of_f32 i32 (i32.const 4))
;; Matrix multiplication of 2 NxN matrices.
;; The function returns 0 to keep it aligned to other benchmark functions.
;;
;; This function uses the linear memory in the following way:
;;
;; - mem[0..N*N): `lhs` matrix
;; - mem[N*N..N*N*2): `rhs` matrix
;; - mem[N*N*2..N*N*3): `result` matrix
;;
;; - There need to be enough linear memory pages to provide
;; at least N*N*3 elements to operate on to run this function
;; - Each element is a `f32` and occupies 4 bytes.
;; - There is no padding or alignment for the matrices to keep it simple.
;;
;; Implements the following pseudo-code:
;;
;; fn matmul(n: i64) -> i64 {
;; offset_lhs = 0
;; offset_rhs = n*n
;; offset_res = rhs*2;
;; for i in 0..n {
;; for j in 0..n {
;; mem[offset_res + (i * n) + j] = 0
;; for k in 0..n {
;; mem[offset_res + (i * n) + j] += mem[offset_lhs + (i * n) + k] * mem[offset_rhs + (k * n) + j]
;; }
;; }
;; }
;; }
(func (export "run") (param $N i64) (result i64)
(local $offset_lhs i32) ;; offset in bytes to `lhs` matrix
(local $offset_rhs i32) ;; offset in bytes to `rhs` matrix
(local $offset_res i32) ;; offset in bytes to result matrix
(local $n i32)
(local $i i32)
(local $j i32)
(local $k i32)
(local $tmp i32)
;; n = N as i32
(local.set $n (i32.wrap_i64 (local.get $N)))
;; offset_lhs = 0
(local.set $offset_lhs (i32.const 0))
;; offset_rhs = N * N
(local.set $offset_rhs (i32.mul (local.get $n) (local.get $n)))
;; offset_res = offset_rhs * 2
(local.set $offset_res (i32.mul (local.get $offset_rhs) (i32.const 2)))
(block $break_i
;; i = 0
(local.set $i (i32.const 0))
(loop $continue_i
;; if i >= n: break
(br_if $break_i (i32.ge_u (local.get $i) (local.get $n)))
(block $break_j
;; j = 0
(local.set $j (i32.const 0))
(loop $continue_j
;; if j >= n: break
(br_if $break_j (i32.ge_u (local.get $j) (local.get $n)))
;; tmp = offset_res + (i * n) + j
(local.set $tmp
(i32.mul
(i32.add
(local.get $offset_res)
(i32.add
(i32.mul (local.get $i) (local.get $n))
(local.get $j)
)
)
(global.get $size_of_f32)
)
)
;; mem[tmp] = 0
(f32.store (local.get $tmp) (f32.const 0.0))
(block $break_k
;; k = 0
(local.set $k (i32.const 0))
(loop $continue_k
;; if k >= n: break
(br_if $break_k (i32.ge_u (local.get $k) (local.get $n)))
;; mem[tmp] += mem[offset_lhs + (i * n) + k] * mem[offset_rhs + (k * n) + j]
(f32.store
(local.get $tmp)
(f32.add
(f32.load (local.get $tmp))
(f32.mul
(f32.load
(i32.mul
;; offset_lhs + (i * n) + k
(i32.add
(local.get $offset_lhs)
(i32.add
(i32.mul (local.get $i) (local.get $n))
(local.get $k)
)
)
(global.get $size_of_f32)
)
)
(f32.load
(i32.mul
;; offset_rhs + (k * n) + j
(i32.add
(local.get $offset_rhs)
(i32.add
(i32.mul (local.get $k) (local.get $n))
(local.get $j)
)
)
(global.get $size_of_f32)
)
)
)
)
)
;; k += 1
(local.set $k (i32.add (local.get $k) (i32.const 1)))
(br $continue_k)
)
)
;; j += 1
(local.set $j (i32.add (local.get $j) (i32.const 1)))
(br $continue_j)
)
)
;; i += 1
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $continue_i)
)
)
(i64.const 0)
)
)
// Minimal Baseline JIT for WebAssembly (ARM64)
// Supports matmul benchmark subset of instructions
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <sys/mman.h>
#ifdef __APPLE__
#include <pthread.h>
#include <libkern/OSCacheControl.h>
#endif
// ============================================================================
// Debug macros
// ============================================================================
#define DEBUG 0
#if DEBUG
#define DBG(...) fprintf(stderr, __VA_ARGS__)
#else
#define DBG(...)
#endif
// ============================================================================
// Code buffer for JIT compilation
// ============================================================================
typedef struct {
uint32_t *code;
size_t capacity;
size_t size;
} CodeBuffer;
CodeBuffer *codebuf_new(size_t capacity) {
CodeBuffer *buf = malloc(sizeof(CodeBuffer));
#ifdef __APPLE__
buf->code = mmap(NULL, capacity * sizeof(uint32_t),
PROT_READ | PROT_WRITE | PROT_EXEC,
MAP_PRIVATE | MAP_ANONYMOUS | MAP_JIT, -1, 0);
#else
buf->code = mmap(NULL, capacity * sizeof(uint32_t),
PROT_READ | PROT_WRITE | PROT_EXEC,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
#endif
if (buf->code == MAP_FAILED) {
perror("mmap");
exit(1);
}
buf->capacity = capacity;
buf->size = 0;
return buf;
}
void codebuf_begin_write(CodeBuffer *buf) {
#ifdef __APPLE__
pthread_jit_write_protect_np(0);
#endif
(void)buf;
}
void codebuf_end_write(CodeBuffer *buf) {
#ifdef __APPLE__
pthread_jit_write_protect_np(1);
sys_icache_invalidate(buf->code, buf->size * sizeof(uint32_t));
#else
__builtin___clear_cache((char*)buf->code, (char*)(buf->code + buf->size));
#endif
}
void codebuf_emit(CodeBuffer *buf, uint32_t inst) {
if (buf->size >= buf->capacity) {
fprintf(stderr, "Code buffer overflow\n");
exit(1);
}
buf->code[buf->size++] = inst;
}
size_t codebuf_pos(CodeBuffer *buf) {
return buf->size;
}
void codebuf_patch(CodeBuffer *buf, size_t pos, uint32_t inst) {
buf->code[pos] = inst;
}
void codebuf_free(CodeBuffer *buf) {
munmap(buf->code, buf->capacity * sizeof(uint32_t));
free(buf);
}
// ============================================================================
// ARM64 instruction encoding helpers
// ============================================================================
// Registers:
// x0-x7: argument/return registers
// x8: indirect result
// x9-x15: caller-saved temps
// x16-x17: intra-procedure-call
// x18: platform register
// x19-x28: callee-saved
// x29: frame pointer
// x30: link register
// sp: stack pointer
// For our JIT:
// x0: first argument (N for matmul)
// x19: memory base
// x20: globals base
// x21: locals base (on native stack)
// x9-x15: wasm value stack (mapped dynamically)
// s0-s15: float registers for wasm stack
// MOV immediate (MOVZ - zero and move)
uint32_t arm64_movz(int rd, uint16_t imm, int shift) {
// shift: 0=bits[15:0], 1=bits[31:16], 2=bits[47:32], 3=bits[63:48]
return 0xD2800000 | (shift << 21) | ((uint32_t)imm << 5) | rd;
}
// MOV immediate (MOVK - keep other bits)
uint32_t arm64_movk(int rd, uint16_t imm, int shift) {
return 0xF2800000 | (shift << 21) | ((uint32_t)imm << 5) | rd;
}
// MOV register (64-bit)
uint32_t arm64_mov(int rd, int rm) {
// ORR rd, xzr, rm
return 0xAA0003E0 | (rm << 16) | rd;
}
// MOV register (32-bit)
uint32_t arm64_mov_w(int rd, int rm) {
// ORR wd, wzr, wm
return 0x2A0003E0 | (rm << 16) | rd;
}
// ADD (64-bit register)
uint32_t arm64_add(int rd, int rn, int rm) {
return 0x8B000000 | (rm << 16) | (rn << 5) | rd;
}
// ADD (32-bit register)
uint32_t arm64_add_w(int rd, int rn, int rm) {
return 0x0B000000 | (rm << 16) | (rn << 5) | rd;
}
// ADD immediate (64-bit)
uint32_t arm64_add_imm(int rd, int rn, uint16_t imm) {
return 0x91000000 | ((uint32_t)imm << 10) | (rn << 5) | rd;
}
// SUB (64-bit register)
uint32_t arm64_sub(int rd, int rn, int rm) {
return 0xCB000000 | (rm << 16) | (rn << 5) | rd;
}
// SUB (32-bit register)
uint32_t arm64_sub_w(int rd, int rn, int rm) {
return 0x4B000000 | (rm << 16) | (rn << 5) | rd;
}
// SUB immediate (64-bit)
uint32_t arm64_sub_imm(int rd, int rn, uint16_t imm) {
return 0xD1000000 | ((uint32_t)imm << 10) | (rn << 5) | rd;
}
// MUL (64-bit)
uint32_t arm64_mul(int rd, int rn, int rm) {
return 0x9B007C00 | (rm << 16) | (rn << 5) | rd;
}
// MUL (32-bit)
uint32_t arm64_mul_w(int rd, int rn, int rm) {
return 0x1B007C00 | (rm << 16) | (rn << 5) | rd;
}
// LDR (64-bit from [base + unsigned offset])
uint32_t arm64_ldr(int rt, int rn, int offset) {
// offset must be 8-byte aligned
return 0xF9400000 | ((offset / 8) << 10) | (rn << 5) | rt;
}
// LDR (32-bit from [base + unsigned offset])
uint32_t arm64_ldr_w(int rt, int rn, int offset) {
// offset must be 4-byte aligned
return 0xB9400000 | ((offset / 4) << 10) | (rn << 5) | rt;
}
// LDR (32-bit from [base + register])
uint32_t arm64_ldr_w_reg(int rt, int rn, int rm) {
// LSL #0 (no scale)
return 0xB8606800 | (rm << 16) | (rn << 5) | rt;
}
// STR (64-bit to [base + unsigned offset])
uint32_t arm64_str(int rt, int rn, int offset) {
return 0xF9000000 | ((offset / 8) << 10) | (rn << 5) | rt;
}
// STR (32-bit to [base + unsigned offset])
uint32_t arm64_str_w(int rt, int rn, int offset) {
return 0xB9000000 | ((offset / 4) << 10) | (rn << 5) | rt;
}
// STR (32-bit to [base + register])
uint32_t arm64_str_w_reg(int rt, int rn, int rm) {
return 0xB8206800 | (rm << 16) | (rn << 5) | rt;
}
// STP (store pair, 64-bit)
uint32_t arm64_stp(int rt1, int rt2, int rn, int offset) {
// offset in units of 8 bytes, signed 7-bit
int imm7 = (offset / 8) & 0x7F;
return 0xA9000000 | (imm7 << 15) | (rt2 << 10) | (rn << 5) | rt1;
}
// LDP (load pair, 64-bit)
uint32_t arm64_ldp(int rt1, int rt2, int rn, int offset) {
int imm7 = (offset / 8) & 0x7F;
return 0xA9400000 | (imm7 << 15) | (rt2 << 10) | (rn << 5) | rt1;
}
// LDR single float (s-register from [base + unsigned offset])
uint32_t arm64_ldr_s(int rt, int rn, int offset) {
return 0xBD400000 | ((offset / 4) << 10) | (rn << 5) | rt;
}
// STR single float (s-register to [base + unsigned offset])
uint32_t arm64_str_s(int rt, int rn, int offset) {
return 0xBD000000 | ((offset / 4) << 10) | (rn << 5) | rt;
}
// LDR single float (s-register from [base + register])
uint32_t arm64_ldr_s_reg(int rt, int rn, int rm) {
return 0xBC606800 | (rm << 16) | (rn << 5) | rt;
}
// STR single float (s-register to [base + register])
uint32_t arm64_str_s_reg(int rt, int rn, int rm) {
return 0xBC206800 | (rm << 16) | (rn << 5) | rt;
}
// FMOV (move GP to float register)
uint32_t arm64_fmov_s_from_w(int sd, int wn) {
return 0x1E270000 | (wn << 5) | sd;
}
// FMOV (move float to GP register)
uint32_t arm64_fmov_w_from_s(int wd, int sn) {
return 0x1E260000 | (sn << 5) | wd;
}
// FADD (single)
uint32_t arm64_fadd_s(int rd, int rn, int rm) {
return 0x1E202800 | (rm << 16) | (rn << 5) | rd;
}
// FMUL (single)
uint32_t arm64_fmul_s(int rd, int rn, int rm) {
return 0x1E200800 | (rm << 16) | (rn << 5) | rd;
}
// CMP (64-bit register) - actually SUBS xzr, rn, rm
uint32_t arm64_cmp(int rn, int rm) {
return 0xEB00001F | (rm << 16) | (rn << 5);
}
// CMP (32-bit register)
uint32_t arm64_cmp_w(int rn, int rm) {
return 0x6B00001F | (rm << 16) | (rn << 5);
}
// CSET (set register to 1 if condition true)
// CSET rd, cond == CSINC rd, xzr, xzr, invert(cond)
uint32_t arm64_cset_w(int rd, int cond) {
int inv_cond = cond ^ 1; // invert condition
return 0x1A9F07E0 | (inv_cond << 12) | rd;
}
// B (unconditional branch, PC-relative)
uint32_t arm64_b(int offset) {
// offset is in instructions (4 bytes each)
return 0x14000000 | (offset & 0x3FFFFFF);
}
// B.cond (conditional branch)
uint32_t arm64_bcond(int cond, int offset) {
return 0x54000000 | ((offset & 0x7FFFF) << 5) | cond;
}
// Condition codes
#define COND_EQ 0
#define COND_NE 1
#define COND_HS 2 // unsigned >=
#define COND_LO 3 // unsigned <
#define COND_HI 8 // unsigned >
#define COND_LS 9 // unsigned <=
#define COND_GE 10 // signed >=
#define COND_LT 11 // signed <
#define COND_GT 12 // signed >
#define COND_LE 13 // signed <=
// CBZ (compare and branch if zero)
uint32_t arm64_cbz_w(int rt, int offset) {
return 0x34000000 | ((offset & 0x7FFFF) << 5) | rt;
}
// CBNZ (compare and branch if non-zero)
uint32_t arm64_cbnz_w(int rt, int offset) {
return 0x35000000 | ((offset & 0x7FFFF) << 5) | rt;
}
// RET
uint32_t arm64_ret(void) {
return 0xD65F03C0;
}
// BL (branch with link - function call)
uint32_t arm64_bl(int offset) {
return 0x94000000 | (offset & 0x3FFFFFF);
}
// NOP
uint32_t arm64_nop(void) {
return 0xD503201F;
}
// ============================================================================
// Wasm parser
// ============================================================================
typedef struct {
uint8_t *data;
size_t size;
size_t pos;
} WasmReader;
uint8_t read_u8(WasmReader *r) {
if (r->pos >= r->size) {
fprintf(stderr, "Unexpected end of wasm\n");
exit(1);
}
return r->data[r->pos++];
}
uint32_t read_u32_leb128(WasmReader *r) {
uint32_t result = 0;
int shift = 0;
while (1) {
uint8_t byte = read_u8(r);
result |= (uint32_t)(byte & 0x7F) << shift;
if ((byte & 0x80) == 0) break;
shift += 7;
}
return result;
}
int32_t read_i32_leb128(WasmReader *r) {
int32_t result = 0;
int shift = 0;
uint8_t byte;
do {
byte = read_u8(r);
result |= (int32_t)(byte & 0x7F) << shift;
shift += 7;
} while (byte & 0x80);
if (shift < 32 && (byte & 0x40)) {
result |= (~0 << shift);
}
return result;
}
int64_t read_i64_leb128(WasmReader *r) {
int64_t result = 0;
int shift = 0;
uint8_t byte;
do {
byte = read_u8(r);
result |= ((int64_t)(byte & 0x7F)) << shift;
shift += 7;
} while (byte & 0x80);
if (shift < 64 && (byte & 0x40)) {
result |= (~0LL << shift);
}
return result;
}
float read_f32(WasmReader *r) {
float f;
memcpy(&f, &r->data[r->pos], 4);
r->pos += 4;
return f;
}
void skip_bytes(WasmReader *r, size_t n) {
r->pos += n;
}
// ============================================================================
// Wasm module structures
// ============================================================================
typedef struct {
int param_count;
int result_count;
uint8_t *param_types; // array of value types
uint8_t *result_types;
} WasmFuncType;
typedef struct {
uint32_t type_idx;
uint32_t local_count;
uint8_t *local_types; // expanded - one per local
uint8_t *code;
size_t code_size;
} WasmFunc;
typedef struct {
uint8_t type; // 0x7F = i32, 0x7E = i64, 0x7D = f32, 0x7C = f64
uint8_t mutable;
int64_t init_val; // initial value (simplified - assume i32.const or i64.const)
} WasmGlobal;
typedef struct {
uint32_t min_pages;
uint32_t max_pages;
int has_max;
} WasmMemory;
typedef struct {
WasmFuncType *types;
uint32_t type_count;
WasmFunc *funcs;
uint32_t func_count;
uint32_t import_func_count;
WasmGlobal *globals;
uint32_t global_count;
WasmMemory memory;
int has_memory;
// Export info
char *run_export_name;
uint32_t run_func_idx;
} WasmModule;
// Parse wasm module
WasmModule *parse_wasm(uint8_t *data, size_t size) {
WasmReader r = { data, size, 0 };
// Check magic and version
uint32_t magic = read_u8(&r) | (read_u8(&r) << 8) | (read_u8(&r) << 16) | (read_u8(&r) << 24);
uint32_t version = read_u8(&r) | (read_u8(&r) << 8) | (read_u8(&r) << 16) | (read_u8(&r) << 24);
if (magic != 0x6D736100) {
fprintf(stderr, "Invalid wasm magic\n");
exit(1);
}
if (version != 1) {
fprintf(stderr, "Unsupported wasm version: %u\n", version);
exit(1);
}
WasmModule *mod = calloc(1, sizeof(WasmModule));
// Temporary array to hold function type indices
uint32_t *func_type_indices = NULL;
uint32_t func_type_count = 0;
// Parse sections
while (r.pos < r.size) {
uint8_t section_id = read_u8(&r);
uint32_t section_size = read_u32_leb128(&r);
size_t section_end = r.pos + section_size;
DBG("Section %d, size %u\n", section_id, section_size);
switch (section_id) {
case 1: { // Type section
mod->type_count = read_u32_leb128(&r);
mod->types = calloc(mod->type_count, sizeof(WasmFuncType));
for (uint32_t i = 0; i < mod->type_count; i++) {
uint8_t form = read_u8(&r);
if (form != 0x60) {
fprintf(stderr, "Expected function type\n");
exit(1);
}
mod->types[i].param_count = read_u32_leb128(&r);
mod->types[i].param_types = malloc(mod->types[i].param_count);
for (int j = 0; j < mod->types[i].param_count; j++) {
mod->types[i].param_types[j] = read_u8(&r);
}
mod->types[i].result_count = read_u32_leb128(&r);
mod->types[i].result_types = malloc(mod->types[i].result_count);
for (int j = 0; j < mod->types[i].result_count; j++) {
mod->types[i].result_types[j] = read_u8(&r);
}
}
break;
}
case 2: { // Import section
uint32_t import_count = read_u32_leb128(&r);
for (uint32_t i = 0; i < import_count; i++) {
uint32_t mod_len = read_u32_leb128(&r);
skip_bytes(&r, mod_len);
uint32_t name_len = read_u32_leb128(&r);
skip_bytes(&r, name_len);
uint8_t kind = read_u8(&r);
if (kind == 0) { // function import
read_u32_leb128(&r); // type index
mod->import_func_count++;
} else if (kind == 1) { // table
read_u8(&r); // reftype
uint8_t flags = read_u8(&r);
read_u32_leb128(&r); // min
if (flags & 1) read_u32_leb128(&r); // max
} else if (kind == 2) { // memory
uint8_t flags = read_u8(&r);
read_u32_leb128(&r); // min
if (flags & 1) read_u32_leb128(&r); // max
} else if (kind == 3) { // global
read_u8(&r); // type
read_u8(&r); // mutable
}
}
break;
}
case 3: { // Function section
func_type_count = read_u32_leb128(&r);
func_type_indices = malloc(func_type_count * sizeof(uint32_t));
for (uint32_t i = 0; i < func_type_count; i++) {
func_type_indices[i] = read_u32_leb128(&r);
}
break;
}
case 5: { // Memory section
uint32_t count = read_u32_leb128(&r);
if (count > 0) {
mod->has_memory = 1;
uint8_t flags = read_u8(&r);
mod->memory.min_pages = read_u32_leb128(&r);
if (flags & 1) {
mod->memory.has_max = 1;
mod->memory.max_pages = read_u32_leb128(&r);
}
}
break;
}
case 6: { // Global section
mod->global_count = read_u32_leb128(&r);
mod->globals = calloc(mod->global_count, sizeof(WasmGlobal));
for (uint32_t i = 0; i < mod->global_count; i++) {
mod->globals[i].type = read_u8(&r);
mod->globals[i].mutable = read_u8(&r);
// Parse init expression (simplified)
uint8_t opcode = read_u8(&r);
if (opcode == 0x41) { // i32.const
mod->globals[i].init_val = read_i32_leb128(&r);
} else if (opcode == 0x42) { // i64.const
mod->globals[i].init_val = read_i64_leb128(&r);
} else {
fprintf(stderr, "Unsupported global init: 0x%02X\n", opcode);
exit(1);
}
read_u8(&r); // end opcode (0x0B)
}
break;
}
case 7: { // Export section
uint32_t export_count = read_u32_leb128(&r);
for (uint32_t i = 0; i < export_count; i++) {
uint32_t name_len = read_u32_leb128(&r);
char *name = malloc(name_len + 1);
for (uint32_t j = 0; j < name_len; j++) {
name[j] = read_u8(&r);
}
name[name_len] = '\0';
uint8_t kind = read_u8(&r);
uint32_t idx = read_u32_leb128(&r);
if (kind == 0 && strcmp(name, "run") == 0) {
mod->run_export_name = name;
mod->run_func_idx = idx;
} else {
free(name);
}
}
break;
}
case 10: { // Code section
mod->func_count = read_u32_leb128(&r);
mod->funcs = calloc(mod->func_count, sizeof(WasmFunc));
for (uint32_t i = 0; i < mod->func_count; i++) {
uint32_t func_size = read_u32_leb128(&r);
size_t func_end = r.pos + func_size;
mod->funcs[i].type_idx = func_type_indices[i];
// Parse locals
uint32_t local_group_count = read_u32_leb128(&r);
uint32_t total_locals = 0;
// First pass: count total locals
size_t locals_start = r.pos;
for (uint32_t j = 0; j < local_group_count; j++) {
uint32_t count = read_u32_leb128(&r);
read_u8(&r); // type
total_locals += count;
}
mod->funcs[i].local_count = total_locals;
mod->funcs[i].local_types = malloc(total_locals);
// Second pass: expand types
r.pos = locals_start;
uint32_t local_idx = 0;
for (uint32_t j = 0; j < local_group_count; j++) {
uint32_t count = read_u32_leb128(&r);
uint8_t type = read_u8(&r);
for (uint32_t k = 0; k < count; k++) {
mod->funcs[i].local_types[local_idx++] = type;
}
}
// Store code
mod->funcs[i].code = &r.data[r.pos];
mod->funcs[i].code_size = func_end - r.pos;
r.pos = func_end;
}
break;
}
default:
skip_bytes(&r, section_size);
break;
}
r.pos = section_end;
}
if (func_type_indices) free(func_type_indices);
return mod;
}
void free_wasm_module(WasmModule *mod) {
for (uint32_t i = 0; i < mod->type_count; i++) {
free(mod->types[i].param_types);
free(mod->types[i].result_types);
}
free(mod->types);
for (uint32_t i = 0; i < mod->func_count; i++) {
free(mod->funcs[i].local_types);
}
free(mod->funcs);
free(mod->globals);
if (mod->run_export_name) free(mod->run_export_name);
free(mod);
}
// ============================================================================
// JIT Compiler
// ============================================================================
#define MAX_BLOCKS 64
#define MAX_STACK 64
typedef struct {
size_t start_pos; // instruction position at block start
size_t *patch_sites; // positions needing patching for br
int patch_count;
int is_loop;
} BlockInfo;
typedef struct {
CodeBuffer *code;
WasmModule *mod;
WasmFunc *func;
// Runtime pointers (passed as arguments)
// x19 = memory base
// x20 = globals base
// stack based locals (frame pointer relative)
// Virtual stack tracking
int stack_depth; // current wasm stack depth
// Stack values are stored in x9-x15 (for i32) or s0-s7 (for f32)
// We'll use a simple scheme: even positions = int regs, track types
uint8_t stack_types[MAX_STACK]; // 0x7F=i32, 0x7D=f32
// Block management
BlockInfo blocks[MAX_BLOCKS];
int block_depth;
// Locals frame offset
int locals_offset; // offset from sp for first local
int param_count;
int local_count;
} JitCompiler;
// Get integer register for stack position
int jit_get_int_reg(int stack_pos) {
// Use x9-x15 for stack positions 0-6
// For deeper stacks, we'd need to spill - simplified here
return 9 + (stack_pos % 7);
}
// Get float register for stack position
int jit_get_float_reg(int stack_pos) {
return stack_pos % 8; // s0-s7
}
// Emit code to load a 32-bit immediate
void jit_emit_mov_imm32(CodeBuffer *c, int rd, int32_t val) {
uint32_t uval = (uint32_t)val;
codebuf_emit(c, arm64_movz(rd, uval & 0xFFFF, 0));
if (uval > 0xFFFF) {
codebuf_emit(c, arm64_movk(rd, (uval >> 16) & 0xFFFF, 1));
}
}
// Emit code to load a 64-bit immediate
void jit_emit_mov_imm64(CodeBuffer *c, int rd, int64_t val) {
uint64_t uval = (uint64_t)val;
codebuf_emit(c, arm64_movz(rd, uval & 0xFFFF, 0));
if ((uval >> 16) & 0xFFFF) {
codebuf_emit(c, arm64_movk(rd, (uval >> 16) & 0xFFFF, 1));
}
if ((uval >> 32) & 0xFFFF) {
codebuf_emit(c, arm64_movk(rd, (uval >> 32) & 0xFFFF, 2));
}
if ((uval >> 48) & 0xFFFF) {
codebuf_emit(c, arm64_movk(rd, (uval >> 48) & 0xFFFF, 3));
}
}
void jit_push_i32(JitCompiler *jit, int reg) {
jit->stack_types[jit->stack_depth] = 0x7F;
int target = jit_get_int_reg(jit->stack_depth);
if (target != reg) {
codebuf_emit(jit->code, arm64_mov_w(target, reg));
}
jit->stack_depth++;
}
void jit_push_f32(JitCompiler *jit, int sreg) {
jit->stack_types[jit->stack_depth] = 0x7D;
// Float is already in the right s-register if using stack position
jit->stack_depth++;
}
int jit_pop_i32(JitCompiler *jit) {
jit->stack_depth--;
return jit_get_int_reg(jit->stack_depth);
}
int jit_pop_f32(JitCompiler *jit) {
jit->stack_depth--;
return jit_get_float_reg(jit->stack_depth);
}
int jit_peek_i32(JitCompiler *jit, int offset) {
return jit_get_int_reg(jit->stack_depth - 1 - offset);
}
void jit_compile_function(JitCompiler *jit, WasmFunc *func) {
CodeBuffer *c = jit->code;
jit->func = func;
WasmFuncType *ftype = &jit->mod->types[func->type_idx];
jit->param_count = ftype->param_count;
jit->local_count = func->local_count;
int total_locals = jit->param_count + jit->local_count;
// Frame layout: [x19,x20 (16 bytes)] [locals (total_locals * 8 bytes)]
int frame_size = (16 + total_locals * 8 + 15) & ~15; // 16-byte aligned
// Prologue
// stp x29, x30, [sp, #-16]! (push frame pointer and link register)
codebuf_emit(c, 0xA9BF7BFD); // stp x29, x30, [sp, #-16]!
// mov x29, sp (must use ADD since ORR x31 is XZR, not SP)
codebuf_emit(c, arm64_add_imm(29, 31, 0)); // add x29, sp, #0
// Allocate locals on stack
if (frame_size > 0) {
codebuf_emit(c, arm64_sub_imm(31, 31, frame_size));
}
// Save callee-saved registers we use (x19, x20)
codebuf_emit(c, arm64_stp(19, 20, 31, 0));
// x0 = first parameter (N for matmul)
// x1 = memory base (passed by caller)
// x2 = globals base (passed by caller)
// Setup our register conventions
codebuf_emit(c, arm64_mov(19, 1)); // x19 = memory base
codebuf_emit(c, arm64_mov(20, 2)); // x20 = globals base
// Store first parameter as local[0]
// Parameters go to stack-based locals
// For matmul: param is i64, local[0] needs it
// We store params at [sp + 16 + idx*8]
jit->locals_offset = 16; // after saved x19, x20
if (jit->param_count > 0) {
// x0 has the i64 parameter
codebuf_emit(c, arm64_str(0, 31, jit->locals_offset));
}
// Initialize other locals to 0
for (int i = jit->param_count; i < total_locals; i++) {
codebuf_emit(c, arm64_movz(8, 0, 0)); // x8 = 0
codebuf_emit(c, arm64_str(8, 31, jit->locals_offset + i * 8));
}
// Parse and compile function body
WasmReader r = { func->code, func->code_size, 0 };
jit->stack_depth = 0;
jit->block_depth = 0;
// Push implicit function block
jit->blocks[jit->block_depth].start_pos = codebuf_pos(c);
jit->blocks[jit->block_depth].patch_sites = calloc(256, sizeof(size_t));
jit->blocks[jit->block_depth].patch_count = 0;
jit->blocks[jit->block_depth].is_loop = 0;
jit->block_depth++;
while (r.pos < r.size) {
uint8_t opcode = read_u8(&r);
DBG(" opcode: 0x%02X at pos %zu, stack=%d, blocks=%d\n",
opcode, r.pos - 1, jit->stack_depth, jit->block_depth);
switch (opcode) {
case 0x00: // unreachable
// Could emit a trap instruction
break;
case 0x01: // nop
break;
case 0x02: { // block
read_i32_leb128(&r); // block type (ignore)
jit->blocks[jit->block_depth].start_pos = codebuf_pos(c);
jit->blocks[jit->block_depth].patch_sites = calloc(256, sizeof(size_t));
jit->blocks[jit->block_depth].patch_count = 0;
jit->blocks[jit->block_depth].is_loop = 0;
jit->block_depth++;
break;
}
case 0x03: { // loop
read_i32_leb128(&r); // block type
jit->blocks[jit->block_depth].start_pos = codebuf_pos(c);
jit->blocks[jit->block_depth].patch_sites = calloc(256, sizeof(size_t));
jit->blocks[jit->block_depth].patch_count = 0;
jit->blocks[jit->block_depth].is_loop = 1;
jit->block_depth++;
break;
}
case 0x0B: { // end
jit->block_depth--;
BlockInfo *block = &jit->blocks[jit->block_depth];
if (jit->block_depth == 0) {
// Function end - return
goto done_compiling;
}
// Patch forward branches
size_t end_pos = codebuf_pos(c);
for (int i = 0; i < block->patch_count; i++) {
size_t patch_pos = block->patch_sites[i];
int offset = (int)(end_pos - patch_pos);
uint32_t old_inst = c->code[patch_pos];
// Check instruction type by opcode pattern
if ((old_inst & 0xFC000000) == 0x14000000) {
// B (unconditional branch)
codebuf_patch(c, patch_pos, arm64_b(offset));
} else if ((old_inst & 0xFF000000) == 0x35000000) {
// CBNZ (Wn)
int rt = old_inst & 0x1F;
codebuf_patch(c, patch_pos, arm64_cbnz_w(rt, offset));
} else if ((old_inst & 0xFF000000) == 0x34000000) {
// CBZ (Wn)
int rt = old_inst & 0x1F;
codebuf_patch(c, patch_pos, arm64_cbz_w(rt, offset));
} else if ((old_inst & 0xFF000000) == 0x54000000) {
// B.cond
int cond = old_inst & 0xF;
codebuf_patch(c, patch_pos, arm64_bcond(cond, offset));
}
}
free(block->patch_sites);
break;
}
case 0x0C: { // br
uint32_t depth = read_u32_leb128(&r);
int target_block = jit->block_depth - 1 - depth;
BlockInfo *block = &jit->blocks[target_block];
if (block->is_loop) {
// Branch back to loop start
int offset = (int)block->start_pos - (int)codebuf_pos(c);
codebuf_emit(c, arm64_b(offset));
} else {
// Forward branch - patch later
block->patch_sites[block->patch_count++] = codebuf_pos(c);
codebuf_emit(c, arm64_b(0)); // placeholder
}
break;
}
case 0x0D: { // br_if
uint32_t depth = read_u32_leb128(&r);
int cond_reg = jit_pop_i32(jit);
int target_block = jit->block_depth - 1 - depth;
BlockInfo *block = &jit->blocks[target_block];
if (block->is_loop) {
// Branch back if non-zero
int offset = (int)block->start_pos - (int)codebuf_pos(c) - 1;
codebuf_emit(c, arm64_cbnz_w(cond_reg, offset));
} else {
// Forward branch - patch later
block->patch_sites[block->patch_count++] = codebuf_pos(c);
codebuf_emit(c, arm64_cbnz_w(cond_reg, 0)); // placeholder
}
break;
}
case 0x20: { // local.get
uint32_t idx = read_u32_leb128(&r);
int target = jit_get_int_reg(jit->stack_depth);
// Load from stack
codebuf_emit(c, arm64_ldr(target, 31, jit->locals_offset + idx * 8));
jit->stack_types[jit->stack_depth] = 0x7F; // assume i32 for now
jit->stack_depth++;
break;
}
case 0x21: { // local.set
uint32_t idx = read_u32_leb128(&r);
int src = jit_pop_i32(jit);
codebuf_emit(c, arm64_str(src, 31, jit->locals_offset + idx * 8));
break;
}
case 0x22: { // local.tee
uint32_t idx = read_u32_leb128(&r);
int src = jit_get_int_reg(jit->stack_depth - 1);
codebuf_emit(c, arm64_str(src, 31, jit->locals_offset + idx * 8));
break;
}
case 0x23: { // global.get
uint32_t idx = read_u32_leb128(&r);
int target = jit_get_int_reg(jit->stack_depth);
codebuf_emit(c, arm64_ldr_w(target, 20, idx * 4));
jit->stack_types[jit->stack_depth] = 0x7F;
jit->stack_depth++;
break;
}
case 0x24: { // global.set
uint32_t idx = read_u32_leb128(&r);
int src = jit_pop_i32(jit);
codebuf_emit(c, arm64_str_w(src, 20, idx * 4));
break;
}
case 0x28: { // i32.load
uint32_t align = read_u32_leb128(&r);
uint32_t offset = read_u32_leb128(&r);
(void)align;
int addr = jit_pop_i32(jit);
int target = jit_get_int_reg(jit->stack_depth);
// Compute effective address: memory_base + addr + offset
if (offset > 0) {
jit_emit_mov_imm32(c, 8, offset);
codebuf_emit(c, arm64_add_w(addr, addr, 8));
}
// target = *(memory_base + addr)
codebuf_emit(c, arm64_ldr_w_reg(target, 19, addr));
jit->stack_types[jit->stack_depth] = 0x7F;
jit->stack_depth++;
break;
}
case 0x36: { // i32.store
uint32_t align = read_u32_leb128(&r);
uint32_t offset = read_u32_leb128(&r);
(void)align;
int value = jit_pop_i32(jit);
int addr = jit_pop_i32(jit);
if (offset > 0) {
jit_emit_mov_imm32(c, 8, offset);
codebuf_emit(c, arm64_add_w(addr, addr, 8));
}
codebuf_emit(c, arm64_str_w_reg(value, 19, addr));
break;
}
case 0x2A: { // f32.load
uint32_t align = read_u32_leb128(&r);
uint32_t offset = read_u32_leb128(&r);
(void)align;
int addr = jit_pop_i32(jit);
int target = jit_get_float_reg(jit->stack_depth);
if (offset > 0) {
jit_emit_mov_imm32(c, 8, offset);
codebuf_emit(c, arm64_add_w(addr, addr, 8));
}
codebuf_emit(c, arm64_ldr_s_reg(target, 19, addr));
jit->stack_types[jit->stack_depth] = 0x7D;
jit->stack_depth++;
break;
}
case 0x38: { // f32.store
uint32_t align = read_u32_leb128(&r);
uint32_t offset = read_u32_leb128(&r);
(void)align;
int value = jit_pop_f32(jit);
int addr = jit_pop_i32(jit);
if (offset > 0) {
jit_emit_mov_imm32(c, 8, offset);
codebuf_emit(c, arm64_add_w(addr, addr, 8));
}
codebuf_emit(c, arm64_str_s_reg(value, 19, addr));
break;
}
case 0x41: { // i32.const
int32_t val = read_i32_leb128(&r);
int target = jit_get_int_reg(jit->stack_depth);
jit_emit_mov_imm32(c, target, val);
jit->stack_types[jit->stack_depth] = 0x7F;
jit->stack_depth++;
break;
}
case 0x42: { // i64.const
int64_t val = read_i64_leb128(&r);
int target = jit_get_int_reg(jit->stack_depth);
jit_emit_mov_imm64(c, target, val);
jit->stack_types[jit->stack_depth] = 0x7E; // i64
jit->stack_depth++;
break;
}
case 0x43: { // f32.const
float val = read_f32(&r);
int target = jit_get_float_reg(jit->stack_depth);
// Move float constant via integer register
uint32_t bits;
memcpy(&bits, &val, 4);
jit_emit_mov_imm32(c, 8, bits);
codebuf_emit(c, arm64_fmov_s_from_w(target, 8));
jit->stack_types[jit->stack_depth] = 0x7D;
jit->stack_depth++;
break;
}
case 0x6A: { // i32.add
int b = jit_pop_i32(jit);
int a = jit_pop_i32(jit);
int target = jit_get_int_reg(jit->stack_depth);
codebuf_emit(c, arm64_add_w(target, a, b));
jit->stack_types[jit->stack_depth] = 0x7F;
jit->stack_depth++;
break;
}
case 0x6B: { // i32.sub
int b = jit_pop_i32(jit);
int a = jit_pop_i32(jit);
int target = jit_get_int_reg(jit->stack_depth);
codebuf_emit(c, arm64_sub_w(target, a, b));
jit->stack_types[jit->stack_depth] = 0x7F;
jit->stack_depth++;
break;
}
case 0x6C: { // i32.mul
int b = jit_pop_i32(jit);
int a = jit_pop_i32(jit);
int target = jit_get_int_reg(jit->stack_depth);
codebuf_emit(c, arm64_mul_w(target, a, b));
jit->stack_types[jit->stack_depth] = 0x7F;
jit->stack_depth++;
break;
}
case 0x4D: { // i32.le_u
int b = jit_pop_i32(jit);
int a = jit_pop_i32(jit);
int target = jit_get_int_reg(jit->stack_depth);
codebuf_emit(c, arm64_cmp_w(a, b));
codebuf_emit(c, arm64_cset_w(target, COND_LS)); // unsigned <=
jit->stack_types[jit->stack_depth] = 0x7F;
jit->stack_depth++;
break;
}
case 0x4F: { // i32.ge_u
int b = jit_pop_i32(jit);
int a = jit_pop_i32(jit);
int target = jit_get_int_reg(jit->stack_depth);
codebuf_emit(c, arm64_cmp_w(a, b));
codebuf_emit(c, arm64_cset_w(target, COND_HS)); // unsigned >=
jit->stack_types[jit->stack_depth] = 0x7F;
jit->stack_depth++;
break;
}
case 0xA7: { // i32.wrap_i64
// Just truncate - value is already in 32-bit form in our regs
// Pop and push with same register (effectively no-op)
int src = jit_pop_i32(jit); // Actually i64 but we only use lower 32
int target = jit_get_int_reg(jit->stack_depth);
if (target != src) {
codebuf_emit(c, arm64_mov_w(target, src));
}
jit->stack_types[jit->stack_depth] = 0x7F;
jit->stack_depth++;
break;
}
case 0x92: { // f32.add
int b = jit_pop_f32(jit);
int a = jit_pop_f32(jit);
int target = jit_get_float_reg(jit->stack_depth);
codebuf_emit(c, arm64_fadd_s(target, a, b));
jit->stack_types[jit->stack_depth] = 0x7D;
jit->stack_depth++;
break;
}
case 0x94: { // f32.mul
int b = jit_pop_f32(jit);
int a = jit_pop_f32(jit);
int target = jit_get_float_reg(jit->stack_depth);
codebuf_emit(c, arm64_fmul_s(target, a, b));
jit->stack_types[jit->stack_depth] = 0x7D;
jit->stack_depth++;
break;
}
default:
fprintf(stderr, "Unsupported opcode: 0x%02X\n", opcode);
exit(1);
}
}
done_compiling:
// Epilogue
// Return value should be in x0 (or s0 for float)
// If stack has a value, move it to x0 for return
if (jit->stack_depth > 0) {
int result_reg = jit_get_int_reg(jit->stack_depth - 1);
if (result_reg != 0) {
codebuf_emit(c, arm64_mov(0, result_reg));
}
} else {
// No return value - return 0
codebuf_emit(c, arm64_movz(0, 0, 0));
}
// Restore callee-saved
codebuf_emit(c, arm64_ldp(19, 20, 31, 0));
// Deallocate frame
if (frame_size > 0) {
codebuf_emit(c, arm64_add_imm(31, 31, frame_size));
}
// ldp x29, x30, [sp], #16
codebuf_emit(c, 0xA8C17BFD);
codebuf_emit(c, arm64_ret());
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char **argv) {
if (argc < 2) {
fprintf(stderr, "Usage: %s <wasm-file> [args...]\n", argv[0]);
return 1;
}
// Read wasm file
FILE *f = fopen(argv[1], "rb");
if (!f) {
perror("fopen");
return 1;
}
fseek(f, 0, SEEK_END);
size_t size = ftell(f);
fseek(f, 0, SEEK_SET);
uint8_t *data = malloc(size);
fread(data, 1, size, f);
fclose(f);
printf("Loaded %zu bytes of wasm\n", size);
// Parse wasm module
WasmModule *mod = parse_wasm(data, size);
printf("Parsed: %u types, %u funcs, %u globals, memory=%d\n",
mod->type_count, mod->func_count, mod->global_count, mod->has_memory);
if (!mod->run_export_name) {
fprintf(stderr, "No 'run' export found\n");
return 1;
}
printf("Found export 'run' -> func[%u]\n", mod->run_func_idx);
// Allocate memory
size_t mem_size = mod->has_memory ? mod->memory.min_pages * 65536 : 65536;
uint8_t *memory = calloc(1, mem_size);
// Allocate globals
int32_t *globals = calloc(mod->global_count + 1, sizeof(int32_t));
for (uint32_t i = 0; i < mod->global_count; i++) {
globals[i] = (int32_t)mod->globals[i].init_val;
printf("Global[%u] = %d\n", i, globals[i]);
}
// Create JIT compiler
CodeBuffer *code = codebuf_new(16384);
JitCompiler jit = {
.code = code,
.mod = mod,
};
codebuf_begin_write(code);
// Compile the run function
uint32_t func_idx = mod->run_func_idx - mod->import_func_count;
jit_compile_function(&jit, &mod->funcs[func_idx]);
codebuf_end_write(code);
printf("Generated %zu instructions\n", code->size);
// Get argument
int64_t arg = 10; // default
if (argc > 2) {
arg = atoll(argv[2]);
}
printf("Running with N=%lld\n", arg);
// Call JIT function
// Function signature: i64 run(i64 N, void* memory, void* globals)
typedef int64_t (*JitFunc)(int64_t, void*, void*);
JitFunc fn = (JitFunc)code->code;
int64_t result = fn(arg, memory, globals);
printf("JIT result: %lld\n", result);
// Cleanup
codebuf_free(code);
free(memory);
free(globals);
free_wasm_module(mod);
free(data);
return 0;
}
@mizchi
Copy link
Author

mizchi commented Jan 19, 2026

Result

(matmul N=200):                              
  ┌──────────────┬─────────┬──────────────┐                     
  │   Runtime    │  Time   │   Relative   │                     
  ├──────────────┼─────────┼──────────────┤                     
  │ baseline-jit │ 13.7ms  │ 1.00x        │                     
  ├──────────────┼─────────┼──────────────┤                     
  │ wasmtime     │ 19.4ms  │ 1.41x slower │                     
  ├──────────────┼─────────┼──────────────┤                     
  │ wasmi        │ 111.9ms │ 8.15x slower │                     
  └──────────────┴─────────┴──────────────┘    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment