SIMD Block Batching

Overview

Transformer models repeat identical blocks (layers with shared weights). Instead of proving each block independently, SIMD block batching proves N identical blocks in a single GKR pass by introducing a randomized block selection dimension.

For Qwen3-14B with 8 SIMD blocks, this reduces the GKR from 8 independent proofs to 1 proof with log₂(8) = 3 extra sumcheck rounds per layer.

Complexity reduction: From O(N × depth × log(width)) to O(depth × log(width) + log(N)) sumcheck rounds, where N = number of identical blocks, depth = layers per block, width = layer dimension.

Configuration

pub struct SIMDBatchConfig {
    pub num_blocks: usize,          // N identical blocks
    pub template_range: Range<usize>, // layer indices for one block
    pub simd_log_size: usize,       // ceil(log₂(N))
}

The LayeredCircuit detects identical block structure during compilation and populates simd_config: Option<SIMDBatchConfig> automatically.

Core Idea

Given N blocks with identical circuit structure, the verifier:

  1. Draws random SIMD challenges r_simd ∈ F^{log₂(N)}
  2. Computes block weights via Lagrange basis: w_b = eq(r_simd, b) for each block index b
  3. The combined output is combined[i] = Σ_b w_b · output_b[i]

If the combined output satisfies the GKR reduction, then with overwhelming probability all individual blocks are correct (Schwartz-Zippel lemma).

Block Weights

The SIMD weights are the multilinear Lagrange basis evaluated at the random point:

w_b = Π_{i=0}^{log₂(N)-1} [(1 - r_simd[i])(1 - b_i) + r_simd[i] · b_i]

where b_i is the i-th bit of block index b. For N=2:

  • w_0 = 1 - r_simd[0]
  • w_1 = r_simd[0]

Shared-Weight vs Dual-Operand Matmuls

Shared-Weight (degree-2)

When only the input A varies per block but weight B is shared:

Σ_b w_b · (A_b × B) = (Σ_b w_b · A_b) × B

Linearity allows combining A first, then running a standard degree-2 sumcheck:

  • GPU combines: combined_A = Σ_b w_b · MLE(A_b) via gpu.combine_blocks()
  • Standard matmul sumcheck: restrict(combined_A, r_i) · restrict_col(B, r_j)
  • Same number of rounds as non-SIMD (no extra overhead)

Used for: output projection, Q/K/V projections.

Dual-Operand (degree-3, block-extended)

When both A and B vary per block (per-head attention matmuls):

Σ_b w_b · Σ_k A_b(r_row, k) · B_b(k, r_col)

The linearity trick fails because (Σ_b w_b · A_b) × (Σ_c w_c · B_c) includes cross-terms where b ≠ c.

Solution: Block-extended 3-factor sumcheck.

Define extended MLEs of length N × K:

ext_w[b·K + k] = w_b          (block weight, replicated K times)
ext_a[b·K + k] = f_a_b[k]     (restricted A for block b)
ext_b[b·K + k] = f_b_b[k]     (restricted B for block b)

Then: claim = Σ_i ext_w[i] · ext_a[i] · ext_b[i]

This is a standard 3-factor sumcheck over log₂(N×K) variables — exactly log₂(N) extra rounds compared to non-SIMD. The degree-3 round polynomial uses RoundPolyDeg3 with Newton interpolation at t = 0, 1, 2, 3.

The verification final check uses the eq evaluation:

running_sum == eq(r_simd, block_challenges) · final_a · final_b

where block_challenges are the first log₂(N) sumcheck challenges (corresponding to the block-index variables).

Used for: per-head score matmul (Q_h × K_h^T), per-head context matmul (softmax_h × V_h).

The proof variant LayerProof::MatMulDualSimd carries the additional n_block_vars: usize field so the verifier knows how many sumcheck rounds correspond to block selection vs inner dimension:

MatMulDualSimd {
    round_polys: Vec<RoundPolyDeg3>,  // log₂(N×K) polynomials
    final_a_eval: SecureField,
    final_b_eval: SecureField,
    n_block_vars: usize,              // = simd_log_size
}

