GPU Acceleration
Overview
stwo-ml uses CUDA for all compute-intensive operations in the proving pipeline. GPU acceleration spans four layers:
- Sumcheck kernels — Round polynomial evaluation + MLE folding (inner loop)
- Fused MLE restrict — Direct matrix-to-restricted-vector without intermediate allocation
- Forward pass ops — MatMul, element-wise add/mul/relu for model execution
- Multi-GPU — Distributed chunk proving across multiple GPUs
All GPU code is gated behind feature = "cuda-runtime" (requires CUDA 12.4+).
CUDA Kernel Architecture
Kernel Groups
| Group | Kernels | Used By |
|---|---|---|
| Sumcheck | sumcheck_round_kernel, sumcheck_reduce_kernel, mle_fold_kernel | MatMul sumcheck proving |
| MLE Restrict | m31_restrict_rows_kernel, m31_restrict_cols_kernel | MatMul sumcheck setup |
| Forward Pass | m31_gemv_kernel, m31_gemm_kernel, m31_add_kernel, m31_mul_kernel, m31_relu_kernel | Model execution |
| LogUp | logup_denominator_kernel, logup_3way_round_kernel, logup_4way_reduce_kernel, logup_3way_fold_kernel | Activation/dual-operand sumcheck |
| GKR | combine_blocks_kernel, evaluate_mle_kernel | SIMD block batching |
Kernels are compiled once via NVRTC and cached globally. Core sumcheck kernels compile eagerly at init; forward, restrict, and LogUp kernels compile lazily on first use.
QM31 Field Arithmetic in CUDA
All kernels operate on M31 (p = 2^31 - 1) and QM31 (degree-4 extension) field elements. QM31 is represented as 4 × u32:
[a0, a1, a2, a3] ↔ (a0 + u·a1) + i·(a2 + u·a3)
where u² = 2, i² = u + 2, all arithmetic mod 2^31 - 1
Conversion between Rust SecureField and CUDA u32 arrays:
fn secure_field_to_u32s(sf: SecureField) -> [u32; 4] {
[sf.0.0.0, sf.0.1.0, sf.1.0.0, sf.1.1.0]
}
fn u32s_to_secure_field(u: &[u32; 4]) -> SecureField {
QM31(CM31(M31(u[0]), M31(u[1])), CM31(M31(u[2]), M31(u[3])))
}
CUDA kernels use _r suffix variants of QM31 arithmetic to avoid name collisions between kernel groups:
// QM31 multiply: (a0 + a1·u) + (a2 + a3·u)·i
// where u² = 2, i² = u + 2
__device__ void qm31_mul_r(uint32_t* out, const uint32_t* a, const uint32_t* b) {
// Karatsuba over CM31 components
// ...
}
Sumcheck Round Kernel
The core proving kernel computes three QM31 sums in parallel for one round of the degree-2 sumcheck.
extern "C" __global__ void sumcheck_round_kernel(
const uint32_t* __restrict__ f_a, // QM31 array, 4 u32 per element
const uint32_t* __restrict__ f_b, // QM31 array, 4 u32 per element
uint32_t* __restrict__ block_s0, // Per-block partial sum at t=0
uint32_t* __restrict__ block_s1, // Per-block partial sum at t=1
uint32_t* __restrict__ block_s2, // Per-block partial sum at t=2
uint32_t mid // Half-length (n_points / 2)
)
Thread/Block configuration:
- Block size: 256 threads (fixed)
- Grid size:
ceil(mid / 256)blocks - Shared memory: 3 × 256 × 4 × 4 = 12 KB per block
Computation (for each pair i ∈ [0, mid)):
s0 += f_a[i] × f_b[i] // p(0)
s1 += f_a[mid + i] × f_b[mid + i] // p(1)
s2 += (2·f_a[mid+i] - f_a[i]) × (2·f_b[mid+i] - f_b[i]) // p(2)
Reduction pattern:
- Each thread accumulates its subset of pairs into thread-local QM31 accumulators
- Block-level tree reduction via shared memory (stride halving, sync barriers)
- Thread 0 writes final block result to global memory
- If
grid_dim > 1: cross-block reduction viasumcheck_reduce_kernel
Cross-Block Reduction
extern "C" __global__ void sumcheck_reduce_kernel(
const uint32_t* __restrict__ partials, // [s0_blocks | s1_blocks | s2_blocks]
uint32_t* __restrict__ output, // Final [s0, s1, s2] (12 u32)
uint32_t n_blocks
)
Grid: 3 blocks (one per channel s0, s1, s2), each with 256 threads. Strided accumulation when n_blocks > 256.
MLE Fold Kernel (GPU-Resident)
After each sumcheck round, the challenge r folds both MLEs in-place on GPU. This eliminates the CPU→GPU→GPU round-trip that was the original bottleneck.
extern "C" __global__ void mle_fold_kernel(
const uint32_t* __restrict__ input, // QM31 array, n_points elements
const uint32_t* __restrict__ alpha, // QM31 challenge (4 u32)
uint32_t* __restrict__ output, // QM31 array, half_n elements
uint32_t half_n
)
Computation: For each i ∈ [0, half_n):
output[i] = input[i] + alpha × (input[half_n + i] - input[i])
Key property: Data stays on GPU across all sumcheck rounds. Per-round transfer is only 16 bytes (challenge) uploaded + 48 bytes (s0, s1, s2) downloaded.
GPU Sumcheck Proving Pipeline
High-Level Flow
Input: M31 matrices A(m×k), B(k×n), C(m×n)
│
[CPU] Fiat-Shamir: draw r_i, r_j
│
[GPU] Fused restrict: A → f_a (QM31, k elements)
B → f_b (QM31, k elements)
│
[GPU] Sumcheck rounds (log_k iterations):
┌──────────────────────────────────────────────┐
│ 1. [GPU] compute_round_poly: s0, s1, s2 │ 48 bytes ↓
│ 2. [CPU] c₀,c₁,c₂ interpolation + Poseidon │
│ 3. [GPU] mle_fold: f_a, f_b with challenge │ 16 bytes ↑
└──────────────────────────────────────────────┘
│
[GPU] Download final evals: f_a(r_k), f_b(r_k) 32 bytes ↓
│
[CPU] MLE opening proofs (Poseidon Merkle paths)
Data Transfer Analysis
For a matmul with k = 2^14 (16,384 elements):
| Operation | Direction | Size | Notes |
|---|---|---|---|
| Restrict A + B matrices | H → D | ~128 KB | One-time upload |
| Lagrange basis | H → D | ~128 KB | One-time upload |
| Sumcheck rounds (14) | D → H | 672 B | 14 × 48 bytes (s0,s1,s2) |
| Challenges (14) | H → D | 224 B | 14 × 16 bytes (QM31) |
| Final evaluations | D → H | 32 B | 2 × QM31 |
| Total | — | ~256 KB |
The original CPU pipeline required ~456 MB allocation (matrix → MLE → restrict) and ~24 seconds of sequential folding per model.
Real Latency (H100)
| Component | Latency | Notes |
|---|---|---|
| GPU fused restrict | ~15 ms | Includes H→D transfer |
| 14 sumcheck rounds | ~20 ms | GPU-resident |
| 14 MLE folds | ~8 ms | In-place on GPU |
| Final download + commitments | ~5 ms | Merkle root computation |
| Total per matmul | ~48 ms | |
| CPU path (same matmul) | ~800 ms | 16× slower |
| 160 matmuls (Qwen3-14B) | ~7.7s | Was ~128s on CPU |
Fused MLE Restrict
Problem
The standard matmul sumcheck setup requires three steps:
pad_matrix_pow2(A)— copy + zero-pad to power-of-2 dimensionsmatrix_to_mle(A_padded)— convert to MLE evaluation table (M31 → SecureField)restrict_mle(A_mle, r_i)— fold variables to get restricted vector
For a 5120×5120 weight matrix padded to 8192×8192, this allocates 67M SecureField elements (~1 GB per matrix).
Solution
Fused GPU kernels take the original M31 matrix and the QM31 Lagrange basis and produce the restricted vector directly:
extern "C" __global__ void m31_restrict_rows_kernel(
const uint32_t* __restrict__ matrix, // M31 matrix, m_orig × k_orig
const uint32_t* __restrict__ lagrange, // QM31 Lagrange basis, m_padded entries
uint32_t* __restrict__ output, // QM31 output, k_padded entries
uint32_t m_orig,
uint32_t k_orig,
uint32_t m_padded
)
Computation: For each output element k:
f_a[k] = Σ_{i=0}^{m_orig-1} M31_to_QM31(matrix[i × k_orig + k]) × lagrange[i]
| Metric | Before (CPU) | After (GPU fused) |
|---|---|---|
| Memory per matrix | ~1 GB | O(k + n) |
| Compute ops (5120→8192) | 67M | 26M |
| Wall time (160 matmuls) | ~18s | eliminated |
The restrict_cols variant transposes the access pattern for column-major restriction of the B matrix.
Lagrange Basis
compute_lagrange_basis(challenges) builds the multilinear Lagrange basis weights via tensor product in O(n log n):
For challenges (r₀, r₁, ..., r_{d-1}):
L[b₀b₁...b_{d-1}] = Π_i ((1-r_i)(1-b_i) + r_i·b_i)
Computed on CPU (~100 μs for typical dimensions), uploaded to GPU once.
LogUp GPU Kernels
Activation, LayerNorm, and dual-operand matmuls use degree-3 eq-sumcheck with three factors (eq, w, d). Four specialized kernels accelerate this:
Denominator Kernel
extern "C" __global__ void logup_denominator_kernel(
const uint32_t* __restrict__ input,
const uint32_t* __restrict__ output,
const uint32_t* __restrict__ gamma,
const uint32_t* __restrict__ beta,
uint32_t* __restrict__ d_out,
uint32_t n
)
Computes d[i] = gamma - input[i] - beta × output[i] (QM31 arithmetic).
3-Way Round Kernel
extern "C" __global__ void logup_3way_round_kernel(
const uint32_t* eq, const uint32_t* w, const uint32_t* d,
uint32_t* block_s0, uint32_t* block_s1,
uint32_t* block_s2, uint32_t* block_s3,
uint32_t mid
)
Evaluates the degree-3 round polynomial at t = 0, 1, 2, 3 by computing products of three interpolated factors per thread.
3-Way Fold Kernel
extern "C" __global__ void logup_3way_fold_kernel(
const uint32_t* eq_in, const uint32_t* w_in, const uint32_t* d_in,
const uint32_t* alpha,
uint32_t* eq_out, uint32_t* w_out, uint32_t* d_out,
uint32_t half_n
)
Folds all three MLEs simultaneously with the same challenge — batch optimization that keeps all three arrays GPU-resident.
GpuSumcheckExecutor
Architecture
pub struct GpuSumcheckExecutor {
pub device: Arc<CudaDevice>,
sumcheck_round_fn: CudaFunction, // Eager: compiled at init
sumcheck_reduce_fn: CudaFunction, // Eager: compiled at init
mle_fold_fn: CudaFunction, // Eager: compiled at init
forward_fns: Mutex<Option<ForwardKernels>>, // Lazy
restrict_fns: Mutex<Option<RestrictKernels>>, // Lazy
logup_fns: Mutex<Option<LogupKernels>>, // Lazy
}
Global Caching
pub fn cached() -> Result<Arc<Self>, CudaFftError> {
static EXECUTOR: OnceLock<Arc<GpuSumcheckExecutor>> = OnceLock::new();
EXECUTOR.get_or_try_init(|| {
eprintln!("[GPU] Compiling sumcheck CUDA kernels (one-time)...");
let executor = GpuSumcheckExecutor::new()?;
Ok(Arc::new(executor))
}).cloned()
}
Impact: NVRTC kernel compilation takes ~200ms. Without caching: 160 matmuls × 200ms = 32s wasted. With caching: 200ms once + 160 × 0.5ms dispatch = 280ms total.
GPU Auto-Dispatch
pub fn prove_matmul_sumcheck_onchain_auto(
a: &M31Matrix, b: &M31Matrix, c: &M31Matrix,
) -> Result<MatMulSumcheckProofOnChain, MatMulError> {
#[cfg(feature = "cuda-runtime")]
{
if gpu_is_available() {
match prove_matmul_sumcheck_onchain_gpu(a, b, c) {
Ok(proof) => return Ok(proof),
Err(e) => {
tracing::warn!("GPU failed, falling back to CPU: {e}");
}
}
}
}
prove_matmul_sumcheck_onchain(a, b, c) // CPU fallback
}
Critical lesson: The original k.is_power_of_two() gate blocked GPU dispatch for all real models (Qwen3-14B uses k=5120, not a power of 2). After removing this gate and adding automatic pad_matrix_pow2 inside the GPU path, all 160 matmuls now run on GPU.
GPU GEMM
gpu_matmul_m31_full() provides M31 matrix multiplication on GPU:
- GEMV (m=1): Single-row dot product, 1D grid
- GEMM (m>1): Full matrix multiply, 2D grid with 16×16 thread blocks
Used in model forward pass execution and weight commitment computation.
GPU Element-wise Operations
| Kernel | Function | Threshold |
|---|---|---|
m31_add_kernel | gpu_elementwise_add() | len >= 4096 |
m31_mul_kernel | gpu_elementwise_mul() | len >= 4096 |
m31_relu_kernel | gpu_relu() | len >= 4096 |
All operations fall back to CPU below threshold. Error return (not panic) on dimension mismatch enables graceful fallback.
GKR GPU Operations
| Method | Description |
|---|---|
evaluate_mle_gpu(mle, point) | MLE evaluation at a point on GPU |
combine_blocks(block_mles, weights) | Σ_b w_b · MLE_b — SIMD block combination |
reduce_matmul_layer_gpu(claim, A, B, m, k, n, channel) | Full matmul reduction on GPU |
restrict_rows(A, r_i, pk) | Fused row restriction |
restrict_cols(B, r_j, pk) | Fused column restriction |
sumcheck_3way(ext_w, ext_a, ext_b, n, channel) | 3-factor GPU sumcheck (reuses LogUp kernels) |
Multi-GPU Distributed Proving
Architecture
┌─────────────────────────────────┐
│ MultiGpuExecutor │
│ ┌─────────┐ ┌─────────┐ │
│ │ GPU 0 │ │ GPU 1 │ ... │
│ │ chunk 0 │ │ chunk 1 │ │
│ │ chunk 3 │ │ chunk 2 │ │
│ └─────────┘ └─────────┘ │
└─────────────────────────────────┘
Thread-Local Device Affinity
thread_local! { CURRENT_DEVICE: Cell<Option<usize>> }
set_thread_device(id)/get_thread_device()— set/query affinityDeviceGuard::new(id)— RAII guard, restores previous device on drop (panic-safe)propagate_device(parent)— for rayon worker threads (they don't inheritthread_local!)
GpuSumcheckExecutor Per-Device Pool
GpuSumcheckExecutor::cached() // uses thread-local device
GpuSumcheckExecutor::cached_for_device(id) // explicit device
Per-device singleton pool via OnceLock<Mutex<HashMap<usize, Arc<Self>>>>. One cached() call routes all downstream GPU operations to the correct device.
Chunk Partitioning
MultiGpuExecutor::partition_chunks(chunks) uses greedy bin-packing with 80% memory safety margin:
- Sort chunks by estimated memory (descending)
- Assign each chunk to the GPU with the most free memory
- Warn on oversized chunks exceeding single GPU capacity
Proving Flow
prove_model_chunked_multi_gpu(graph, input, weights, executor)
- CPU forward pass → chunk inputs
partition_chunks()→ device assignmentsstd::thread::scopewithDeviceGuardper chunk thread- Collect all results (ALL errors, not just first)
- Return
MultiGpuProvingResultwith per-device stats
CLI
prove-model --model model.onnx --gpu --multi-gpu --chunk-budget-gb 8
Pipeline Optimizations
Three-Phase Pipeline
Phase 1: Forward pass (CPU) │ 2-3s
Phase 2: MatMul proving (GPU) │ 30-35s (160 matmuls)
Phase 3: Unified STARK (GPU) │ 5-7s
────────
~40s total (Qwen3-14B)
Parallel Merkle Tree
PoseidonMerkleTree::build_parallel() + root_only_parallel():
par_chunks(2)for layers with >= 256 pairsroot_only_parallel: no intermediate layer storage (just the root)- Used in
commit_mle_root_only()for batch entry prep
Pipelined Weight Commitment
std::thread::scope(|s| {
s.spawn(|| commit_weights_background(...)); // Runs during proving
prove_layers(...);
});
Zero added latency — weight commitment runs on a background thread during proving.
Weight Loading
Two-phase bulk extract + parallel process:
- Phase 1: Extract raw f32 from ALL SafeTensor shards
- Phase 2: Single rayon parallel pass (160 tasks for Qwen3-14B)
madvise(MADV_SEQUENTIAL | MADV_WILLNEED)on shard mmaps for OS prefetch
Streaming Serialization
serialize_ml_proof_to_file() writes via BufWriter (1 MB buffer) instead of Vec<String> + .join(), eliminating large intermediate allocations for proofs.
Memory Management
Tiled MatMul
For matrices exceeding GPU memory, tiled proving splits the k-dimension:
pub struct TiledMatMulConfig {
pub max_tile_k: usize, // Power of 2
}
impl TiledMatMulConfig {
pub fn from_memory_budget(m: usize, k: usize, n: usize, budget: usize) -> Self {
// Halve tile_k until estimated memory fits budget
}
}
Memory budget: 64 GB (configured for H100's 80 GB HBM). This eliminates tiling for all production models — the previous 4 GB budget forced all FFN matmuls into slow tiled mode.
GPU Dispatch Thresholds
pub struct GpuThresholds;
impl GpuThresholds {
const DEFAULT_MLE: u32 = 14; // 2^14 = 16K elements
const DEFAULT_COLUMN_OPS: u32 = 14;
const DEFAULT_MERKLE: u32 = 14;
}
Override via OBELYSK_GPU_THRESHOLD=0 to force all GPU on H100/A100.
Kernel Launch Overhead
| Kernel | Compile Time | Per-Launch |
|---|---|---|
| sumcheck_round | ~60ms (one-time) | ~5 μs |
| mle_fold | shared module | ~5 μs |
| restrict_rows/cols | ~50ms (lazy) | ~5 μs |
| logup_3way_round | ~50ms (lazy) | ~5 μs |
Total overhead for 14-round sumcheck: 14 × (5+5) μs = 140 μs (negligible vs ~48ms total).