GPU Acceleration

Transparent GPU acceleration via WebGPU. Move models and tensors to GPU with a single call — all operations automatically dispatch to 24+ WGSL compute shaders with CPU fallback.

Overview

RUMUS uses WebGPU (via wgpu) for GPU acceleration. The API is simple: call .to_gpu() on tensors or models, and all subsequent operations automatically run on the GPU. If no GPU is available, operations fall back to CPU without errors.

rust
use rumus::tensor::type">Tensor;
use rumus::nn;

"token-comment">// Move a tensor to the GPU
let x = type">Tensor::randn(&[32, 784]);
let x_gpu = x.to_gpu();

"token-comment">// Move an entire model to the GPU
model.to_gpu();

"token-comment">// All operations automatically dispatch to GPU
let output = model.forward(&x_gpu);

Moving to GPU

Tensors provide .to_gpu() to create a GPU-backed copy. Models implement the ModuleToGpu trait, which moves all parameters in-place. Every operation checks is_gpu() at dispatch time to choose the GPU or CPU path.

rust
"token-comment">// Tensors: .to_gpu() returns a GPU-backed tensor
let x = type">Tensor::randn(&[64, 3, 28, 28]);
let x_gpu = x.to_gpu();

"token-comment">// Models: .to_gpu() moves all parameters in-place
"token-comment">// via the ModuleToGpu trait
model.to_gpu();

"token-comment">// Operations check is_gpu() at dispatch time:
"token-comment">// → GPU tensor: runs WGSL compute shader
"token-comment">// → CPU tensor: runs CPU fallback
let output = model.forward(&x_gpu);

GPU Architecture

The GPU subsystem is built on three core components that work together to provide efficient, safe GPU computation.

GpuContext

A OnceLock singleton that lazily initializes the GPU device, queue, pipeline cache, and buffer pool on first use. It never panics on missing hardware.

rust
"token-comment">// GpuContext is a OnceLock singleton — initialized
"token-comment">// lazily on first GPU operation. Never panics on
"token-comment">// missing hardware; falls back to CPU gracefully.
"token-comment">//
"token-comment">// Internally holds:
"token-comment">//   - wgpu::Device
"token-comment">//   - wgpu::Queue
"token-comment">//   - PipelineCache
"token-comment">//   - BufferPool

PipelineCache

Pre-compiles 50+ compute pipelines across 24 shader modules. Pipeline selection is by enum variant, giving compile-time guarantees that every GPU operation has a valid pipeline.

rust
"token-comment">// PipelineCache: 50+ compute pipelines across 24
"token-comment">// shader modules — all validated at compile time.
"token-comment">//
"token-comment">// Pipelines are created once and cached. Each GPU op
"token-comment">// looks up its pipeline by enum variant, guaranteeing
"token-comment">// no runtime pipeline compilation stalls.
"token-comment">//
"token-comment">// Bind group layouts are shared across compatible
"token-comment">// pipelines, minimizing GPU resource usage.

BufferPool

A thread-safe buffer cache using power-of-2 bucketing. Buffers are recycled via the Drop trait, eliminating repeated GPU memory allocation during training loops.

rust
"token-comment">// BufferPool: thread-safe GPU buffer cache with
"token-comment">// power-of-2 bucketing. When a buffer is dropped,
"token-comment">// it is returned to the pool for reuse.

"token-comment">// Allocation: finds the smallest power-of-2 bucket
"token-comment">// that fits the request, returns a cached buffer
"token-comment">// or allocates a new one.
let buffer = pool.allocate(1024);  "token-comment">// → 1024-byte buffer

"token-comment">// Drop: buffer is returned to its bucket
"token-comment">// automatically via the Drop trait.
drop(buffer);  "token-comment">// → recycled, not deallocated

"token-comment">// This eliminates repeated GPU memory allocation
"token-comment">// during training loops.

WGSL Shader Modules

RUMUS ships 24 WGSL shader modules with 80+ entry points, covering every operation needed for training and inference — from element-wise ops to fused layer normalization and batched matrix multiplication.

