GPU Acceleration
Transparent GPU acceleration via WebGPU. Move models and tensors to GPU with a single call — all operations automatically dispatch to 24+ WGSL compute shaders with CPU fallback.
Overview
RUMUS uses WebGPU (via wgpu) for GPU acceleration. The API is simple: call .to_gpu() on tensors or models, and all subsequent operations automatically run on the GPU. If no GPU is available, operations fall back to CPU without errors.
use rumus::tensor::type">Tensor;
use rumus::nn;
"token-comment">// Move a tensor to the GPU
let x = type">Tensor::randn(&[32, 784]);
let x_gpu = x.to_gpu();
"token-comment">// Move an entire model to the GPU
model.to_gpu();
"token-comment">// All operations automatically dispatch to GPU
let output = model.forward(&x_gpu);Moving to GPU
Tensors provide .to_gpu() to create a GPU-backed copy. Models implement the ModuleToGpu trait, which moves all parameters in-place. Every operation checks is_gpu() at dispatch time to choose the GPU or CPU path.
"token-comment">// Tensors: .to_gpu() returns a GPU-backed tensor
let x = type">Tensor::randn(&[64, 3, 28, 28]);
let x_gpu = x.to_gpu();
"token-comment">// Models: .to_gpu() moves all parameters in-place
"token-comment">// via the ModuleToGpu trait
model.to_gpu();
"token-comment">// Operations check is_gpu() at dispatch time:
"token-comment">// → GPU tensor: runs WGSL compute shader
"token-comment">// → CPU tensor: runs CPU fallback
let output = model.forward(&x_gpu);GPU Architecture
The GPU subsystem is built on three core components that work together to provide efficient, safe GPU computation.
GpuContext
A OnceLock singleton that lazily initializes the GPU device, queue, pipeline cache, and buffer pool on first use. It never panics on missing hardware.
"token-comment">// GpuContext is a OnceLock singleton — initialized
"token-comment">// lazily on first GPU operation. Never panics on
"token-comment">// missing hardware; falls back to CPU gracefully.
"token-comment">//
"token-comment">// Internally holds:
"token-comment">// - wgpu::Device
"token-comment">// - wgpu::Queue
"token-comment">// - PipelineCache
"token-comment">// - BufferPoolPipelineCache
Pre-compiles 50+ compute pipelines across 24 shader modules. Pipeline selection is by enum variant, giving compile-time guarantees that every GPU operation has a valid pipeline.
"token-comment">// PipelineCache: 50+ compute pipelines across 24
"token-comment">// shader modules — all validated at compile time.
"token-comment">//
"token-comment">// Pipelines are created once and cached. Each GPU op
"token-comment">// looks up its pipeline by enum variant, guaranteeing
"token-comment">// no runtime pipeline compilation stalls.
"token-comment">//
"token-comment">// Bind group layouts are shared across compatible
"token-comment">// pipelines, minimizing GPU resource usage.BufferPool
A thread-safe buffer cache using power-of-2 bucketing. Buffers are recycled via the Drop trait, eliminating repeated GPU memory allocation during training loops.
"token-comment">// BufferPool: thread-safe GPU buffer cache with
"token-comment">// power-of-2 bucketing. When a buffer is dropped,
"token-comment">// it is returned to the pool for reuse.
"token-comment">// Allocation: finds the smallest power-of-2 bucket
"token-comment">// that fits the request, returns a cached buffer
"token-comment">// or allocates a new one.
let buffer = pool.allocate(1024); "token-comment">// → 1024-byte buffer
"token-comment">// Drop: buffer is returned to its bucket
"token-comment">// automatically via the Drop trait.
drop(buffer); "token-comment">// → recycled, not deallocated
"token-comment">// This eliminates repeated GPU memory allocation
"token-comment">// during training loops.WGSL Shader Modules
RUMUS ships 24 WGSL shader modules with 80+ entry points, covering every operation needed for training and inference — from element-wise ops to fused layer normalization and batched matrix multiplication.
| Category | Modules | Operations |
|---|---|---|
| Element-wise & Activations | elementwise.wgsl, activations.wgsl | add, sub, mul, div, neg, relu, sigmoid, tanh, gelu, leaky_relu (forward + backward) |
| Linear Algebra | matmul.wgsl, bmm.wgsl, bias.wgsl | matmul, batched matmul (3D Z-axis dispatch), bias add |
| Broadcasting & Scaling | broadcast.wgsl, broadcast_scale.wgsl, fused_scale.wgsl | N-dim broadcast binary ops, reduce_sum, fused scaling |
| Convolution & Pooling | conv.wgsl, pool.wgsl, adaptive_pool.wgsl | conv2d forward/backward, max_pool2d forward/backward, adaptive avg/max pool |
| Normalization | layer_norm.wgsl, layer_norm_bw.wgsl, layer_norm_grad_weight.wgsl, batch_norm.wgsl, batch_norm_bw.wgsl | layer norm (3-phase fused forward, backward, grad weight), batch norm (per-channel with running stats, backward) |
| Softmax | softmax.wgsl, softmax_bw.wgsl | row-wise softmax with Log-Sum-Exp trick, softmax backward |
| Dropout & Loss | dropout.wgsl, fused_dropout.wgsl, cross_entropy.wgsl | dropout, fused dropout, cross_entropy_loss (forward & backward) |
| Optimizer | optim.wgsl, adamw.wgsl | sgd_step, adam_step, adamw_step (fused weight decay) |
| Utility | contiguous.wgsl, embedding.wgsl, unary.wgsl | contiguous copy, embedding lookup, unary ops (exp, log, sqrt, neg) |
| Casting | cast.wgsl | F32↔F16 cast kernels (requires enable f16;) |
| Gradient Clipping | reduce_sum_sq.wgsl | Sum of squares reduction for gradient clipping (shared memory tree reduction) |
| Quantization | quantize.wgsl, dequantize.wgsl, matmul_q8.wgsl | Symmetric block quantization (parallel abs_max reduction, packs 4 i8s per u32), block dequantization (sign-extended i8 extraction), mixed-precision matmul (scalar activations × Q8 weights, 16x16 tiled, on-the-fly dequant in registers) |
| Attention | flash_attn.wgsl | FlashAttention: block-tiled online softmax, 64 threads, 16KB shared memory, eliminates O(N²) VRAM |
| Sparse Graph (rumus-graph) | spmm.wgsl | Sparse Matrix-Matrix Multiply: 1 thread/node, edge-outer/dim-inner loop, 256 threads/workgroup |
| Vision (rumus-vision) | conv2d_direct.wgsl, conv2d_backward_data.wgsl, conv2d_backward_weight.wgsl, maxpool2d_direct.wgsl, maxpool2d_backward.wgsl, qmatmul_int4.wgsl, qmatmul_int4_transpose.wgsl | Direct sliding-window 2D convolution with padding/stride/dilation, gradient w.r.t. input for transposed convolution routing, gradient w.r.t. weight accumulation, max pooling with f16-safe argmax tracking, scatter backward for max pooling, fused INT4 dequant-matmul with register-level bit unpacking |
Per-Resource Fences
Instead of global pipeline barriers, RUMUS uses per-resource fences with AtomicUsizesentinels. Each GPU buffer tracks its last submission index. Before reading, only that specific buffer's fence is awaited — independent operations on different buffers proceed in parallel without stalls.
"token-comment">// Per-resource fences: each GPU buffer carries an
"token-comment">// AtomicUsize sentinel tracking its last submission.
"token-comment">//
"token-comment">// Before reading a buffer, RUMUS checks its fence:
"token-comment">// - If the fence matches the current submission
"token-comment">// index, the buffer is ready.
"token-comment">// - Otherwise, only that buffer's fence is awaited.
"token-comment">//
"token-comment">// This avoids global pipeline stalls — independent
"token-comment">// operations on different buffers proceed in parallel.WGSL Preprocessor
RUMUS includes a shader preprocessor that enables precision-agnostic WGSL shaders. The function preprocess_shader(source, dtype) prepends the appropriate type alias based on the target dtype:
DType::F32: prependsalias scalar = f32;DType::F16: prependsenable f16; alias scalar = f16;
This allows 35+ shaders to be written with a generic scalar type and compiled for either precision without code duplication.
Dual pipeline compilation: The F32 pipeline set is always compiled. The F16 pipeline set is compiled only when GpuContext::supports_f16 returns true, avoiding compilation errors on hardware without F16 support.
JIT Kernel Fusion
jit::compile(|| { ... }) captures element-wise ops and fuses them into a single WGSL kernel. The closure is traced into a FusionBlock IR, which is compiled once and cached in a JitCache for O(1) pipeline reuse. This reduces VRAM bandwidth bottleneck by eliminating intermediate buffer round-trips between kernels.
Feature gate: Enable with --features jit.
"token-comment">// JIT Kernel Fusion: captures element-wise ops and
"token-comment">// fuses them into a single WGSL kernel at runtime.
"token-comment">//
"token-comment">// jit::compile(|| { ... }) traces the closure, builds
"token-comment">// a FusionBlock IR, and compiles a fused WGSL kernel.
"token-comment">// JitCache stores compiled pipelines for O(1) reuse.
"token-comment">//
"token-comment">// This eliminates VRAM bandwidth bottleneck by avoiding
"token-comment">// intermediate buffer round-trips between kernels.
"token-comment">//
"token-comment">// Enable with: --features jit
use rumus::jit;
let fused = jit::compile(|| {
let y = x.mul(&w).add(&b);
y.relu()
});Multi-GPU
MultiGpuContext::get() enumerates all available GPUs. DataParallel provides scatter-gather training by splitting batches across devices. FSDP enables parameter sharding for large models that exceed single-GPU memory. AllReduceSync handles gradient averaging across devices.
Feature gate: Enable with --features multi_gpu.
"token-comment">// Multi-GPU support: scatter-gather training and
"token-comment">// parameter sharding across multiple GPUs.
"token-comment">//
"token-comment">// MultiGpuContext::get() enumerates all available GPUs.
"token-comment">// DataParallel splits batches across devices.
"token-comment">// FSDP shards parameters for memory efficiency.
"token-comment">// AllReduceSync averages gradients across devices.
"token-comment">//
"token-comment">// Enable with: --features multi_gpu
use rumus::multi_gpu::{MultiGpuContext, DataParallel, FSDP};
let ctx = MultiGpuContext::get();
let parallel_model = DataParallel::new(&model, &ctx);
"token-comment">// or for large models:
let sharded_model = FSDP::new(&model, &ctx);Custom Ops API
The CustomOp trait lets you register user-defined WGSL kernels. CustomOpCache handles dynamic bind group layout generation and pipeline caching, so your custom kernels integrate seamlessly with the existing GPU subsystem.
Feature gate: Enable with --features gpu.
"token-comment">// Custom Ops API: register user-defined WGSL kernels.
"token-comment">//
"token-comment">// Implement the CustomOp trait and register your kernel.
"token-comment">// CustomOpCache handles dynamic bind group layout
"token-comment">// generation and pipeline caching.
"token-comment">//
"token-comment">// Enable with: --features gpu
use rumus::gpu::{CustomOp, CustomOpCache};
struct MyKernel;
impl CustomOp for MyKernel {
fn wgsl_source(&self) -> &str { include_str!("my_kernel.wgsl") }
fn dispatch(&self, encoder: &mut Encoder, buffers: &[&Buffer]) { "token-comment">/* ... */ }
}Full GPU Training Example
A complete end-to-end example: define a Transformer model, move it to GPU, train with AdamW, and save the trained weights.
use rumus::tensor::type">Tensor;
use rumus::nn::{
self, type">Module, type">Linear, LayerNorm, type">Dropout,
MultiheadAttention, TransformerBlock, Embedding, type">Flatten,
};
use rumus::optim::{type">Trainer, type">AdamW};
"token-attribute">#[derive(type">Module)]
struct TinyTransformer {
embedding: Embedding,
block1: TransformerBlock,
block2: TransformerBlock,
ln_final: LayerNorm,
fc_out: type">Linear,
}
impl TinyTransformer {
fn new(vocab_size: type">usize, d_model: type">usize, n_heads: type">usize, ff_dim: type">usize) -> type">Self {
type">Self {
embedding: Embedding::new(vocab_size, d_model),
block1: TransformerBlock::new(d_model, n_heads, ff_dim, 0.1),
block2: TransformerBlock::new(d_model, n_heads, ff_dim, 0.1),
ln_final: LayerNorm::new(d_model),
fc_out: type">Linear::new(d_model, vocab_size),
}
}
fn forward(&self, x: &type">Tensor) -> type">Tensor {
let x = self.embedding.forward(x); "token-comment">// [B, T] → [B, T, D]
let x = self.block1.forward(&x); "token-comment">// self-attn + FFN
let x = self.block2.forward(&x); "token-comment">// self-attn + FFN
let x = self.ln_final.forward(&x); "token-comment">// final layer norm
"token-comment">// Take the last token's representation
let last = x.select(1, -1); "token-comment">// [B, D]
self.fc_out.forward(&last) "token-comment">// [B, vocab]
}
}
fn main() -> type">Result<(), Box<dyn std::error::Error>> {
let mut model = TinyTransformer::new(10000, 128, 4, 512);
model.to_gpu(); "token-comment">// move all parameters to GPU
model.train();
let optimizer = type">AdamW::new(model.parameters(), 3e-4);
let mut trainer = type">Trainer::new(optimizer);
for epoch in 0..20 {
for (tokens, labels) in &train_loader {
let tokens = tokens.to_gpu();
let labels = labels.to_gpu();
trainer.train_step(|| {
let logits = model.forward(&tokens);
nn::cross_entropy_loss(&logits, &labels)
});
}
let avg_loss = trainer.epoch_avg_loss();
println!("Epoch {}: loss = {:.4}", epoch, avg_loss);
}
"token-comment">// Save trained model
model.eval();
let state = model.state_dict("");
nn::save_safetensors(&state, "transformer.safetensors")?;
type">Ok(())
}