SIMD Block Batching
Overview
Transformer models repeat identical blocks (layers with shared weights). Instead of proving each block independently, SIMD block batching proves N identical blocks in a single GKR pass by introducing a randomized block selection dimension.
For Qwen3-14B with 8 SIMD blocks, this reduces the GKR from 8 independent proofs to 1 proof with log₂(8) = 3 extra sumcheck rounds per layer.
Complexity reduction: From O(N × depth × log(width)) to O(depth × log(width) + log(N)) sumcheck rounds, where N = number of identical blocks, depth = layers per block, width = layer dimension.
Configuration
pub struct SIMDBatchConfig {
pub num_blocks: usize, // N identical blocks
pub template_range: Range<usize>, // layer indices for one block
pub simd_log_size: usize, // ceil(log₂(N))
}
The LayeredCircuit detects identical block structure during compilation and populates simd_config: Option<SIMDBatchConfig> automatically.
Core Idea
Given N blocks with identical circuit structure, the verifier:
- Draws random SIMD challenges
r_simd ∈ F^{log₂(N)} - Computes block weights via Lagrange basis:
w_b = eq(r_simd, b)for each block indexb - The combined output is
combined[i] = Σ_b w_b · output_b[i]
If the combined output satisfies the GKR reduction, then with overwhelming probability all individual blocks are correct (Schwartz-Zippel lemma).
Block Weights
The SIMD weights are the multilinear Lagrange basis evaluated at the random point:
w_b = Π_{i=0}^{log₂(N)-1} [(1 - r_simd[i])(1 - b_i) + r_simd[i] · b_i]
where b_i is the i-th bit of block index b. For N=2:
w_0 = 1 - r_simd[0]w_1 = r_simd[0]
Shared-Weight vs Dual-Operand Matmuls
Shared-Weight (degree-2)
When only the input A varies per block but weight B is shared:
Σ_b w_b · (A_b × B) = (Σ_b w_b · A_b) × B
Linearity allows combining A first, then running a standard degree-2 sumcheck:
- GPU combines:
combined_A = Σ_b w_b · MLE(A_b)viagpu.combine_blocks() - Standard matmul sumcheck:
restrict(combined_A, r_i) · restrict_col(B, r_j) - Same number of rounds as non-SIMD (no extra overhead)
Used for: output projection, Q/K/V projections.
Dual-Operand (degree-3, block-extended)
When both A and B vary per block (per-head attention matmuls):
Σ_b w_b · Σ_k A_b(r_row, k) · B_b(k, r_col)
The linearity trick fails because (Σ_b w_b · A_b) × (Σ_c w_c · B_c) includes cross-terms where b ≠ c.
Solution: Block-extended 3-factor sumcheck.
Define extended MLEs of length N × K:
ext_w[b·K + k] = w_b (block weight, replicated K times)
ext_a[b·K + k] = f_a_b[k] (restricted A for block b)
ext_b[b·K + k] = f_b_b[k] (restricted B for block b)
Then: claim = Σ_i ext_w[i] · ext_a[i] · ext_b[i]
This is a standard 3-factor sumcheck over log₂(N×K) variables — exactly log₂(N) extra rounds compared to non-SIMD. The degree-3 round polynomial uses RoundPolyDeg3 with Newton interpolation at t = 0, 1, 2, 3.
The verification final check uses the eq evaluation:
running_sum == eq(r_simd, block_challenges) · final_a · final_b
where block_challenges are the first log₂(N) sumcheck challenges (corresponding to the block-index variables).
Used for: per-head score matmul (Q_h × K_h^T), per-head context matmul (softmax_h × V_h).
The proof variant LayerProof::MatMulDualSimd carries the additional n_block_vars: usize field so the verifier knows how many sumcheck rounds correspond to block selection vs inner dimension:
MatMulDualSimd {
round_polys: Vec<RoundPolyDeg3>, // log₂(N×K) polynomials
final_a_eval: SecureField,
final_b_eval: SecureField,
n_block_vars: usize, // = simd_log_size
}
Verification split: The first n_block_vars challenges select the block, the remaining log₂(K) challenges select the inner dimension position. The final check:
running_sum == eq(r_simd, block_challenges[0..n_block_vars]) × final_a × final_b
Attention Layer Batching
Sub-matmul Types
| Sub-matmul | A varies? | B varies? | Protocol |
|---|---|---|---|
Output: concat × W_O | Yes (concat differs) | No (shared weight) | Shared-weight (degree-2) |
Context: softmax_h × V_h | Yes | Yes | Dual-operand (degree-3) |
Score: Q_h × K_h^T | Yes | Yes | Dual-operand (degree-3) |
V proj: input × W_V | Yes | No (shared weight) | Shared-weight (degree-2) |
K proj: input × W_K | Yes | No (shared weight) | Shared-weight (degree-2) |
Q proj: input × W_Q | Yes | No (shared weight) | Shared-weight (degree-2) |
Score Matrix Scaling Gotcha
score_matrices[h] in AttentionIntermediates includes the 1/√d_k scaling factor applied after Q_h × K_h^T. The sumcheck operates on the raw unscaled product. You must compute matmul_m31(Q_h, K_h^T) fresh for the combined output MLE — using score_matrices[h] directly causes an exact √d_k factor mismatch.
For d_k = 64 (typical): scale_inv = 1/8, so the mismatch would be 8×.
Prover Entry Point
reduce_attention_layer_simd_gpu(
gpu, output_claim, block_executions,
attn_weights, config, block_weights, r_simd, channel,
) → Result<(LayerProof::Attention, GKRClaim), GKRError>
Verifier Dispatch
verify_attention_reduction accepts r_simd: Option{'<'}&[SecureField]{'>'}:
None→ non-SIMD path, all sub-proofs must beLayerProof::MatMulSome(r_simd)→ SIMD path, per-head sub-proofs can beMatMulDualSimd
Sub-proof 0 (output projection) and the Q projection (last sub-proof) must always be LayerProof::MatMul.
Activation and Non-Linear Layers
Activation Fallback
Activation functions (ReLU, GELU, softmax_exp) use LogUp lookup tables. When combining across blocks, Σ_b w_b · table_lookup(x_b) produces QM31 sums that don't correspond to valid table entries — the combined value is a linear combination of table outputs, not a table output itself.
Solution: Activation layers fall back to single-block CPU proving. The SIMD prover uses the first block's execution data for LogUp proofs, and the per-block activations are checked individually. This is sound because activations are validated per-element, not per-block.
The same fallback applies to Dequantize and RMSNorm layers (both use LogUp tables).
LayerNorm Non-Linearity
LayerNorm involves mean and rsqrt — non-linear operations that cause cross-terms when combining across blocks:
Σ_b w_b · centered_b × Σ_c w_c · rsqrt_c ≠ Σ_b w_b · (centered_b × rsqrt_b)
Combined-Product MLE Solution
Pre-compute the per-element product before combining:
combined_product[i] = Σ_b w_b · (centered_b[i] × rsqrt_b[i])
On the boolean hypercube, this equals the combined output. The eq-sumcheck proves:
Σ eq(r, x) · combined_product(x) · 1 = output_claim.value
The constant-1 MLE as the second factor means rsqrt_final = 1 after all folds.
Prover Flow
- Per-block CPU forward pass → compute product/mean/rsqrt/input MLEs per block
- GPU combine → 4×
gpu.combine_blocks()calls for product, mean, rsqrt, and input - GPU MLE evaluation at claim point
- Degree-3 eq-sumcheck over
combined_product × ones - LogUp skipped →
logup_proof: None
The simd_combined: true flag in LayerProof::LayerNorm signals to the verifier that LogUp validation should be skipped — the QM31-combined rsqrt values don't map to table entries, but soundness is maintained because the combined-product approach verifies the overall normalization relationship.
SIMD Proving Flow
prove_gkr_simd_gpu(circuit, block_executions, weights, channel)
- Seed channel: mix dimensions, block count
- Draw SIMD challenges:
r_simd = channel.draw_qm31s(log₂(n_blocks)) - Compute block weights: Lagrange basis evaluation
- Per-block forward passes: execute all blocks to get intermediates
- Walk layers (output → input):
- MatMul (shared weight):
reduce_matmul_layer_simd_gpu()— combine A, standard sumcheck - MatMul (dual operand):
reduce_matmul_layer_dual_simd_gpu()— block-extended 3-factor - Activation:
reduce_activation_layer()with combined input/output MLEs - LayerNorm:
reduce_layernorm_layer_simd()with combined-product approach - Attention:
reduce_attention_layer_simd_gpu()— mixed shared/dual sub-proofs
- MatMul (shared weight):
- Return proof with combined input claim
GKR Proof Structure
The complete SIMD GKR proof contains:
pub struct GKRProof {
pub layer_proofs: Vec<LayerProof>, // one per template layer
pub output_claim: GKRClaim, // MLE evaluation at output
pub input_claim: GKRClaim, // MLE evaluation at input
pub weight_commitments: Vec<FieldElement>, // Poseidon hash per weight matrix
pub weight_openings: Vec<MleOpeningProof>, // Merkle proofs for weight MLEs
pub weight_claims: Vec<WeightClaim>,
pub io_commitment: FieldElement, // Poseidon(inputs || outputs)
pub deferred_proofs: Vec<DeferredProof>, // DAG skip-connection proofs
}
Each LayerProof variant carries the data needed for its reduction:
| Variant | Fields | When Used |
|---|---|---|
MatMul | round_polys, final_a, final_b | Shared-weight SIMD matmul |
MatMulDualSimd | round_polys, final_a, final_b, n_block_vars | Both-operand-varying matmul |
Add | lhs_eval, rhs_eval, trunk_idx | Residual connection split |
Activation | logup_proof, input_eval, output_eval | Per-element activation |
LayerNorm | logup_proof, linear_round_polys, simd_combined | Combined-product path |
RMSNorm | logup_proof, linear_round_polys, simd_combined | Combined-product path |
Verification
verify_gkr_simd(circuit, proof, combined_output, channel)
Mirrors the prover's channel state exactly:
- Same seeding and SIMD challenge derivation:
channel.mix_u64(num_blocks), thenr_simd = channel.draw_qm31s(simd_log_size) - Reconstructs block weights via Lagrange basis evaluation
- Per-layer verification with
r_simdcontext — dispatches to appropriate verifier perLayerProofvariant MatMulDualSimdsub-proofs verified viaverify_matmul_dual_simd_reductionwhich:- Replays degree-3 round polynomials
- Verifies block-challenge consistency:
eq(r_simd, challenges[0..n_block_vars]) - Checks final evaluation:
running_sum == eq_eval × final_a × final_b
- Final input claim checked against public input MLE
Performance Analysis
Proving Cost
| Configuration | Independent Proofs | SIMD Proof | Savings |
|---|---|---|---|
| N=2 blocks, depth=9 | 18 layer reductions | 9 + 1 extra round each | 47% fewer rounds |
| N=4 blocks, depth=9 | 36 layer reductions | 9 + 2 extra rounds each | 72% fewer rounds |
| N=8 blocks, depth=9 | 72 layer reductions | 9 + 3 extra rounds each | 85% fewer rounds |
SIMD Overhead per Layer
For each layer in the template, SIMD adds:
| Layer Type | Extra Work | Extra Rounds |
|---|---|---|
| MatMul (shared weight) | GPU combine_blocks() | 0 |
| MatMul (dual operand) | Extended MLE construction | log₂(N) |
| Add | GPU combine 2 operands | 0 |
| Activation | None (CPU fallback) | 0 |
| LayerNorm | GPU combine 4 MLEs | 0 |
The dual-operand matmul is the only case with extra sumcheck rounds. For Qwen3-14B attention (40 Q heads, 8 KV heads), per-head score and context matmuls are dual-operand — adding log₂(N) rounds to each of 2 × 40 = 80 per-head matmuls.
Verifier Cost
SIMD verification is nearly identical to non-SIMD: the verifier walks the same number of template layers, checking one proof per layer. The only additional work is:
- Computing
Nblock weights fromlog₂(N)challenges:O(N)field multiplications - Computing combined output MLE:
O(N × output_size)additions - Verifying
log₂(N)extra sumcheck rounds for dual-operand matmuls
For N=8: ~11 extra field operations per dual-operand matmul.