CategoryModulesOperations
Element-wise & Activationselementwise.wgsl, activations.wgsladd, sub, mul, div, neg, relu, sigmoid, tanh, gelu, leaky_relu (forward + backward)
Linear Algebramatmul.wgsl, bmm.wgsl, bias.wgslmatmul, batched matmul (3D Z-axis dispatch), bias add
Broadcasting & Scalingbroadcast.wgsl, broadcast_scale.wgsl, fused_scale.wgslN-dim broadcast binary ops, reduce_sum, fused scaling
Convolution & Poolingconv.wgsl, pool.wgsl, adaptive_pool.wgslconv2d forward/backward, max_pool2d forward/backward, adaptive avg/max pool
Normalizationlayer_norm.wgsl, layer_norm_bw.wgsl, layer_norm_grad_weight.wgsl, batch_norm.wgsl, batch_norm_bw.wgsllayer norm (3-phase fused forward, backward, grad weight), batch norm (per-channel with running stats, backward)
Softmaxsoftmax.wgsl, softmax_bw.wgslrow-wise softmax with Log-Sum-Exp trick, softmax backward
Dropout & Lossdropout.wgsl, fused_dropout.wgsl, cross_entropy.wgsldropout, fused dropout, cross_entropy_loss (forward & backward)
Optimizeroptim.wgsl, adamw.wgslsgd_step, adam_step, adamw_step (fused weight decay)
Utilitycontiguous.wgsl, embedding.wgsl, unary.wgslcontiguous copy, embedding lookup, unary ops (exp, log, sqrt, neg)
Castingcast.wgslF32↔F16 cast kernels (requires enable f16;)
Gradient Clippingreduce_sum_sq.wgslSum of squares reduction for gradient clipping (shared memory tree reduction)
Quantizationquantize.wgsl, dequantize.wgsl, matmul_q8.wgslSymmetric block quantization (parallel abs_max reduction, packs 4 i8s per u32), block dequantization (sign-extended i8 extraction), mixed-precision matmul (scalar activations × Q8 weights, 16x16 tiled, on-the-fly dequant in registers)
Attentionflash_attn.wgslFlashAttention: block-tiled online softmax, 64 threads, 16KB shared memory, eliminates O(N²) VRAM
Sparse Graph (rumus-graph)spmm.wgslSparse Matrix-Matrix Multiply: 1 thread/node, edge-outer/dim-inner loop, 256 threads/workgroup
Vision (rumus-vision)conv2d_direct.wgsl, conv2d_backward_data.wgsl, conv2d_backward_weight.wgsl, maxpool2d_direct.wgsl, maxpool2d_backward.wgsl, qmatmul_int4.wgsl, qmatmul_int4_transpose.wgslDirect sliding-window 2D convolution with padding/stride/dilation, gradient w.r.t. input for transposed convolution routing, gradient w.r.t. weight accumulation, max pooling with f16-safe argmax tracking, scatter backward for max pooling, fused INT4 dequant-matmul with register-level bit unpacking

Per-Resource Fences

Instead of global pipeline barriers, RUMUS uses per-resource fences with AtomicUsizesentinels. Each GPU buffer tracks its last submission index. Before reading, only that specific buffer's fence is awaited — independent operations on different buffers proceed in parallel without stalls.

rust
"token-comment">// Per-resource fences: each GPU buffer carries an
"token-comment">// AtomicUsize sentinel tracking its last submission.
"token-comment">//
"token-comment">// Before reading a buffer, RUMUS checks its fence:
"token-comment">//   - If the fence matches the current submission
"token-comment">//     index, the buffer is ready.
"token-comment">//   - Otherwise, only that buffer's fence is awaited.
"token-comment">//
"token-comment">// This avoids global pipeline stalls — independent
"token-comment">// operations on different buffers proceed in parallel.

WGSL Preprocessor

RUMUS includes a shader preprocessor that enables precision-agnostic WGSL shaders. The function preprocess_shader(source, dtype) prepends the appropriate type alias based on the target dtype:

  • DType::F32: prepends alias scalar = f32;
  • DType::F16: prepends enable f16; alias scalar = f16;

This allows 35+ shaders to be written with a generic scalar type and compiled for either precision without code duplication.

Dual pipeline compilation: The F32 pipeline set is always compiled. The F16 pipeline set is compiled only when GpuContext::supports_f16 returns true, avoiding compilation errors on hardware without F16 support.

JIT Kernel Fusion

jit::compile(|| { ... }) captures element-wise ops and fuses them into a single WGSL kernel. The closure is traced into a FusionBlock IR, which is compiled once and cached in a JitCache for O(1) pipeline reuse. This reduces VRAM bandwidth bottleneck by eliminating intermediate buffer round-trips between kernels.

Feature gate: Enable with --features jit.

rust
"token-comment">// JIT Kernel Fusion: captures element-wise ops and
"token-comment">// fuses them into a single WGSL kernel at runtime.
"token-comment">//
"token-comment">// jit::compile(|| { ... }) traces the closure, builds
"token-comment">// a FusionBlock IR, and compiles a fused WGSL kernel.
"token-comment">// JitCache stores compiled pipelines for O(1) reuse.
"token-comment">//
"token-comment">// This eliminates VRAM bandwidth bottleneck by avoiding
"token-comment">// intermediate buffer round-trips between kernels.
"token-comment">//
"token-comment">// Enable with: --features jit

use rumus::jit;

let fused = jit::compile(|| {
    let y = x.mul(&w).add(&b);
    y.relu()
});

Multi-GPU

MultiGpuContext::get() enumerates all available GPUs. DataParallel provides scatter-gather training by splitting batches across devices. FSDP enables parameter sharding for large models that exceed single-GPU memory. AllReduceSync handles gradient averaging across devices.

Feature gate: Enable with --features multi_gpu.

rust
"token-comment">// Multi-GPU support: scatter-gather training and
"token-comment">// parameter sharding across multiple GPUs.
"token-comment">//
"token-comment">// MultiGpuContext::get() enumerates all available GPUs.
"token-comment">// DataParallel splits batches across devices.
"token-comment">// FSDP shards parameters for memory efficiency.
"token-comment">// AllReduceSync averages gradients across devices.
"token-comment">//
"token-comment">// Enable with: --features multi_gpu

use rumus::multi_gpu::{MultiGpuContext, DataParallel, FSDP};

let ctx = MultiGpuContext::get();
let parallel_model = DataParallel::new(&model, &ctx);
"token-comment">// or for large models:
let sharded_model = FSDP::new(&model, &ctx);

Custom Ops API

The CustomOp trait lets you register user-defined WGSL kernels. CustomOpCache handles dynamic bind group layout generation and pipeline caching, so your custom kernels integrate seamlessly with the existing GPU subsystem.

Feature gate: Enable with --features gpu.

rust
"token-comment">// Custom Ops API: register user-defined WGSL kernels.
"token-comment">//
"token-comment">// Implement the CustomOp trait and register your kernel.
"token-comment">// CustomOpCache handles dynamic bind group layout
"token-comment">// generation and pipeline caching.
"token-comment">//
"token-comment">// Enable with: --features gpu

use rumus::gpu::{CustomOp, CustomOpCache};

struct MyKernel;
impl CustomOp for MyKernel {
    fn wgsl_source(&self) -> &str { include_str!("my_kernel.wgsl") }
    fn dispatch(&self, encoder: &mut Encoder, buffers: &[&Buffer]) { "token-comment">/* ... */ }
}

Full GPU Training Example

A complete end-to-end example: define a Transformer model, move it to GPU, train with AdamW, and save the trained weights.

rust
use rumus::tensor::type">Tensor;
use rumus::nn::{
    self, type">Module, type">Linear, LayerNorm, type">Dropout,
    MultiheadAttention, TransformerBlock, Embedding, type">Flatten,
};
use rumus::optim::{type">Trainer, type">AdamW};

"token-attribute">#[derive(type">Module)]
struct TinyTransformer {
    embedding: Embedding,
    block1: TransformerBlock,
    block2: TransformerBlock,
    ln_final: LayerNorm,
    fc_out: type">Linear,
}

impl TinyTransformer {
    fn new(vocab_size: type">usize, d_model: type">usize, n_heads: type">usize, ff_dim: type">usize) -> type">Self {
        type">Self {
            embedding: Embedding::new(vocab_size, d_model),
            block1: TransformerBlock::new(d_model, n_heads, ff_dim, 0.1),
            block2: TransformerBlock::new(d_model, n_heads, ff_dim, 0.1),
            ln_final: LayerNorm::new(d_model),
            fc_out: type">Linear::new(d_model, vocab_size),
        }
    }

    fn forward(&self, x: &type">Tensor) -> type">Tensor {
        let x = self.embedding.forward(x);        "token-comment">// [B, T] → [B, T, D]
        let x = self.block1.forward(&x);           "token-comment">// self-attn + FFN
        let x = self.block2.forward(&x);           "token-comment">// self-attn + FFN
        let x = self.ln_final.forward(&x);         "token-comment">// final layer norm
        "token-comment">// Take the last token's representation
        let last = x.select(1, -1);                "token-comment">// [B, D]
        self.fc_out.forward(&last)                 "token-comment">// [B, vocab]
    }
}

fn main() -> type">Result<(), Box<dyn std::error::Error>> {
    let mut model = TinyTransformer::new(10000, 128, 4, 512);
    model.to_gpu();  "token-comment">// move all parameters to GPU
    model.train();

    let optimizer = type">AdamW::new(model.parameters(), 3e-4);
    let mut trainer = type">Trainer::new(optimizer);

    for epoch in 0..20 {
        for (tokens, labels) in &train_loader {
            let tokens = tokens.to_gpu();
            let labels = labels.to_gpu();

            trainer.train_step(|| {
                let logits = model.forward(&tokens);
                nn::cross_entropy_loss(&logits, &labels)
            });
        }

        let avg_loss = trainer.epoch_avg_loss();
        println!("Epoch {}: loss = {:.4}", epoch, avg_loss);
    }

    "token-comment">// Save trained model
    model.eval();
    let state = model.state_dict("");
    nn::save_safetensors(&state, "transformer.safetensors")?;

    type">Ok(())
}