Tile-Level Streaming Architecture
Double-buffered tile pipeline with precomputed matmul injection. Eliminates weight re-loading and overlaps I/O with proving.
Problem
The original prove_model_chunked_streaming_tiled had a critical inefficiency. While Phases 1 and 2 used tile-level streaming (loading weight tiles on demand from mmap), Phase 3 fell back to pipeline.load_chunk_weights() and called the monolithic prove_model_aggregated_onchain. This re-loaded full weight matrices and re-proved all matmuls — negating the tile-level memory savings.
BEFORE (wasteful):
Phase 1: tile-by-tile forward pass → A, C per matmul (good)
Phase 2: tile-by-tile proving → TiledProofs (good)
Phase 3: load_chunk_weights() → full B matrices (BAD — re-loads everything)
prove_model_aggregated_onchain → re-proves matmuls (BAD — duplicated work)
Solution: PrecomputedMatmuls Injection
The aggregation pipeline (prove_model_aggregated_onchain_with_cache) was refactored to accept an optional PrecomputedMatmuls struct that provides pre-computed matmul outputs and proofs. When present:
- Phase 1 (forward pass): Uses
precomputed.outputs[node_id]instead ofweights.get_weight() + matmul_m31(). - Phase 2 (proving): Skipped entirely — proofs come from
precomputed.proofsandprecomputed.tiled_proofs. - Phase 3 (STARK): Runs normally for non-matmul components (activations, add, mul, layernorm, etc.).
AFTER (efficient):
Phase 1: tile-by-tile forward pass → A, C per matmul
Phase 2: tile-by-tile proving → TiledProofs
Phase 3: PrecomputedMatmuls injected → NO weight loading
aggregation pipeline → NO matmul re-proving
→ Only STARK for non-matmul components
Streaming Weight Pipeline
The StreamingWeightPipeline manages memory-mapped model weight files, loading only the data needed for the current operation:
pub struct StreamingWeightPipeline {
shards: Vec<ShardHandle>, // mmap'd SafeTensors files
tensor_to_shard: HashMap<String, usize>, // tensor name → shard index
name_map: HashMap<usize, String>, // node_id → tensor name
strategy: QuantStrategy,
graph: ComputationGraph,
}
Tensor Layout Detection
Weight tensors in SafeTensors files can be stored in either (k, n) or (n, k) layout. The pipeline detects this automatically:
fn detect_layout(shape: &[usize], k: usize, n: usize) -> TensorLayout {
if shape == [n, k] && k != n {
TensorLayout::StoredNK // needs transpose on extraction
} else {
TensorLayout::StoredKN // native layout, contiguous tiles
}
}
StoredKN (native): Tile rows [k_start..k_end] are contiguous in memory — sequential mmap reads with excellent cache locality. Byte range: k_start × n × elem_size .. k_end × n × elem_size.
StoredNK (transposed): Columns are scattered across rows. Requires extracting an (n, tile_k) sub-matrix from scattered reads, then transposing to (tile_k, n) via a cache-friendly 64×64 block transpose algorithm.
Min/Max Scanning
Quantization parameters require global min/max of each weight tensor. The pipeline scans the mmap'd bytes without allocating a full f32 buffer:
pub fn scan_tensor_minmax(data: &[u8], dtype: DType) -> (f64, f64) {
for chunk in data.chunks_exact(elem_size) {
let v = bytes_to_f32_single(chunk, dtype) as f64;
min_val = min_val.min(v);
max_val = max_val.max(v);
}
(min_val, max_val)
}
Zero-allocation scan: touches each byte exactly once, no intermediate buffers.
Tile Loading
pub fn load_weight_tile(
&self, node_id: usize, k_start: usize, k_end: usize, params: &QuantParams
) -> M31Matrix
Pipeline: resolve tensor metadata → detect layout → extract tile f32s → quantize to M31.
Memory: Only tile_k × n × 4 bytes (f32) + tile_k × n × 4 bytes (M31) are allocated.
For Qwen3-14B (k=5120, tile_k=1024, n=14336): 56 MB per tile vs 280 MB full weight.
Prefetch Strategy
The pipeline uses madvise(MADV_WILLNEED) to hint the OS to preload mmap pages before they're needed:
- StoredKN: Prefetch contiguous rows
[k_start..k_end]— page-aligned boundaries (4096-byte alignment) - StoredNK: Prefetch entire tensor (columns scattered across all rows)
For the double-buffered pipeline, prefetch is called for tile N+1 while tile N is being proven.
Key Types
PrecomputedMatmuls (aggregation.rs)
pub(crate) struct PrecomputedMatmuls {
/// Pre-computed matmul output matrices (node_id -> C matrix).
/// Phase 1 uses these instead of weights.get_weight() + matmul_m31().
pub outputs: HashMap<usize, M31Matrix>,
/// Pre-composed single-tile matmul proofs (node_id, proof).
/// Phase 2 is skipped entirely for these.
pub proofs: Vec<(usize, MatMulSumcheckProofOnChain)>,
/// Multi-tile proofs that can't be composed into a single proof.
pub tiled_proofs: Vec<(usize, TiledMatMulProof)>,
}
AggregatedModelProofOnChain — new field
pub struct AggregatedModelProofOnChain {
// ... existing fields ...
/// Multi-tile matmul proofs that couldn't be composed into a single proof.
/// Present when tile-level streaming is used with multi-tile matmuls.
pub tiled_matmul_proofs: Vec<(usize, TiledMatMulProof)>,
// ... existing fields ...
}
API
prove_model_aggregated_onchain_with_precomputed
pub(crate) fn prove_model_aggregated_onchain_with_precomputed(
graph: &ComputationGraph,
input: &M31Matrix,
weights: &GraphWeights, // Can be empty — weights not needed
precomputed: PrecomputedMatmuls,
) -> Result<AggregatedModelProofOnChain, AggregationError>
Thin wrapper around prove_model_aggregated_onchain_with_cache that passes Some(precomputed).
Data Flow
prove_model_chunked_streaming_tiled:
┌─────────────────────────────────────────────────────────┐
│ Phase 1: Forward Pass (tile-level streaming) │
│ │
│ for each matmul node: │
│ pipeline.forward_matmul_tiled(node_id, A, config) │
│ → double-buffered: loads tile N+1 while computing │
│ matmul for tile N (std::thread::scope) │
│ → accumulates C = A × B tile by tile │
│ → peak mem: 2 × tile_k × n (current + next tile) │
│ stores (node_id, A, C) in chunk_matmul_data │
└─────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────┐
│ Phase 2: Prove (tile-level streaming) │
│ │
│ for each matmul: │
│ pipeline.prove_matmul_tiled_streaming(...) │
│ → double-buffered: loads tile N+1 while proving │
│ tile N via sumcheck (std::thread::scope) │
│ → verify_tiled_matmul() sanity check │
│ stores (node_id, TiledMatMulProof) │
└─────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────┐
│ Phase 3: Aggregate (precomputed injection) │
│ │
│ Build PrecomputedMatmuls: │
│ outputs: node_id → C from Phase 1 │
│ proofs: single-tile → compose_tiled_proof() │
│ tiled_proofs: multi-tile (passed through as-is) │
│ │
│ prove_model_aggregated_onchain_with_precomputed( │
│ graph, input, GraphWeights::new(), precomputed │
│ ) │
│ → Phase 1: uses precomputed C (no weight lookup) │
│ → Phase 2: uses precomputed proofs (no re-proving) │
│ → Phase 3: builds STARK for activations/add/mul/... │
└─────────────────────────────────────────────────────────┘
Double-Buffered Tile Pipeline
Both forward_matmul_tiled and prove_matmul_tiled_streaming use a double-buffered pipeline that overlaps tile loading with computation using std::thread::scope.
Timeline
Sequential (before):
[load tile 0] [compute tile 0] [load tile 1] [compute tile 1] [load tile 2] ...
|----- T0 ----|----- T1 -------|----- T2 ----|----- T3 -------|----- T4 ----|
Double-buffered (after):
[load tile 0] [compute tile 0 ||||| load tile 1] [compute tile 1 ||||| load tile 2] ...
|----- T0 ----|------------ T1 ----------------|----------- T2 -------------------|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
I/O hidden behind compute — wall time reduced by load latency
How It Works
- Tile 0: Loaded synchronously (nothing to overlap with yet).
- Tiles 1..N:
std::thread::scopespawns a loader thread for tile N+1 while the main thread computes/proves tile N. The loader performs mmap read + f32 conversion + M31 quantization. - Single-tile case: Falls through to a simple non-pipelined path (no thread overhead for 1 tile).
Implementation (streaming.rs)
// forward_matmul_tiled — overlaps matmul with next tile load
let (c_tile, next_tile_result) = std::thread::scope(|s| {
let loader = s.spawn(|| {
self.load_weight_tile(node_id, next_k_start, next_k_end, ¶ms)
});
let c_tile = matmul_m31(&a_tile, &b_tile);
let next = loader.join().expect("loader thread panicked");
(c_tile, Some(next))
});
// prove_matmul_tiled_streaming — overlaps sumcheck proving with next tile load
let (proof_result, next_tile_result) = std::thread::scope(|s| {
let loader = s.spawn(|| {
self.load_weight_tile(node_id, next_k_start, next_k_end, ¶ms)
});
let proof = prove_matmul_sumcheck_onchain_auto(&a_padded, &b_padded, &c_padded);
let next = loader.join().expect("loader thread panicked");
(proof, next)
});
Why std::thread::scope
- Borrows
&self: The loader thread needs&selfto callload_weight_tile. Scoped threads allow borrowing from the enclosing stack frame without'staticbounds. - No
Arc/Mutexoverhead: The pipeline reference is shared read-only (mmap isSync). - Panic safety: If the loader panics,
join()propagates it cleanly. The scoped thread is always joined before the scope exits.
Latency Hiding
The load_weight_tile call performs:
- mmap page fault (kernel I/O, ~0.5-2ms per tile for NVMe)
- f32 conversion from dtype (bf16/f16 → f32, ~0.1ms per tile)
- M31 quantization (~0.2ms per tile)
For the proving path, prove_matmul_sumcheck_onchain_auto takes 50-500ms per tile (depending on dimensions). The ~1-3ms of tile loading is completely hidden behind the proving time — effectively free I/O.
For the forward path, matmul_m31 takes 1-50ms per tile. The loading is partially or fully hidden depending on tile size.
Memory Impact
For a matmul with dimensions m × k × n and tile size tile_k:
| Stage | Before | After |
|---|---|---|
| Forward pass | k × n × 4 bytes (full weight) | tile_k × n × 4 bytes (1 tile) |
| Proving | k × n × 4 bytes (full weight, loaded twice) | tile_k × n × 4 bytes (1 tile) |
| Aggregation | k × n × 4 bytes (full weight, loaded AGAIN) | 0 bytes (precomputed) |
Example: Qwen3-14B matmul 5120×14336:
- Full weight: 280 MB per matrix
- Tile (tile_k=1024): 56 MB per tile
- Aggregation weight load eliminated: 280 MB saved per matmul
- For 160 matmuls in a transformer block: ~44 GB of redundant weight I/O eliminated
Proof Composition
Single-tile proofs (where TiledMatMulProof.tile_proofs.len() == 1) are composed into standard MatMulSumcheckProofOnChain via compose_tiled_proof(). This makes them indistinguishable from non-tiled proofs in the final AggregatedModelProofOnChain.
Multi-tile proofs (2+ tiles) cannot be composed into a single sumcheck proof because the Fiat-Shamir challenges differ per tile. These are stored in the tiled_matmul_proofs field for separate verification.
Chunked Proving Pipeline
The tile-streaming architecture integrates with three chunked proving strategies:
Sequential Chunked (prove_model_chunked)
For memory-constrained environments. Processes one chunk at a time:
for each chunk:
1. Load checkpoint (if resumable)
2. Load chunk weights from mmap
3. Prove chunk → ChunkProofResult
4. Save checkpoint for resumability
5. Carry output activation to next chunk
6. Drop chunk weights (free memory)
Checkpoints are saved as chunk_{i}.json for crash recovery:
pub struct ChunkProofResult {
pub chunk_index: usize,
pub node_range: Range<usize>,
pub proof: AggregatedModelProofOnChain,
pub output: M31Matrix, // output activation for next chunk
}
Parallel Chunked (prove_model_chunked_parallel)
Uses rayon for multi-core proving:
Phase 1 (sequential): Forward pass on CPU to compute all chunk inputs.
Phase 2 (parallel): Independent chunk proving via rayon::par_iter():
let results: Vec<ChunkProofResult> = blocks
.par_iter()
.enumerate()
.map(|(i, range)| {
let sub_graph = graph.subgraph(range);
let sub_weights = weights.subset(range);
prove_model_aggregated_onchain(&sub_graph, &chunk_inputs[i], &sub_weights)
})
.collect();
Different chunks touch different mmap pages — no contention between workers.
Multi-GPU Chunked (prove_model_chunked_multi_gpu)
For H100/A100 multi-GPU servers:
Phase 1: Sequential forward pass on CPU
Phase 2: Partition chunks across GPUs (by estimated memory + matmul count)
Phase 3: Parallel proving with device affinity via DeviceGuard RAII
Each worker thread pins to a specific GPU via DeviceGuard::new(device_id), which uses RAII to restore the previous device on drop:
std::thread::scope(|s| {
for assignment in &assignments {
s.spawn(move || {
let _guard = DeviceGuard::new(assignment.device_id);
prove_model_aggregated_onchain_auto(&sub_graph, input, &weights)
// _guard dropped here → restore previous device
});
}
});
Proof Composition (compose_chunk_proofs)
After parallel proving, individual chunk proofs are composed into a single AggregatedModelProofOnChain:
- Validate: Total chunk nodes == graph nodes
- Remap IDs:
local_id → (chunk_offset + local_id)for matmul proofs, attention proofs, etc. - Re-run forward pass: Collect non-matmul layer data (activations, adds, muls, layernorms)
- Compute commitments:
layer_chain_commitment,io_commitment - Build unified STARK: Single STARK for all non-matmul components across all chunks
- Assemble: Remapped matmul proofs + new unified STARK → final proof
Cost: O(forward_pass) + O(build_unified_stark) — no weight re-loading or matmul re-proving.
Memory Budget Estimation
The pipeline estimates memory requirements before execution:
pub fn estimate_tile_memory(node_id, tile_config) -> usize {
let tile_k = tile_config.tile_k;
let (m, k, n) = node_dimensions(node_id);
let m_pad = m.next_power_of_two();
let tile_k_pad = tile_k.next_power_of_two();
let n_pad = n.next_power_of_two();
// B tile (f32) + A slice (f32) + C tile (f32) + padded copies (QM31 = 16 bytes)
tile_k * n * 4 + m * tile_k * 4 + m * n * 4
+ (m_pad * tile_k_pad + tile_k_pad * n_pad + m_pad * n_pad) * 16
}
Example budgets for Qwen3-14B (m=2048, k=5120, n=14336):
| Tile Size | B Tile | A Slice | C Tile | Padded | Total |
|---|---|---|---|---|---|
| tile_k=256 | 14 MB | 2 MB | 112 MB | ~280 MB | ~408 MB |
| tile_k=512 | 28 MB | 4 MB | 112 MB | ~310 MB | ~454 MB |
| tile_k=1024 | 56 MB | 8 MB | 112 MB | ~370 MB | ~546 MB |
The padded copies dominate because QM31 elements are 16 bytes (4× M31) and dimensions are rounded to power-of-2.
Budget selection: The prover auto-selects tile_k to fit within the TILED_MEMORY_BUDGET (default 64 GB for H100, preventing tiling entirely for most matmuls). On memory-constrained hardware, smaller tile sizes trade proving speed for memory.
Backward Compatibility
All existing callers pass None for the new precomputed parameter, making this fully backward compatible. The new tiled_matmul_proofs field on AggregatedModelProofOnChain defaults to Vec::new() in all non-streaming paths.
Performance Summary
| Optimization | Before | After | Improvement |
|---|---|---|---|
| Weight I/O per matmul | 280 MB (full) | 56 MB (tile) | 5× less memory |
| Aggregation weight load | 280 MB (re-loaded) | 0 MB (precomputed) | Eliminated |
| I/O latency | Sequential | Double-buffered | Hidden behind compute |
| Multi-matmul weights | 44 GB redundant I/O | 0 (precomputed injection) | Eliminated |
| Crash recovery | Start over | Checkpoint resume | Resumable |