GKR Protocol — Layer-by-Layer Interactive Proof

Overview

The GKR (Goldwasser-Kalai-Rothblum) protocol replaces per-layer independent STARK proofs with a single interactive proof that walks the computation graph from output to input. Instead of generating O(L) independent proofs for L layers, GKR produces one proof whose verification cost is proportional to the circuit depth.

For a transformer with 40 layers, each containing matmul + activation + layernorm, this eliminates ~120 independent STARK proofs and replaces them with a single GKR proof verified in O(depth) time.

Architecture

 Output MLE claim
       │
       ▼
 ┌─────────────┐
 │  Layer L-1   │  ← MatMul sumcheck (log k rounds)
 └──────┬──────┘
        ▼
 ┌─────────────┐
 │  Layer L-2   │  ← Activation LogUp (eq-sumcheck + lookup proof)
 └──────┬──────┘
        ▼
 ┌─────────────┐
 │  Layer L-3   │  ← LayerNorm (combined-product sumcheck)
 └──────┬──────┘
        ▼
       ...
        ▼
 ┌─────────────┐
 │  Layer 0     │  ← Input claim (verified against public input)
 └─────────────┘

Module Layout

FilePurpose
src/gkr/types.rsGKRProof, LayerProof, GKRClaim, RoundPolyDeg3, LogUpProof
src/gkr/circuit.rsLayeredCircuit compiler from ComputationGraph
src/gkr/prover.rsLayer reduction protocols + GPU/SIMD variants
src/gkr/verifier.rsFiat-Shamir transcript replay verifier

Mathematical Foundations

Multilinear Extensions (MLE)

Every layer's data (a matrix or vector) is represented as a multilinear extension — a unique polynomial over the boolean hypercube {0,1}^n that agrees with the data at all binary points.

Definition: For a function f: {0,1}^n → F, its MLE is:

f̃(x₁, ..., xₙ) = Σ_{b ∈ {0,1}^n} f(b) · Π_{i=1}^{n} [(1-xᵢ)(1-bᵢ) + xᵢ·bᵢ]

The inner product Π_{i}[(1-xᵢ)(1-bᵢ) + xᵢ·bᵢ] is the multilinear Lagrange basis polynomial eq(x, b), which evaluates to 1 when x = b and 0 at all other binary points.

Evaluation via in-place folding: To evaluate f̃(r₁, r₂, ..., rₙ) efficiently, use the identity:

f̃(r, x_rest) = (1 - r) · f̃(0, x_rest) + r · f̃(1, x_rest)

This allows computing the evaluation in O(n) passes, each halving the table:

fn evaluate_mle(evals: &[SecureField], point: &[SecureField]) -> SecureField {
    let mut current = evals.to_vec();
    for &r in point.iter() {
        let mid = current.len() / 2;
        for i in 0..mid {
            current[i] = current[i] + r * (current[mid + i] - current[i]);
        }
        current.truncate(mid);
    }
    current[0]
}

Complexity: O(2^n) time, O(1) extra space (in-place).

Lagrange Basis Computation (Tensor Product)

The Lagrange basis weights L_j(c) for a challenge vector c = (c₀, c₁, ..., c_{v-1}) are computed via tensor product construction:

L_j(c) = Π_{i=0}^{v-1} [(1 - cᵢ)(1 - jᵢ) + cᵢ · jᵢ]

where jᵢ is the i-th bit of index j.

Algorithm:

weights = [1]
for each cᵢ in challenges:
    new_weights = []
    for w in weights:
        new_weights.push(w × (1 - cᵢ))    // bit = 0
        new_weights.push(w × cᵢ)           // bit = 1

Result: weights[j] = L_j(c) for all j ∈ {0, ..., 2^v - 1}

Complexity: O(2^v) time and space. This is used in the fused GPU restrict kernels to avoid materializing full MLE tables.

Sumcheck Protocol

The sumcheck protocol reduces a multivariate sum to a single evaluation. Given the claim:

H = Σ_{x ∈ {0,1}^n} g(x)

