Skip to content

Instantly share code, notes, and snippets.

@leonardoalt
Created March 4, 2026 16:22
Show Gist options
  • Select an option

  • Save leonardoalt/bce2cfe978081a6336d45e5ece23113b to your computer and use it in GitHub Desktop.

Select an option

Save leonardoalt/bce2cfe978081a6336d45e5ece23113b to your computer and use it in GitHub Desktop.
CUDA diff: Jump vs ALU implementation (womir-openvm)
--- extensions/womir_circuit/cuda/include/womir/adapters/alu.cuh 2026-03-04 17:06:33.450505710 +0100
+++ extensions/womir_circuit/cuda/include/womir/adapters/jump.cuh 2026-03-04 17:10:52.534016146 +0100
@@ -1,9 +1,8 @@
-// Adapted from <openvm>/extensions/rv32im/circuit/cuda/include/rv32im/adapters/alu.cuh
-// Main changes: adds frame pointer (fp) field, WomirExecutionState, fp_read_aux, timestamp +1 shift
-// Diff: https://gist.github.com/leonardoalt/09fd3d60bd571851bb656dc53cec0a4b#file-diff-adapters-alu-cuh-diff
+// WOMIR Jump adapter for CUDA tracegen.
+// Reads FP + 1 register (condition/offset), no writes.
+// Simpler than the ALU adapter: single register read, no rs2, no write.
#pragma once
-#include "primitives/execution.h"
#include "primitives/trace_access.h"
#include "system/memory/controller.cuh"
#include "system/memory/offline_checker.cuh"
@@ -11,125 +10,49 @@
using namespace riscv;
-template <typename T, size_t NUM_READ_OPS, size_t NUM_WRITE_OPS>
-struct WomirBaseAluAdapterCols {
+template <typename T> struct WomirJumpAdapterCols {
WomirExecutionState<T> from_state;
- T rd_ptr;
- T rs1_ptr;
- T rs2; // Pointer if rs2 was a read, immediate value otherwise
- T rs2_as; // 1 if rs2 was a read, 0 if an immediate
+ T rs_ptr;
MemoryReadAuxCols<T> fp_read_aux;
- MemoryReadAuxCols<T> rs1_reads_aux[NUM_READ_OPS];
- MemoryReadAuxCols<T> rs2_reads_aux[NUM_READ_OPS];
- MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS> writes_aux[NUM_WRITE_OPS];
+ MemoryReadAuxCols<T> rs_read_aux;
};
-template <size_t NUM_READ_OPS, size_t NUM_WRITE_OPS>
-struct WomirBaseAluAdapterRecord {
+struct WomirJumpAdapterRecord {
uint32_t from_pc;
uint32_t fp;
uint32_t from_timestamp;
- uint32_t rd_ptr;
- uint32_t rs1_ptr;
- uint32_t rs2; // Pointer if rs2 was a read, immediate value otherwise
- uint8_t rs2_as; // 1 if rs2 was a read, 0 if an immediate
+ uint32_t rs_ptr;
MemoryReadAuxRecord fp_read_aux;
- MemoryReadAuxRecord rs1_reads_aux[NUM_READ_OPS];
- MemoryReadAuxRecord rs2_reads_aux[NUM_READ_OPS];
- MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS> writes_aux[NUM_WRITE_OPS];
+ MemoryReadAuxRecord rs_read_aux;
};
-template <size_t NUM_READ_OPS, size_t NUM_WRITE_OPS>
-struct WomirBaseAluAdapter {
+struct WomirJumpAdapter {
MemoryAuxColsFactory mem_helper;
- BitwiseOperationLookup bitwise_lookup;
- template <typename T>
- using Cols = WomirBaseAluAdapterCols<T, NUM_READ_OPS, NUM_WRITE_OPS>;
- using Record = WomirBaseAluAdapterRecord<NUM_READ_OPS, NUM_WRITE_OPS>;
-
- __device__ WomirBaseAluAdapter(
+ __device__ WomirJumpAdapter(
VariableRangeChecker range_checker,
- BitwiseOperationLookup lookup,
uint32_t timestamp_max_bits
)
- : mem_helper(range_checker, timestamp_max_bits), bitwise_lookup(lookup) {}
+ : mem_helper(range_checker, timestamp_max_bits) {}
- __device__ void fill_trace_row(RowSlice row, Record record) {
- COL_WRITE_VALUE(row, Cols, from_state.pc, record.from_pc);
- COL_WRITE_VALUE(row, Cols, from_state.fp, record.fp);
- COL_WRITE_VALUE(row, Cols, from_state.timestamp, record.from_timestamp);
-
- COL_WRITE_VALUE(row, Cols, rd_ptr, record.rd_ptr);
- COL_WRITE_VALUE(row, Cols, rs1_ptr, record.rs1_ptr);
- COL_WRITE_VALUE(row, Cols, rs2, record.rs2);
- COL_WRITE_VALUE(row, Cols, rs2_as, record.rs2_as);
+ __device__ void fill_trace_row(RowSlice row, WomirJumpAdapterRecord record) {
+ // rs read (at from_timestamp + 1, after fp read)
+ mem_helper.fill(
+ row.slice_from(COL_INDEX(WomirJumpAdapterCols, rs_read_aux)),
+ record.rs_read_aux.prev_timestamp,
+ record.from_timestamp + 1
+ );
- // Read auxiliary for fp (at from_timestamp + 0)
+ // fp read (at from_timestamp + 0)
mem_helper.fill(
- row.slice_from(COL_INDEX(Cols, fp_read_aux)),
+ row.slice_from(COL_INDEX(WomirJumpAdapterCols, fp_read_aux)),
record.fp_read_aux.prev_timestamp,
record.from_timestamp
);
- // rs1 reads (at from_timestamp + 1 + r for r in 0..NUM_READ_OPS)
- constexpr size_t read_aux_elem_size = sizeof(MemoryReadAuxCols<uint8_t>);
-#pragma unroll
- for (size_t r = 0; r < NUM_READ_OPS; r++) {
- mem_helper.fill(
- row.slice_from(
- offsetof(Cols<uint8_t>, rs1_reads_aux) + r * read_aux_elem_size
- ),
- record.rs1_reads_aux[r].prev_timestamp,
- record.from_timestamp + 1 + r
- );
- }
-
- // rs2: register read when rs2_as != 0, otherwise immediate.
- if (record.rs2_as != 0) {
-#pragma unroll
- for (size_t r = 0; r < NUM_READ_OPS; r++) {
- mem_helper.fill(
- row.slice_from(
- offsetof(Cols<uint8_t>, rs2_reads_aux) + r * read_aux_elem_size
- ),
- record.rs2_reads_aux[r].prev_timestamp,
- record.from_timestamp + 1 + NUM_READ_OPS + r
- );
- }
- } else {
-#pragma unroll
- for (size_t r = 0; r < NUM_READ_OPS; r++) {
- RowSlice rs2_aux = row.slice_from(
- offsetof(Cols<uint8_t>, rs2_reads_aux) + r * read_aux_elem_size
- );
-#pragma unroll
- for (size_t i = 0; i < read_aux_elem_size; i++) {
- rs2_aux.write(i, 0);
- }
- }
- uint32_t mask = (1u << RV32_CELL_BITS) - 1u;
- bitwise_lookup.add_range(record.rs2 & mask, (record.rs2 >> RV32_CELL_BITS) & mask);
- }
-
- // Writes (at from_timestamp + 1 + 2*NUM_READ_OPS + w for w in 0..NUM_WRITE_OPS)
- // Type alias avoids commas inside offsetof/sizeof macros (preprocessor limitation).
- using WriteAuxCols = MemoryWriteAuxCols<uint8_t, RV32_REGISTER_NUM_LIMBS>;
- constexpr size_t write_aux_elem_size = sizeof(WriteAuxCols);
- constexpr size_t prev_data_offset = offsetof(WriteAuxCols, prev_data);
-#pragma unroll
- for (size_t w = 0; w < NUM_WRITE_OPS; w++) {
- size_t base = offsetof(Cols<uint8_t>, writes_aux) + w * write_aux_elem_size;
- row.write_array(
- base + prev_data_offset,
- RV32_REGISTER_NUM_LIMBS,
- record.writes_aux[w].prev_data
- );
- mem_helper.fill(
- row.slice_from(base),
- record.writes_aux[w].prev_timestamp,
- record.from_timestamp + 1 + 2 * NUM_READ_OPS + w
- );
- }
+ COL_WRITE_VALUE(row, WomirJumpAdapterCols, rs_ptr, record.rs_ptr);
+ COL_WRITE_VALUE(row, WomirJumpAdapterCols, from_state.timestamp, record.from_timestamp);
+ COL_WRITE_VALUE(row, WomirJumpAdapterCols, from_state.fp, record.fp);
+ COL_WRITE_VALUE(row, WomirJumpAdapterCols, from_state.pc, record.from_pc);
}
};
--- extensions/womir_circuit/cuda/include/womir/cores/alu.cuh 2026-03-04 17:06:33.450505710 +0100
+++ extensions/womir_circuit/cuda/include/womir/cores/jump.cuh 2026-03-04 17:11:12.450130692 +0100
@@ -1,143 +1,87 @@
-// 100% copy of <openvm>/extensions/rv32im/circuit/cuda/include/rv32im/cores/alu.cuh
+// WOMIR Jump core for CUDA tracegen.
+// Mirrors JumpCoreFiller::fill_trace_row from jump/core.rs.
#pragma once
#include "primitives/constants.h"
-#include "primitives/histogram.cuh"
#include "primitives/trace_access.h"
using namespace riscv;
-template <size_t NUM_LIMBS>
-__device__ __forceinline__ void run_add(
- const uint8_t *x,
- const uint8_t *y,
- uint8_t *out,
- uint8_t *carry
-) {
-#pragma unroll
- for (size_t i = 0; i < NUM_LIMBS; i++) {
- uint32_t res = (i > 0) ? carry[i - 1] : 0;
- res += static_cast<uint32_t>(x[i]) + static_cast<uint32_t>(y[i]);
- carry[i] = res >> RV32_CELL_BITS;
- out[i] = static_cast<uint8_t>(res & ((1u << RV32_CELL_BITS) - 1));
- }
-}
-
-template <size_t NUM_LIMBS>
-__device__ __forceinline__ void run_sub(
- const uint8_t *x,
- const uint8_t *y,
- uint8_t *out,
- uint8_t *carry
-) {
-#pragma unroll
- for (size_t i = 0; i < NUM_LIMBS; i++) {
- uint32_t rhs = static_cast<uint32_t>(y[i]) + ((i > 0) ? carry[i - 1] : 0);
- if (static_cast<uint32_t>(x[i]) >= rhs) {
- out[i] = static_cast<uint8_t>(static_cast<uint32_t>(x[i]) - rhs);
- carry[i] = 0;
- } else {
- uint32_t wrap =
- (static_cast<uint32_t>(1u << RV32_CELL_BITS) + static_cast<uint32_t>(x[i]) - rhs);
- out[i] = static_cast<uint8_t>(wrap);
- carry[i] = 1;
- }
- }
-}
-
-template <size_t NUM_LIMBS>
-__device__ __forceinline__ void run_xor(const uint8_t *x, const uint8_t *y, uint8_t *out) {
-#pragma unroll
- for (size_t i = 0; i < NUM_LIMBS; i++) {
- out[i] = x[i] ^ y[i];
- }
-}
-
-template <size_t NUM_LIMBS>
-__device__ __forceinline__ void run_or(const uint8_t *x, const uint8_t *y, uint8_t *out) {
-#pragma unroll
- for (size_t i = 0; i < NUM_LIMBS; i++) {
- out[i] = x[i] | y[i];
- }
-}
-
-template <size_t NUM_LIMBS>
-__device__ __forceinline__ void run_and(const uint8_t *x, const uint8_t *y, uint8_t *out) {
-#pragma unroll
- for (size_t i = 0; i < NUM_LIMBS; i++) {
- out[i] = x[i] & y[i];
- }
-}
-
-template <size_t NUM_LIMBS> struct BaseAluCoreRecord {
- uint8_t b[NUM_LIMBS];
- uint8_t c[NUM_LIMBS];
+struct JumpCoreRecord {
+ uint8_t rs_val[RV32_REGISTER_NUM_LIMBS];
+ uint32_t imm;
uint8_t local_opcode;
};
-template <typename T, size_t NUM_LIMBS> struct BaseAluCoreCols {
- T a[NUM_LIMBS];
- T b[NUM_LIMBS];
- T c[NUM_LIMBS];
- T opcode_add_flag;
- T opcode_sub_flag;
- T opcode_xor_flag;
- T opcode_or_flag;
- T opcode_and_flag;
+template <typename T> struct JumpCoreCols {
+ T rs_val[RV32_REGISTER_NUM_LIMBS];
+ T imm;
+ T opcode_jump_flag;
+ T opcode_skip_flag;
+ T opcode_jump_if_flag;
+ T opcode_jump_if_zero_flag;
+ T cond_is_zero;
+ T do_absolute_jump;
+ T nonzero_inv_marker[RV32_REGISTER_NUM_LIMBS];
};
-template <size_t NUM_LIMBS> struct BaseAluCore {
- BitwiseOperationLookup bitwise_lookup;
-
- template <typename T> using Cols = BaseAluCoreCols<T, NUM_LIMBS>;
+// Opcode indices matching JumpOpcode enum order
+enum JumpOpcode : uint8_t {
+ JUMP = 0,
+ SKIP = 1,
+ JUMP_IF = 2,
+ JUMP_IF_ZERO = 3,
+};
- __device__ BaseAluCore(BitwiseOperationLookup lookup) : bitwise_lookup(lookup) {}
+struct JumpCore {
+ template <typename T> using Cols = JumpCoreCols<T>;
- __device__ void fill_trace_row(RowSlice row, BaseAluCoreRecord<NUM_LIMBS> record) {
- uint8_t a[NUM_LIMBS];
- uint8_t carry_buf[NUM_LIMBS];
+ __device__ void fill_trace_row(RowSlice row, JumpCoreRecord record) {
+ // Compute cond_is_zero and nonzero_inv_marker
+ bool cond_is_zero = true;
+ Fp nonzero_inv_marker[RV32_REGISTER_NUM_LIMBS];
+#pragma unroll
+ for (size_t i = 0; i < RV32_REGISTER_NUM_LIMBS; i++) {
+ nonzero_inv_marker[i] = Fp::zero();
+ }
+ for (size_t i = 0; i < RV32_REGISTER_NUM_LIMBS; i++) {
+ if (record.rs_val[i] != 0) {
+ cond_is_zero = false;
+ nonzero_inv_marker[i] = inv(Fp(record.rs_val[i]));
+ break;
+ }
+ }
+ // Compute do_absolute_jump
+ bool do_absolute_jump;
switch (record.local_opcode) {
- case 0:
- run_add<NUM_LIMBS>(record.b, record.c, a, carry_buf);
+ case JUMP:
+ do_absolute_jump = true;
break;
- case 1:
- run_sub<NUM_LIMBS>(record.b, record.c, a, carry_buf);
+ case SKIP:
+ do_absolute_jump = false;
break;
- case 2:
- run_xor<NUM_LIMBS>(record.b, record.c, a);
+ case JUMP_IF:
+ do_absolute_jump = !cond_is_zero;
break;
- case 3:
- run_or<NUM_LIMBS>(record.b, record.c, a);
- break;
- case 4:
- run_and<NUM_LIMBS>(record.b, record.c, a);
+ case JUMP_IF_ZERO:
+ do_absolute_jump = cond_is_zero;
break;
default:
-#pragma unroll
- for (size_t i = 0; i < NUM_LIMBS; i++) {
- a[i] = 0;
- carry_buf[i] = 0;
- }
+ do_absolute_jump = false;
}
- COL_WRITE_ARRAY(row, Cols, a, a);
- COL_WRITE_ARRAY(row, Cols, b, record.b);
- COL_WRITE_ARRAY(row, Cols, c, record.c);
-
- COL_WRITE_VALUE(row, Cols, opcode_add_flag, record.local_opcode == 0);
- COL_WRITE_VALUE(row, Cols, opcode_sub_flag, record.local_opcode == 1);
- COL_WRITE_VALUE(row, Cols, opcode_xor_flag, record.local_opcode == 2);
- COL_WRITE_VALUE(row, Cols, opcode_or_flag, record.local_opcode == 3);
- COL_WRITE_VALUE(row, Cols, opcode_and_flag, record.local_opcode == 4);
+ // Write columns (reverse order to match CPU filler pattern)
+ COL_WRITE_ARRAY(row, Cols, nonzero_inv_marker, nonzero_inv_marker);
+ COL_WRITE_VALUE(row, Cols, do_absolute_jump, do_absolute_jump);
+ COL_WRITE_VALUE(row, Cols, cond_is_zero, cond_is_zero);
+
+ COL_WRITE_VALUE(row, Cols, opcode_jump_if_zero_flag, record.local_opcode == JUMP_IF_ZERO);
+ COL_WRITE_VALUE(row, Cols, opcode_jump_if_flag, record.local_opcode == JUMP_IF);
+ COL_WRITE_VALUE(row, Cols, opcode_skip_flag, record.local_opcode == SKIP);
+ COL_WRITE_VALUE(row, Cols, opcode_jump_flag, record.local_opcode == JUMP);
-#pragma unroll
- for (size_t i = 0; i < NUM_LIMBS; i++) {
- if (record.local_opcode == 0 || record.local_opcode == 1) {
- bitwise_lookup.add_xor(a[i], a[i]);
- } else {
- bitwise_lookup.add_xor(record.b[i], record.c[i]);
- }
- }
+ COL_WRITE_VALUE(row, Cols, imm, record.imm);
+ COL_WRITE_ARRAY(row, Cols, rs_val, record.rs_val);
}
-};
\ No newline at end of file
+};
--- extensions/womir_circuit/cuda/src/alu.cu 2026-03-04 17:06:33.450505710 +0100
+++ extensions/womir_circuit/cuda/src/jump.cu 2026-03-04 17:11:29.698229744 +0100
@@ -1,35 +1,26 @@
-// Adapted from <openvm>/extensions/rv32im/circuit/cuda/src/alu.cu (namespace renames only)
-// Diff: https://gist.github.com/leonardoalt/09fd3d60bd571851bb656dc53cec0a4b#file-diff-alu-cu-diff
#include "launcher.cuh"
#include "primitives/buffer_view.cuh"
#include "primitives/trace_access.h"
#include "womir/constants.cuh"
-#include "womir/adapters/alu.cuh"
-#include "womir/cores/alu.cuh"
+#include "womir/adapters/jump.cuh"
+#include "womir/cores/jump.cuh"
-// Concrete type aliases for 32-bit
-using WomirBaseAluCoreRecord = BaseAluCoreRecord<RV32_REGISTER_NUM_LIMBS>;
-using WomirBaseAluCore = BaseAluCore<RV32_REGISTER_NUM_LIMBS>;
-template <typename T> using WomirBaseAluCoreCols = BaseAluCoreCols<T, RV32_REGISTER_NUM_LIMBS>;
-
-template <typename T> struct WomirBaseAluCols {
- WomirBaseAluAdapterCols<T, W32_REG_OPS, W32_REG_OPS> adapter;
- WomirBaseAluCoreCols<T> core;
+template <typename T> struct WomirJumpCols {
+ WomirJumpAdapterCols<T> adapter;
+ JumpCoreCols<T> core;
};
-struct WomirBaseAluRecord {
- WomirBaseAluAdapterRecord<W32_REG_OPS, W32_REG_OPS> adapter;
- WomirBaseAluCoreRecord core;
+struct WomirJumpRecord {
+ WomirJumpAdapterRecord adapter;
+ JumpCoreRecord core;
};
-__global__ void womir_alu_tracegen(
+__global__ void womir_jump_tracegen(
Fp *d_trace,
size_t height,
- DeviceBufferConstView<WomirBaseAluRecord> d_records,
+ DeviceBufferConstView<WomirJumpRecord> d_records,
uint32_t *d_range_checker_ptr,
size_t range_checker_bins,
- uint32_t *d_bitwise_lookup_ptr,
- size_t bitwise_num_bits,
uint32_t timestamp_max_bits
) {
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -37,43 +28,38 @@
if (idx < d_records.len()) {
auto const &rec = d_records[idx];
- WomirBaseAluAdapter<W32_REG_OPS, W32_REG_OPS> adapter(
+ WomirJumpAdapter adapter(
VariableRangeChecker(d_range_checker_ptr, range_checker_bins),
- BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits),
timestamp_max_bits
);
adapter.fill_trace_row(row, rec.adapter);
- WomirBaseAluCore core(BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits));
- core.fill_trace_row(row.slice_from(COL_INDEX(WomirBaseAluCols, core)), rec.core);
+ JumpCore core;
+ core.fill_trace_row(row.slice_from(COL_INDEX(WomirJumpCols, core)), rec.core);
} else {
- row.fill_zero(0, sizeof(WomirBaseAluCols<uint8_t>));
+ row.fill_zero(0, sizeof(WomirJumpCols<uint8_t>));
}
}
-extern "C" int _womir_alu_tracegen(
+extern "C" int _womir_jump_tracegen(
Fp *d_trace,
size_t height,
size_t width,
- DeviceBufferConstView<WomirBaseAluRecord> d_records,
+ DeviceBufferConstView<WomirJumpRecord> d_records,
uint32_t *d_range_checker_ptr,
size_t range_checker_bins,
- uint32_t *d_bitwise_lookup_ptr,
- size_t bitwise_num_bits,
uint32_t timestamp_max_bits
) {
assert((height & (height - 1)) == 0);
assert(height >= d_records.len());
- assert(width == sizeof(WomirBaseAluCols<uint8_t>));
+ assert(width == sizeof(WomirJumpCols<uint8_t>));
auto [grid, block] = kernel_launch_params(height);
- womir_alu_tracegen<<<grid, block>>>(
+ womir_jump_tracegen<<<grid, block>>>(
d_trace,
height,
d_records,
d_range_checker_ptr,
range_checker_bins,
- d_bitwise_lookup_ptr,
- bitwise_num_bits,
timestamp_max_bits
);
return CHECK_KERNEL();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment