GKR Protocol — Layer-by-Layer Interactive Proof
Overview
The GKR (Goldwasser-Kalai-Rothblum) protocol replaces per-layer independent STARK proofs with a single interactive proof that walks the computation graph from output to input. Instead of generating O(L) independent proofs for L layers, GKR produces one proof whose verification cost is proportional to the circuit depth.
For a transformer with 40 layers, each containing matmul + activation + layernorm, this eliminates ~120 independent STARK proofs and replaces them with a single GKR proof verified in O(depth) time.
Architecture
Output MLE claim
│
▼
┌─────────────┐
│ Layer L-1 │ ← MatMul sumcheck (log k rounds)
└──────┬──────┘
▼
┌─────────────┐
│ Layer L-2 │ ← Activation LogUp (eq-sumcheck + lookup proof)
└──────┬──────┘
▼
┌─────────────┐
│ Layer L-3 │ ← LayerNorm (combined-product sumcheck)
└──────┬──────┘
▼
...
▼
┌─────────────┐
│ Layer 0 │ ← Input claim (verified against public input)
└─────────────┘
Module Layout
| File | Purpose |
|---|---|
src/gkr/types.rs | GKRProof, LayerProof, GKRClaim, RoundPolyDeg3, LogUpProof |
src/gkr/circuit.rs | LayeredCircuit compiler from ComputationGraph |
src/gkr/prover.rs | Layer reduction protocols + GPU/SIMD variants |
src/gkr/verifier.rs | Fiat-Shamir transcript replay verifier |
Mathematical Foundations
Multilinear Extensions (MLE)
Every layer's data (a matrix or vector) is represented as a multilinear extension — a unique polynomial over the boolean hypercube {0,1}^n that agrees with the data at all binary points.
Definition: For a function f: {0,1}^n → F, its MLE is:
f̃(x₁, ..., xₙ) = Σ_{b ∈ {0,1}^n} f(b) · Π_{i=1}^{n} [(1-xᵢ)(1-bᵢ) + xᵢ·bᵢ]
The inner product Π_{i}[(1-xᵢ)(1-bᵢ) + xᵢ·bᵢ] is the multilinear Lagrange basis polynomial eq(x, b), which evaluates to 1 when x = b and 0 at all other binary points.
Evaluation via in-place folding: To evaluate f̃(r₁, r₂, ..., rₙ) efficiently, use the identity:
f̃(r, x_rest) = (1 - r) · f̃(0, x_rest) + r · f̃(1, x_rest)
This allows computing the evaluation in O(n) passes, each halving the table:
fn evaluate_mle(evals: &[SecureField], point: &[SecureField]) -> SecureField {
let mut current = evals.to_vec();
for &r in point.iter() {
let mid = current.len() / 2;
for i in 0..mid {
current[i] = current[i] + r * (current[mid + i] - current[i]);
}
current.truncate(mid);
}
current[0]
}
Complexity: O(2^n) time, O(1) extra space (in-place).
Lagrange Basis Computation (Tensor Product)
The Lagrange basis weights L_j(c) for a challenge vector c = (c₀, c₁, ..., c_{v-1}) are computed via tensor product construction:
L_j(c) = Π_{i=0}^{v-1} [(1 - cᵢ)(1 - jᵢ) + cᵢ · jᵢ]
where jᵢ is the i-th bit of index j.
Algorithm:
weights = [1]
for each cᵢ in challenges:
new_weights = []
for w in weights:
new_weights.push(w × (1 - cᵢ)) // bit = 0
new_weights.push(w × cᵢ) // bit = 1
Result: weights[j] = L_j(c) for all j ∈ {0, ..., 2^v - 1}
Complexity: O(2^v) time and space. This is used in the fused GPU restrict kernels to avoid materializing full MLE tables.
Sumcheck Protocol
The sumcheck protocol reduces a multivariate sum to a single evaluation. Given the claim:
H = Σ_{x ∈ {0,1}^n} g(x)
The prover sends degree-d univariate polynomials in n rounds, each fixing one variable. The verifier checks consistency and draws random challenges.
Round i: The prover sends pᵢ(t) such that pᵢ(0) + pᵢ(1) = Hᵢ₋₁ (the running sum). The verifier draws rᵢ ← F and sets Hᵢ = pᵢ(rᵢ).
After n rounds, the verifier holds a single-point claim g(r₁, ..., rₙ) = Hₙ which can be checked directly.
Degree-2 Interpolation (MatMul Sumcheck)
For matmul, each round polynomial is degree-2. Given evaluations at t = 0, 1, 2:
p(t) = c₀ + c₁·t + c₂·t²
The coefficients are recovered via:
c₀ = s₀
c₂ = (s₂ - 2·s₁ + s₀) / 2
c₁ = s₁ - s₀ - c₂
Verifier check: s₀ + s₁ = p(0) + p(1) = c₀ + (c₀ + c₁ + c₂) = 2c₀ + c₁ + c₂ must equal the prior round's running sum.
Degree-3 Interpolation (Eq-Sumcheck)
For mul/activation/layernorm, each round polynomial is degree-3. Given evaluations at t = 0, 1, 2, 3, Newton divided differences recover (c₀, c₁, c₂, c₃):
pub struct RoundPolyDeg3 {
pub c0: SecureField, // p(0)
pub c1: SecureField, // linear coefficient
pub c2: SecureField, // quadratic coefficient
pub c3: SecureField, // cubic coefficient
}
// Evaluation: p(t) = c₀ + c₁·t + c₂·t² + c₃·t³
impl RoundPolyDeg3 {
pub fn eval(&self, t: SecureField) -> SecureField {
self.c0 + t * (self.c1 + t * (self.c2 + t * self.c3))
}
}
Layer Types and Reduction Protocols
MatMul (degree-2 sumcheck)
For C = A × B with dimensions (m × k) × (k × n):
Mathematical Identity: The matmul inner product enables efficient proof via the MLE restriction:
MLE_C(r_i, r_j) = Σ_{x ∈ {0,1}^{log k}} MLE_A(r_i, x) · MLE_B(x, r_j)
This converts a claim about C into an inner product of two restricted MLEs over the inner dimension k.
Protocol steps:
- Fiat-Shamir setup: Mix dimensions
(m, k, n), draw row challengesr_i ∈ F^{log m}and column challengesr_j ∈ F^{log n} - Compute claimed sum:
claimed_sum = MLE_C(r_i, r_j) - Restrict MLEs: Extract
f_a(x) = MLE_A(r_i, x)andf_b(x) = MLE_B(x, r_j)forx ∈ {0,1}^{log k} - Sumcheck over
log₂(k)rounds, each producing a degree-2 polynomial:
p(t) = Σ_{i=0}^{mid-1} [(1-t)·f_a[i] + t·f_a[mid+i]] · [(1-t)·f_b[i] + t·f_b[mid+i]]
Evaluated at t = 0, 1, 2:
s₀ = Σ f_a[i] · f_b[i]s₁ = Σ f_a[mid+i] · f_b[mid+i]s₂ = Σ (2·f_a[mid+i] - f_a[i]) · (2·f_b[mid+i] - f_b[i])
- Final check: After all rounds,
product = final_a_eval × final_b_eval - Weight opening: Merkle proof that
MLE_B(assignment, r_j) = final_b_eval
Oracle implementation:
pub struct MatMulOracle {
pub f_a: Vec<SecureField>,
pub f_b: Vec<SecureField>,
}
impl MultivariatePolyOracle for MatMulOracle {
fn sum_as_poly_in_first_variable(&self, _claim: SecureField) -> UnivariatePoly {
let half = self.f_a.len() / 2;
let (mut s0, mut s1, mut s2) = (zero(), zero(), zero());
for i in 0..half {
s0 += self.f_a[i] * self.f_b[i];
s1 += self.f_a[half + i] * self.f_b[half + i];
s2 += (2 * self.f_a[half + i] - self.f_a[i])
* (2 * self.f_b[half + i] - self.f_b[i]);
}
let c0 = s0;
let c2 = (s2 - 2 * s1 + s0) / 2;
let c1 = s1 - s0 - c2;
UnivariatePoly::from_coeffs([c0, c1, c2])
}
fn fix_first_variable(self, challenge: SecureField) -> Self {
let mid = self.f_a.len() / 2;
let new_a = (0..mid).map(|i|
self.f_a[i] + challenge * (self.f_a[mid + i] - self.f_a[i])
).collect();
let new_b = (0..mid).map(|i|
self.f_b[i] + challenge * (self.f_b[mid + i] - self.f_b[i])
).collect();
MatMulOracle { f_a: new_a, f_b: new_b }
}
}
Channel protocol: mix(m, k, n), mix(claimed_value), per-round mix_poly_coeffs(c₀, c₁, c₂) + draw(), then mix(final_a), mix(final_b).
Degree analysis:
- For k = 5120 (Qwen3-14B FFN): padded to 8192, log₂(8192) = 13 sumcheck rounds
- Per round: 3 QM31 coefficients (c₀, c₁, c₂) = 12 M31 values
- Traditional STARK trace: O(m × k × n) rows = 2M rows for 128×128×128
- Sumcheck witness: O(m×n + m×k + k×n) = 49K rows (42× reduction)
Add (degree-1 split)
For C = A + B, the MLE decomposes linearly:
MLE_C(r) = MLE_A(r) + MLE_B(r)
No sumcheck needed — the claim splits into two sub-claims that are recursively verified.
DAG handling (residual connections): When A and B come from different branches:
- Main GKR walk follows the trunk (higher layer index in DAG)
- The other branch generates a deferred proof verified after the main walk
trunk_idx: u8marks which input (0=lhs, 1=rhs) is the trunk
Add {
lhs_eval: SecureField, // MLE_A(r)
rhs_eval: SecureField, // MLE_B(r)
trunk_idx: u8, // 0=lhs is trunk, 1=rhs is trunk
}
Mul (degree-3 eq-sumcheck)
For element-wise C = A ⊙ B:
claim = Σ_{x ∈ {0,1}^n} eq(r, x) · MLE_A(x) · MLE_B(x)
Three degree-1 factors → degree-3 univariate per round. Each round evaluates at t = 0, 1, 2, 3:
s₀ = Σ eq_lo[i] · a_lo[i] · b_lo[i]s₁ = Σ eq_hi[i] · a_hi[i] · b_hi[i]s₂ = Σ (2·eq_hi - eq_lo) · (2·a_hi - a_lo) · (2·b_hi - b_lo)s₃ = Σ (3·eq_hi - 2·eq_lo) · (3·a_hi - 2·a_lo) · (3·b_hi - 2·b_lo)
Newton divided differences → (c₀, c₁, c₂, c₃).
Final check: running_sum == eq(r, assignment) × a_final × b_final.
Activation (LogUp eq-sumcheck)
For y = f(x) where f is a non-linear activation (ReLU, GELU, Sigmoid):
LogUp Protocol:
- Precomputed table:
T = {(x, f(x)) : x ∈ domain}(65K entries for ReLU, 262K for GELU) - Encoding:
encode(x, y) = x + β · y(collapsing 2D pair to 1D via random β) - Denominators:
dᵢ = γ - encode(inputᵢ, outputᵢ)for trace rows - Witness fractions:
wᵢ = 1 / dᵢ - Multiplicity check:
Σᵢ wᵢ = Σⱼ multⱼ / (γ - encode(tableⱼ)) - Eq-sumcheck: Proves
w(x) · d(x) = 1on the boolean hypercube:
Σ_{x ∈ {0,1}^n} eq(r, x) · w(x) · d(x) = 1
This is a degree-3 sumcheck (three factors: eq, w, d). The verifier checks that the witness fractions are correctly formed.
LogUpProof structure:
pub struct LogUpProof {
pub eq_round_polys: Vec<RoundPolyDeg3>,
pub final_evals: (SecureField, SecureField, SecureField), // (w(s), in(s), out(s))
pub claimed_sum: SecureField,
pub multiplicities: Vec<u32>,
}
LayerNorm (combined-product sumcheck)
LayerNorm is non-linear (involves mean + rsqrt), which causes cross-terms when batching across SIMD blocks. The solution decomposes into two provable stages:
Stage 1 — Linear eq-sumcheck:
output = (input - mean) × rsqrt
This is verified via degree-3 eq-sumcheck over centered × rsqrt:
Σ_{x} eq(r, x) · centered(x) · rsqrt(x) = output_claim.value
where centered(x) = input(x) - mean.
Stage 2 — rsqrt LogUp lookup: Proves (variance, rsqrt) pairs exist in a precomputed table of reciprocal square roots. Uses the same LogUp protocol as activations.
SIMD batching path: When batching identical blocks, combined_product[i] = Σ_b w_b · (centered_b[i] × rsqrt_b[i]). A constant-1 MLE serves as the second factor, giving rsqrt_final = 1 after all folds. When simd_combined = true, LogUp is skipped (QM31 sum can't decompose back to M31 table entries).
RMSNorm (LogUp rsqrt lookup)
RMSNorm (y = x / sqrt(mean(x²) + ε) · γ) is handled similarly to LayerNorm but without mean subtraction:
claim = Σ_x eq(r, x) · MLE_output(x) · MLE_rsqrt(x)
Degree-2 linear sumcheck (two factors: input, rsqrt) + degree-3 LogUp for rsqrt lookup.
RoPE (LogUp rotation table)
Rotary Positional Embedding applies position-dependent (cos, sin) rotations to Q/K vectors. The rotation factors are deterministic from (seq_len, head_dim, base), so the verifier reconstructs the table. LogUp proves each (cos, sin) pair used in the trace exists in the table.
Dequantize (LogUp 2D table)
For quantized models (INT4/INT8), dequantization maps quantized integer values to their M31 equivalents via a small lookup table (16 entries for INT4, 256 for INT8). LogUp proves each (quantized_input, dequantized_output) pair matches the table. Follows the Activation pattern with finalize_logup_in_pairs().
Attention (composed sub-matmuls)
Attention decomposes into 4 + 2H sub-matmuls (H = num_heads):
| Index | Sub-matmul | Type | Operands |
|---|---|---|---|
| 0 | Output projection | Shared-weight | concat × W_O |
| 1..2H | Per-head context | Dual-operand | softmax_h × V_h |
| 1..2H | Per-head score | Dual-operand | Q_h × K_h^T (unscaled) |
| 2H+1 | V projection | Shared-weight | input × W_V |
| 2H+2 | K projection | Shared-weight | input × W_K |
| 2H+3 | Q projection | Shared-weight | input × W_Q |
Important: score_matrices[h] includes the 1/√d_k scaling factor. The sumcheck must use the raw Q_h × K_h^T product (unscaled), not the stored score matrix.
Fiat-Shamir Transcript Protocol
PoseidonChannel (Poseidon2-M31)
The Fiat-Shamir channel uses the Poseidon2-M31 permutation (t=16, rate=8) for native M31 field compatibility:
pub struct PoseidonChannel {
digest: FieldElement, // Hades permutation state
n_draws: u32, // Draw counter (domain separation)
}
Mix operation (absorb): state = [digest, felt(value), 2] → hades(state) → digest = state[0]
Draw operation (squeeze): state = [digest, n_draws, 3] → hades(state) → value = state[0]
QM31 extraction from felt252: A single felt252 is unpacked into 4 M31 components via successive floor_div(2^31):
pub fn draw_qm31(&mut self) -> SecureField {
let felt = self.draw_felt252();
let shift = FieldElement::from(1u64 << 31);
let p = (1u64 << 31) - 1;
let mut m31s = [M31::from(0u32); 4];
let mut cur = felt;
for m31 in m31s.iter_mut() {
let next = cur.floor_div(shift);
let res = cur - next * shift;
cur = next;
*m31 = M31::from((res.to_u64() % p) as u32);
}
QM31(CM31(m31s[0], m31s[1]), CM31(m31s[2], m31s[3]))
}
M31 packing algorithm (for mixing polynomial coefficients):
pack_m31s([a, b, c, d]) = ((((1 × 2^31 + a) × 2^31 + b) × 2^31 + c) × 2^31 + d)
The leading 1 is a sentinel that preserves leading zeros. Two QM31 values (8 M31s) pack into a single felt252. Unpacking reverses via floor_div(2^31).
Transcript Ordering
The matmul sumcheck transcript follows this exact sequence (any divergence between prover and verifier causes failure):
1. mix_u64(m), mix_u64(k), mix_u64(n) — dimensions
2. draw_qm31() × log_m — row challenges r_i
3. draw_qm31() × log_n — col challenges r_j
4. mix_felt(claimed_sum) — claimed MLE value
5. mix_felt(a_commitment), mix_felt(b_commitment) — Merkle roots
6. For each round:
mix_poly_coeffs(c₀, c₁, c₂) — 12 M31 → 2 felt252
draw_qm31() — sumcheck challenge
7. prove_mle_opening(f_a, assignment) — MLE opening proof
8. prove_mle_opening(f_b, assignment) — MLE opening proof
Both prover and Cairo verifier must produce identical channel digests at every step. This is the single most common source of bugs.
Proof Types
/// A claim: "MLE evaluated at `point` equals `value`"
pub struct GKRClaim {
pub point: Vec<SecureField>,
pub value: SecureField,
}
/// Per-layer proof variant
pub enum LayerProof {
MatMul { round_polys, final_a_eval, final_b_eval },
MatMulDualSimd { round_polys, final_a_eval, final_b_eval, n_block_vars },
Add { lhs_eval, rhs_eval },
Mul { round_polys: Vec<RoundPolyDeg3>, final_a_eval, final_b_eval },
Activation { activation_type, round_polys, final_input_eval, final_output_eval, logup_proof },
LayerNorm { round_polys, final_input_eval, final_output_eval, logup_proof },
RMSNorm { round_polys, final_input_eval, final_output_eval, logup_proof },
RoPE { logup_proof, input_eval, output_eval },
Dequantize { logup_proof, input_eval, output_eval, table_commitment },
Attention { sub_proofs: Vec<LayerProof>, sub_claim_values },
}
On-Chain Proof Structure
pub struct MatMulSumcheckProofOnChain {
pub m: u32,
pub k: u32,
pub n: u32,
pub num_rounds: u32, // = log₂(k)
pub claimed_sum: SecureField, // MLE_C(r_i, r_j)
pub round_polys: Vec<RoundPoly>, // log_k round polynomials
pub final_a_eval: SecureField, // MLE_A(r_i, assignment)
pub final_b_eval: SecureField, // MLE_B(assignment, r_j)
pub a_commitment: FieldElement, // Poseidon Merkle root
pub b_commitment: FieldElement, // Poseidon Merkle root
pub a_opening: MleOpeningProof,
pub b_opening: MleOpeningProof,
}
Calldata cost: ~50 + 12 × num_rounds felt252 values per matmul proof.
Circuit Compilation
LayeredCircuit::from_graph(graph) converts a ComputationGraph (topologically sorted DAG of ML operations) into a layered circuit where each layer has:
layer_type: Which reduction protocol to useinput_shape/output_shape: Matrix dimensions (for MLE sizing)node_id: Back-reference to the original graph nodeinput_layers: Predecessor layer indices
Deferred Proofs for DAG Residuals
When an Add layer has inputs from different branches (e.g., residual skip connections in transformers), the main GKR walk follows the trunk branch. The skip branch generates a deferred proof verified after the main walk completes.
pub struct DeferredProof {
pub claim: GKRClaim, // Claim at the Add layer
pub dims: (usize, usize, usize),
pub layer_proof: LayerProof, // MatMul sumcheck for skip branch
pub input_claim: GKRClaim,
pub weight_commitment: FieldElement,
pub weight_opening: MleOpeningProof,
pub weight_claim: WeightClaim,
}
Fiat-Shamir ordering: main walk → weight openings → deferred proofs. The deferred claim's evaluation point is the same point from the Add layer during the main walk.
On-chain: The Cairo verifier saves Add layer claim points during the walk, then reads and verifies deferred matmul sumcheck proofs in sequence after the main loop.
Entry Points
Proving
| Function | Description |
|---|---|
prove_gkr(circuit, execution, weights, channel) | CPU-only GKR proof |
prove_gkr_gpu(circuit, execution, weights, channel) | GPU-accelerated (single block) |
prove_gkr_simd_gpu(circuit, block_executions, weights, channel) | GPU + SIMD batching across blocks |
Verification
| Function | Description |
|---|---|
verify_gkr(circuit, proof, output, channel) | Standard verification |
verify_gkr_with_execution(circuit, proof, execution, channel) | Verification with intermediate checks |
verify_gkr_simd(circuit, proof, combined_output, channel) | SIMD-aware verification |
On-Chain (Cairo)
| Function | Description |
|---|---|
verify_model_gkr(proof_data, num_layers, dims, initial_claim, channel) | Full on-chain GKR walk |
verify_matmul_sumcheck(proof, channel) | Single matmul verification |
verify_batched_matmul(batch_proof, channel) | Lambda-weighted batch verification |
Input/Output Claim Verification
The GKR walk proves internal consistency but must be anchored to real data at both endpoints. See Input Claim Verification for the full design.
Output side: The verifier draws r_out from the Fiat-Shamir channel, evaluates MLE(raw_output, r_out) on-chain, and uses the result as the initial GKR claim.
Input side: After the GKR walk completes with final_claim, the verifier evaluates MLE(raw_input, final_claim.point) on-chain and asserts equality:
assert!(MLE(input, final_claim.point) == final_claim.value, "INPUT_CLAIM_MISMATCH");
Both MLEs are constructed from raw data in calldata with power-of-2 padding and row-major layout, matching the Rust-side pad_matrix_pow2 + matrix_to_mle.
Integration with Aggregation Pipeline
GKR is optional and additive — the standard STARK pipeline runs first, then GKR produces an additional proof:
// Standard pipeline
let stark_proof = prove_model_aggregated_onchain(graph, input, weights);
// With GKR (additional verification layer)
let (stark_proof, gkr_proof) = prove_model_aggregated_onchain_gkr(graph, input, weights);
The verifier checks both proofs independently. GKR provides a second, complementary verification path.
Performance
| Metric | Value | Notes |
|---|---|---|
| Qwen3-14B per-layer proving | 3.04s | H100 GPU, 160 matmuls |
| Qwen3-14B 40-layer total | 122s | End-to-end with streaming |
| Single matmul sumcheck | ~48ms | GPU-accelerated (k=5120) |
| On-chain verification | 18 TXs | Streaming GKR on Starknet |
| Sumcheck rounds per matmul | 13 | log₂(8192) for k=5120 |
| Proof size per matmul | ~5 KB | Round polys + openings |
| Total calldata | 86,723 felts | Via aggregated weight binding (28× reduction) |
References
- Goldwasser, Kalai, Rothblum (2015) — "Delegating Computation: Interactive Proofs for Muggles"
- Lund, Fortnow, Karloff, Nisan (1992) — "Algebraic Methods for Interactive Proof Systems"
- Thaler (2013) — "Time-Optimal Interactive Proofs for Circuit Evaluation"
- Grassi et al. (2021) — "Poseidon: A New Hash Function for Zero-Knowledge Proofs"