Verification split: The first n_block_vars challenges select the block, the remaining log₂(K) challenges select the inner dimension position. The final check:

running_sum == eq(r_simd, block_challenges[0..n_block_vars]) × final_a × final_b

Attention Layer Batching

Sub-matmul Types

Sub-matmulA varies?B varies?Protocol
Output: concat × W_OYes (concat differs)No (shared weight)Shared-weight (degree-2)
Context: softmax_h × V_hYesYesDual-operand (degree-3)
Score: Q_h × K_h^TYesYesDual-operand (degree-3)
V proj: input × W_VYesNo (shared weight)Shared-weight (degree-2)
K proj: input × W_KYesNo (shared weight)Shared-weight (degree-2)
Q proj: input × W_QYesNo (shared weight)Shared-weight (degree-2)

Score Matrix Scaling Gotcha

score_matrices[h] in AttentionIntermediates includes the 1/√d_k scaling factor applied after Q_h × K_h^T. The sumcheck operates on the raw unscaled product. You must compute matmul_m31(Q_h, K_h^T) fresh for the combined output MLE — using score_matrices[h] directly causes an exact √d_k factor mismatch.

For d_k = 64 (typical): scale_inv = 1/8, so the mismatch would be 8×.

Prover Entry Point

reduce_attention_layer_simd_gpu(
    gpu, output_claim, block_executions,
    attn_weights, config, block_weights, r_simd, channel,
) → Result<(LayerProof::Attention, GKRClaim), GKRError>

Verifier Dispatch

verify_attention_reduction accepts r_simd: Option{'<'}&[SecureField]{'>'}:

  • None → non-SIMD path, all sub-proofs must be LayerProof::MatMul
  • Some(r_simd) → SIMD path, per-head sub-proofs can be MatMulDualSimd

Sub-proof 0 (output projection) and the Q projection (last sub-proof) must always be LayerProof::MatMul.

Activation and Non-Linear Layers

Activation Fallback

Activation functions (ReLU, GELU, softmax_exp) use LogUp lookup tables. When combining across blocks, Σ_b w_b · table_lookup(x_b) produces QM31 sums that don't correspond to valid table entries — the combined value is a linear combination of table outputs, not a table output itself.

Solution: Activation layers fall back to single-block CPU proving. The SIMD prover uses the first block's execution data for LogUp proofs, and the per-block activations are checked individually. This is sound because activations are validated per-element, not per-block.

The same fallback applies to Dequantize and RMSNorm layers (both use LogUp tables).

LayerNorm Non-Linearity

LayerNorm involves mean and rsqrt — non-linear operations that cause cross-terms when combining across blocks:

Σ_b w_b · centered_b × Σ_c w_c · rsqrt_c  ≠  Σ_b w_b · (centered_b × rsqrt_b)

Combined-Product MLE Solution

Pre-compute the per-element product before combining:

combined_product[i] = Σ_b w_b · (centered_b[i] × rsqrt_b[i])

On the boolean hypercube, this equals the combined output. The eq-sumcheck proves:

Σ eq(r, x) · combined_product(x) · 1 = output_claim.value

The constant-1 MLE as the second factor means rsqrt_final = 1 after all folds.

Prover Flow

  1. Per-block CPU forward pass → compute product/mean/rsqrt/input MLEs per block
  2. GPU combine → 4× gpu.combine_blocks() calls for product, mean, rsqrt, and input
  3. GPU MLE evaluation at claim point
  4. Degree-3 eq-sumcheck over combined_product × ones
  5. LogUp skipped → logup_proof: None

The simd_combined: true flag in LayerProof::LayerNorm signals to the verifier that LogUp validation should be skipped — the QM31-combined rsqrt values don't map to table entries, but soundness is maintained because the combined-product approach verifies the overall normalization relationship.

SIMD Proving Flow

