Transformer Architecture — Full Llama-Style Proving Pipeline
Overview
stwo-ml supports proving full transformer decoder blocks as used in Llama, Qwen, Mistral, and similar architectures. Each block follows the pre-norm residual pattern:
┌───────────────┐
Input ────┤ Identity ├──────────────────────┐
└───────┬───────┘ │
▼ │ (residual)
┌───────────────┐ │
│ RMSNorm │ │
└───────┬───────┘ │
▼ │
┌───────────────┐ │
│ Attention │ (GQA/MQA/MHA) │
│ Q×K^T→soft→V│ │
└───────┬───────┘ │
▼ │
┌───────────────┐ │
│ Add │◄─────────────────────┘
└───────┬───────┘
│──────────────────────────────┐
▼ │ (residual)
┌───────────────┐ │
│ RMSNorm │ │
└───────┬───────┘ │
▼ │
┌───────────────┐ │
│ FFN: Linear │ (d_model → ffn_dim) │
└───────┬───────┘ │
▼ │
┌───────────────┐ │
│ GELU │ │
└───────┬───────┘ │
▼ │
┌───────────────┐ │
│ FFN: Linear │ (ffn_dim → d_model) │
└───────┬───────┘ │
▼ │
┌───────────────┐ │
│ Add │◄─────────────────────┘
└───────┬───────┘
Output
Builder API
The GraphBuilder::transformer_block() method constructs this entire pattern in one call:
use stwo_ml::compiler::graph::GraphBuilder;
let mut builder = GraphBuilder::new((seq_len, d_model));
// Stack N transformer blocks
for _ in 0..num_layers {
builder.transformer_block(
num_heads, // Q heads (e.g., 32)
num_kv_heads, // KV heads (e.g., 8 for GQA)
seq_len, // sequence length
ffn_dim, // feed-forward intermediate dim (e.g., 4 × d_model)
);
}
let graph = builder.build();
Each block produces 9 graph nodes: Identity, RMSNorm, Attention, Add, RMSNorm, Linear, GELU, Linear, Add.
Components
RMSNorm (components/rmsnorm.rs)
Root Mean Square Layer Normalization — used in Llama/Qwen instead of LayerNorm.
Formula: y = x / sqrt(mean(x^2) + epsilon) * gamma
Key difference from LayerNorm: No mean subtraction. Cheaper to compute and prove.
Proving approach: Decomposed into three provable operations:
- Compute
rms^2 = sum(x^2) / nvia M31 arithmetic - Reciprocal sqrt
rsqrt(rms^2)via LogUp lookup table (precomputed table of 2^16 entries) - Scale:
output = input * rsqrt_val
Trace layout (5 columns):
| Column | Name | Description |
|---|---|---|
| 0 | input | Original value x |
| 1 | rms_sq | mean(x^2), shared per row |
| 2 | rsqrt_val | 1/sqrt(rms_sq), from lookup |
| 3 | output | x * rsqrt_val |
| 4 | multiplicity | LogUp multiplicity |
Preprocessed columns (2 columns, table side):
| Column | Name | Content |
|---|---|---|
| 0 | rms_sq table input | Lookup input key |
| 1 | rsqrt table output | Lookup output value |
Constraint evaluation — the RMSNormEval implements FrameworkEval:
impl FrameworkEval for RMSNormEval {
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let table_rms = eval.get_preprocessed_column(/* rmsnorm_rms_input */);
let table_rsqrt = eval.get_preprocessed_column(/* rmsnorm_rsqrt_output */);
let input = eval.next_trace_mask();
let rms_sq = eval.next_trace_mask();
let rsqrt_val = eval.next_trace_mask();
let output = eval.next_trace_mask();
let multiplicity = eval.next_trace_mask();
// Degree-2 constraint: output = input × rsqrt_val
eval.add_constraint(output - input * rsqrt_val.clone());
// LogUp table side: yield -multiplicity for each table entry
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
-E::EF::from(multiplicity),
&[table_rms, table_rsqrt],
));
// LogUp trace side: use +1 for each (rms_sq, rsqrt_val) pair
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
E::EF::from(E::F::from(BaseField::from(1))),
&[rms_sq, rsqrt_val],
));
eval.finalize_logup_in_pairs();
eval
}
}
The relation is declared as relation!(RMSNormRelation, 2) — a 2-column LogUp relation. The finalize_logup_in_pairs() call (not finalize_logup()) is required when 2 add_to_relation calls exist per evaluation row.
Reciprocal sqrt table: Reuses LayerNorm's rsqrt table. Maps rms_sq → 1/sqrt(rms_sq) with fixed-point scale 2^16. Table size: 2^log_size entries.
RoPE — Rotary Positional Embedding (components/rope.rs)
Position-dependent rotations applied to Q and K vectors before attention scoring.
Formula:
x' = x * cos(theta * m) - y * sin(theta * m)
y' = x * sin(theta * m) + y * cos(theta * m)
where theta_j = base^(-2j/d), m = position index.
Proving approach:
- Precompute rotation table: All (position, dim_pair) -> (cos_val, sin_val) in M31
- Element-wise rotation: Apply rotation using table values
- LogUp proof: Every (cos, sin) pair comes from the precomputed table
The rotation factors are deterministic from (seq_len, head_dim, base), so the verifier reconstructs the table independently.
Fixed-point M31 encoding: Trigonometric values in [-1, 1] are encoded into M31 via:
FP_SCALE = (P - 1) / 2 = 1073741823
encode(v) = round((v + 1.0) × FP_SCALE) // float → M31
decode(m) = m / FP_SCALE - 1.0 // M31 → float
This maps the full [-1, 1] range into [0, P-1] with ~31 bits of precision.
Trace layout (7 columns):
| Column | Name | Description |
|---|---|---|
| 0 | input_x | First element of dimension pair |
| 1 | input_y | Second element of pair |
| 2 | cos_val | Rotation cosine from table |
| 3 | sin_val | Rotation sine from table |
| 4 | output_x | x·cos - y·sin |
| 5 | output_y | x·sin + y·cos |
| 6 | multiplicity | LogUp multiplicity |
Constraint evaluation — two degree-2 AIR constraints plus LogUp:
// Constraint 1: output_x = input_x × cos - input_y × sin
eval.add_constraint(
output_x - (input_x.clone() * cos_val.clone() - input_y.clone() * sin_val.clone())
);
// Constraint 2: output_y = input_x × sin + input_y × cos
eval.add_constraint(
output_y - (input_x * sin_val.clone() + input_y * cos_val.clone())
);
The LogUp relation (RoPERelation, 2) verifies every (cos_val, sin_val) pair exists in the precomputed table.
Table size: max_seq_len × (head_dim / 2) entries. For Qwen3-14B with max_seq_len=8192, head_dim=128: 524,288 table entries.
Configuration:
pub struct RoPEConfig {
pub seq_len: usize, // number of positions
pub head_dim: usize, // per-head dimension (must be even)
pub base: f64, // frequency base (default: 10000)
pub max_seq_len: usize, // max positions for table precomputation
}
Grouped Query Attention — GQA/MQA (components/attention.rs)
The attention component supports three modes:
| Mode | KV Heads | Description |
|---|---|---|
| MHA | num_kv_heads == num_heads | Standard multi-head attention |
| GQA | 1 {'<'} num_kv_heads {'<'} num_heads | Groups of Q heads share K/V |
| MQA | num_kv_heads == 1 | All Q heads share one K/V head |
GQA is used by Llama 3, Qwen 2/3, Mistral, and most modern LLMs because it reduces KV-cache memory by num_heads / num_kv_heads with minimal quality loss.
How it works: K and V are projected to num_kv_heads heads instead of num_heads. Each Q head h uses KV head h / group_size:
let group_size = num_heads / num_kv_heads;
for h in 0..num_heads {
let kv_idx = h / group_size;
// Q_h uses K[kv_idx] and V[kv_idx]
let scores = matmul(&q_heads[h], &transpose(&kv_heads_k[kv_idx]));
let context = matmul(&softmax(&scores), &kv_heads_v[kv_idx]);
}
Proof decomposition — attention is decomposed into 5 sequential stages:
Stage 1: Projections → 3 matmul sumcheck proofs (Q, K, V)
Stage 2: Per-head scores → H matmul sumcheck proofs (Q_h × K_h^T)
Stage 3: Softmax → 1 aggregated LogUp STARK proof (all heads)
Stage 4: Per-head context → H matmul sumcheck proofs (softmax_h × V_h)
Stage 5: Output projection → 1 matmul sumcheck proof (concat × W_O)
Total proof count: 4 + 2H + 1 where H = num_heads. For 32-head GQA with 8 KV heads: 69 proofs.
Proof structure:
pub struct AttentionProof<H: MerkleHasherLifted> {
pub q_proof: MatMulSumcheckProof,
pub k_proof: MatMulSumcheckProof,
pub v_proof: MatMulSumcheckProof,
pub score_proofs: Vec<MatMulSumcheckProof>, // num_heads entries
pub attn_v_proofs: Vec<MatMulSumcheckProof>, // num_heads entries
pub output_proof: MatMulSumcheckProof,
pub softmax_exp_proof: StarkProof<H>, // aggregated all-heads
pub softmax_claimed_sum: SecureField,
pub softmax_log_size: u32,
pub intermediates: AttentionIntermediates,
}
Intermediates stored (needed for verification replay):
pub struct AttentionIntermediates {
pub q: M31Matrix, // (seq_len, d_model)
pub k: M31Matrix, // (seq_len, num_kv_heads × d_k)
pub v: M31Matrix, // (seq_len, num_kv_heads × d_k)
pub score_matrices: Vec<M31Matrix>, // H × (seq_len, seq_len)
pub softmax_outputs: Vec<M31Matrix>, // H × (seq_len, seq_len)
pub head_outputs: Vec<M31Matrix>, // H × (seq_len, d_k)
pub concat: M31Matrix, // (seq_len, d_model)
pub final_output: M31Matrix, // (seq_len, d_model)
}
For GQA/MQA, K/V projections have shape (d_model, num_kv_heads × d_k) instead of full (d_model, d_model). The group index kv_idx = h / group_size is used to share K/V heads across Q head groups.
Cost model (witness rows per block):
pub fn sumcheck_trace_rows(&self) -> usize {
let d_k = self.d_k();
let s = self.seq_len;
// Q×K^T decomposition per head
let qkt_witness = s * d_k + s * s + d_k * s;
let softmax_lookups = s * s;
// softmax×V decomposition per head
let attn_v_witness = s * s + s * d_k + s * d_k;
(qkt_witness + softmax_lookups + attn_v_witness) * self.num_heads
}
Constructors:
// Standard MHA
MultiHeadAttentionConfig::new(32, 4096, 2048)
// GQA: 32 Q heads, 8 KV heads
MultiHeadAttentionConfig::new_gqa(32, 8, 4096, 2048, true)
// MQA: all Q heads share 1 KV head
MultiHeadAttentionConfig::new_mqa(32, 4096, 2048, true)
KV-Cache — Incremental Decoding
For autoregressive generation, the KV-Cache stores previously computed K/V projections so each new token only requires O(1) new computation instead of reprocessing the full sequence.
pub struct KVCache {
pub k_cache: Vec<M31Matrix>, // per KV-head: (cached_len, d_k)
pub v_cache: Vec<M31Matrix>, // per KV-head: (cached_len, d_k)
pub cached_len: usize,
pub num_kv_heads: usize,
pub d_k: usize,
}
Usage:
let config = MultiHeadAttentionConfig::new_gqa(32, 8, 4096, 1, true);
let mut cache = KVCache::new(&config);
// Step 1: process first token
let out1 = attention_forward_cached(&input_tok1, &weights, &config, &mut cache);
// cache.cached_len == 1
// Step 2: process next token (uses cached K/V from step 1)
let out2 = attention_forward_cached(&input_tok2, &weights, &config, &mut cache);
// cache.cached_len == 2
Causal masking in M31 arithmetic: The mask sentinel value is P - 2 = 2^31 - 3:
For position (i, j) where j > i (future token):
score[i][j] = P - 2
softmax_exp(P - 2) ≈ 0 // kills the attention weight
The sentinel P-2 is chosen because softmax_exp(P-2) returns approximately zero in M31 arithmetic, effectively zeroing future positions without needing actual -infinity. For KV-cached decoding, the cache offset is tracked so j > i + cache_offset triggers the mask — new queries attend to all cached positions plus the current token.
Softmax in M31 arithmetic: The softmax function operates entirely in the M31 field:
pub fn softmax_row_m31(row: &[M31]) -> Vec<M31> {
// 1. Element-wise: exp_vals = softmax_exp(row[i]) for each i
// 2. Sum: sum_exp = Σ exp_vals[i]
// 3. Normalize: result[i] = exp_vals[i] × inv(sum_exp)
}
The SoftmaxNormEval component enforces a degree-2 constraint: weight × sum_exp - exp_val = 0, proving that each softmax output equals exp_val / sum_exp. The claimed sum is bound to the proof transcript via the Fiat-Shamir channel.
Multi-layer cache: ModelKVCache wraps per-layer caches:
let mut model_cache = ModelKVCache::new(num_layers, &attn_config);
for layer in 0..num_layers {
output = attention_forward_cached(&output, &weights[layer], &config, &mut model_cache.layers[layer]);
}
Dequantization (components/dequantize.rs)
Weight dequantization converts INT4/INT8 quantized weights back to M31 field elements for proving. Uses a 2D LogUp table mapping (quantized_value, dequantized_value).
Trace layout (3 columns):
| Column | Name | Description |
|---|---|---|
| 0 | trace_input | Quantized value |
| 1 | trace_output | Dequantized value (M31) |
| 2 | multiplicity | LogUp multiplicity |
Preprocessed table (2 columns): Maps each quantized value q ∈ [0, 2^bits) to its dequantized M31 representation via (q - zero_point) × scale.
Supported quantization strategies:
| Strategy | Bits | Zero Point | Use Case |
|---|---|---|---|
Symmetric4 | 4 | 7 | INT4 symmetric |
Symmetric8 | 8 | 127 | INT8 symmetric |
Asymmetric4 | 4 | custom | INT4 with custom offset |
Direct | — | 0 | Direct M31 encoding |
Table sizes: 16 entries for INT4, 256 entries for INT8.
Proof Structure
A single transformer block generates:
| Component | Count | Protocol |
|---|---|---|
| RMSNorm | 2 | LogUp STARK (rsqrt table) |
| Attention | 1 | 4+2H composed sumcheck + LogUp (softmax) |
| FFN MatMul | 2 | Sumcheck over MLE |
| GELU | 1 | LogUp STARK (activation table) |
| Add (residual) | 2 | Linear split (no proof needed in GKR) |
For a 32-head GQA model with 8 KV heads, one block produces 69 matmul sumcheck proofs + 2 LogUp STARKs + 1 softmax STARK.
Integration with GKR
When using the GKR protocol, the transformer block is compiled into a LayeredCircuit where each graph node becomes one or more layers:
GKR Output claim
→ RMSNorm (eq-sumcheck + LogUp)
→ Attention (composed sub-matmuls)
→ Add (linear split)
→ RMSNorm (eq-sumcheck + LogUp)
→ Linear (matmul sumcheck)
→ GELU (LogUp eq-sumcheck)
→ Linear (matmul sumcheck)
→ Add (linear split)
→ Input claim
The GKR proof for the entire block is a single interactive proof, replacing what would otherwise be ~75 independent STARK proofs.
Circuit Compilation
The GraphBuilder compiles transformer blocks into a LayeredCircuit with typed layers:
pub struct CircuitLayer {
pub layer_type: LayerType, // MatMul, Add, Mul, Activation, LayerNorm, ...
pub input_shape: (usize, usize),
pub output_shape: (usize, usize),
pub node_id: usize, // back-reference to ComputationGraph node
pub input_layers: Vec<usize>, // layer indices this layer reads from
}
Each LayerType maps to a specific GKR reduction strategy:
| LayerType | GKR Reduction | Proof Variant |
|---|---|---|
MatMul | Sumcheck (degree-2) | LayerProof::MatMul |
Add | Linear split (trunk/skip) | LayerProof::Add |
Mul | Eq-sumcheck (degree-3) | LayerProof::Mul |
Activation | LogUp eq-sumcheck | LayerProof::Activation |
LayerNorm | LogUp + linear sumcheck | LayerProof::LayerNorm |
RMSNorm | LogUp + linear sumcheck | LayerProof::RMSNorm |
Attention | Composed sub-matmuls | LayerProof::Attention |
Residual Connections as DAG Add
The Add layers (residual connections) split the GKR claim into two sub-claims — a trunk (main computation path) and a skip (identity residual). The trunk continues the GKR walk. The skip branch produces a deferred proof that is verified separately:
pub struct DeferredProof {
pub claim: GKRClaim, // evaluation point + value
pub dims: (usize, usize, usize), // (m, k, n) for matmul
pub layer_proof: LayerProof, // standalone proof for skip branch
pub input_claim: GKRClaim, // resulting input claim
pub weight_commitment: FieldElement,
pub weight_opening: MleOpeningProof,
}
For the transformer's pre-norm residual pattern, the skip connection is always an identity (no weights), so the deferred proof is trivial. For more complex DAG topologies (e.g., U-Net), deferred proofs carry full matmul sumcheck data.
Component Cost Comparison
Per-block proving cost for a Qwen3-14B-like config (d_model=5120, seq=2048, heads=40, kv_heads=8, ffn=13824):
| Component | Matmul Proofs | LogUp STARKs | Sumcheck Rounds |
|---|---|---|---|
| RMSNorm ×2 | 0 | 2 | ~26 (eq-sumcheck) |
| Attention (GQA-40/8) | 44 | 1 (softmax) | ~560 |
| FFN Linear ×2 | 2 | 0 | ~26 |
| GELU | 0 | 1 | ~22 (eq-sumcheck) |
| Add ×2 | 0 | 0 | 0 (linear split) |
| Total per block | 46 | 4 | ~634 |
With GKR, the entire block becomes a single interactive proof with ~634 sumcheck rounds instead of 50+ independent proofs.
Example: Full Pipeline
use stwo_ml::compiler::graph::GraphBuilder;
use stwo_ml::aggregation::{prove_model_aggregated, verify_aggregated_model_proof};
// Build a 2-layer transformer
let mut builder = GraphBuilder::new((2, 64)); // seq_len=2, d_model=64
builder
.transformer_block(4, 2, 2, 128) // 4 heads, 2 KV heads, seq=2, ffn=128
.transformer_block(4, 2, 2, 128);
let graph = builder.build();
// Create weights for all MatMul and Attention nodes
let weights = create_weights_for_graph(&graph);
// Prove
let proof = prove_model_aggregated(&graph, &input, &weights)
.expect("proving failed");
// Verify
verify_aggregated_model_proof(proof, &graph, &input, &weights)
.expect("verification failed");
Proving Functions
Two proving backends exist for each component:
// Off-chain (Blake2s channel, for testing/development)
pub fn prove_attention_with<B, MC>(
input: &M31Matrix,
weights: &AttentionWeights,
config: &MultiHeadAttentionConfig,
causal: bool,
) -> Result<AttentionProof<<MC as MerkleChannel>::H>, AttentionError>
// On-chain (Poseidon channel for matmuls, Blake2s for softmax STARK)
pub fn prove_attention_onchain(
input: &M31Matrix,
weights: &AttentionWeights,
config: &MultiHeadAttentionConfig,
causal: bool,
) -> Result<AttentionProofOnChain, AttentionError>
The on-chain path uses PoseidonChannel for all matmul sumchecks (Cairo-native) and Blake2sChannel for the softmax STARK (efficient off-chain commitment). The AttentionProofOnChain struct mirrors AttentionProof with Poseidon-compatible types.
M31 Arithmetic Subtleties
All transformer components operate in M31 = Z/(2^31 - 1):
| Component | Encoding | Range |
|---|---|---|
| Activations | Direct M31 | [0, P-1] |
| RoPE cos/sin | Signed fixed-point | [-1, 1] → [0, P-1] via (v+1) × scale |
| RMSNorm rsqrt | Fixed-point 2^16 | Precomputed table |
| Causal mask | Sentinel value | P - 2 = 2^31 - 3 |
| Softmax exp | Lookup table | Precomputed exp(x) in M31 |
Overflow safety: All M31 multiplication is modular — a × b mod P never overflows a u64 because (P-1)^2 < 2^62. This eliminates carry-handling in both prover and verifier.
References
- Goldwasser, Kalai, Rothblum. "Delegating Computation: Interactive Proofs for Muggles." STOC 2008.
- Su, Zhang, Deng, Thaler. "Efficient and Provably Secure Machine Learning via Interactive Proofs." 2023.
- Vaswani et al. "Attention Is All You Need." NeurIPS 2017.
- Shazeer. "Fast Transformer Decoding: One Write-Head is All You Need." 2019. (MQA)
- Ainslie et al. "GQA: Training Generalized Multi-Query Transformer Models." 2023.
- Su et al. "RoFormer: Enhanced Transformer with Rotary Position Embedding." 2021.