Custom AIR over STWO

Overview

stwo-ml implements 11 custom AIR (Algebraic Intermediate Representation) components using STWO's FrameworkComponent pattern. Each component defines polynomial constraints and optional LogUp lookup arguments that are proven inside a single unified STARK proof covering all non-matmul operations.

This design means a transformer model with 40 layers of activations, layer norms, and element-wise operations produces one STARK proof (not 40+ independent proofs).

FrameworkComponent Pattern

Every AIR component follows the STWO constraint framework:

trait FrameworkEval {
    fn log_size(&self) -> u32;
    fn max_constraint_log_degree_bound(&self) -> u32;
    fn evaluate<E: EvalAtRow>(&self, eval: E) -> E;
}

type FrameworkComponent<E: FrameworkEval> = /* STWO wrapper */;

The evaluate method defines the component's polynomial constraints:

impl FrameworkEval for MyComponentEval {
    fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
        // Read trace columns
        let col_a = eval.next_trace_mask();
        let col_b = eval.next_trace_mask();

        // Add polynomial constraint: col_a * col_b - col_c = 0
        eval.add_constraint(col_a * col_b - col_c);

        // (Optional) LogUp relation entry
        eval.add_to_relation(RelationEntry::new(
            &self.lookup_elements,
            multiplicity,
            &[col_a, col_b],
        ));

        eval.finalize_logup_in_pairs();
        eval
    }
}

The same evaluate function runs in both the prover (generating constraints over the trace) and verifier (checking constraints at random evaluation points). STWO handles the FRI commitment, quotient computation, and verification protocol automatically.

Component Table

ComponentFileConstraint DegreeLogUpPreprocessed ColsExecution ColsInteraction Cols
ElementwiseAddelementwise.rs1No030
ElementwiseMulelementwise.rs2No030
Activationactivation.rs1 + LogUpYes (3-tuple)231
LayerNormlayernorm.rs2 + LogUpYes (2-tuple)261
RMSNormrmsnorm.rs2 + LogUpYes (2-tuple)251
Embeddingembedding.rs1 + LogUpYes (3-tuple)341
Quantizequantize.rs1 + LogUpYes (2-tuple)231
Dequantizedequantize.rs1 + LogUpYes (2-tuple)231
RoPErope.rs1 + LogUpYes251
MatMulmatmul.rsSumcheckNo
Receiptreceipt.rs2No070

MatMul and Attention produce separate sumcheck proofs outside the unified STARK.

Three-Tree Commitment Scheme

STWO's polynomial commitment scheme commits to three Merkle trees per proof. All AIR components share these three trees in a single prove() call.

Tree 0: Preprocessed Columns

Constant-time lookup tables, committed before proving begins. Known to both prover and verifier.

Activation tables:   [table_input, table_output]        — (x, f(x)) pairs
LayerNorm tables:    [table_variance, table_rsqrt]      — reciprocal sqrt lookup
RMSNorm tables:      [table_rms_sq, table_rsqrt]        — same format
Embedding tables:    [table_token, table_col, table_val] — embedding matrix
Quantize tables:     [table_input, table_output]         — quantization mapping

Critical: Tree 0 is always committed even when empty (no lookup components). STWO's verifier expects a fixed tree count.

Tree 1: Execution Trace

The actual computation trace — one row per operation instance.

Activation:    [trace_input, trace_output, multiplicity]
Add:           [lhs, rhs, output]
Mul:           [lhs, rhs, output]
LayerNorm:     [input, mean, variance, rsqrt_val, output, multiplicity]
RMSNorm:       [input, rms_sq, rsqrt, output, multiplicity]
Embedding:     [token_id, col_idx, value, multiplicity]
Quantize:      [input, output, multiplicity]

Tree 2: Interaction Columns (LogUp)

Created during proving, after Tree 1 is committed. Contains the LogUp fraction columns.

Per-component LogUp column:
  - Table side: contribute -multiplicity (supply)
  - Trace side: contribute +1 (demand)
  - Combined fraction: (table_encode - mult × trace_encode) / (table × trace)

Components without LogUp (ElementwiseAdd, ElementwiseMul) have empty Tree 2 entries.

Tree constants:

  • PREPROCESSED_TRACE_IDX = 0
  • ORIGINAL_TRACE_IDX = 1
  • INTERACTION_TRACE_IDX = 2

LogUp Protocol

LogUp is a lookup argument that proves every trace row's (input, output) pair exists in a precomputed table. Used for non-linear operations where polynomial constraints alone are insufficient.

Relation Declaration

Each component declares a typed relation with fixed arity:

relation!(ActivationRelation, 3);    // (type_tag, input, output)
relation!(LayerNormRelation, 2);     // (variance, rsqrt)
relation!(QuantizeRelation, 2);      // (input, output)
relation!(EmbeddingRelation, 3);     // (token_id, col_idx, value)

Evaluation Pattern

The evaluate() method always performs exactly two add_to_relation() calls — one for the table side (supply) and one for the trace side (demand):

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
    // Table side (preprocessed columns): yield entries with -multiplicity
    let table_in = eval.get_preprocessed_column(id_table_input);
    let table_out = eval.get_preprocessed_column(id_table_output);
    eval.add_to_relation(RelationEntry::new(
        &self.lookup_elements,
        -E::EF::from(multiplicity.clone()),
        &[tag, table_in, table_out],
    ));

    // Trace side (execution columns): consume entries with coefficient +1
    let trace_in = eval.next_trace_mask();
    let trace_out = eval.next_trace_mask();
    eval.add_to_relation(RelationEntry::new(
        &self.lookup_elements,
        E::EF::one(),
        &[tag, trace_in, trace_out],
    ));

    eval.finalize_logup_in_pairs();
    eval
}

Critical: finalize_logup_in_pairs() (not finalize_logup()) is required when exactly 2 add_to_relation calls are made. This handles the algebraic LogUp combination correctly.

Type Tags for Domain Separation

Activation components use type tags to ensure different activation functions share the same ActivationRelation random challenges without cross-activation forgery:

pub enum ActivationType {
    ReLU,      // type_tag = 1
    GELU,      // type_tag = 2
    Sigmoid,   // type_tag = 3
    Softmax,   // type_tag = 4
}

The tag is included as the first element of the LogUp relation tuple.

LogUp Claimed Sum

Each LogUp component produces a claimed_sum: SecureField — the total of all fraction evaluations. The verifier checks this sum against the interaction column commitment. This value is part of the LayerClaim structure:

pub struct LayerClaim {
    pub layer_index: usize,
    pub claimed_sum: SecureField,
    pub trace_rows: usize,
}

Precomputed Tables

PrecomputedTable

The gadgets/lookup_table.rs module provides table construction:

pub struct PrecomputedTable {
    pub inputs: Vec<M31>,
    pub outputs: Vec<M31>,
    pub log_size: u32,
    index: Option<HashMap<u32, usize>>,  // O(1) for tables > 2^10
}

impl PrecomputedTable {
    pub fn build(f: impl Fn(M31) -> M31, log_size: u32) -> Self { ... }
    pub fn build_parallel(f: impl Fn(M31) -> M31 + Sync, log_size: u32) -> Self { ... }
    pub fn from_pairs(pairs: Vec<(M31, M31)>, log_size: u32) -> Self { ... }
}

Production Table Sizes

OperationBitsTable SizeNotes
ReLU1665K entriesExact: max(0, x)
GELU18262K entriesApproximate: x · Φ(x)
Sigmoid1665K entries1 / (1 + e^(-x))
Softmax exp201M entriese^x normalized
LayerNorm rsqrt1665K entries1/√(x + ε)
Quantize INT88256 entriesScale + zero-point
Dequantize INT4416 entriesInverse mapping

Trace Generation

Each component generates its trace columns using a generic pattern over the STWO backend:

pub fn generate_activation_trace<B: ColumnOps<BaseField>>(
    inputs: &[M31],
    outputs: &[M31],
    multiplicities: &[M31],
    log_size: u32,
) -> Vec<CircleEvaluation<B, BaseField, BitReversedOrder>> {
    let size = 1usize << log_size;
    let domain = CanonicCoset::new(log_size).circle_domain();

    let mut col_input = Col::<B, BaseField>::zeros(size);
    let mut col_output = Col::<B, BaseField>::zeros(size);
    let mut col_mult = Col::<B, BaseField>::zeros(size);

    for (i, (&inp, &out)) in inputs.iter().zip(outputs).enumerate().take(size) {
        col_input.set(i, inp);
        col_output.set(i, out);
        col_mult.set(i, multiplicities[i]);
    }

    vec![
        CircleEvaluation::new(domain, col_input),
        CircleEvaluation::new(domain, col_output),
        CircleEvaluation::new(domain, col_mult),
    ]
}

Generic over Backend B:

  • SimdBackend — SIMD CPU (default for trace generation)
  • CpuBackend — Scalar CPU
  • GpuBackend — NVIDIA CUDA (requires cuda-runtime feature)

Unified STARK Aggregation

Single prove() Call

All non-matmul components are combined into one STARK proof:

let stark_proof = prove::<B, MC>(
    &[
        &activation_comp,   // FrameworkComponent<ActivationEval>
        &add_comp,          // FrameworkComponent<ElementwiseAddEval>
        &mul_comp,          // FrameworkComponent<ElementwiseMulEval>
        &layernorm_comp,    // FrameworkComponent<LayerNormEval>
        &rmsnorm_comp,      // FrameworkComponent<RMSNormEval>
        &embedding_comp,    // FrameworkComponent<EmbeddingEval>
        &quantize_comp,     // FrameworkComponent<QuantizeEval>
        &dequantize_comp,   // FrameworkComponent<DequantizeEval>
    ],
    &mut channel,
    commitment_scheme,
)?;

Prover Setup

let config = PcsConfig::default();
let twiddles = B::precompute_twiddles(
    CanonicCoset::new(max_degree_bound + config.fri_config.log_blowup_factor)
        .circle_domain()
        .half_coset,
);

let mut commitment_scheme = CommitmentSchemeProver::<B, MC>::new(config, &twiddles);

// Tree 0: Preprocessed (activation tables, layernorm tables, etc.)
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(preproc_evals);
tree_builder.commit(channel);

// Tree 1: Execution trace (all component traces)
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(convert_evaluations::<SimdBackend, B, BaseField>(trace_evals));
tree_builder.commit(channel);

// Tree 2: Created automatically during prove()

Trait Bounds

All components must implement ComponentProver<B> for the chosen backend:

where
    B: BackendForChannel<MC> + PolyOps,
    MC: MerkleChannel,
    FrameworkComponent<ActivationEval>: ComponentProver<B>,
    FrameworkComponent<ElementwiseAddEval>: ComponentProver<B>,
    FrameworkComponent<LayerNormEval>: ComponentProver<B>,
    // ... all component types

Heterogeneous Component Storage

Components have different types but share a common trait. The ComponentProverErased<B> trait enables heterogeneous Vec<Box<dyn>> storage:

trait ComponentProverErased<B: Backend> {
    fn prove(&self, ...) -> StarkProof;
}

// Blanket impl for all FrameworkComponent<E>
impl<B, E: FrameworkEval> ComponentProverErased<B> for FrameworkComponent<E>
where FrameworkComponent<E>: ComponentProver<B> { ... }

Similarly, ComponentRefErased enables &dyn Component references from FrameworkComponent<E> for verification.

Aggregated Proof Structure

pub struct AggregatedModelProofFor<H: MerkleHasherLifted> {
    // Single STARK covering all non-matmul components
    pub unified_stark: Option<StarkProof<H>>,

    // Per-matmul sumcheck proofs (separate from STARK)
    pub matmul_proofs: Vec<(usize, MatMulSumcheckProof)>,

    // Claims verified inside the unified STARK
    pub activation_claims: Vec<LayerClaim>,
    pub add_claims: Vec<LayerClaim>,
    pub mul_claims: Vec<LayerClaim>,
    pub layernorm_claims: Vec<LayerClaim>,
    pub rmsnorm_claims: Vec<LayerClaim>,
    pub embedding_claims: Vec<LayerClaim>,
    pub quantize_claims: Vec<LayerClaim>,
    pub dequantize_claims: Vec<LayerClaim>,

    // Attention proofs (composite: matmul + softmax)
    pub attention_proofs: Vec<(usize, AttentionProof<H>)>,

    // Commitments
    pub execution: GraphExecution,
    pub layer_chain_commitment: FieldElement,
    pub io_commitment: FieldElement,
    pub layernorm_mean_var_commitments: Vec<FieldElement>,
    pub quantize_params_commitment: FieldElement,
}

Verification Pipeline

Off-Chain (Rust)

verify_aggregated_model_proof(proof, graph, input, weights)
  1. Re-run forward pass to reconstruct expected outputs
  2. Verify unified STARK (Trees 0/1/2 commitments + FRI)
  3. Per-matmul: verify sumcheck proofs independently
  4. Check all commitments (IO, layer chain, mean/var)

On-Chain (Cairo)

The Cairo verifier receives serialized proof data and verifies in three stages:

1. STARK.verify() on unified proof
   ├── Check Tree 0 preprocessed column root
   ├── Check Tree 1 execution trace root
   ├── Check Tree 2 LogUp interaction root
   ├── Verify FRI (low-degree test)
   └── Check constraint evaluation at random points

2. Per-MatMul: sumcheck.verify()
   ├── Check claimed_sum via sumcheck rounds
   └── Verify final A×B evaluation

3. Commitment checks (Poseidon)
   ├── IO commitment (input || output hash)
   ├── Layer chain (running hash of intermediates)
   ├── LayerNorm mean/var per-layer
   └── Quantize params

Cairo Serialization Format

The cairo_serde.rs module serializes proofs for on-chain consumption:

Type Mapping

Rust TypeCairo Typefelt252 Count
M31M311
QM31QM314 (a, b, c, d components)
Blake2sHashBlake2sHash8 (8 × u32 little-endian)
u32u321
Vec<T>Span<T>1 (length) + N × size(T)

MatMul Proof Serialization

[m: 1] [k: 1] [n: 1] [num_rounds: 1]
[claimed_sum: 4]
[round_polys_len: 1] [round_poly × num_rounds: 12 each]
[final_a_eval: 4] [final_b_eval: 4]
[a_commitment: 1] [b_commitment: 1]
[a_opening: variable] [b_opening: variable]

Total: ~50 + 12 × num_rounds felt252 values per matmul.

Quantization Components

Strategy

pub enum QuantStrategy {
    Direct,       // Direct clamping: [0, P-1]
    Symmetric8,   // INT8: [-127, 127] → [0, 254]
    Asymmetric8,  // INT8 with zero-point
    Symmetric4,   // INT4: [-7, 7] → [0, 14]
    Asymmetric4,  // INT4 with zero-point
}

Proving

  1. Observe trace inputs (f32 values encoded as M31)
  2. Apply quantize_value() → M31 outputs
  3. Build 2D lookup table: (input_m31, output_m31)
  4. Prove via LogUp: each (trace_input, trace_output) exists in table

Dequantize

For INT4 models, dequantization maps 16 possible quantized values to their M31 equivalents. The small table (16 entries) makes LogUp verification very efficient.

Receipt Proof System

The receipt component provides billing arithmetic verification:

Constraints:
  time_billing × 1000 == gpu_time_ms × rate_per_sec
  token_billing == token_count × rate_per_token
  billing_total == time_billing + token_billing

7 execution columns, degree-2 constraints, no LogUp. Receipts are chain-linked via Poseidon hashes with TEE freshness validation (MAX_TEE_AGE_SECS = 3600).

Key Lessons

Always commit Tree 0: Even when empty (no preprocessed columns), the tree must be committed. STWO's verifier expects a fixed tree count.

finalize_logup_in_pairs(): Required (not finalize_logup()) when exactly 2 add_to_relation calls are made per evaluation.

Type tags prevent cross-activation forgery: Without tags, a ReLU table entry could satisfy a GELU lookup if both share the same relation challenges.

Audit finding: simd_combined field was not serialized in the Cairo serialization module, causing the Cairo verifier to reject SIMD-batched proofs. Fixed by adding explicit serialization before the FRI proof.

GPU bridge for proving: SimdToGpuTreeBuilder adapter enables zero-copy conversion between SimdBackend and GpuBackend column types via unsafe { transmute() } — they share identical memory layout.