The prover sends degree-d univariate polynomials in n rounds, each fixing one variable. The verifier checks consistency and draws random challenges.

Round i: The prover sends pᵢ(t) such that pᵢ(0) + pᵢ(1) = Hᵢ₋₁ (the running sum). The verifier draws rᵢ ← F and sets Hᵢ = pᵢ(rᵢ).

After n rounds, the verifier holds a single-point claim g(r₁, ..., rₙ) = Hₙ which can be checked directly.

Degree-2 Interpolation (MatMul Sumcheck)

For matmul, each round polynomial is degree-2. Given evaluations at t = 0, 1, 2:

p(t) = c₀ + c₁·t + c₂·t²

The coefficients are recovered via:

c₀ = s₀
c₂ = (s₂ - 2·s₁ + s₀) / 2
c₁ = s₁ - s₀ - c₂

Verifier check: s₀ + s₁ = p(0) + p(1) = c₀ + (c₀ + c₁ + c₂) = 2c₀ + c₁ + c₂ must equal the prior round's running sum.

Degree-3 Interpolation (Eq-Sumcheck)

For mul/activation/layernorm, each round polynomial is degree-3. Given evaluations at t = 0, 1, 2, 3, Newton divided differences recover (c₀, c₁, c₂, c₃):

pub struct RoundPolyDeg3 {
    pub c0: SecureField,  // p(0)
    pub c1: SecureField,  // linear coefficient
    pub c2: SecureField,  // quadratic coefficient
    pub c3: SecureField,  // cubic coefficient
}

// Evaluation: p(t) = c₀ + c₁·t + c₂·t² + c₃·t³
impl RoundPolyDeg3 {
    pub fn eval(&self, t: SecureField) -> SecureField {
        self.c0 + t * (self.c1 + t * (self.c2 + t * self.c3))
    }
}

Layer Types and Reduction Protocols

MatMul (degree-2 sumcheck)

For C = A × B with dimensions (m × k) × (k × n):

Mathematical Identity: The matmul inner product enables efficient proof via the MLE restriction:

MLE_C(r_i, r_j) = Σ_{x ∈ {0,1}^{log k}} MLE_A(r_i, x) · MLE_B(x, r_j)

This converts a claim about C into an inner product of two restricted MLEs over the inner dimension k.

Protocol steps:

  1. Fiat-Shamir setup: Mix dimensions (m, k, n), draw row challenges r_i ∈ F^{log m} and column challenges r_j ∈ F^{log n}
  2. Compute claimed sum: claimed_sum = MLE_C(r_i, r_j)
  3. Restrict MLEs: Extract f_a(x) = MLE_A(r_i, x) and f_b(x) = MLE_B(x, r_j) for x ∈ {0,1}^{log k}
  4. Sumcheck over log₂(k) rounds, each producing a degree-2 polynomial:
p(t) = Σ_{i=0}^{mid-1} [(1-t)·f_a[i] + t·f_a[mid+i]] · [(1-t)·f_b[i] + t·f_b[mid+i]]

Evaluated at t = 0, 1, 2:

  • s₀ = Σ f_a[i] · f_b[i]
  • s₁ = Σ f_a[mid+i] · f_b[mid+i]
  • s₂ = Σ (2·f_a[mid+i] - f_a[i]) · (2·f_b[mid+i] - f_b[i])
  1. Final check: After all rounds, product = final_a_eval × final_b_eval
  2. Weight opening: Merkle proof that MLE_B(assignment, r_j) = final_b_eval

Oracle implementation:

pub struct MatMulOracle {
    pub f_a: Vec<SecureField>,
    pub f_b: Vec<SecureField>,
}

impl MultivariatePolyOracle for MatMulOracle {
    fn sum_as_poly_in_first_variable(&self, _claim: SecureField) -> UnivariatePoly {
        let half = self.f_a.len() / 2;
        let (mut s0, mut s1, mut s2) = (zero(), zero(), zero());
        for i in 0..half {
            s0 += self.f_a[i] * self.f_b[i];
            s1 += self.f_a[half + i] * self.f_b[half + i];
            s2 += (2 * self.f_a[half + i] - self.f_a[i])
                * (2 * self.f_b[half + i] - self.f_b[i]);
        }
        let c0 = s0;
        let c2 = (s2 - 2 * s1 + s0) / 2;
        let c1 = s1 - s0 - c2;
        UnivariatePoly::from_coeffs([c0, c1, c2])
    }

    fn fix_first_variable(self, challenge: SecureField) -> Self {
        let mid = self.f_a.len() / 2;
        let new_a = (0..mid).map(|i|
            self.f_a[i] + challenge * (self.f_a[mid + i] - self.f_a[i])
        ).collect();
        let new_b = (0..mid).map(|i|
            self.f_b[i] + challenge * (self.f_b[mid + i] - self.f_b[i])
        ).collect();
        MatMulOracle { f_a: new_a, f_b: new_b }
    }
}

Channel protocol: mix(m, k, n), mix(claimed_value), per-round mix_poly_coeffs(c₀, c₁, c₂) + draw(), then mix(final_a), mix(final_b).

Degree analysis:

  • For k = 5120 (Qwen3-14B FFN): padded to 8192, log₂(8192) = 13 sumcheck rounds
  • Per round: 3 QM31 coefficients (c₀, c₁, c₂) = 12 M31 values
  • Traditional STARK trace: O(m × k × n) rows = 2M rows for 128×128×128
  • Sumcheck witness: O(m×n + m×k + k×n) = 49K rows (42× reduction)

Add (degree-1 split)

For C = A + B, the MLE decomposes linearly:

MLE_C(r) = MLE_A(r) + MLE_B(r)

No sumcheck needed — the claim splits into two sub-claims that are recursively verified.

DAG handling (residual connections): When A and B come from different branches:

  • Main GKR walk follows the trunk (higher layer index in DAG)
  • The other branch generates a deferred proof verified after the main walk
  • trunk_idx: u8 marks which input (0=lhs, 1=rhs) is the trunk
Add {
    lhs_eval: SecureField,    // MLE_A(r)
    rhs_eval: SecureField,    // MLE_B(r)
    trunk_idx: u8,            // 0=lhs is trunk, 1=rhs is trunk
}

Mul (degree-3 eq-sumcheck)

For element-wise C = A ⊙ B:

claim = Σ_{x ∈ {0,1}^n} eq(r, x) · MLE_A(x) · MLE_B(x)

Three degree-1 factors → degree-3 univariate per round. Each round evaluates at t = 0, 1, 2, 3:

  • s₀ = Σ eq_lo[i] · a_lo[i] · b_lo[i]
  • s₁ = Σ eq_hi[i] · a_hi[i] · b_hi[i]
  • s₂ = Σ (2·eq_hi - eq_lo) · (2·a_hi - a_lo) · (2·b_hi - b_lo)
  • s₃ = Σ (3·eq_hi - 2·eq_lo) · (3·a_hi - 2·a_lo) · (3·b_hi - 2·b_lo)

Newton divided differences → (c₀, c₁, c₂, c₃).

Final check: running_sum == eq(r, assignment) × a_final × b_final.

Activation (LogUp eq-sumcheck)

For y = f(x) where f is a non-linear activation (ReLU, GELU, Sigmoid):

LogUp Protocol:

  1. Precomputed table: T = {(x, f(x)) : x ∈ domain} (65K entries for ReLU, 262K for GELU)
  2. Encoding: encode(x, y) = x + β · y (collapsing 2D pair to 1D via random β)
  3. Denominators: dᵢ = γ - encode(inputᵢ, outputᵢ) for trace rows
  4. Witness fractions: wᵢ = 1 / dᵢ
  5. Multiplicity check: Σᵢ wᵢ = Σⱼ multⱼ / (γ - encode(tableⱼ))
  6. Eq-sumcheck: Proves w(x) · d(x) = 1 on the boolean hypercube:
Σ_{x ∈ {0,1}^n} eq(r, x) · w(x) · d(x) = 1

This is a degree-3 sumcheck (three factors: eq, w, d). The verifier checks that the witness fractions are correctly formed.

LogUpProof structure:

pub struct LogUpProof {
    pub eq_round_polys: Vec<RoundPolyDeg3>,
    pub final_evals: (SecureField, SecureField, SecureField),  // (w(s), in(s), out(s))
    pub claimed_sum: SecureField,
    pub multiplicities: Vec<u32>,
}

LayerNorm (combined-product sumcheck)

LayerNorm is non-linear (involves mean + rsqrt), which causes cross-terms when batching across SIMD blocks. The solution decomposes into two provable stages:

Stage 1 — Linear eq-sumcheck:

output = (input - mean) × rsqrt

This is verified via degree-3 eq-sumcheck over centered × rsqrt:

Σ_{x} eq(r, x) · centered(x) · rsqrt(x) = output_claim.value

where centered(x) = input(x) - mean.

Stage 2 — rsqrt LogUp lookup: Proves (variance, rsqrt) pairs exist in a precomputed table of reciprocal square roots. Uses the same LogUp protocol as activations.

SIMD batching path: When batching identical blocks, combined_product[i] = Σ_b w_b · (centered_b[i] × rsqrt_b[i]). A constant-1 MLE serves as the second factor, giving rsqrt_final = 1 after all folds. When simd_combined = true, LogUp is skipped (QM31 sum can't decompose back to M31 table entries).

RMSNorm (LogUp rsqrt lookup)

RMSNorm (y = x / sqrt(mean(x²) + ε) · γ) is handled similarly to LayerNorm but without mean subtraction:

claim = Σ_x eq(r, x) · MLE_output(x) · MLE_rsqrt(x)

Degree-2 linear sumcheck (two factors: input, rsqrt) + degree-3 LogUp for rsqrt lookup.

RoPE (LogUp rotation table)

Rotary Positional Embedding applies position-dependent (cos, sin) rotations to Q/K vectors. The rotation factors are deterministic from (seq_len, head_dim, base), so the verifier reconstructs the table. LogUp proves each (cos, sin) pair used in the trace exists in the table.

Dequantize (LogUp 2D table)

For quantized models (INT4/INT8), dequantization maps quantized integer values to their M31 equivalents via a small lookup table (16 entries for INT4, 256 for INT8). LogUp proves each (quantized_input, dequantized_output) pair matches the table. Follows the Activation pattern with finalize_logup_in_pairs().

Attention (composed sub-matmuls)

Attention decomposes into 4 + 2H sub-matmuls (H = num_heads):

IndexSub-matmulTypeOperands
0Output projectionShared-weightconcat × W_O
1..2HPer-head contextDual-operandsoftmax_h × V_h
1..2HPer-head scoreDual-operandQ_h × K_h^T (unscaled)
2H+1V projectionShared-weightinput × W_V
2H+2K projectionShared-weightinput × W_K
2H+3Q projectionShared-weightinput × W_Q

Important: score_matrices[h] includes the 1/√d_k scaling factor. The sumcheck must use the raw Q_h × K_h^T product (unscaled), not the stored score matrix.

Fiat-Shamir Transcript Protocol

PoseidonChannel (Poseidon2-M31)

The Fiat-Shamir channel uses the Poseidon2-M31 permutation (t=16, rate=8) for native M31 field compatibility:

pub struct PoseidonChannel {
    digest: FieldElement,   // Hades permutation state
    n_draws: u32,           // Draw counter (domain separation)
}

Mix operation (absorb): state = [digest, felt(value), 2] → hades(state) → digest = state[0]

Draw operation (squeeze): state = [digest, n_draws, 3] → hades(state) → value = state[0]

QM31 extraction from felt252: A single felt252 is unpacked into 4 M31 components via successive floor_div(2^31):

pub fn draw_qm31(&mut self) -> SecureField {
    let felt = self.draw_felt252();
    let shift = FieldElement::from(1u64 << 31);
    let p = (1u64 << 31) - 1;

    let mut m31s = [M31::from(0u32); 4];
    let mut cur = felt;
    for m31 in m31s.iter_mut() {
        let next = cur.floor_div(shift);
        let res = cur - next * shift;
        cur = next;
        *m31 = M31::from((res.to_u64() % p) as u32);
    }
    QM31(CM31(m31s[0], m31s[1]), CM31(m31s[2], m31s[3]))
}

M31 packing algorithm (for mixing polynomial coefficients):

pack_m31s([a, b, c, d]) = ((((1 × 2^31 + a) × 2^31 + b) × 2^31 + c) × 2^31 + d)

The leading 1 is a sentinel that preserves leading zeros. Two QM31 values (8 M31s) pack into a single felt252. Unpacking reverses via floor_div(2^31).

Transcript Ordering

The matmul sumcheck transcript follows this exact sequence (any divergence between prover and verifier causes failure):

1. mix_u64(m), mix_u64(k), mix_u64(n)       — dimensions
2. draw_qm31() × log_m                       — row challenges r_i
3. draw_qm31() × log_n                       — col challenges r_j
4. mix_felt(claimed_sum)                      — claimed MLE value
5. mix_felt(a_commitment), mix_felt(b_commitment)  — Merkle roots
6. For each round:
     mix_poly_coeffs(c₀, c₁, c₂)            — 12 M31 → 2 felt252
     draw_qm31()                              — sumcheck challenge
7. prove_mle_opening(f_a, assignment)         — MLE opening proof
8. prove_mle_opening(f_b, assignment)         — MLE opening proof

Both prover and Cairo verifier must produce identical channel digests at every step. This is the single most common source of bugs.

Proof Types

/// A claim: "MLE evaluated at `point` equals `value`"
pub struct GKRClaim {
    pub point: Vec<SecureField>,
    pub value: SecureField,
}

/// Per-layer proof variant
pub enum LayerProof {
    MatMul { round_polys, final_a_eval, final_b_eval },
    MatMulDualSimd { round_polys, final_a_eval, final_b_eval, n_block_vars },
    Add { lhs_eval, rhs_eval },
    Mul { round_polys: Vec<RoundPolyDeg3>, final_a_eval, final_b_eval },
    Activation { activation_type, round_polys, final_input_eval, final_output_eval, logup_proof },
    LayerNorm { round_polys, final_input_eval, final_output_eval, logup_proof },
    RMSNorm { round_polys, final_input_eval, final_output_eval, logup_proof },
    RoPE { logup_proof, input_eval, output_eval },
    Dequantize { logup_proof, input_eval, output_eval, table_commitment },
    Attention { sub_proofs: Vec<LayerProof>, sub_claim_values },
}

On-Chain Proof Structure

pub struct MatMulSumcheckProofOnChain {
    pub m: u32,
    pub k: u32,
    pub n: u32,
    pub num_rounds: u32,             // = log₂(k)
    pub claimed_sum: SecureField,    // MLE_C(r_i, r_j)
    pub round_polys: Vec<RoundPoly>, // log_k round polynomials
    pub final_a_eval: SecureField,   // MLE_A(r_i, assignment)
    pub final_b_eval: SecureField,   // MLE_B(assignment, r_j)
    pub a_commitment: FieldElement,  // Poseidon Merkle root
    pub b_commitment: FieldElement,  // Poseidon Merkle root
    pub a_opening: MleOpeningProof,
    pub b_opening: MleOpeningProof,
}

Calldata cost: ~50 + 12 × num_rounds felt252 values per matmul proof.

Circuit Compilation

LayeredCircuit::from_graph(graph) converts a ComputationGraph (topologically sorted DAG of ML operations) into a layered circuit where each layer has:

  • layer_type: Which reduction protocol to use
  • input_shape / output_shape: Matrix dimensions (for MLE sizing)
  • node_id: Back-reference to the original graph node
  • input_layers: Predecessor layer indices

Deferred Proofs for DAG Residuals

When an Add layer has inputs from different branches (e.g., residual skip connections in transformers), the main GKR walk follows the trunk branch. The skip branch generates a deferred proof verified after the main walk completes.

pub struct DeferredProof {
    pub claim: GKRClaim,           // Claim at the Add layer
    pub dims: (usize, usize, usize),
    pub layer_proof: LayerProof,   // MatMul sumcheck for skip branch
    pub input_claim: GKRClaim,
    pub weight_commitment: FieldElement,
    pub weight_opening: MleOpeningProof,
    pub weight_claim: WeightClaim,
}

Fiat-Shamir ordering: main walk → weight openings → deferred proofs. The deferred claim's evaluation point is the same point from the Add layer during the main walk.

On-chain: The Cairo verifier saves Add layer claim points during the walk, then reads and verifies deferred matmul sumcheck proofs in sequence after the main loop.

Entry Points

Proving

FunctionDescription
prove_gkr(circuit, execution, weights, channel)CPU-only GKR proof
prove_gkr_gpu(circuit, execution, weights, channel)GPU-accelerated (single block)
prove_gkr_simd_gpu(circuit, block_executions, weights, channel)GPU + SIMD batching across blocks

Verification

FunctionDescription
verify_gkr(circuit, proof, output, channel)Standard verification
verify_gkr_with_execution(circuit, proof, execution, channel)Verification with intermediate checks
verify_gkr_simd(circuit, proof, combined_output, channel)SIMD-aware verification

On-Chain (Cairo)

FunctionDescription
verify_model_gkr(proof_data, num_layers, dims, initial_claim, channel)Full on-chain GKR walk
verify_matmul_sumcheck(proof, channel)Single matmul verification
verify_batched_matmul(batch_proof, channel)Lambda-weighted batch verification

Input/Output Claim Verification

The GKR walk proves internal consistency but must be anchored to real data at both endpoints. See Input Claim Verification for the full design.

Output side: The verifier draws r_out from the Fiat-Shamir channel, evaluates MLE(raw_output, r_out) on-chain, and uses the result as the initial GKR claim.

Input side: After the GKR walk completes with final_claim, the verifier evaluates MLE(raw_input, final_claim.point) on-chain and asserts equality:

assert!(MLE(input, final_claim.point) == final_claim.value, "INPUT_CLAIM_MISMATCH");

Both MLEs are constructed from raw data in calldata with power-of-2 padding and row-major layout, matching the Rust-side pad_matrix_pow2 + matrix_to_mle.

Integration with Aggregation Pipeline

GKR is optional and additive — the standard STARK pipeline runs first, then GKR produces an additional proof:

// Standard pipeline
let stark_proof = prove_model_aggregated_onchain(graph, input, weights);

// With GKR (additional verification layer)
let (stark_proof, gkr_proof) = prove_model_aggregated_onchain_gkr(graph, input, weights);

The verifier checks both proofs independently. GKR provides a second, complementary verification path.

Performance

MetricValueNotes
Qwen3-14B per-layer proving3.04sH100 GPU, 160 matmuls
Qwen3-14B 40-layer total122sEnd-to-end with streaming
Single matmul sumcheck~48msGPU-accelerated (k=5120)
On-chain verification18 TXsStreaming GKR on Starknet
Sumcheck rounds per matmul13log₂(8192) for k=5120
Proof size per matmul~5 KBRound polys + openings
Total calldata86,723 feltsVia aggregated weight binding (28× reduction)

References

  1. Goldwasser, Kalai, Rothblum (2015) — "Delegating Computation: Interactive Proofs for Muggles"
  2. Lund, Fortnow, Karloff, Nisan (1992) — "Algebraic Methods for Interactive Proof Systems"
  3. Thaler (2013) — "Time-Optimal Interactive Proofs for Circuit Evaluation"
  4. Grassi et al. (2021) — "Poseidon: A New Hash Function for Zero-Knowledge Proofs"