Tile-Level Streaming Architecture

Double-buffered tile pipeline with precomputed matmul injection. Eliminates weight re-loading and overlaps I/O with proving.

Problem

The original prove_model_chunked_streaming_tiled had a critical inefficiency. While Phases 1 and 2 used tile-level streaming (loading weight tiles on demand from mmap), Phase 3 fell back to pipeline.load_chunk_weights() and called the monolithic prove_model_aggregated_onchain. This re-loaded full weight matrices and re-proved all matmuls — negating the tile-level memory savings.

BEFORE (wasteful):
  Phase 1: tile-by-tile forward pass    → A, C per matmul  (good)
  Phase 2: tile-by-tile proving          → TiledProofs      (good)
  Phase 3: load_chunk_weights()          → full B matrices   (BAD — re-loads everything)
           prove_model_aggregated_onchain → re-proves matmuls (BAD — duplicated work)

Solution: PrecomputedMatmuls Injection

The aggregation pipeline (prove_model_aggregated_onchain_with_cache) was refactored to accept an optional PrecomputedMatmuls struct that provides pre-computed matmul outputs and proofs. When present:

  • Phase 1 (forward pass): Uses precomputed.outputs[node_id] instead of weights.get_weight() + matmul_m31().
  • Phase 2 (proving): Skipped entirely — proofs come from precomputed.proofs and precomputed.tiled_proofs.
  • Phase 3 (STARK): Runs normally for non-matmul components (activations, add, mul, layernorm, etc.).
AFTER (efficient):
  Phase 1: tile-by-tile forward pass    → A, C per matmul
  Phase 2: tile-by-tile proving          → TiledProofs
  Phase 3: PrecomputedMatmuls injected   → NO weight loading
           aggregation pipeline           → NO matmul re-proving
                                          → Only STARK for non-matmul components

Streaming Weight Pipeline

The StreamingWeightPipeline manages memory-mapped model weight files, loading only the data needed for the current operation:

pub struct StreamingWeightPipeline {
    shards: Vec<ShardHandle>,                   // mmap'd SafeTensors files
    tensor_to_shard: HashMap<String, usize>,    // tensor name → shard index
    name_map: HashMap<usize, String>,           // node_id → tensor name
    strategy: QuantStrategy,
    graph: ComputationGraph,
}

Tensor Layout Detection

Weight tensors in SafeTensors files can be stored in either (k, n) or (n, k) layout. The pipeline detects this automatically:

fn detect_layout(shape: &[usize], k: usize, n: usize) -> TensorLayout {
    if shape == [n, k] && k != n {
        TensorLayout::StoredNK   // needs transpose on extraction
    } else {
        TensorLayout::StoredKN   // native layout, contiguous tiles
    }
}

StoredKN (native): Tile rows [k_start..k_end] are contiguous in memory — sequential mmap reads with excellent cache locality. Byte range: k_start × n × elem_size .. k_end × n × elem_size.

StoredNK (transposed): Columns are scattered across rows. Requires extracting an (n, tile_k) sub-matrix from scattered reads, then transposing to (tile_k, n) via a cache-friendly 64×64 block transpose algorithm.

Min/Max Scanning

Quantization parameters require global min/max of each weight tensor. The pipeline scans the mmap'd bytes without allocating a full f32 buffer:

pub fn scan_tensor_minmax(data: &[u8], dtype: DType) -> (f64, f64) {
    for chunk in data.chunks_exact(elem_size) {
        let v = bytes_to_f32_single(chunk, dtype) as f64;
        min_val = min_val.min(v);
        max_val = max_val.max(v);
    }
    (min_val, max_val)
}

Zero-allocation scan: touches each byte exactly once, no intermediate buffers.

Tile Loading

pub fn load_weight_tile(
    &self, node_id: usize, k_start: usize, k_end: usize, params: &QuantParams
) -> M31Matrix

Pipeline: resolve tensor metadata → detect layout → extract tile f32s → quantize to M31.

Memory: Only tile_k × n × 4 bytes (f32) + tile_k × n × 4 bytes (M31) are allocated.

For Qwen3-14B (k=5120, tile_k=1024, n=14336): 56 MB per tile vs 280 MB full weight.

Prefetch Strategy

The pipeline uses madvise(MADV_WILLNEED) to hint the OS to preload mmap pages before they're needed:

  • StoredKN: Prefetch contiguous rows [k_start..k_end] — page-aligned boundaries (4096-byte alignment)
  • StoredNK: Prefetch entire tensor (columns scattered across all rows)

For the double-buffered pipeline, prefetch is called for tile N+1 while tile N is being proven.

Key Types

PrecomputedMatmuls (aggregation.rs)

pub(crate) struct PrecomputedMatmuls {
    /// Pre-computed matmul output matrices (node_id -> C matrix).
    /// Phase 1 uses these instead of weights.get_weight() + matmul_m31().
    pub outputs: HashMap<usize, M31Matrix>,

    /// Pre-composed single-tile matmul proofs (node_id, proof).
    /// Phase 2 is skipped entirely for these.
    pub proofs: Vec<(usize, MatMulSumcheckProofOnChain)>,

    /// Multi-tile proofs that can't be composed into a single proof.
    pub tiled_proofs: Vec<(usize, TiledMatMulProof)>,
}

AggregatedModelProofOnChain — new field

pub struct AggregatedModelProofOnChain {
    // ... existing fields ...

    /// Multi-tile matmul proofs that couldn't be composed into a single proof.
    /// Present when tile-level streaming is used with multi-tile matmuls.
    pub tiled_matmul_proofs: Vec<(usize, TiledMatMulProof)>,

    // ... existing fields ...
}

API

prove_model_aggregated_onchain_with_precomputed

pub(crate) fn prove_model_aggregated_onchain_with_precomputed(
    graph: &ComputationGraph,
    input: &M31Matrix,
    weights: &GraphWeights,  // Can be empty — weights not needed
    precomputed: PrecomputedMatmuls,
) -> Result<AggregatedModelProofOnChain, AggregationError>

Thin wrapper around prove_model_aggregated_onchain_with_cache that passes Some(precomputed).

Data Flow

prove_model_chunked_streaming_tiled:

  ┌─────────────────────────────────────────────────────────┐
  │ Phase 1: Forward Pass (tile-level streaming)            │
  │                                                         │
  │   for each matmul node:                                 │
  │     pipeline.forward_matmul_tiled(node_id, A, config)   │
  │       → double-buffered: loads tile N+1 while computing │
  │         matmul for tile N (std::thread::scope)          │
  │       → accumulates C = A × B tile by tile              │
  │       → peak mem: 2 × tile_k × n (current + next tile) │
  │     stores (node_id, A, C) in chunk_matmul_data         │
  └─────────────────────────────────────────────────────────┘
                              │
                              ▼
  ┌─────────────────────────────────────────────────────────┐
  │ Phase 2: Prove (tile-level streaming)                   │
  │                                                         │
  │   for each matmul:                                      │
  │     pipeline.prove_matmul_tiled_streaming(...)           │
  │       → double-buffered: loads tile N+1 while proving   │
  │         tile N via sumcheck (std::thread::scope)        │
  │       → verify_tiled_matmul() sanity check              │
  │     stores (node_id, TiledMatMulProof)                  │
  └─────────────────────────────────────────────────────────┘
                              │
                              ▼
  ┌─────────────────────────────────────────────────────────┐
  │ Phase 3: Aggregate (precomputed injection)              │
  │                                                         │
  │   Build PrecomputedMatmuls:                             │
  │     outputs: node_id → C from Phase 1                   │
  │     proofs: single-tile → compose_tiled_proof()         │
  │     tiled_proofs: multi-tile (passed through as-is)     │
  │                                                         │
  │   prove_model_aggregated_onchain_with_precomputed(      │
  │     graph, input, GraphWeights::new(), precomputed      │
  │   )                                                     │
  │     → Phase 1: uses precomputed C (no weight lookup)    │
  │     → Phase 2: uses precomputed proofs (no re-proving)  │
  │     → Phase 3: builds STARK for activations/add/mul/... │
  └─────────────────────────────────────────────────────────┘

Double-Buffered Tile Pipeline

Both forward_matmul_tiled and prove_matmul_tiled_streaming use a double-buffered pipeline that overlaps tile loading with computation using std::thread::scope.

Timeline

Sequential (before):
  [load tile 0] [compute tile 0] [load tile 1] [compute tile 1] [load tile 2] ...
  |----- T0 ----|----- T1 -------|----- T2 ----|----- T3 -------|----- T4 ----|

Double-buffered (after):
  [load tile 0] [compute tile 0 ||||| load tile 1] [compute tile 1 ||||| load tile 2] ...
  |----- T0 ----|------------ T1 ----------------|-----------  T2 -------------------|
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                 I/O hidden behind compute — wall time reduced by load latency

How It Works

  1. Tile 0: Loaded synchronously (nothing to overlap with yet).
  2. Tiles 1..N: std::thread::scope spawns a loader thread for tile N+1 while the main thread computes/proves tile N. The loader performs mmap read + f32 conversion + M31 quantization.
  3. Single-tile case: Falls through to a simple non-pipelined path (no thread overhead for 1 tile).

Implementation (streaming.rs)

// forward_matmul_tiled — overlaps matmul with next tile load
let (c_tile, next_tile_result) = std::thread::scope(|s| {
    let loader = s.spawn(|| {
        self.load_weight_tile(node_id, next_k_start, next_k_end, &params)
    });
    let c_tile = matmul_m31(&a_tile, &b_tile);
    let next = loader.join().expect("loader thread panicked");
    (c_tile, Some(next))
});

// prove_matmul_tiled_streaming — overlaps sumcheck proving with next tile load
let (proof_result, next_tile_result) = std::thread::scope(|s| {
    let loader = s.spawn(|| {
        self.load_weight_tile(node_id, next_k_start, next_k_end, &params)
    });
    let proof = prove_matmul_sumcheck_onchain_auto(&a_padded, &b_padded, &c_padded);
    let next = loader.join().expect("loader thread panicked");
    (proof, next)
});

Why std::thread::scope

  • Borrows &self: The loader thread needs &self to call load_weight_tile. Scoped threads allow borrowing from the enclosing stack frame without 'static bounds.
  • No Arc/Mutex overhead: The pipeline reference is shared read-only (mmap is Sync).
  • Panic safety: If the loader panics, join() propagates it cleanly. The scoped thread is always joined before the scope exits.

Latency Hiding

The load_weight_tile call performs:

  1. mmap page fault (kernel I/O, ~0.5-2ms per tile for NVMe)
  2. f32 conversion from dtype (bf16/f16 → f32, ~0.1ms per tile)
  3. M31 quantization (~0.2ms per tile)

For the proving path, prove_matmul_sumcheck_onchain_auto takes 50-500ms per tile (depending on dimensions). The ~1-3ms of tile loading is completely hidden behind the proving time — effectively free I/O.

For the forward path, matmul_m31 takes 1-50ms per tile. The loading is partially or fully hidden depending on tile size.

Memory Impact

For a matmul with dimensions m × k × n and tile size tile_k:

StageBeforeAfter
Forward passk × n × 4 bytes (full weight)tile_k × n × 4 bytes (1 tile)
Provingk × n × 4 bytes (full weight, loaded twice)tile_k × n × 4 bytes (1 tile)
Aggregationk × n × 4 bytes (full weight, loaded AGAIN)0 bytes (precomputed)

Example: Qwen3-14B matmul 5120×14336:

  • Full weight: 280 MB per matrix
  • Tile (tile_k=1024): 56 MB per tile
  • Aggregation weight load eliminated: 280 MB saved per matmul
  • For 160 matmuls in a transformer block: ~44 GB of redundant weight I/O eliminated

Proof Composition

Single-tile proofs (where TiledMatMulProof.tile_proofs.len() == 1) are composed into standard MatMulSumcheckProofOnChain via compose_tiled_proof(). This makes them indistinguishable from non-tiled proofs in the final AggregatedModelProofOnChain.

Multi-tile proofs (2+ tiles) cannot be composed into a single sumcheck proof because the Fiat-Shamir challenges differ per tile. These are stored in the tiled_matmul_proofs field for separate verification.

Chunked Proving Pipeline

The tile-streaming architecture integrates with three chunked proving strategies:

Sequential Chunked (prove_model_chunked)

For memory-constrained environments. Processes one chunk at a time:

for each chunk:
  1. Load checkpoint (if resumable)
  2. Load chunk weights from mmap
  3. Prove chunk → ChunkProofResult
  4. Save checkpoint for resumability
  5. Carry output activation to next chunk
  6. Drop chunk weights (free memory)

Checkpoints are saved as chunk_{i}.json for crash recovery:

pub struct ChunkProofResult {
    pub chunk_index: usize,
    pub node_range: Range<usize>,
    pub proof: AggregatedModelProofOnChain,
    pub output: M31Matrix,       // output activation for next chunk
}

Parallel Chunked (prove_model_chunked_parallel)

Uses rayon for multi-core proving:

Phase 1 (sequential): Forward pass on CPU to compute all chunk inputs.

Phase 2 (parallel): Independent chunk proving via rayon::par_iter():

let results: Vec<ChunkProofResult> = blocks
    .par_iter()
    .enumerate()
    .map(|(i, range)| {
        let sub_graph = graph.subgraph(range);
        let sub_weights = weights.subset(range);
        prove_model_aggregated_onchain(&sub_graph, &chunk_inputs[i], &sub_weights)
    })
    .collect();

Different chunks touch different mmap pages — no contention between workers.

Multi-GPU Chunked (prove_model_chunked_multi_gpu)

For H100/A100 multi-GPU servers:

Phase 1: Sequential forward pass on CPU
Phase 2: Partition chunks across GPUs (by estimated memory + matmul count)
Phase 3: Parallel proving with device affinity via DeviceGuard RAII

Each worker thread pins to a specific GPU via DeviceGuard::new(device_id), which uses RAII to restore the previous device on drop:

std::thread::scope(|s| {
    for assignment in &assignments {
        s.spawn(move || {
            let _guard = DeviceGuard::new(assignment.device_id);
            prove_model_aggregated_onchain_auto(&sub_graph, input, &weights)
            // _guard dropped here → restore previous device
        });
    }
});

Proof Composition (compose_chunk_proofs)

After parallel proving, individual chunk proofs are composed into a single AggregatedModelProofOnChain:

  1. Validate: Total chunk nodes == graph nodes
  2. Remap IDs: local_id → (chunk_offset + local_id) for matmul proofs, attention proofs, etc.
  3. Re-run forward pass: Collect non-matmul layer data (activations, adds, muls, layernorms)
  4. Compute commitments: layer_chain_commitment, io_commitment
  5. Build unified STARK: Single STARK for all non-matmul components across all chunks
  6. Assemble: Remapped matmul proofs + new unified STARK → final proof

Cost: O(forward_pass) + O(build_unified_stark) — no weight re-loading or matmul re-proving.

Memory Budget Estimation

The pipeline estimates memory requirements before execution:

pub fn estimate_tile_memory(node_id, tile_config) -> usize {
    let tile_k = tile_config.tile_k;
    let (m, k, n) = node_dimensions(node_id);
    let m_pad = m.next_power_of_two();
    let tile_k_pad = tile_k.next_power_of_two();
    let n_pad = n.next_power_of_two();

    // B tile (f32) + A slice (f32) + C tile (f32) + padded copies (QM31 = 16 bytes)
    tile_k * n * 4 + m * tile_k * 4 + m * n * 4
        + (m_pad * tile_k_pad + tile_k_pad * n_pad + m_pad * n_pad) * 16
}

Example budgets for Qwen3-14B (m=2048, k=5120, n=14336):

Tile SizeB TileA SliceC TilePaddedTotal
tile_k=25614 MB2 MB112 MB~280 MB~408 MB
tile_k=51228 MB4 MB112 MB~310 MB~454 MB
tile_k=102456 MB8 MB112 MB~370 MB~546 MB

The padded copies dominate because QM31 elements are 16 bytes (4× M31) and dimensions are rounded to power-of-2.

Budget selection: The prover auto-selects tile_k to fit within the TILED_MEMORY_BUDGET (default 64 GB for H100, preventing tiling entirely for most matmuls). On memory-constrained hardware, smaller tile sizes trade proving speed for memory.

Backward Compatibility

All existing callers pass None for the new precomputed parameter, making this fully backward compatible. The new tiled_matmul_proofs field on AggregatedModelProofOnChain defaults to Vec::new() in all non-streaming paths.

Performance Summary

OptimizationBeforeAfterImprovement
Weight I/O per matmul280 MB (full)56 MB (tile)5× less memory
Aggregation weight load280 MB (re-loaded)0 MB (precomputed)Eliminated
I/O latencySequentialDouble-bufferedHidden behind compute
Multi-matmul weights44 GB redundant I/O0 (precomputed injection)Eliminated
Crash recoveryStart overCheckpoint resumeResumable