Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save leonardoalt/42492a8abcbbc86ea4959bf9b5a1237e to your computer and use it in GitHub Desktop.

Select an option

Save leonardoalt/42492a8abcbbc86ea4959bf9b5a1237e to your computer and use it in GitHub Desktop.
Diffs between Const32 and ALU CUDA tracegen implementations
--- extensions/womir_circuit/cuda/include/womir/adapters/alu.cuh 2026-03-04 16:22:21.368217432 +0100
+++ extensions/womir_circuit/cuda/include/womir/const32.cuh 2026-03-04 16:26:53.584878211 +0100
@@ -1,9 +1,9 @@
-// Adapted from <openvm>/extensions/rv32im/circuit/cuda/include/rv32im/adapters/alu.cuh
-// Main changes: adds frame pointer (fp) field, WomirExecutionState, fp_read_aux, timestamp +1 shift
-// Diff: https://gist.github.com/leonardoalt/09fd3d60bd571851bb656dc53cec0a4b#file-diff-adapters-alu-cuh-diff
+// CUDA tracegen for Const32 chip.
+// Unlike ALU chips, Const32 has no adapter+core split: it's a single unified structure.
#pragma once
-#include "primitives/execution.h"
+#include "primitives/constants.h"
+#include "primitives/histogram.cuh"
#include "primitives/trace_access.h"
#include "system/memory/controller.cuh"
#include "system/memory/offline_checker.cuh"
@@ -11,125 +11,88 @@
using namespace riscv;
-template <typename T, size_t NUM_READ_OPS, size_t NUM_WRITE_OPS>
-struct WomirBaseAluAdapterCols {
- WomirExecutionState<T> from_state;
- T rd_ptr;
- T rs1_ptr;
- T rs2; // Pointer if rs2 was a read, immediate value otherwise
- T rs2_as; // 1 if rs2 was a read, 0 if an immediate
- MemoryReadAuxCols<T> fp_read_aux;
- MemoryReadAuxCols<T> rs1_reads_aux[NUM_READ_OPS];
- MemoryReadAuxCols<T> rs2_reads_aux[NUM_READ_OPS];
- MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS> writes_aux[NUM_WRITE_OPS];
-};
-
-template <size_t NUM_READ_OPS, size_t NUM_WRITE_OPS>
-struct WomirBaseAluAdapterRecord {
+// Record layout must match Rust Const32Record exactly (repr(C)).
+struct Const32Record {
uint32_t from_pc;
uint32_t fp;
uint32_t from_timestamp;
uint32_t rd_ptr;
- uint32_t rs1_ptr;
- uint32_t rs2; // Pointer if rs2 was a read, immediate value otherwise
- uint8_t rs2_as; // 1 if rs2 was a read, 0 if an immediate
+ uint32_t imm;
MemoryReadAuxRecord fp_read_aux;
- MemoryReadAuxRecord rs1_reads_aux[NUM_READ_OPS];
- MemoryReadAuxRecord rs2_reads_aux[NUM_READ_OPS];
- MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS> writes_aux[NUM_WRITE_OPS];
+ MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS> writes_aux;
+};
+
+// Column layout must match Rust Const32AdapterAirCol exactly.
+template <typename T, size_t NUM_LIMBS = RV32_REGISTER_NUM_LIMBS>
+struct Const32Cols {
+ T is_valid;
+ WomirExecutionState<T> from_state;
+ T rd_ptr;
+ T imm_limbs[NUM_LIMBS];
+ MemoryReadAuxCols<T> fp_read_aux;
+ MemoryWriteAuxCols<T, NUM_LIMBS> write_aux;
};
-template <size_t NUM_READ_OPS, size_t NUM_WRITE_OPS>
-struct WomirBaseAluAdapter {
+struct Const32TraceFiller {
MemoryAuxColsFactory mem_helper;
BitwiseOperationLookup bitwise_lookup;
- template <typename T>
- using Cols = WomirBaseAluAdapterCols<T, NUM_READ_OPS, NUM_WRITE_OPS>;
- using Record = WomirBaseAluAdapterRecord<NUM_READ_OPS, NUM_WRITE_OPS>;
+ template <typename T, size_t NUM_LIMBS = RV32_REGISTER_NUM_LIMBS>
+ using Cols = Const32Cols<T, NUM_LIMBS>;
- __device__ WomirBaseAluAdapter(
+ __device__ Const32TraceFiller(
VariableRangeChecker range_checker,
BitwiseOperationLookup lookup,
uint32_t timestamp_max_bits
)
: mem_helper(range_checker, timestamp_max_bits), bitwise_lookup(lookup) {}
- __device__ void fill_trace_row(RowSlice row, Record record) {
+ __device__ void fill_trace_row(RowSlice row, Const32Record const &record) {
+ // is_valid
+ COL_WRITE_VALUE(row, Cols, is_valid, 1);
+
+ // from_state
COL_WRITE_VALUE(row, Cols, from_state.pc, record.from_pc);
COL_WRITE_VALUE(row, Cols, from_state.fp, record.fp);
COL_WRITE_VALUE(row, Cols, from_state.timestamp, record.from_timestamp);
+ // rd_ptr
COL_WRITE_VALUE(row, Cols, rd_ptr, record.rd_ptr);
- COL_WRITE_VALUE(row, Cols, rs1_ptr, record.rs1_ptr);
- COL_WRITE_VALUE(row, Cols, rs2, record.rs2);
- COL_WRITE_VALUE(row, Cols, rs2_as, record.rs2_as);
- // Read auxiliary for fp (at from_timestamp + 0)
+ // imm_limbs: decompose the 32-bit immediate into 8-bit limbs
+ uint32_t imm = record.imm;
+ constexpr uint32_t mask = (1u << RV32_CELL_BITS) - 1u;
+ uint8_t imm_limbs[RV32_REGISTER_NUM_LIMBS];
+#pragma unroll
+ for (size_t i = 0; i < RV32_REGISTER_NUM_LIMBS; i++) {
+ imm_limbs[i] = (imm >> (RV32_CELL_BITS * i)) & mask;
+ }
+ COL_WRITE_ARRAY(row, Cols, imm_limbs, imm_limbs);
+
+ // Range-check imm_limbs via bitwise lookup (pairs)
+ bitwise_lookup.add_range(imm_limbs[0], imm_limbs[1]);
+ bitwise_lookup.add_range(imm_limbs[2], imm_limbs[3]);
+
+ // fp_read_aux: fill timestamp proof for FP read at from_timestamp + 0
mem_helper.fill(
row.slice_from(COL_INDEX(Cols, fp_read_aux)),
record.fp_read_aux.prev_timestamp,
record.from_timestamp
);
- // rs1 reads (at from_timestamp + 1 + r for r in 0..NUM_READ_OPS)
- constexpr size_t read_aux_elem_size = sizeof(MemoryReadAuxCols<uint8_t>);
-#pragma unroll
- for (size_t r = 0; r < NUM_READ_OPS; r++) {
- mem_helper.fill(
- row.slice_from(
- offsetof(Cols<uint8_t>, rs1_reads_aux) + r * read_aux_elem_size
- ),
- record.rs1_reads_aux[r].prev_timestamp,
- record.from_timestamp + 1 + r
- );
- }
-
- // rs2: register read when rs2_as != 0, otherwise immediate.
- if (record.rs2_as != 0) {
-#pragma unroll
- for (size_t r = 0; r < NUM_READ_OPS; r++) {
- mem_helper.fill(
- row.slice_from(
- offsetof(Cols<uint8_t>, rs2_reads_aux) + r * read_aux_elem_size
- ),
- record.rs2_reads_aux[r].prev_timestamp,
- record.from_timestamp + 1 + NUM_READ_OPS + r
- );
- }
- } else {
-#pragma unroll
- for (size_t r = 0; r < NUM_READ_OPS; r++) {
- RowSlice rs2_aux = row.slice_from(
- offsetof(Cols<uint8_t>, rs2_reads_aux) + r * read_aux_elem_size
- );
-#pragma unroll
- for (size_t i = 0; i < read_aux_elem_size; i++) {
- rs2_aux.write(i, 0);
- }
- }
- uint32_t mask = (1u << RV32_CELL_BITS) - 1u;
- bitwise_lookup.add_range(record.rs2 & mask, (record.rs2 >> RV32_CELL_BITS) & mask);
- }
-
- // Writes (at from_timestamp + 1 + 2*NUM_READ_OPS + w for w in 0..NUM_WRITE_OPS)
- // Type alias avoids commas inside offsetof/sizeof macros (preprocessor limitation).
+ // write_aux: set prev_data and fill timestamp proof
+ // Write happens at from_timestamp + 1 (after FP read at from_timestamp + 0)
using WriteAuxCols = MemoryWriteAuxCols<uint8_t, RV32_REGISTER_NUM_LIMBS>;
- constexpr size_t write_aux_elem_size = sizeof(WriteAuxCols);
- constexpr size_t prev_data_offset = offsetof(WriteAuxCols, prev_data);
-#pragma unroll
- for (size_t w = 0; w < NUM_WRITE_OPS; w++) {
- size_t base = offsetof(Cols<uint8_t>, writes_aux) + w * write_aux_elem_size;
- row.write_array(
- base + prev_data_offset,
- RV32_REGISTER_NUM_LIMBS,
- record.writes_aux[w].prev_data
- );
- mem_helper.fill(
- row.slice_from(base),
- record.writes_aux[w].prev_timestamp,
- record.from_timestamp + 1 + 2 * NUM_READ_OPS + w
- );
- }
+ size_t write_base = COL_INDEX(Cols, write_aux);
+ row.write_array(
+ write_base + offsetof(WriteAuxCols, prev_data),
+ RV32_REGISTER_NUM_LIMBS,
+ record.writes_aux.prev_data
+ );
+ mem_helper.fill(
+ row.slice_from(write_base),
+ record.writes_aux.prev_timestamp,
+ record.from_timestamp + 1
+ );
}
};
--- extensions/womir_circuit/cuda/src/alu.cu 2026-03-04 16:22:21.368217432 +0100
+++ extensions/womir_circuit/cuda/src/const32.cu 2026-03-04 16:27:04.808944518 +0100
@@ -1,31 +1,12 @@
-// Adapted from <openvm>/extensions/rv32im/circuit/cuda/src/alu.cu (namespace renames only)
-// Diff: https://gist.github.com/leonardoalt/09fd3d60bd571851bb656dc53cec0a4b#file-diff-alu-cu-diff
#include "launcher.cuh"
#include "primitives/buffer_view.cuh"
#include "primitives/trace_access.h"
-#include "womir/constants.cuh"
-#include "womir/adapters/alu.cuh"
-#include "womir/cores/alu.cuh"
+#include "womir/const32.cuh"
-// Concrete type aliases for 32-bit
-using WomirBaseAluCoreRecord = BaseAluCoreRecord<RV32_REGISTER_NUM_LIMBS>;
-using WomirBaseAluCore = BaseAluCore<RV32_REGISTER_NUM_LIMBS>;
-template <typename T> using WomirBaseAluCoreCols = BaseAluCoreCols<T, RV32_REGISTER_NUM_LIMBS>;
-
-template <typename T> struct WomirBaseAluCols {
- WomirBaseAluAdapterCols<T, W32_REG_OPS, W32_REG_OPS> adapter;
- WomirBaseAluCoreCols<T> core;
-};
-
-struct WomirBaseAluRecord {
- WomirBaseAluAdapterRecord<W32_REG_OPS, W32_REG_OPS> adapter;
- WomirBaseAluCoreRecord core;
-};
-
-__global__ void womir_alu_tracegen(
+__global__ void womir_const32_tracegen(
Fp *d_trace,
size_t height,
- DeviceBufferConstView<WomirBaseAluRecord> d_records,
+ DeviceBufferConstView<Const32Record> d_records,
uint32_t *d_range_checker_ptr,
size_t range_checker_bins,
uint32_t *d_bitwise_lookup_ptr,
@@ -37,25 +18,22 @@
if (idx < d_records.len()) {
auto const &rec = d_records[idx];
- WomirBaseAluAdapter<W32_REG_OPS, W32_REG_OPS> adapter(
+ Const32TraceFiller filler(
VariableRangeChecker(d_range_checker_ptr, range_checker_bins),
BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits),
timestamp_max_bits
);
- adapter.fill_trace_row(row, rec.adapter);
-
- WomirBaseAluCore core(BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits));
- core.fill_trace_row(row.slice_from(COL_INDEX(WomirBaseAluCols, core)), rec.core);
+ filler.fill_trace_row(row, rec);
} else {
- row.fill_zero(0, sizeof(WomirBaseAluCols<uint8_t>));
+ row.fill_zero(0, sizeof(Const32Cols<uint8_t>));
}
}
-extern "C" int _womir_alu_tracegen(
+extern "C" int _womir_const32_tracegen(
Fp *d_trace,
size_t height,
size_t width,
- DeviceBufferConstView<WomirBaseAluRecord> d_records,
+ DeviceBufferConstView<Const32Record> d_records,
uint32_t *d_range_checker_ptr,
size_t range_checker_bins,
uint32_t *d_bitwise_lookup_ptr,
@@ -64,9 +42,9 @@
) {
assert((height & (height - 1)) == 0);
assert(height >= d_records.len());
- assert(width == sizeof(WomirBaseAluCols<uint8_t>));
+ assert(width == sizeof(Const32Cols<uint8_t>));
auto [grid, block] = kernel_launch_params(height);
- womir_alu_tracegen<<<grid, block>>>(
+ womir_const32_tracegen<<<grid, block>>>(
d_trace,
height,
d_records,
--- extensions/womir_circuit/src/base_alu/cuda.rs 2026-03-04 16:22:21.369217438 +0100
+++ extensions/womir_circuit/src/const32/cuda.rs 2026-03-04 16:34:52.810614168 +0100
@@ -1,5 +1,3 @@
-// Adapted from <openvm>/extensions/rv32im/circuit/src/base_alu/cuda.rs (import paths only)
-// Diff: https://gist.github.com/leonardoalt/09fd3d60bd571851bb656dc53cec0a4b#file-diff-base-alu-cuda-rs-diff
use std::{mem::size_of, sync::Arc};
use derive_new::new;
@@ -11,91 +9,37 @@
base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
};
use openvm_cuda_common::copy::MemCopyH2D;
-use openvm_rv32im_circuit::BaseAluCoreCols;
use openvm_stark_backend::{Chip, prover::types::AirProvingContext};
use crate::{
- adapters::{
- BaseAluAdapterCols, BaseAluAdapterRecord, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
- Rv32BaseAluAdapterCols, Rv32BaseAluAdapterRecord, W64_NUM_LIMBS, W64_REG_OPS,
- },
- cuda_abi::{alu_cuda, alu64_cuda},
+ Const32Record, adapters::RV32_CELL_BITS, const32::air::Const32AdapterAirCol,
+ cuda_abi::const32_cuda,
};
-use openvm_rv32im_circuit::BaseAluCoreRecord;
-
-#[derive(new)]
-pub struct Rv32BaseAluChipGpu {
- pub range_checker: Arc<VariableRangeCheckerChipGPU>,
- pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
- pub timestamp_max_bits: usize,
-}
-
-impl Chip<DenseRecordArena, GpuBackend> for Rv32BaseAluChipGpu {
- fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
- const RECORD_SIZE: usize = size_of::<(
- Rv32BaseAluAdapterRecord,
- BaseAluCoreRecord<RV32_REGISTER_NUM_LIMBS>,
- )>();
- let records = arena.allocated();
- if records.is_empty() {
- return get_empty_air_proving_ctx::<GpuBackend>();
- }
- debug_assert_eq!(records.len() % RECORD_SIZE, 0);
-
- let trace_width = BaseAluCoreCols::<F, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>::width()
- + Rv32BaseAluAdapterCols::<F>::width();
- let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
-
- let d_records = records.to_device().unwrap();
- let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
-
- unsafe {
- alu_cuda::tracegen(
- d_trace.buffer(),
- trace_height,
- &d_records,
- &self.range_checker.count,
- self.range_checker.count.len(),
- &self.bitwise_lookup.count,
- RV32_CELL_BITS,
- self.timestamp_max_bits as u32,
- )
- .unwrap();
- }
-
- AirProvingContext::simple_no_pis(d_trace)
- }
-}
-
#[derive(new)]
-pub struct BaseAlu64ChipGpu {
+pub struct Const32ChipGpu {
pub range_checker: Arc<VariableRangeCheckerChipGPU>,
pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
pub timestamp_max_bits: usize,
}
-impl Chip<DenseRecordArena, GpuBackend> for BaseAlu64ChipGpu {
+impl Chip<DenseRecordArena, GpuBackend> for Const32ChipGpu {
fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
- const RECORD_SIZE: usize = size_of::<(
- BaseAluAdapterRecord<W64_REG_OPS>,
- BaseAluCoreRecord<W64_NUM_LIMBS>,
- )>();
+ const RECORD_SIZE: usize = size_of::<Const32Record>();
let records = arena.allocated();
if records.is_empty() {
return get_empty_air_proving_ctx::<GpuBackend>();
}
debug_assert_eq!(records.len() % RECORD_SIZE, 0);
- let trace_width = BaseAluCoreCols::<F, W64_NUM_LIMBS, RV32_CELL_BITS>::width()
- + BaseAluAdapterCols::<F, W64_REG_OPS>::width();
+ let trace_width = Const32AdapterAirCol::<F, 4>::width();
let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
let d_records = records.to_device().unwrap();
let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
unsafe {
- alu64_cuda::tracegen(
+ const32_cuda::tracegen(
d_trace.buffer(),
trace_height,
&d_records,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment