GPU Acceleration

Overview

stwo-ml uses CUDA for all compute-intensive operations in the proving pipeline. GPU acceleration spans four layers:

  1. Sumcheck kernels — Round polynomial evaluation + MLE folding (inner loop)
  2. Fused MLE restrict — Direct matrix-to-restricted-vector without intermediate allocation
  3. Forward pass ops — MatMul, element-wise add/mul/relu for model execution
  4. Multi-GPU — Distributed chunk proving across multiple GPUs

All GPU code is gated behind feature = "cuda-runtime" (requires CUDA 12.4+).

CUDA Kernel Architecture

Kernel Groups

GroupKernelsUsed By
Sumchecksumcheck_round_kernel, sumcheck_reduce_kernel, mle_fold_kernelMatMul sumcheck proving
MLE Restrictm31_restrict_rows_kernel, m31_restrict_cols_kernelMatMul sumcheck setup
Forward Passm31_gemv_kernel, m31_gemm_kernel, m31_add_kernel, m31_mul_kernel, m31_relu_kernelModel execution
LogUplogup_denominator_kernel, logup_3way_round_kernel, logup_4way_reduce_kernel, logup_3way_fold_kernelActivation/dual-operand sumcheck
GKRcombine_blocks_kernel, evaluate_mle_kernelSIMD block batching

Kernels are compiled once via NVRTC and cached globally. Core sumcheck kernels compile eagerly at init; forward, restrict, and LogUp kernels compile lazily on first use.

QM31 Field Arithmetic in CUDA

All kernels operate on M31 (p = 2^31 - 1) and QM31 (degree-4 extension) field elements. QM31 is represented as 4 × u32:

[a0, a1, a2, a3]  ↔  (a0 + u·a1) + i·(a2 + u·a3)
where u² = 2, i² = u + 2, all arithmetic mod 2^31 - 1

Conversion between Rust SecureField and CUDA u32 arrays:

fn secure_field_to_u32s(sf: SecureField) -> [u32; 4] {
    [sf.0.0.0, sf.0.1.0, sf.1.0.0, sf.1.1.0]
}

fn u32s_to_secure_field(u: &[u32; 4]) -> SecureField {
    QM31(CM31(M31(u[0]), M31(u[1])), CM31(M31(u[2]), M31(u[3])))
}

CUDA kernels use _r suffix variants of QM31 arithmetic to avoid name collisions between kernel groups:

// QM31 multiply: (a0 + a1·u) + (a2 + a3·u)·i
// where u² = 2, i² = u + 2
__device__ void qm31_mul_r(uint32_t* out, const uint32_t* a, const uint32_t* b) {
    // Karatsuba over CM31 components
    // ...
}

Sumcheck Round Kernel

The core proving kernel computes three QM31 sums in parallel for one round of the degree-2 sumcheck.

extern "C" __global__ void sumcheck_round_kernel(
    const uint32_t* __restrict__ f_a,       // QM31 array, 4 u32 per element
    const uint32_t* __restrict__ f_b,       // QM31 array, 4 u32 per element
    uint32_t* __restrict__ block_s0,        // Per-block partial sum at t=0
    uint32_t* __restrict__ block_s1,        // Per-block partial sum at t=1
    uint32_t* __restrict__ block_s2,        // Per-block partial sum at t=2
    uint32_t mid                            // Half-length (n_points / 2)
)

Thread/Block configuration:

  • Block size: 256 threads (fixed)
  • Grid size: ceil(mid / 256) blocks
  • Shared memory: 3 × 256 × 4 × 4 = 12 KB per block

Computation (for each pair i ∈ [0, mid)):

s0 += f_a[i] × f_b[i]                                        // p(0)
s1 += f_a[mid + i] × f_b[mid + i]                            // p(1)
s2 += (2·f_a[mid+i] - f_a[i]) × (2·f_b[mid+i] - f_b[i])    // p(2)

Reduction pattern:

  1. Each thread accumulates its subset of pairs into thread-local QM31 accumulators
  2. Block-level tree reduction via shared memory (stride halving, sync barriers)
  3. Thread 0 writes final block result to global memory
  4. If grid_dim > 1: cross-block reduction via sumcheck_reduce_kernel

Cross-Block Reduction

extern "C" __global__ void sumcheck_reduce_kernel(
    const uint32_t* __restrict__ partials,  // [s0_blocks | s1_blocks | s2_blocks]
    uint32_t* __restrict__ output,          // Final [s0, s1, s2] (12 u32)
    uint32_t n_blocks
)

Grid: 3 blocks (one per channel s0, s1, s2), each with 256 threads. Strided accumulation when n_blocks > 256.

MLE Fold Kernel (GPU-Resident)

After each sumcheck round, the challenge r folds both MLEs in-place on GPU. This eliminates the CPU→GPU→GPU round-trip that was the original bottleneck.

extern "C" __global__ void mle_fold_kernel(
    const uint32_t* __restrict__ input,     // QM31 array, n_points elements
    const uint32_t* __restrict__ alpha,     // QM31 challenge (4 u32)
    uint32_t* __restrict__ output,          // QM31 array, half_n elements
    uint32_t half_n
)

Computation: For each i ∈ [0, half_n):

output[i] = input[i] + alpha × (input[half_n + i] - input[i])

Key property: Data stays on GPU across all sumcheck rounds. Per-round transfer is only 16 bytes (challenge) uploaded + 48 bytes (s0, s1, s2) downloaded.

GPU Sumcheck Proving Pipeline

High-Level Flow

Input: M31 matrices A(m×k), B(k×n), C(m×n)
                    │
  [CPU] Fiat-Shamir: draw r_i, r_j
                    │
  [GPU] Fused restrict: A → f_a (QM31, k elements)
                        B → f_b (QM31, k elements)
                    │
  [GPU] Sumcheck rounds (log_k iterations):
    ┌──────────────────────────────────────────────┐
    │  1. [GPU] compute_round_poly: s0, s1, s2     │  48 bytes ↓
    │  2. [CPU] c₀,c₁,c₂ interpolation + Poseidon │
    │  3. [GPU] mle_fold: f_a, f_b with challenge  │  16 bytes ↑
    └──────────────────────────────────────────────┘
                    │
  [GPU] Download final evals: f_a(r_k), f_b(r_k)     32 bytes ↓
                    │
  [CPU] MLE opening proofs (Poseidon Merkle paths)

Data Transfer Analysis

For a matmul with k = 2^14 (16,384 elements):

OperationDirectionSizeNotes
Restrict A + B matricesH → D~128 KBOne-time upload
Lagrange basisH → D~128 KBOne-time upload
Sumcheck rounds (14)D → H672 B14 × 48 bytes (s0,s1,s2)
Challenges (14)H → D224 B14 × 16 bytes (QM31)
Final evaluationsD → H32 B2 × QM31
Total~256 KB

The original CPU pipeline required ~456 MB allocation (matrix → MLE → restrict) and ~24 seconds of sequential folding per model.

Real Latency (H100)

ComponentLatencyNotes
GPU fused restrict~15 msIncludes H→D transfer
14 sumcheck rounds~20 msGPU-resident
14 MLE folds~8 msIn-place on GPU
Final download + commitments~5 msMerkle root computation
Total per matmul~48 ms
CPU path (same matmul)~800 ms16× slower
160 matmuls (Qwen3-14B)~7.7sWas ~128s on CPU

Fused MLE Restrict

Problem

The standard matmul sumcheck setup requires three steps:

  1. pad_matrix_pow2(A) — copy + zero-pad to power-of-2 dimensions
  2. matrix_to_mle(A_padded) — convert to MLE evaluation table (M31 → SecureField)
  3. restrict_mle(A_mle, r_i) — fold variables to get restricted vector

For a 5120×5120 weight matrix padded to 8192×8192, this allocates 67M SecureField elements (~1 GB per matrix).

Solution

Fused GPU kernels take the original M31 matrix and the QM31 Lagrange basis and produce the restricted vector directly:

extern "C" __global__ void m31_restrict_rows_kernel(
    const uint32_t* __restrict__ matrix,     // M31 matrix, m_orig × k_orig
    const uint32_t* __restrict__ lagrange,   // QM31 Lagrange basis, m_padded entries
    uint32_t* __restrict__ output,           // QM31 output, k_padded entries
    uint32_t m_orig,
    uint32_t k_orig,
    uint32_t m_padded
)

Computation: For each output element k:

f_a[k] = Σ_{i=0}^{m_orig-1} M31_to_QM31(matrix[i × k_orig + k]) × lagrange[i]
MetricBefore (CPU)After (GPU fused)
Memory per matrix~1 GBO(k + n)
Compute ops (5120→8192)67M26M
Wall time (160 matmuls)~18seliminated

The restrict_cols variant transposes the access pattern for column-major restriction of the B matrix.

Lagrange Basis

compute_lagrange_basis(challenges) builds the multilinear Lagrange basis weights via tensor product in O(n log n):

For challenges (r₀, r₁, ..., r_{d-1}):
  L[b₀b₁...b_{d-1}] = Π_i ((1-r_i)(1-b_i) + r_i·b_i)

Computed on CPU (~100 μs for typical dimensions), uploaded to GPU once.

LogUp GPU Kernels

Activation, LayerNorm, and dual-operand matmuls use degree-3 eq-sumcheck with three factors (eq, w, d). Four specialized kernels accelerate this:

Denominator Kernel

extern "C" __global__ void logup_denominator_kernel(
    const uint32_t* __restrict__ input,
    const uint32_t* __restrict__ output,
    const uint32_t* __restrict__ gamma,
    const uint32_t* __restrict__ beta,
    uint32_t* __restrict__ d_out,
    uint32_t n
)

Computes d[i] = gamma - input[i] - beta × output[i] (QM31 arithmetic).

3-Way Round Kernel

extern "C" __global__ void logup_3way_round_kernel(
    const uint32_t* eq, const uint32_t* w, const uint32_t* d,
    uint32_t* block_s0, uint32_t* block_s1,
    uint32_t* block_s2, uint32_t* block_s3,
    uint32_t mid
)

Evaluates the degree-3 round polynomial at t = 0, 1, 2, 3 by computing products of three interpolated factors per thread.

3-Way Fold Kernel

extern "C" __global__ void logup_3way_fold_kernel(
    const uint32_t* eq_in, const uint32_t* w_in, const uint32_t* d_in,
    const uint32_t* alpha,
    uint32_t* eq_out, uint32_t* w_out, uint32_t* d_out,
    uint32_t half_n
)

Folds all three MLEs simultaneously with the same challenge — batch optimization that keeps all three arrays GPU-resident.

GpuSumcheckExecutor

Architecture

pub struct GpuSumcheckExecutor {
    pub device: Arc<CudaDevice>,
    sumcheck_round_fn: CudaFunction,     // Eager: compiled at init
    sumcheck_reduce_fn: CudaFunction,    // Eager: compiled at init
    mle_fold_fn: CudaFunction,           // Eager: compiled at init
    forward_fns: Mutex<Option<ForwardKernels>>,   // Lazy
    restrict_fns: Mutex<Option<RestrictKernels>>,  // Lazy
    logup_fns: Mutex<Option<LogupKernels>>,        // Lazy
}

Global Caching

pub fn cached() -> Result<Arc<Self>, CudaFftError> {
    static EXECUTOR: OnceLock<Arc<GpuSumcheckExecutor>> = OnceLock::new();
    EXECUTOR.get_or_try_init(|| {
        eprintln!("[GPU] Compiling sumcheck CUDA kernels (one-time)...");
        let executor = GpuSumcheckExecutor::new()?;
        Ok(Arc::new(executor))
    }).cloned()
}

Impact: NVRTC kernel compilation takes ~200ms. Without caching: 160 matmuls × 200ms = 32s wasted. With caching: 200ms once + 160 × 0.5ms dispatch = 280ms total.

GPU Auto-Dispatch

pub fn prove_matmul_sumcheck_onchain_auto(
    a: &M31Matrix, b: &M31Matrix, c: &M31Matrix,
) -> Result<MatMulSumcheckProofOnChain, MatMulError> {
    #[cfg(feature = "cuda-runtime")]
    {
        if gpu_is_available() {
            match prove_matmul_sumcheck_onchain_gpu(a, b, c) {
                Ok(proof) => return Ok(proof),
                Err(e) => {
                    tracing::warn!("GPU failed, falling back to CPU: {e}");
                }
            }
        }
    }
    prove_matmul_sumcheck_onchain(a, b, c)  // CPU fallback
}

Critical lesson: The original k.is_power_of_two() gate blocked GPU dispatch for all real models (Qwen3-14B uses k=5120, not a power of 2). After removing this gate and adding automatic pad_matrix_pow2 inside the GPU path, all 160 matmuls now run on GPU.

GPU GEMM

gpu_matmul_m31_full() provides M31 matrix multiplication on GPU:

  • GEMV (m=1): Single-row dot product, 1D grid
  • GEMM (m>1): Full matrix multiply, 2D grid with 16×16 thread blocks

Used in model forward pass execution and weight commitment computation.

GPU Element-wise Operations

KernelFunctionThreshold
m31_add_kernelgpu_elementwise_add()len >= 4096
m31_mul_kernelgpu_elementwise_mul()len >= 4096
m31_relu_kernelgpu_relu()len >= 4096

All operations fall back to CPU below threshold. Error return (not panic) on dimension mismatch enables graceful fallback.

GKR GPU Operations

MethodDescription
evaluate_mle_gpu(mle, point)MLE evaluation at a point on GPU
combine_blocks(block_mles, weights)Σ_b w_b · MLE_b — SIMD block combination
reduce_matmul_layer_gpu(claim, A, B, m, k, n, channel)Full matmul reduction on GPU
restrict_rows(A, r_i, pk)Fused row restriction
restrict_cols(B, r_j, pk)Fused column restriction
sumcheck_3way(ext_w, ext_a, ext_b, n, channel)3-factor GPU sumcheck (reuses LogUp kernels)

Multi-GPU Distributed Proving

Architecture

┌─────────────────────────────────┐
│       MultiGpuExecutor          │
│  ┌─────────┐  ┌─────────┐      │
│  │  GPU 0   │  │  GPU 1   │ ... │
│  │  chunk 0 │  │  chunk 1 │     │
│  │  chunk 3 │  │  chunk 2 │     │
│  └─────────┘  └─────────┘      │
└─────────────────────────────────┘

Thread-Local Device Affinity

thread_local! { CURRENT_DEVICE: Cell<Option<usize>> }
  • set_thread_device(id) / get_thread_device() — set/query affinity
  • DeviceGuard::new(id) — RAII guard, restores previous device on drop (panic-safe)
  • propagate_device(parent) — for rayon worker threads (they don't inherit thread_local!)

GpuSumcheckExecutor Per-Device Pool

GpuSumcheckExecutor::cached()              // uses thread-local device
GpuSumcheckExecutor::cached_for_device(id) // explicit device

Per-device singleton pool via OnceLock<Mutex<HashMap<usize, Arc<Self>>>>. One cached() call routes all downstream GPU operations to the correct device.

Chunk Partitioning

MultiGpuExecutor::partition_chunks(chunks) uses greedy bin-packing with 80% memory safety margin:

  1. Sort chunks by estimated memory (descending)
  2. Assign each chunk to the GPU with the most free memory
  3. Warn on oversized chunks exceeding single GPU capacity

Proving Flow

prove_model_chunked_multi_gpu(graph, input, weights, executor)
  1. CPU forward pass → chunk inputs
  2. partition_chunks() → device assignments
  3. std::thread::scope with DeviceGuard per chunk thread
  4. Collect all results (ALL errors, not just first)
  5. Return MultiGpuProvingResult with per-device stats

CLI

prove-model --model model.onnx --gpu --multi-gpu --chunk-budget-gb 8

Pipeline Optimizations

Three-Phase Pipeline

Phase 1: Forward pass (CPU)           │  2-3s
Phase 2: MatMul proving (GPU)         │  30-35s (160 matmuls)
Phase 3: Unified STARK (GPU)          │  5-7s
                                      ────────
                                      ~40s total (Qwen3-14B)

Parallel Merkle Tree

PoseidonMerkleTree::build_parallel() + root_only_parallel():

  • par_chunks(2) for layers with >= 256 pairs
  • root_only_parallel: no intermediate layer storage (just the root)
  • Used in commit_mle_root_only() for batch entry prep

Pipelined Weight Commitment

std::thread::scope(|s| {
    s.spawn(|| commit_weights_background(...));  // Runs during proving
    prove_layers(...);
});

Zero added latency — weight commitment runs on a background thread during proving.

Weight Loading

Two-phase bulk extract + parallel process:

  1. Phase 1: Extract raw f32 from ALL SafeTensor shards
  2. Phase 2: Single rayon parallel pass (160 tasks for Qwen3-14B)
  3. madvise(MADV_SEQUENTIAL | MADV_WILLNEED) on shard mmaps for OS prefetch

Streaming Serialization

serialize_ml_proof_to_file() writes via BufWriter (1 MB buffer) instead of Vec<String> + .join(), eliminating large intermediate allocations for proofs.

Memory Management

Tiled MatMul

For matrices exceeding GPU memory, tiled proving splits the k-dimension:

pub struct TiledMatMulConfig {
    pub max_tile_k: usize,  // Power of 2
}

impl TiledMatMulConfig {
    pub fn from_memory_budget(m: usize, k: usize, n: usize, budget: usize) -> Self {
        // Halve tile_k until estimated memory fits budget
    }
}

Memory budget: 64 GB (configured for H100's 80 GB HBM). This eliminates tiling for all production models — the previous 4 GB budget forced all FFN matmuls into slow tiled mode.

GPU Dispatch Thresholds

pub struct GpuThresholds;
impl GpuThresholds {
    const DEFAULT_MLE: u32 = 14;         // 2^14 = 16K elements
    const DEFAULT_COLUMN_OPS: u32 = 14;
    const DEFAULT_MERKLE: u32 = 14;
}

Override via OBELYSK_GPU_THRESHOLD=0 to force all GPU on H100/A100.

Kernel Launch Overhead

KernelCompile TimePer-Launch
sumcheck_round~60ms (one-time)~5 μs
mle_foldshared module~5 μs
restrict_rows/cols~50ms (lazy)~5 μs
logup_3way_round~50ms (lazy)~5 μs

Total overhead for 14-round sumcheck: 14 × (5+5) μs = 140 μs (negligible vs ~48ms total).