Custom AIR over STWO
Overview
stwo-ml implements 11 custom AIR (Algebraic Intermediate Representation) components using STWO's FrameworkComponent pattern. Each component defines polynomial constraints and optional LogUp lookup arguments that are proven inside a single unified STARK proof covering all non-matmul operations.
This design means a transformer model with 40 layers of activations, layer norms, and element-wise operations produces one STARK proof (not 40+ independent proofs).
FrameworkComponent Pattern
Every AIR component follows the STWO constraint framework:
trait FrameworkEval {
fn log_size(&self) -> u32;
fn max_constraint_log_degree_bound(&self) -> u32;
fn evaluate<E: EvalAtRow>(&self, eval: E) -> E;
}
type FrameworkComponent<E: FrameworkEval> = /* STWO wrapper */;
The evaluate method defines the component's polynomial constraints:
impl FrameworkEval for MyComponentEval {
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
// Read trace columns
let col_a = eval.next_trace_mask();
let col_b = eval.next_trace_mask();
// Add polynomial constraint: col_a * col_b - col_c = 0
eval.add_constraint(col_a * col_b - col_c);
// (Optional) LogUp relation entry
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
multiplicity,
&[col_a, col_b],
));
eval.finalize_logup_in_pairs();
eval
}
}
The same evaluate function runs in both the prover (generating constraints over the trace) and verifier (checking constraints at random evaluation points). STWO handles the FRI commitment, quotient computation, and verification protocol automatically.
Component Table
| Component | File | Constraint Degree | LogUp | Preprocessed Cols | Execution Cols | Interaction Cols |
|---|---|---|---|---|---|---|
| ElementwiseAdd | elementwise.rs | 1 | No | 0 | 3 | 0 |
| ElementwiseMul | elementwise.rs | 2 | No | 0 | 3 | 0 |
| Activation | activation.rs | 1 + LogUp | Yes (3-tuple) | 2 | 3 | 1 |
| LayerNorm | layernorm.rs | 2 + LogUp | Yes (2-tuple) | 2 | 6 | 1 |
| RMSNorm | rmsnorm.rs | 2 + LogUp | Yes (2-tuple) | 2 | 5 | 1 |
| Embedding | embedding.rs | 1 + LogUp | Yes (3-tuple) | 3 | 4 | 1 |
| Quantize | quantize.rs | 1 + LogUp | Yes (2-tuple) | 2 | 3 | 1 |
| Dequantize | dequantize.rs | 1 + LogUp | Yes (2-tuple) | 2 | 3 | 1 |
| RoPE | rope.rs | 1 + LogUp | Yes | 2 | 5 | 1 |
| MatMul | matmul.rs | Sumcheck | No | — | — | — |
| Receipt | receipt.rs | 2 | No | 0 | 7 | 0 |
MatMul and Attention produce separate sumcheck proofs outside the unified STARK.
Three-Tree Commitment Scheme
STWO's polynomial commitment scheme commits to three Merkle trees per proof. All AIR components share these three trees in a single prove() call.
Tree 0: Preprocessed Columns
Constant-time lookup tables, committed before proving begins. Known to both prover and verifier.
Activation tables: [table_input, table_output] — (x, f(x)) pairs
LayerNorm tables: [table_variance, table_rsqrt] — reciprocal sqrt lookup
RMSNorm tables: [table_rms_sq, table_rsqrt] — same format
Embedding tables: [table_token, table_col, table_val] — embedding matrix
Quantize tables: [table_input, table_output] — quantization mapping
Critical: Tree 0 is always committed even when empty (no lookup components). STWO's verifier expects a fixed tree count.
Tree 1: Execution Trace
The actual computation trace — one row per operation instance.
Activation: [trace_input, trace_output, multiplicity]
Add: [lhs, rhs, output]
Mul: [lhs, rhs, output]
LayerNorm: [input, mean, variance, rsqrt_val, output, multiplicity]
RMSNorm: [input, rms_sq, rsqrt, output, multiplicity]
Embedding: [token_id, col_idx, value, multiplicity]
Quantize: [input, output, multiplicity]
Tree 2: Interaction Columns (LogUp)
Created during proving, after Tree 1 is committed. Contains the LogUp fraction columns.
Per-component LogUp column:
- Table side: contribute -multiplicity (supply)
- Trace side: contribute +1 (demand)
- Combined fraction: (table_encode - mult × trace_encode) / (table × trace)
Components without LogUp (ElementwiseAdd, ElementwiseMul) have empty Tree 2 entries.
Tree constants:
PREPROCESSED_TRACE_IDX = 0ORIGINAL_TRACE_IDX = 1INTERACTION_TRACE_IDX = 2
LogUp Protocol
LogUp is a lookup argument that proves every trace row's (input, output) pair exists in a precomputed table. Used for non-linear operations where polynomial constraints alone are insufficient.
Relation Declaration
Each component declares a typed relation with fixed arity:
relation!(ActivationRelation, 3); // (type_tag, input, output)
relation!(LayerNormRelation, 2); // (variance, rsqrt)
relation!(QuantizeRelation, 2); // (input, output)
relation!(EmbeddingRelation, 3); // (token_id, col_idx, value)
Evaluation Pattern
The evaluate() method always performs exactly two add_to_relation() calls — one for the table side (supply) and one for the trace side (demand):
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
// Table side (preprocessed columns): yield entries with -multiplicity
let table_in = eval.get_preprocessed_column(id_table_input);
let table_out = eval.get_preprocessed_column(id_table_output);
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
-E::EF::from(multiplicity.clone()),
&[tag, table_in, table_out],
));
// Trace side (execution columns): consume entries with coefficient +1
let trace_in = eval.next_trace_mask();
let trace_out = eval.next_trace_mask();
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
E::EF::one(),
&[tag, trace_in, trace_out],
));
eval.finalize_logup_in_pairs();
eval
}
Critical: finalize_logup_in_pairs() (not finalize_logup()) is required when exactly 2 add_to_relation calls are made. This handles the algebraic LogUp combination correctly.
Type Tags for Domain Separation
Activation components use type tags to ensure different activation functions share the same ActivationRelation random challenges without cross-activation forgery:
pub enum ActivationType {
ReLU, // type_tag = 1
GELU, // type_tag = 2
Sigmoid, // type_tag = 3
Softmax, // type_tag = 4
}
The tag is included as the first element of the LogUp relation tuple.
LogUp Claimed Sum
Each LogUp component produces a claimed_sum: SecureField — the total of all fraction evaluations. The verifier checks this sum against the interaction column commitment. This value is part of the LayerClaim structure:
pub struct LayerClaim {
pub layer_index: usize,
pub claimed_sum: SecureField,
pub trace_rows: usize,
}
Precomputed Tables
PrecomputedTable
The gadgets/lookup_table.rs module provides table construction:
pub struct PrecomputedTable {
pub inputs: Vec<M31>,
pub outputs: Vec<M31>,
pub log_size: u32,
index: Option<HashMap<u32, usize>>, // O(1) for tables > 2^10
}
impl PrecomputedTable {
pub fn build(f: impl Fn(M31) -> M31, log_size: u32) -> Self { ... }
pub fn build_parallel(f: impl Fn(M31) -> M31 + Sync, log_size: u32) -> Self { ... }
pub fn from_pairs(pairs: Vec<(M31, M31)>, log_size: u32) -> Self { ... }
}
Production Table Sizes
| Operation | Bits | Table Size | Notes |
|---|---|---|---|
| ReLU | 16 | 65K entries | Exact: max(0, x) |
| GELU | 18 | 262K entries | Approximate: x · Φ(x) |
| Sigmoid | 16 | 65K entries | 1 / (1 + e^(-x)) |
| Softmax exp | 20 | 1M entries | e^x normalized |
| LayerNorm rsqrt | 16 | 65K entries | 1/√(x + ε) |
| Quantize INT8 | 8 | 256 entries | Scale + zero-point |
| Dequantize INT4 | 4 | 16 entries | Inverse mapping |
Trace Generation
Each component generates its trace columns using a generic pattern over the STWO backend:
pub fn generate_activation_trace<B: ColumnOps<BaseField>>(
inputs: &[M31],
outputs: &[M31],
multiplicities: &[M31],
log_size: u32,
) -> Vec<CircleEvaluation<B, BaseField, BitReversedOrder>> {
let size = 1usize << log_size;
let domain = CanonicCoset::new(log_size).circle_domain();
let mut col_input = Col::<B, BaseField>::zeros(size);
let mut col_output = Col::<B, BaseField>::zeros(size);
let mut col_mult = Col::<B, BaseField>::zeros(size);
for (i, (&inp, &out)) in inputs.iter().zip(outputs).enumerate().take(size) {
col_input.set(i, inp);
col_output.set(i, out);
col_mult.set(i, multiplicities[i]);
}
vec![
CircleEvaluation::new(domain, col_input),
CircleEvaluation::new(domain, col_output),
CircleEvaluation::new(domain, col_mult),
]
}
Generic over Backend B:
SimdBackend— SIMD CPU (default for trace generation)CpuBackend— Scalar CPUGpuBackend— NVIDIA CUDA (requirescuda-runtimefeature)
Unified STARK Aggregation
Single prove() Call
All non-matmul components are combined into one STARK proof:
let stark_proof = prove::<B, MC>(
&[
&activation_comp, // FrameworkComponent<ActivationEval>
&add_comp, // FrameworkComponent<ElementwiseAddEval>
&mul_comp, // FrameworkComponent<ElementwiseMulEval>
&layernorm_comp, // FrameworkComponent<LayerNormEval>
&rmsnorm_comp, // FrameworkComponent<RMSNormEval>
&embedding_comp, // FrameworkComponent<EmbeddingEval>
&quantize_comp, // FrameworkComponent<QuantizeEval>
&dequantize_comp, // FrameworkComponent<DequantizeEval>
],
&mut channel,
commitment_scheme,
)?;
Prover Setup
let config = PcsConfig::default();
let twiddles = B::precompute_twiddles(
CanonicCoset::new(max_degree_bound + config.fri_config.log_blowup_factor)
.circle_domain()
.half_coset,
);
let mut commitment_scheme = CommitmentSchemeProver::<B, MC>::new(config, &twiddles);
// Tree 0: Preprocessed (activation tables, layernorm tables, etc.)
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(preproc_evals);
tree_builder.commit(channel);
// Tree 1: Execution trace (all component traces)
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(convert_evaluations::<SimdBackend, B, BaseField>(trace_evals));
tree_builder.commit(channel);
// Tree 2: Created automatically during prove()
Trait Bounds
All components must implement ComponentProver<B> for the chosen backend:
where
B: BackendForChannel<MC> + PolyOps,
MC: MerkleChannel,
FrameworkComponent<ActivationEval>: ComponentProver<B>,
FrameworkComponent<ElementwiseAddEval>: ComponentProver<B>,
FrameworkComponent<LayerNormEval>: ComponentProver<B>,
// ... all component types
Heterogeneous Component Storage
Components have different types but share a common trait. The ComponentProverErased<B> trait enables heterogeneous Vec<Box<dyn>> storage:
trait ComponentProverErased<B: Backend> {
fn prove(&self, ...) -> StarkProof;
}
// Blanket impl for all FrameworkComponent<E>
impl<B, E: FrameworkEval> ComponentProverErased<B> for FrameworkComponent<E>
where FrameworkComponent<E>: ComponentProver<B> { ... }
Similarly, ComponentRefErased enables &dyn Component references from FrameworkComponent<E> for verification.
Aggregated Proof Structure
pub struct AggregatedModelProofFor<H: MerkleHasherLifted> {
// Single STARK covering all non-matmul components
pub unified_stark: Option<StarkProof<H>>,
// Per-matmul sumcheck proofs (separate from STARK)
pub matmul_proofs: Vec<(usize, MatMulSumcheckProof)>,
// Claims verified inside the unified STARK
pub activation_claims: Vec<LayerClaim>,
pub add_claims: Vec<LayerClaim>,
pub mul_claims: Vec<LayerClaim>,
pub layernorm_claims: Vec<LayerClaim>,
pub rmsnorm_claims: Vec<LayerClaim>,
pub embedding_claims: Vec<LayerClaim>,
pub quantize_claims: Vec<LayerClaim>,
pub dequantize_claims: Vec<LayerClaim>,
// Attention proofs (composite: matmul + softmax)
pub attention_proofs: Vec<(usize, AttentionProof<H>)>,
// Commitments
pub execution: GraphExecution,
pub layer_chain_commitment: FieldElement,
pub io_commitment: FieldElement,
pub layernorm_mean_var_commitments: Vec<FieldElement>,
pub quantize_params_commitment: FieldElement,
}
Verification Pipeline
Off-Chain (Rust)
verify_aggregated_model_proof(proof, graph, input, weights)
- Re-run forward pass to reconstruct expected outputs
- Verify unified STARK (Trees 0/1/2 commitments + FRI)
- Per-matmul: verify sumcheck proofs independently
- Check all commitments (IO, layer chain, mean/var)
On-Chain (Cairo)
The Cairo verifier receives serialized proof data and verifies in three stages:
1. STARK.verify() on unified proof
├── Check Tree 0 preprocessed column root
├── Check Tree 1 execution trace root
├── Check Tree 2 LogUp interaction root
├── Verify FRI (low-degree test)
└── Check constraint evaluation at random points
2. Per-MatMul: sumcheck.verify()
├── Check claimed_sum via sumcheck rounds
└── Verify final A×B evaluation
3. Commitment checks (Poseidon)
├── IO commitment (input || output hash)
├── Layer chain (running hash of intermediates)
├── LayerNorm mean/var per-layer
└── Quantize params
Cairo Serialization Format
The cairo_serde.rs module serializes proofs for on-chain consumption:
Type Mapping
| Rust Type | Cairo Type | felt252 Count |
|---|---|---|
| M31 | M31 | 1 |
| QM31 | QM31 | 4 (a, b, c, d components) |
| Blake2sHash | Blake2sHash | 8 (8 × u32 little-endian) |
| u32 | u32 | 1 |
| Vec<T> | Span<T> | 1 (length) + N × size(T) |
MatMul Proof Serialization
[m: 1] [k: 1] [n: 1] [num_rounds: 1]
[claimed_sum: 4]
[round_polys_len: 1] [round_poly × num_rounds: 12 each]
[final_a_eval: 4] [final_b_eval: 4]
[a_commitment: 1] [b_commitment: 1]
[a_opening: variable] [b_opening: variable]
Total: ~50 + 12 × num_rounds felt252 values per matmul.
Quantization Components
Strategy
pub enum QuantStrategy {
Direct, // Direct clamping: [0, P-1]
Symmetric8, // INT8: [-127, 127] → [0, 254]
Asymmetric8, // INT8 with zero-point
Symmetric4, // INT4: [-7, 7] → [0, 14]
Asymmetric4, // INT4 with zero-point
}
Proving
- Observe trace inputs (f32 values encoded as M31)
- Apply
quantize_value()→ M31 outputs - Build 2D lookup table:
(input_m31, output_m31) - Prove via LogUp: each
(trace_input, trace_output)exists in table
Dequantize
For INT4 models, dequantization maps 16 possible quantized values to their M31 equivalents. The small table (16 entries) makes LogUp verification very efficient.
Receipt Proof System
The receipt component provides billing arithmetic verification:
Constraints:
time_billing × 1000 == gpu_time_ms × rate_per_sec
token_billing == token_count × rate_per_token
billing_total == time_billing + token_billing
7 execution columns, degree-2 constraints, no LogUp. Receipts are chain-linked via Poseidon hashes with TEE freshness validation (MAX_TEE_AGE_SECS = 3600).
Key Lessons
Always commit Tree 0: Even when empty (no preprocessed columns), the tree must be committed. STWO's verifier expects a fixed tree count.
finalize_logup_in_pairs(): Required (not finalize_logup()) when exactly 2 add_to_relation calls are made per evaluation.
Type tags prevent cross-activation forgery: Without tags, a ReLU table entry could satisfy a GELU lookup if both share the same relation challenges.
Audit finding: simd_combined field was not serialized in the Cairo serialization module, causing the Cairo verifier to reject SIMD-batched proofs. Fixed by adding explicit serialization before the FRI proof.
GPU bridge for proving: SimdToGpuTreeBuilder adapter enables zero-copy conversion between SimdBackend and GpuBackend column types via unsafe { transmute() } — they share identical memory layout.