prove_gkr_simd_gpu(circuit, block_executions, weights, channel)
  1. Seed channel: mix dimensions, block count
  2. Draw SIMD challenges: r_simd = channel.draw_qm31s(log₂(n_blocks))
  3. Compute block weights: Lagrange basis evaluation
  4. Per-block forward passes: execute all blocks to get intermediates
  5. Walk layers (output → input):
    • MatMul (shared weight): reduce_matmul_layer_simd_gpu() — combine A, standard sumcheck
    • MatMul (dual operand): reduce_matmul_layer_dual_simd_gpu() — block-extended 3-factor
    • Activation: reduce_activation_layer() with combined input/output MLEs
    • LayerNorm: reduce_layernorm_layer_simd() with combined-product approach
    • Attention: reduce_attention_layer_simd_gpu() — mixed shared/dual sub-proofs
  6. Return proof with combined input claim

GKR Proof Structure

The complete SIMD GKR proof contains:

pub struct GKRProof {
    pub layer_proofs: Vec<LayerProof>,           // one per template layer
    pub output_claim: GKRClaim,                  // MLE evaluation at output
    pub input_claim: GKRClaim,                   // MLE evaluation at input
    pub weight_commitments: Vec<FieldElement>,    // Poseidon hash per weight matrix
    pub weight_openings: Vec<MleOpeningProof>,    // Merkle proofs for weight MLEs
    pub weight_claims: Vec<WeightClaim>,
    pub io_commitment: FieldElement,              // Poseidon(inputs || outputs)
    pub deferred_proofs: Vec<DeferredProof>,      // DAG skip-connection proofs
}

Each LayerProof variant carries the data needed for its reduction:

VariantFieldsWhen Used
MatMulround_polys, final_a, final_bShared-weight SIMD matmul
MatMulDualSimdround_polys, final_a, final_b, n_block_varsBoth-operand-varying matmul
Addlhs_eval, rhs_eval, trunk_idxResidual connection split
Activationlogup_proof, input_eval, output_evalPer-element activation
LayerNormlogup_proof, linear_round_polys, simd_combinedCombined-product path
RMSNormlogup_proof, linear_round_polys, simd_combinedCombined-product path

Verification

verify_gkr_simd(circuit, proof, combined_output, channel)

Mirrors the prover's channel state exactly:

  1. Same seeding and SIMD challenge derivation: channel.mix_u64(num_blocks), then r_simd = channel.draw_qm31s(simd_log_size)
  2. Reconstructs block weights via Lagrange basis evaluation
  3. Per-layer verification with r_simd context — dispatches to appropriate verifier per LayerProof variant
  4. MatMulDualSimd sub-proofs verified via verify_matmul_dual_simd_reduction which:
    • Replays degree-3 round polynomials
    • Verifies block-challenge consistency: eq(r_simd, challenges[0..n_block_vars])
    • Checks final evaluation: running_sum == eq_eval × final_a × final_b
  5. Final input claim checked against public input MLE

Performance Analysis

Proving Cost

ConfigurationIndependent ProofsSIMD ProofSavings
N=2 blocks, depth=918 layer reductions9 + 1 extra round each47% fewer rounds
N=4 blocks, depth=936 layer reductions9 + 2 extra rounds each72% fewer rounds
N=8 blocks, depth=972 layer reductions9 + 3 extra rounds each85% fewer rounds

SIMD Overhead per Layer

For each layer in the template, SIMD adds:

Layer TypeExtra WorkExtra Rounds
MatMul (shared weight)GPU combine_blocks()0
MatMul (dual operand)Extended MLE constructionlog₂(N)
AddGPU combine 2 operands0
ActivationNone (CPU fallback)0
LayerNormGPU combine 4 MLEs0

The dual-operand matmul is the only case with extra sumcheck rounds. For Qwen3-14B attention (40 Q heads, 8 KV heads), per-head score and context matmuls are dual-operand — adding log₂(N) rounds to each of 2 × 40 = 80 per-head matmuls.

Verifier Cost

SIMD verification is nearly identical to non-SIMD: the verifier walks the same number of template layers, checking one proof per layer. The only additional work is:

  1. Computing N block weights from log₂(N) challenges: O(N) field multiplications
  2. Computing combined output MLE: O(N × output_size) additions
  3. Verifying log₂(N) extra sumcheck rounds for dual-operand matmuls

For N=8: ~11 extra field operations per dual-operand matmul.