Created
March 4, 2026 15:37
-
-
Save leonardoalt/42492a8abcbbc86ea4959bf9b5a1237e to your computer and use it in GitHub Desktop.
Diffs between Const32 and ALU CUDA tracegen implementations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --- extensions/womir_circuit/cuda/include/womir/adapters/alu.cuh 2026-03-04 16:22:21.368217432 +0100 | |
| +++ extensions/womir_circuit/cuda/include/womir/const32.cuh 2026-03-04 16:26:53.584878211 +0100 | |
| @@ -1,9 +1,9 @@ | |
| -// Adapted from <openvm>/extensions/rv32im/circuit/cuda/include/rv32im/adapters/alu.cuh | |
| -// Main changes: adds frame pointer (fp) field, WomirExecutionState, fp_read_aux, timestamp +1 shift | |
| -// Diff: https://gist.github.com/leonardoalt/09fd3d60bd571851bb656dc53cec0a4b#file-diff-adapters-alu-cuh-diff | |
| +// CUDA tracegen for Const32 chip. | |
| +// Unlike ALU chips, Const32 has no adapter+core split: it's a single unified structure. | |
| #pragma once | |
| -#include "primitives/execution.h" | |
| +#include "primitives/constants.h" | |
| +#include "primitives/histogram.cuh" | |
| #include "primitives/trace_access.h" | |
| #include "system/memory/controller.cuh" | |
| #include "system/memory/offline_checker.cuh" | |
| @@ -11,125 +11,88 @@ | |
| using namespace riscv; | |
| -template <typename T, size_t NUM_READ_OPS, size_t NUM_WRITE_OPS> | |
| -struct WomirBaseAluAdapterCols { | |
| - WomirExecutionState<T> from_state; | |
| - T rd_ptr; | |
| - T rs1_ptr; | |
| - T rs2; // Pointer if rs2 was a read, immediate value otherwise | |
| - T rs2_as; // 1 if rs2 was a read, 0 if an immediate | |
| - MemoryReadAuxCols<T> fp_read_aux; | |
| - MemoryReadAuxCols<T> rs1_reads_aux[NUM_READ_OPS]; | |
| - MemoryReadAuxCols<T> rs2_reads_aux[NUM_READ_OPS]; | |
| - MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS> writes_aux[NUM_WRITE_OPS]; | |
| -}; | |
| - | |
| -template <size_t NUM_READ_OPS, size_t NUM_WRITE_OPS> | |
| -struct WomirBaseAluAdapterRecord { | |
| +// Record layout must match Rust Const32Record exactly (repr(C)). | |
| +struct Const32Record { | |
| uint32_t from_pc; | |
| uint32_t fp; | |
| uint32_t from_timestamp; | |
| uint32_t rd_ptr; | |
| - uint32_t rs1_ptr; | |
| - uint32_t rs2; // Pointer if rs2 was a read, immediate value otherwise | |
| - uint8_t rs2_as; // 1 if rs2 was a read, 0 if an immediate | |
| + uint32_t imm; | |
| MemoryReadAuxRecord fp_read_aux; | |
| - MemoryReadAuxRecord rs1_reads_aux[NUM_READ_OPS]; | |
| - MemoryReadAuxRecord rs2_reads_aux[NUM_READ_OPS]; | |
| - MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS> writes_aux[NUM_WRITE_OPS]; | |
| + MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS> writes_aux; | |
| +}; | |
| + | |
| +// Column layout must match Rust Const32AdapterAirCol exactly. | |
| +template <typename T, size_t NUM_LIMBS = RV32_REGISTER_NUM_LIMBS> | |
| +struct Const32Cols { | |
| + T is_valid; | |
| + WomirExecutionState<T> from_state; | |
| + T rd_ptr; | |
| + T imm_limbs[NUM_LIMBS]; | |
| + MemoryReadAuxCols<T> fp_read_aux; | |
| + MemoryWriteAuxCols<T, NUM_LIMBS> write_aux; | |
| }; | |
| -template <size_t NUM_READ_OPS, size_t NUM_WRITE_OPS> | |
| -struct WomirBaseAluAdapter { | |
| +struct Const32TraceFiller { | |
| MemoryAuxColsFactory mem_helper; | |
| BitwiseOperationLookup bitwise_lookup; | |
| - template <typename T> | |
| - using Cols = WomirBaseAluAdapterCols<T, NUM_READ_OPS, NUM_WRITE_OPS>; | |
| - using Record = WomirBaseAluAdapterRecord<NUM_READ_OPS, NUM_WRITE_OPS>; | |
| + template <typename T, size_t NUM_LIMBS = RV32_REGISTER_NUM_LIMBS> | |
| + using Cols = Const32Cols<T, NUM_LIMBS>; | |
| - __device__ WomirBaseAluAdapter( | |
| + __device__ Const32TraceFiller( | |
| VariableRangeChecker range_checker, | |
| BitwiseOperationLookup lookup, | |
| uint32_t timestamp_max_bits | |
| ) | |
| : mem_helper(range_checker, timestamp_max_bits), bitwise_lookup(lookup) {} | |
| - __device__ void fill_trace_row(RowSlice row, Record record) { | |
| + __device__ void fill_trace_row(RowSlice row, Const32Record const &record) { | |
| + // is_valid | |
| + COL_WRITE_VALUE(row, Cols, is_valid, 1); | |
| + | |
| + // from_state | |
| COL_WRITE_VALUE(row, Cols, from_state.pc, record.from_pc); | |
| COL_WRITE_VALUE(row, Cols, from_state.fp, record.fp); | |
| COL_WRITE_VALUE(row, Cols, from_state.timestamp, record.from_timestamp); | |
| + // rd_ptr | |
| COL_WRITE_VALUE(row, Cols, rd_ptr, record.rd_ptr); | |
| - COL_WRITE_VALUE(row, Cols, rs1_ptr, record.rs1_ptr); | |
| - COL_WRITE_VALUE(row, Cols, rs2, record.rs2); | |
| - COL_WRITE_VALUE(row, Cols, rs2_as, record.rs2_as); | |
| - // Read auxiliary for fp (at from_timestamp + 0) | |
| + // imm_limbs: decompose the 32-bit immediate into 8-bit limbs | |
| + uint32_t imm = record.imm; | |
| + constexpr uint32_t mask = (1u << RV32_CELL_BITS) - 1u; | |
| + uint8_t imm_limbs[RV32_REGISTER_NUM_LIMBS]; | |
| +#pragma unroll | |
| + for (size_t i = 0; i < RV32_REGISTER_NUM_LIMBS; i++) { | |
| + imm_limbs[i] = (imm >> (RV32_CELL_BITS * i)) & mask; | |
| + } | |
| + COL_WRITE_ARRAY(row, Cols, imm_limbs, imm_limbs); | |
| + | |
| + // Range-check imm_limbs via bitwise lookup (pairs) | |
| + bitwise_lookup.add_range(imm_limbs[0], imm_limbs[1]); | |
| + bitwise_lookup.add_range(imm_limbs[2], imm_limbs[3]); | |
| + | |
| + // fp_read_aux: fill timestamp proof for FP read at from_timestamp + 0 | |
| mem_helper.fill( | |
| row.slice_from(COL_INDEX(Cols, fp_read_aux)), | |
| record.fp_read_aux.prev_timestamp, | |
| record.from_timestamp | |
| ); | |
| - // rs1 reads (at from_timestamp + 1 + r for r in 0..NUM_READ_OPS) | |
| - constexpr size_t read_aux_elem_size = sizeof(MemoryReadAuxCols<uint8_t>); | |
| -#pragma unroll | |
| - for (size_t r = 0; r < NUM_READ_OPS; r++) { | |
| - mem_helper.fill( | |
| - row.slice_from( | |
| - offsetof(Cols<uint8_t>, rs1_reads_aux) + r * read_aux_elem_size | |
| - ), | |
| - record.rs1_reads_aux[r].prev_timestamp, | |
| - record.from_timestamp + 1 + r | |
| - ); | |
| - } | |
| - | |
| - // rs2: register read when rs2_as != 0, otherwise immediate. | |
| - if (record.rs2_as != 0) { | |
| -#pragma unroll | |
| - for (size_t r = 0; r < NUM_READ_OPS; r++) { | |
| - mem_helper.fill( | |
| - row.slice_from( | |
| - offsetof(Cols<uint8_t>, rs2_reads_aux) + r * read_aux_elem_size | |
| - ), | |
| - record.rs2_reads_aux[r].prev_timestamp, | |
| - record.from_timestamp + 1 + NUM_READ_OPS + r | |
| - ); | |
| - } | |
| - } else { | |
| -#pragma unroll | |
| - for (size_t r = 0; r < NUM_READ_OPS; r++) { | |
| - RowSlice rs2_aux = row.slice_from( | |
| - offsetof(Cols<uint8_t>, rs2_reads_aux) + r * read_aux_elem_size | |
| - ); | |
| -#pragma unroll | |
| - for (size_t i = 0; i < read_aux_elem_size; i++) { | |
| - rs2_aux.write(i, 0); | |
| - } | |
| - } | |
| - uint32_t mask = (1u << RV32_CELL_BITS) - 1u; | |
| - bitwise_lookup.add_range(record.rs2 & mask, (record.rs2 >> RV32_CELL_BITS) & mask); | |
| - } | |
| - | |
| - // Writes (at from_timestamp + 1 + 2*NUM_READ_OPS + w for w in 0..NUM_WRITE_OPS) | |
| - // Type alias avoids commas inside offsetof/sizeof macros (preprocessor limitation). | |
| + // write_aux: set prev_data and fill timestamp proof | |
| + // Write happens at from_timestamp + 1 (after FP read at from_timestamp + 0) | |
| using WriteAuxCols = MemoryWriteAuxCols<uint8_t, RV32_REGISTER_NUM_LIMBS>; | |
| - constexpr size_t write_aux_elem_size = sizeof(WriteAuxCols); | |
| - constexpr size_t prev_data_offset = offsetof(WriteAuxCols, prev_data); | |
| -#pragma unroll | |
| - for (size_t w = 0; w < NUM_WRITE_OPS; w++) { | |
| - size_t base = offsetof(Cols<uint8_t>, writes_aux) + w * write_aux_elem_size; | |
| - row.write_array( | |
| - base + prev_data_offset, | |
| - RV32_REGISTER_NUM_LIMBS, | |
| - record.writes_aux[w].prev_data | |
| - ); | |
| - mem_helper.fill( | |
| - row.slice_from(base), | |
| - record.writes_aux[w].prev_timestamp, | |
| - record.from_timestamp + 1 + 2 * NUM_READ_OPS + w | |
| - ); | |
| - } | |
| + size_t write_base = COL_INDEX(Cols, write_aux); | |
| + row.write_array( | |
| + write_base + offsetof(WriteAuxCols, prev_data), | |
| + RV32_REGISTER_NUM_LIMBS, | |
| + record.writes_aux.prev_data | |
| + ); | |
| + mem_helper.fill( | |
| + row.slice_from(write_base), | |
| + record.writes_aux.prev_timestamp, | |
| + record.from_timestamp + 1 | |
| + ); | |
| } | |
| }; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --- extensions/womir_circuit/cuda/src/alu.cu 2026-03-04 16:22:21.368217432 +0100 | |
| +++ extensions/womir_circuit/cuda/src/const32.cu 2026-03-04 16:27:04.808944518 +0100 | |
| @@ -1,31 +1,12 @@ | |
| -// Adapted from <openvm>/extensions/rv32im/circuit/cuda/src/alu.cu (namespace renames only) | |
| -// Diff: https://gist.github.com/leonardoalt/09fd3d60bd571851bb656dc53cec0a4b#file-diff-alu-cu-diff | |
| #include "launcher.cuh" | |
| #include "primitives/buffer_view.cuh" | |
| #include "primitives/trace_access.h" | |
| -#include "womir/constants.cuh" | |
| -#include "womir/adapters/alu.cuh" | |
| -#include "womir/cores/alu.cuh" | |
| +#include "womir/const32.cuh" | |
| -// Concrete type aliases for 32-bit | |
| -using WomirBaseAluCoreRecord = BaseAluCoreRecord<RV32_REGISTER_NUM_LIMBS>; | |
| -using WomirBaseAluCore = BaseAluCore<RV32_REGISTER_NUM_LIMBS>; | |
| -template <typename T> using WomirBaseAluCoreCols = BaseAluCoreCols<T, RV32_REGISTER_NUM_LIMBS>; | |
| - | |
| -template <typename T> struct WomirBaseAluCols { | |
| - WomirBaseAluAdapterCols<T, W32_REG_OPS, W32_REG_OPS> adapter; | |
| - WomirBaseAluCoreCols<T> core; | |
| -}; | |
| - | |
| -struct WomirBaseAluRecord { | |
| - WomirBaseAluAdapterRecord<W32_REG_OPS, W32_REG_OPS> adapter; | |
| - WomirBaseAluCoreRecord core; | |
| -}; | |
| - | |
| -__global__ void womir_alu_tracegen( | |
| +__global__ void womir_const32_tracegen( | |
| Fp *d_trace, | |
| size_t height, | |
| - DeviceBufferConstView<WomirBaseAluRecord> d_records, | |
| + DeviceBufferConstView<Const32Record> d_records, | |
| uint32_t *d_range_checker_ptr, | |
| size_t range_checker_bins, | |
| uint32_t *d_bitwise_lookup_ptr, | |
| @@ -37,25 +18,22 @@ | |
| if (idx < d_records.len()) { | |
| auto const &rec = d_records[idx]; | |
| - WomirBaseAluAdapter<W32_REG_OPS, W32_REG_OPS> adapter( | |
| + Const32TraceFiller filler( | |
| VariableRangeChecker(d_range_checker_ptr, range_checker_bins), | |
| BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits), | |
| timestamp_max_bits | |
| ); | |
| - adapter.fill_trace_row(row, rec.adapter); | |
| - | |
| - WomirBaseAluCore core(BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits)); | |
| - core.fill_trace_row(row.slice_from(COL_INDEX(WomirBaseAluCols, core)), rec.core); | |
| + filler.fill_trace_row(row, rec); | |
| } else { | |
| - row.fill_zero(0, sizeof(WomirBaseAluCols<uint8_t>)); | |
| + row.fill_zero(0, sizeof(Const32Cols<uint8_t>)); | |
| } | |
| } | |
| -extern "C" int _womir_alu_tracegen( | |
| +extern "C" int _womir_const32_tracegen( | |
| Fp *d_trace, | |
| size_t height, | |
| size_t width, | |
| - DeviceBufferConstView<WomirBaseAluRecord> d_records, | |
| + DeviceBufferConstView<Const32Record> d_records, | |
| uint32_t *d_range_checker_ptr, | |
| size_t range_checker_bins, | |
| uint32_t *d_bitwise_lookup_ptr, | |
| @@ -64,9 +42,9 @@ | |
| ) { | |
| assert((height & (height - 1)) == 0); | |
| assert(height >= d_records.len()); | |
| - assert(width == sizeof(WomirBaseAluCols<uint8_t>)); | |
| + assert(width == sizeof(Const32Cols<uint8_t>)); | |
| auto [grid, block] = kernel_launch_params(height); | |
| - womir_alu_tracegen<<<grid, block>>>( | |
| + womir_const32_tracegen<<<grid, block>>>( | |
| d_trace, | |
| height, | |
| d_records, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --- extensions/womir_circuit/src/base_alu/cuda.rs 2026-03-04 16:22:21.369217438 +0100 | |
| +++ extensions/womir_circuit/src/const32/cuda.rs 2026-03-04 16:34:52.810614168 +0100 | |
| @@ -1,5 +1,3 @@ | |
| -// Adapted from <openvm>/extensions/rv32im/circuit/src/base_alu/cuda.rs (import paths only) | |
| -// Diff: https://gist.github.com/leonardoalt/09fd3d60bd571851bb656dc53cec0a4b#file-diff-base-alu-cuda-rs-diff | |
| use std::{mem::size_of, sync::Arc}; | |
| use derive_new::new; | |
| @@ -11,91 +9,37 @@ | |
| base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F, | |
| }; | |
| use openvm_cuda_common::copy::MemCopyH2D; | |
| -use openvm_rv32im_circuit::BaseAluCoreCols; | |
| use openvm_stark_backend::{Chip, prover::types::AirProvingContext}; | |
| use crate::{ | |
| - adapters::{ | |
| - BaseAluAdapterCols, BaseAluAdapterRecord, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, | |
| - Rv32BaseAluAdapterCols, Rv32BaseAluAdapterRecord, W64_NUM_LIMBS, W64_REG_OPS, | |
| - }, | |
| - cuda_abi::{alu_cuda, alu64_cuda}, | |
| + Const32Record, adapters::RV32_CELL_BITS, const32::air::Const32AdapterAirCol, | |
| + cuda_abi::const32_cuda, | |
| }; | |
| -use openvm_rv32im_circuit::BaseAluCoreRecord; | |
| - | |
| -#[derive(new)] | |
| -pub struct Rv32BaseAluChipGpu { | |
| - pub range_checker: Arc<VariableRangeCheckerChipGPU>, | |
| - pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>, | |
| - pub timestamp_max_bits: usize, | |
| -} | |
| - | |
| -impl Chip<DenseRecordArena, GpuBackend> for Rv32BaseAluChipGpu { | |
| - fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> { | |
| - const RECORD_SIZE: usize = size_of::<( | |
| - Rv32BaseAluAdapterRecord, | |
| - BaseAluCoreRecord<RV32_REGISTER_NUM_LIMBS>, | |
| - )>(); | |
| - let records = arena.allocated(); | |
| - if records.is_empty() { | |
| - return get_empty_air_proving_ctx::<GpuBackend>(); | |
| - } | |
| - debug_assert_eq!(records.len() % RECORD_SIZE, 0); | |
| - | |
| - let trace_width = BaseAluCoreCols::<F, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>::width() | |
| - + Rv32BaseAluAdapterCols::<F>::width(); | |
| - let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE); | |
| - | |
| - let d_records = records.to_device().unwrap(); | |
| - let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width); | |
| - | |
| - unsafe { | |
| - alu_cuda::tracegen( | |
| - d_trace.buffer(), | |
| - trace_height, | |
| - &d_records, | |
| - &self.range_checker.count, | |
| - self.range_checker.count.len(), | |
| - &self.bitwise_lookup.count, | |
| - RV32_CELL_BITS, | |
| - self.timestamp_max_bits as u32, | |
| - ) | |
| - .unwrap(); | |
| - } | |
| - | |
| - AirProvingContext::simple_no_pis(d_trace) | |
| - } | |
| -} | |
| - | |
| #[derive(new)] | |
| -pub struct BaseAlu64ChipGpu { | |
| +pub struct Const32ChipGpu { | |
| pub range_checker: Arc<VariableRangeCheckerChipGPU>, | |
| pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>, | |
| pub timestamp_max_bits: usize, | |
| } | |
| -impl Chip<DenseRecordArena, GpuBackend> for BaseAlu64ChipGpu { | |
| +impl Chip<DenseRecordArena, GpuBackend> for Const32ChipGpu { | |
| fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> { | |
| - const RECORD_SIZE: usize = size_of::<( | |
| - BaseAluAdapterRecord<W64_REG_OPS>, | |
| - BaseAluCoreRecord<W64_NUM_LIMBS>, | |
| - )>(); | |
| + const RECORD_SIZE: usize = size_of::<Const32Record>(); | |
| let records = arena.allocated(); | |
| if records.is_empty() { | |
| return get_empty_air_proving_ctx::<GpuBackend>(); | |
| } | |
| debug_assert_eq!(records.len() % RECORD_SIZE, 0); | |
| - let trace_width = BaseAluCoreCols::<F, W64_NUM_LIMBS, RV32_CELL_BITS>::width() | |
| - + BaseAluAdapterCols::<F, W64_REG_OPS>::width(); | |
| + let trace_width = Const32AdapterAirCol::<F, 4>::width(); | |
| let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE); | |
| let d_records = records.to_device().unwrap(); | |
| let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width); | |
| unsafe { | |
| - alu64_cuda::tracegen( | |
| + const32_cuda::tracegen( | |
| d_trace.buffer(), | |
| trace_height, | |
| &d_records, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment