Transformer Architecture — Full Llama-Style Proving Pipeline

Overview

stwo-ml supports proving full transformer decoder blocks as used in Llama, Qwen, Mistral, and similar architectures. Each block follows the pre-norm residual pattern:

                    ┌───────────────┐
          Input ────┤   Identity    ├──────────────────────┐
                    └───────┬───────┘                      │
                            ▼                              │ (residual)
                    ┌───────────────┐                      │
                    │   RMSNorm     │                      │
                    └───────┬───────┘                      │
                            ▼                              │
                    ┌───────────────┐                      │
                    │  Attention    │  (GQA/MQA/MHA)       │
                    │  Q×K^T→soft→V│                       │
                    └───────┬───────┘                      │
                            ▼                              │
                    ┌───────────────┐                      │
                    │     Add       │◄─────────────────────┘
                    └───────┬───────┘
                            │──────────────────────────────┐
                            ▼                              │ (residual)
                    ┌───────────────┐                      │
                    │   RMSNorm     │                      │
                    └───────┬───────┘                      │
                            ▼                              │
                    ┌───────────────┐                      │
                    │  FFN: Linear  │  (d_model → ffn_dim) │
                    └───────┬───────┘                      │
                            ▼                              │
                    ┌───────────────┐                      │
                    │     GELU      │                      │
                    └───────┬───────┘                      │
                            ▼                              │
                    ┌───────────────┐                      │
                    │  FFN: Linear  │  (ffn_dim → d_model) │
                    └───────┬───────┘                      │
                            ▼                              │
                    ┌───────────────┐                      │
                    │     Add       │◄─────────────────────┘
                    └───────┬───────┘
                         Output

Builder API

The GraphBuilder::transformer_block() method constructs this entire pattern in one call:

use stwo_ml::compiler::graph::GraphBuilder;

let mut builder = GraphBuilder::new((seq_len, d_model));

// Stack N transformer blocks
for _ in 0..num_layers {
    builder.transformer_block(
        num_heads,      // Q heads (e.g., 32)
        num_kv_heads,   // KV heads (e.g., 8 for GQA)
        seq_len,        // sequence length
        ffn_dim,        // feed-forward intermediate dim (e.g., 4 × d_model)
    );
}

let graph = builder.build();

Each block produces 9 graph nodes: Identity, RMSNorm, Attention, Add, RMSNorm, Linear, GELU, Linear, Add.

Components

RMSNorm (components/rmsnorm.rs)

Root Mean Square Layer Normalization — used in Llama/Qwen instead of LayerNorm.

Formula: y = x / sqrt(mean(x^2) + epsilon) * gamma

Key difference from LayerNorm: No mean subtraction. Cheaper to compute and prove.

Proving approach: Decomposed into three provable operations:

  1. Compute rms^2 = sum(x^2) / n via M31 arithmetic
  2. Reciprocal sqrt rsqrt(rms^2) via LogUp lookup table (precomputed table of 2^16 entries)
  3. Scale: output = input * rsqrt_val

Trace layout (5 columns):

ColumnNameDescription
0inputOriginal value x
1rms_sqmean(x^2), shared per row
2rsqrt_val1/sqrt(rms_sq), from lookup
3outputx * rsqrt_val
4multiplicityLogUp multiplicity

Preprocessed columns (2 columns, table side):

ColumnNameContent
0rms_sq table inputLookup input key
1rsqrt table outputLookup output value

Constraint evaluation — the RMSNormEval implements FrameworkEval:

impl FrameworkEval for RMSNormEval {
    fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
        let table_rms = eval.get_preprocessed_column(/* rmsnorm_rms_input */);
        let table_rsqrt = eval.get_preprocessed_column(/* rmsnorm_rsqrt_output */);

        let input = eval.next_trace_mask();
        let rms_sq = eval.next_trace_mask();
        let rsqrt_val = eval.next_trace_mask();
        let output = eval.next_trace_mask();
        let multiplicity = eval.next_trace_mask();

        // Degree-2 constraint: output = input × rsqrt_val
        eval.add_constraint(output - input * rsqrt_val.clone());

        // LogUp table side: yield -multiplicity for each table entry
        eval.add_to_relation(RelationEntry::new(
            &self.lookup_elements,
            -E::EF::from(multiplicity),
            &[table_rms, table_rsqrt],
        ));

        // LogUp trace side: use +1 for each (rms_sq, rsqrt_val) pair
        eval.add_to_relation(RelationEntry::new(
            &self.lookup_elements,
            E::EF::from(E::F::from(BaseField::from(1))),
            &[rms_sq, rsqrt_val],
        ));

        eval.finalize_logup_in_pairs();
        eval
    }
}

The relation is declared as relation!(RMSNormRelation, 2) — a 2-column LogUp relation. The finalize_logup_in_pairs() call (not finalize_logup()) is required when 2 add_to_relation calls exist per evaluation row.

Reciprocal sqrt table: Reuses LayerNorm's rsqrt table. Maps rms_sq → 1/sqrt(rms_sq) with fixed-point scale 2^16. Table size: 2^log_size entries.

RoPE — Rotary Positional Embedding (components/rope.rs)

Position-dependent rotations applied to Q and K vectors before attention scoring.

Formula:

x' = x * cos(theta * m) - y * sin(theta * m)
y' = x * sin(theta * m) + y * cos(theta * m)

where theta_j = base^(-2j/d), m = position index.

Proving approach:

  1. Precompute rotation table: All (position, dim_pair) -> (cos_val, sin_val) in M31
  2. Element-wise rotation: Apply rotation using table values
  3. LogUp proof: Every (cos, sin) pair comes from the precomputed table

The rotation factors are deterministic from (seq_len, head_dim, base), so the verifier reconstructs the table independently.

Fixed-point M31 encoding: Trigonometric values in [-1, 1] are encoded into M31 via:

FP_SCALE = (P - 1) / 2 = 1073741823

encode(v) = round((v + 1.0) × FP_SCALE)    // float → M31
decode(m) = m / FP_SCALE - 1.0              // M31 → float

This maps the full [-1, 1] range into [0, P-1] with ~31 bits of precision.

Trace layout (7 columns):

ColumnNameDescription
0input_xFirst element of dimension pair
1input_ySecond element of pair
2cos_valRotation cosine from table
3sin_valRotation sine from table
4output_xx·cos - y·sin
5output_yx·sin + y·cos
6multiplicityLogUp multiplicity

Constraint evaluation — two degree-2 AIR constraints plus LogUp:

// Constraint 1: output_x = input_x × cos - input_y × sin
eval.add_constraint(
    output_x - (input_x.clone() * cos_val.clone() - input_y.clone() * sin_val.clone())
);

// Constraint 2: output_y = input_x × sin + input_y × cos
eval.add_constraint(
    output_y - (input_x * sin_val.clone() + input_y * cos_val.clone())
);

The LogUp relation (RoPERelation, 2) verifies every (cos_val, sin_val) pair exists in the precomputed table.

Table size: max_seq_len × (head_dim / 2) entries. For Qwen3-14B with max_seq_len=8192, head_dim=128: 524,288 table entries.

Configuration:

pub struct RoPEConfig {
    pub seq_len: usize,      // number of positions
    pub head_dim: usize,     // per-head dimension (must be even)
    pub base: f64,           // frequency base (default: 10000)
    pub max_seq_len: usize,  // max positions for table precomputation
}

Grouped Query Attention — GQA/MQA (components/attention.rs)

The attention component supports three modes:

ModeKV HeadsDescription
MHAnum_kv_heads == num_headsStandard multi-head attention
GQA1 {'<'} num_kv_heads {'<'} num_headsGroups of Q heads share K/V
MQAnum_kv_heads == 1All Q heads share one K/V head

GQA is used by Llama 3, Qwen 2/3, Mistral, and most modern LLMs because it reduces KV-cache memory by num_heads / num_kv_heads with minimal quality loss.

How it works: K and V are projected to num_kv_heads heads instead of num_heads. Each Q head h uses KV head h / group_size:

let group_size = num_heads / num_kv_heads;
for h in 0..num_heads {
    let kv_idx = h / group_size;
    // Q_h uses K[kv_idx] and V[kv_idx]
    let scores = matmul(&q_heads[h], &transpose(&kv_heads_k[kv_idx]));
    let context = matmul(&softmax(&scores), &kv_heads_v[kv_idx]);
}

Proof decomposition — attention is decomposed into 5 sequential stages:

Stage 1: Projections        → 3 matmul sumcheck proofs (Q, K, V)
Stage 2: Per-head scores    → H matmul sumcheck proofs (Q_h × K_h^T)
Stage 3: Softmax            → 1 aggregated LogUp STARK proof (all heads)
Stage 4: Per-head context   → H matmul sumcheck proofs (softmax_h × V_h)
Stage 5: Output projection  → 1 matmul sumcheck proof (concat × W_O)

Total proof count: 4 + 2H + 1 where H = num_heads. For 32-head GQA with 8 KV heads: 69 proofs.

Proof structure:

pub struct AttentionProof<H: MerkleHasherLifted> {
    pub q_proof: MatMulSumcheckProof,
    pub k_proof: MatMulSumcheckProof,
    pub v_proof: MatMulSumcheckProof,
    pub score_proofs: Vec<MatMulSumcheckProof>,      // num_heads entries
    pub attn_v_proofs: Vec<MatMulSumcheckProof>,     // num_heads entries
    pub output_proof: MatMulSumcheckProof,
    pub softmax_exp_proof: StarkProof<H>,            // aggregated all-heads
    pub softmax_claimed_sum: SecureField,
    pub softmax_log_size: u32,
    pub intermediates: AttentionIntermediates,
}

Intermediates stored (needed for verification replay):

pub struct AttentionIntermediates {
    pub q: M31Matrix,                // (seq_len, d_model)
    pub k: M31Matrix,               // (seq_len, num_kv_heads × d_k)
    pub v: M31Matrix,               // (seq_len, num_kv_heads × d_k)
    pub score_matrices: Vec<M31Matrix>,    // H × (seq_len, seq_len)
    pub softmax_outputs: Vec<M31Matrix>,   // H × (seq_len, seq_len)
    pub head_outputs: Vec<M31Matrix>,      // H × (seq_len, d_k)
    pub concat: M31Matrix,                 // (seq_len, d_model)
    pub final_output: M31Matrix,           // (seq_len, d_model)
}

For GQA/MQA, K/V projections have shape (d_model, num_kv_heads × d_k) instead of full (d_model, d_model). The group index kv_idx = h / group_size is used to share K/V heads across Q head groups.

Cost model (witness rows per block):

pub fn sumcheck_trace_rows(&self) -> usize {
    let d_k = self.d_k();
    let s = self.seq_len;
    // Q×K^T decomposition per head
    let qkt_witness = s * d_k + s * s + d_k * s;
    let softmax_lookups = s * s;
    // softmax×V decomposition per head
    let attn_v_witness = s * s + s * d_k + s * d_k;
    (qkt_witness + softmax_lookups + attn_v_witness) * self.num_heads
}

Constructors:

// Standard MHA
MultiHeadAttentionConfig::new(32, 4096, 2048)

// GQA: 32 Q heads, 8 KV heads
MultiHeadAttentionConfig::new_gqa(32, 8, 4096, 2048, true)

// MQA: all Q heads share 1 KV head
MultiHeadAttentionConfig::new_mqa(32, 4096, 2048, true)

KV-Cache — Incremental Decoding

For autoregressive generation, the KV-Cache stores previously computed K/V projections so each new token only requires O(1) new computation instead of reprocessing the full sequence.

pub struct KVCache {
    pub k_cache: Vec<M31Matrix>,  // per KV-head: (cached_len, d_k)
    pub v_cache: Vec<M31Matrix>,  // per KV-head: (cached_len, d_k)
    pub cached_len: usize,
    pub num_kv_heads: usize,
    pub d_k: usize,
}

Usage:

let config = MultiHeadAttentionConfig::new_gqa(32, 8, 4096, 1, true);
let mut cache = KVCache::new(&config);

// Step 1: process first token
let out1 = attention_forward_cached(&input_tok1, &weights, &config, &mut cache);
// cache.cached_len == 1

// Step 2: process next token (uses cached K/V from step 1)
let out2 = attention_forward_cached(&input_tok2, &weights, &config, &mut cache);
// cache.cached_len == 2

Causal masking in M31 arithmetic: The mask sentinel value is P - 2 = 2^31 - 3:

For position (i, j) where j > i (future token):
  score[i][j] = P - 2

softmax_exp(P - 2) ≈ 0    // kills the attention weight

The sentinel P-2 is chosen because softmax_exp(P-2) returns approximately zero in M31 arithmetic, effectively zeroing future positions without needing actual -infinity. For KV-cached decoding, the cache offset is tracked so j > i + cache_offset triggers the mask — new queries attend to all cached positions plus the current token.

Softmax in M31 arithmetic: The softmax function operates entirely in the M31 field:

pub fn softmax_row_m31(row: &[M31]) -> Vec<M31> {
    // 1. Element-wise: exp_vals = softmax_exp(row[i]) for each i
    // 2. Sum: sum_exp = Σ exp_vals[i]
    // 3. Normalize: result[i] = exp_vals[i] × inv(sum_exp)
}

The SoftmaxNormEval component enforces a degree-2 constraint: weight × sum_exp - exp_val = 0, proving that each softmax output equals exp_val / sum_exp. The claimed sum is bound to the proof transcript via the Fiat-Shamir channel.

Multi-layer cache: ModelKVCache wraps per-layer caches:

let mut model_cache = ModelKVCache::new(num_layers, &attn_config);
for layer in 0..num_layers {
    output = attention_forward_cached(&output, &weights[layer], &config, &mut model_cache.layers[layer]);
}

Dequantization (components/dequantize.rs)

Weight dequantization converts INT4/INT8 quantized weights back to M31 field elements for proving. Uses a 2D LogUp table mapping (quantized_value, dequantized_value).

Trace layout (3 columns):

ColumnNameDescription
0trace_inputQuantized value
1trace_outputDequantized value (M31)
2multiplicityLogUp multiplicity

Preprocessed table (2 columns): Maps each quantized value q ∈ [0, 2^bits) to its dequantized M31 representation via (q - zero_point) × scale.

Supported quantization strategies:

StrategyBitsZero PointUse Case
Symmetric447INT4 symmetric
Symmetric88127INT8 symmetric
Asymmetric44customINT4 with custom offset
Direct0Direct M31 encoding

Table sizes: 16 entries for INT4, 256 entries for INT8.

Proof Structure

A single transformer block generates:

ComponentCountProtocol
RMSNorm2LogUp STARK (rsqrt table)
Attention14+2H composed sumcheck + LogUp (softmax)
FFN MatMul2Sumcheck over MLE
GELU1LogUp STARK (activation table)
Add (residual)2Linear split (no proof needed in GKR)

For a 32-head GQA model with 8 KV heads, one block produces 69 matmul sumcheck proofs + 2 LogUp STARKs + 1 softmax STARK.

Integration with GKR

When using the GKR protocol, the transformer block is compiled into a LayeredCircuit where each graph node becomes one or more layers:

GKR Output claim
    → RMSNorm (eq-sumcheck + LogUp)
    → Attention (composed sub-matmuls)
    → Add (linear split)
    → RMSNorm (eq-sumcheck + LogUp)
    → Linear (matmul sumcheck)
    → GELU (LogUp eq-sumcheck)
    → Linear (matmul sumcheck)
    → Add (linear split)
    → Input claim

The GKR proof for the entire block is a single interactive proof, replacing what would otherwise be ~75 independent STARK proofs.

Circuit Compilation

The GraphBuilder compiles transformer blocks into a LayeredCircuit with typed layers:

pub struct CircuitLayer {
    pub layer_type: LayerType,       // MatMul, Add, Mul, Activation, LayerNorm, ...
    pub input_shape: (usize, usize),
    pub output_shape: (usize, usize),
    pub node_id: usize,              // back-reference to ComputationGraph node
    pub input_layers: Vec<usize>,    // layer indices this layer reads from
}

Each LayerType maps to a specific GKR reduction strategy:

LayerTypeGKR ReductionProof Variant
MatMulSumcheck (degree-2)LayerProof::MatMul
AddLinear split (trunk/skip)LayerProof::Add
MulEq-sumcheck (degree-3)LayerProof::Mul
ActivationLogUp eq-sumcheckLayerProof::Activation
LayerNormLogUp + linear sumcheckLayerProof::LayerNorm
RMSNormLogUp + linear sumcheckLayerProof::RMSNorm
AttentionComposed sub-matmulsLayerProof::Attention

Residual Connections as DAG Add

The Add layers (residual connections) split the GKR claim into two sub-claims — a trunk (main computation path) and a skip (identity residual). The trunk continues the GKR walk. The skip branch produces a deferred proof that is verified separately:

pub struct DeferredProof {
    pub claim: GKRClaim,              // evaluation point + value
    pub dims: (usize, usize, usize), // (m, k, n) for matmul
    pub layer_proof: LayerProof,      // standalone proof for skip branch
    pub input_claim: GKRClaim,        // resulting input claim
    pub weight_commitment: FieldElement,
    pub weight_opening: MleOpeningProof,
}

For the transformer's pre-norm residual pattern, the skip connection is always an identity (no weights), so the deferred proof is trivial. For more complex DAG topologies (e.g., U-Net), deferred proofs carry full matmul sumcheck data.

Component Cost Comparison

Per-block proving cost for a Qwen3-14B-like config (d_model=5120, seq=2048, heads=40, kv_heads=8, ffn=13824):

ComponentMatmul ProofsLogUp STARKsSumcheck Rounds
RMSNorm ×202~26 (eq-sumcheck)
Attention (GQA-40/8)441 (softmax)~560
FFN Linear ×220~26
GELU01~22 (eq-sumcheck)
Add ×2000 (linear split)
Total per block464~634

With GKR, the entire block becomes a single interactive proof with ~634 sumcheck rounds instead of 50+ independent proofs.

Example: Full Pipeline

use stwo_ml::compiler::graph::GraphBuilder;
use stwo_ml::aggregation::{prove_model_aggregated, verify_aggregated_model_proof};

// Build a 2-layer transformer
let mut builder = GraphBuilder::new((2, 64)); // seq_len=2, d_model=64
builder
    .transformer_block(4, 2, 2, 128)  // 4 heads, 2 KV heads, seq=2, ffn=128
    .transformer_block(4, 2, 2, 128);

let graph = builder.build();

// Create weights for all MatMul and Attention nodes
let weights = create_weights_for_graph(&graph);

// Prove
let proof = prove_model_aggregated(&graph, &input, &weights)
    .expect("proving failed");

// Verify
verify_aggregated_model_proof(proof, &graph, &input, &weights)
    .expect("verification failed");

Proving Functions

Two proving backends exist for each component:

// Off-chain (Blake2s channel, for testing/development)
pub fn prove_attention_with<B, MC>(
    input: &M31Matrix,
    weights: &AttentionWeights,
    config: &MultiHeadAttentionConfig,
    causal: bool,
) -> Result<AttentionProof<<MC as MerkleChannel>::H>, AttentionError>

// On-chain (Poseidon channel for matmuls, Blake2s for softmax STARK)
pub fn prove_attention_onchain(
    input: &M31Matrix,
    weights: &AttentionWeights,
    config: &MultiHeadAttentionConfig,
    causal: bool,
) -> Result<AttentionProofOnChain, AttentionError>

The on-chain path uses PoseidonChannel for all matmul sumchecks (Cairo-native) and Blake2sChannel for the softmax STARK (efficient off-chain commitment). The AttentionProofOnChain struct mirrors AttentionProof with Poseidon-compatible types.

M31 Arithmetic Subtleties

All transformer components operate in M31 = Z/(2^31 - 1):

ComponentEncodingRange
ActivationsDirect M31[0, P-1]
RoPE cos/sinSigned fixed-point[-1, 1] → [0, P-1] via (v+1) × scale
RMSNorm rsqrtFixed-point 2^16Precomputed table
Causal maskSentinel valueP - 2 = 2^31 - 3
Softmax expLookup tablePrecomputed exp(x) in M31

Overflow safety: All M31 multiplication is modular — a × b mod P never overflows a u64 because (P-1)^2 < 2^62. This eliminates carry-handling in both prover and verifier.

References

  • Goldwasser, Kalai, Rothblum. "Delegating Computation: Interactive Proofs for Muggles." STOC 2008.
  • Su, Zhang, Deng, Thaler. "Efficient and Provably Secure Machine Learning via Interactive Proofs." 2023.
  • Vaswani et al. "Attention Is All You Need." NeurIPS 2017.
  • Shazeer. "Fast Transformer Decoding: One Write-Head is All You Need." 2019. (MQA)
  • Ainslie et al. "GQA: Training Generalized Multi-Query Transformer Models." 2023.
  • Su et al. "RoFormer: Enhanced Transformer with Rotary Position Embedding." 2021.