Neural Networks
Build neural networks in pure Rust with RUMUS's module system, automatic parameter tracking, and built-in serialization. From CNNs to Transformers, all layers integrate with the autograd engine.
Module System
The #[derive(Module)] proc macro is the foundation of RUMUS neural networks. Apply it to any struct containing layer fields, and it automatically generates implementations for parameter collection, training/eval mode toggling, and state serialization.
use rumus::nn::{self, type">Module, type">Linear, type">Conv2d, ConvTranspose2d,
type">MaxPool2d, AdaptiveAvgPool2d, type">Flatten, type">Dropout,
BatchNorm2d, LayerNorm, Embedding,
MultiheadAttention, TransformerBlock};
use rumus::tensor::type">Tensor;
"token-attribute">#[derive(type">Module)]
struct MyModel {
fc1: type">Linear,
fc2: type">Linear,
dropout: type">Dropout,
}
impl MyModel {
fn new() -> type">Self {
type">Self {
fc1: type">Linear::new(784, 128),
fc2: type">Linear::new(128, 10),
dropout: type">Dropout::new(0.5),
}
}
fn forward(&self, x: &type">Tensor) -> type">Tensor {
let x = nn::relu(&self.fc1.forward(x));
let x = self.dropout.forward(&x);
self.fc2.forward(&x)
}
}The macro auto-generates five essential methods on your struct:
"token-comment">// All auto-generated by "token-attribute">#[derive(type">Module)]:
let params = model.parameters(); "token-comment">// type">Vec<Parameter>
model.train(); "token-comment">// enable training mode
model.eval(); "token-comment">// enable inference mode
let state = model.state_dict(""); "token-comment">// serialize weights
model.load_state_dict(&state); "token-comment">// restore weightsLayers
RUMUS provides 12 layer types covering convolutions, pooling, normalization, regularization, and embeddings.
| Layer | Constructor | Description |
|---|---|---|
| Linear | Linear::new(in_features, out_features) | Fully-connected layer with weight and bias |
| Conv2d | Conv2d::new(in_ch, out_ch, kernel) | 2D convolution with learned filters; supports stride, padding, bias options |
| ConvTranspose2d | ConvTranspose2d::new(in_ch, out_ch, kernel, stride, padding, with_bias) | Transposed convolution for learned upsampling (deconv) |
| MaxPool2d | MaxPool2d::new(kernel, stride) | 2D max pooling for spatial downsampling |
| AdaptiveAvgPool2d | AdaptiveAvgPool2d::new(output_h, output_w) | Dynamic average pooling to a fixed output size regardless of input |
| Flatten | Flatten::new() | Flatten spatial dims into a 1D vector |
| Dropout | Dropout::new(p) | Randomly zero elements during training |
| BatchNorm2d | BatchNorm2d::new(num_features) | Per-channel normalization with running mean/variance stats |
| BatchNorm2d (config) | BatchNorm2d::with_config(num_features, epsilon, momentum) | BatchNorm with custom epsilon and momentum values |
| LayerNorm | LayerNorm::new(norm_size, epsilon) | Normalizes the last dimension with learnable affine parameters |
| Embedding | Embedding::new(vocab_size, embed_dim) | Token lookup table mapping integer indices to dense vectors |
Attention & Transformer
RUMUS includes first-class support for attention-based architectures. MultiheadAttention handles Q/K/V projections and head splitting, while TransformerBlock implements the full pre-norm transformer layer with a 4x-expanded MLP.
| Component | Constructor | Description |
|---|---|---|
| MultiheadAttention | MultiheadAttention::new(d_model, num_heads) | Q/K/V linear projections, head splitting, scaled dot-product attention, output projection |
| TransformerBlock | TransformerBlock::new(d_model, num_heads) | Pre-norm block: LayerNorm → Attn → Residual → LayerNorm → MLP(4x) → Residual |
The standalone scaled_dot_product_attention function computes softmax(QKT/√d)V with an optional causal mask:
"token-comment">// Scaled dot-product attention(standalone function)
"token-comment">// Computes softmax(Q K^T / sqrt(d_k)) V with optional causal mask
let attn_output = nn::scaled_dot_product_attention(&q, &k, &v, type">Some(&mask));
"token-comment">// Without mask
let attn_output = nn::scaled_dot_product_attention(&q, &k, &v, type">None);Activations & Loss Functions
Six activation functions and two loss functions are provided as free functions in the nn module. They operate on tensor references and return new tensors, fully integrated with the autograd graph.
"token-comment">// Activations (6 total)
let a = nn::relu(&x); "token-comment">// max(0, x)
let a = nn::gelu(&x); "token-comment">// GELU with tanh approximation
let a = nn::sigmoid(&x); "token-comment">// 1 / (1 + exp(-x))
let a = nn::tanh(&x); "token-comment">// hyperbolic tangent
let a = nn::leaky_relu(&x, 0.01); "token-comment">// max(alpha * x, x)
let a = nn::softmax(&x); "token-comment">// exp(x_i) / sum(exp(x_j))
"token-comment">// Loss Functions
let loss = nn::mse_loss(&predictions, &targets);
let loss = nn::cross_entropy_loss(&logits, &labels);Example: CNN with BatchNorm
A convolutional network using Conv2d, BatchNorm2d, MaxPool2d, AdaptiveAvgPool2d, and Dropout. The AdaptiveAvgPool2d layer ensures the network accepts variable-size inputs by pooling to a fixed spatial output.
"token-attribute">#[derive(type">Module)]
struct ConvNet {
conv1: type">Conv2d,
bn1: BatchNorm2d,
conv2: type">Conv2d,
bn2: BatchNorm2d,
pool: type">MaxPool2d,
adaptive_pool: AdaptiveAvgPool2d,
flatten: type">Flatten,
fc1: type">Linear,
dropout: type">Dropout,
fc2: type">Linear,
}
impl ConvNet {
fn new() -> type">Self {
type">Self {
conv1: type">Conv2d::new(1, 32, 3), "token-comment">// 1 channel in, 32 out, 3x3 kernel
bn1: BatchNorm2d::new(32),
conv2: type">Conv2d::new(32, 64, 3),
bn2: BatchNorm2d::new(64),
pool: type">MaxPool2d::new(2, 2),
adaptive_pool: AdaptiveAvgPool2d::new(4, 4), "token-comment">// fixed 4x4 output
flatten: type">Flatten::new(),
fc1: type">Linear::new(64 * 4 * 4, 128),
dropout: type">Dropout::new(0.5),
fc2: type">Linear::new(128, 10),
}
}
fn forward(&self, x: &type">Tensor) -> type">Tensor {
let x = nn::relu(&self.bn1.forward(&self.conv1.forward(x)));
let x = self.pool.forward(&x);
let x = nn::relu(&self.bn2.forward(&self.conv2.forward(&x)));
let x = self.adaptive_pool.forward(&x); "token-comment">// variable input -> fixed 4x4
let x = self.flatten.forward(&x);
let x = nn::relu(&self.fc1.forward(&x));
let x = self.dropout.forward(&x);
self.fc2.forward(&x)
}
}Example: TinyGPT Transformer
A GPT-style causal language model built with Embedding, TransformerBlock, and LayerNorm. Each TransformerBlock contains a pre-norm multi-head self-attention layer and a feed-forward MLP that expands to 4x the model dimension.
"token-attribute">#[derive(type">Module)]
struct TinyGPT {
token_embed: Embedding,
pos_embed: Embedding,
blocks: type">Vec<TransformerBlock>,
ln_f: LayerNorm,
head: type">Linear,
}
impl TinyGPT {
fn new(vocab_size: type">usize, d_model: type">usize, num_heads: type">usize, num_layers: type">usize, max_seq_len: type">usize) -> type">Self {
let blocks = (0..num_layers)
.map(|_| TransformerBlock::new(d_model, num_heads))
.collect();
type">Self {
token_embed: Embedding::new(vocab_size, d_model),
pos_embed: Embedding::new(max_seq_len, d_model),
blocks,
ln_f: LayerNorm::new(d_model, 1e-5),
head: type">Linear::new(d_model, vocab_size),
}
}
fn forward(&self, token_ids: &type">Tensor, positions: &type">Tensor) -> type">Tensor {
"token-comment">// Embed tokens and add positional encoding
let mut x = self.token_embed.forward(token_ids)
+ self.pos_embed.forward(positions);
"token-comment">// Pass through transformer blocks
"token-comment">// Each block: LayerNorm -> MultiheadAttention -> Residual
"token-comment">// -> LayerNorm -> MLP (4x expand) -> Residual
for block in &self.blocks {
x = block.forward(&x);
}
"token-comment">// Final layer norm and project to vocab
let x = self.ln_f.forward(&x);
self.head.forward(&x) "token-comment">// logits: [batch, seq_len, vocab_size]
}
}
"token-comment">// Usage
let model = TinyGPT::new(
50257, "token-comment">// vocab_size(GPT-2 tokenizer)
512, "token-comment">// d_model
8, "token-comment">// num_heads
6, "token-comment">// num_layers
1024, "token-comment">// max_seq_len
);
model.train();
let logits = model.forward(&token_ids, &positions);
let loss = nn::cross_entropy_loss(&logits, &target_ids);
loss.backward();State Management
RUMUS uses the SafeTensors format for model serialization. Weights are keyed with dot-path notation (e.g., "blocks.0.attn.q_proj.weight"), making state dicts human-readable and compatible with nested module hierarchies.
"token-comment">// Save model weights
let state = model.state_dict("");
nn::save_safetensors(&state, "model.safetensors")?;
"token-comment">// Keys use dot-path notation:
"token-comment">// "token_embed.weight", "blocks.0.attn.q_proj.weight",
"token-comment">// "blocks.0.ln1.weight", "head.bias", etc.
"token-comment">// Load model weights
let state = nn::load_safetensors("model.safetensors")?;
model.load_state_dict(&state);ONNX Export
RUMUS can export trained models to the ONNX format for deployment with runtimes like ONNX Runtime and TensorRT. The exporter is feature-gated — enable it with --features onnx (this adds a prost dependency for protobuf serialization).
Two entry points are provided: export_onnx(model, input_specs, path, forward_fn) exports using the default opset 17, while export_onnx_with_opset(...) lets you specify a custom opset version for compatibility with older runtimes.
Under the hood, a thread-local tracer intercepts eager execution to build the ONNX graph. Module-level fusions map high-level layers to efficient ONNX operators — Linear becomes Gemm and Conv2d becomes Conv, suppressing the primitive ops that would otherwise appear. The exporter supports F16 and F32 weight serialization; Q8 quantized weights are auto-dequantized to F32 for maximum compatibility.
use rumus::onnx::export_onnx;
export_onnx(
&model,
&[("input", vec![1, 3, 224, 224], DType::F32)],
"model.onnx",
|inputs| model.forward(&inputs[0]),
)?;FlashAttention
RUMUS implements memory-efficient attention based on Tri Dao's online softmax algorithm. The WGSL kernel uses 16KB of shared memory to compute attention in a single pass without materializing the full N×N attention matrix, eliminating the O(N2) VRAM bottleneck. FlashAttention is used automatically when sequence lengths exceed the tiling threshold — no API changes required.
DataParallel
nn::DataParallel::new(module, device_ids) wraps any Module for multi-GPU data parallelism. It scatters the input batch across GPUs, runs concurrent forward passes, and gathers the output. After each optimizer step, call broadcast_weights() to re-sync parameters across devices. Requires --features multi_gpu.
use rumus::nn::DataParallel;
let model = MyModel::new();
let device_ids = vec![0, 1, 2, 3];
let dp_model = DataParallel::new(model, &device_ids);
"token-comment">// Forward pass scatters input across 4 GPUs,
"token-comment">// runs concurrent forward passes, gathers output
let output = dp_model.forward(&input);
"token-comment">// After optimizer step, re-sync weights
dp_model.broadcast_weights();Fully Sharded Data Parallel (FSDP)
nn::FSDP::new(model, device_ids, rank) shards model parameters across GPUs along dimension 0 for O(1) peak memory per device. During the forward pass, parameters are all-gathered on-demand for compute. During the backward pass, gradients are reduce-scattered across ranks. This enables training models that exceed the memory of a single GPU. Requires --features multi_gpu.
use rumus::nn::FSDP;
let model = TinyGPT::new(50257, 1024, 16, 24, 2048);
let device_ids = vec![0, 1, 2, 3];
let rank = 0; "token-comment">// this process's GPU index
let fsdp_model = FSDP::new(model, &device_ids, rank);
"token-comment">// Parameters are sharded across GPUs along dim 0.
"token-comment">// All-gather materializes full params on-demand for compute,
"token-comment">// reduce-scatter distributes gradients after backward.
let output = fsdp_model.forward(&input);
let mut grads = output.backward();
fsdp_model.reduce_scatter_grads(&mut grads);3D Parallelism (rumus-distributed)
rumus-distributed is a separate crate that combines tensor parallelism, pipeline parallelism, and data parallelism for training models at scale across many GPUs.
ColumnParallelLinear— Megatron-style column tensor parallelism:Y_t = X @ W_t, with async AllReduce for grad_xRowParallelLinear— row tensor parallelism:Y_t = X_t @ W_t, AllReduce to sum partial outputsPipelineExecutor::new(stages, num_micro_batches)— 1F1B pipeline parallel with per-micro-batch isolated autograd tapesPipelineStage { device_index, forward_fn }— represents one pipeline stage bound to a specific deviceCollectiveBarrier::new(world_size)— cross-rank barrier for gradient averagingCommThread::spawn(device, queue)— dedicated async communication threadasync_allreduce()— non-blocking AllReduce returning anAllReduceHandle
use rumus_distributed::{PipelineExecutor, PipelineStage, ColumnParallelLinear};
let stages = vec![
PipelineStage { device_index: 0, forward_fn: Box::new(|x| stage0.forward(x)) },
PipelineStage { device_index: 1, forward_fn: Box::new(|x| stage1.forward(x)) },
];
let executor = PipelineExecutor::new(stages, 4);
let grads = executor.run(&input, &|out| nn::cross_entropy_loss(out, &targets));Custom Ops
The CustomOp trait lets you define user-authored GPU kernels with full autograd integration. Implement five methods: op_name(), wgsl_source(), entry_point(), output_shape(), and backward_handler(). Then call custom_forward(op, inputs) which compiles, caches, and dispatches the kernel. Requires --features gpu.
use rumus::nn::{CustomOp, custom_forward};
use rumus::tensor::type">Tensor;
struct MySiLU;
impl CustomOp for MySiLU {
fn op_name(&self) -> &str { "my_silu" }
fn wgsl_source(&self) -> &str {
"@group(0) @binding(0) var<storage, read> input: array<type">f32>;
@group(0) @binding(1) var<storage, read_write> output: array<type">f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<type">u32>) {
let i = gid.x;
let x = input[i];
output[i] = x / (1.0 + exp(-x));
}"
}
fn entry_point(&self) -> &str { "main" }
fn output_shape(&self, input_shapes: &[&[type">usize]]) -> type">Vec<type">usize> {
input_shapes[0].to_vec()
}
fn backward_handler(&self) -> type">Option<fn(&type">Tensor, &type">Tensor) -> type">Tensor> {
type">None "token-comment">// numerical fallback
}
}
let result = custom_forward(&MySiLU, &[&input]);Graph Neural Networks (rumus-graph)
rumus-graph is a separate crate that depends on rumus with the gpu feature, providing first-class support for Graph Neural Networks. It is built on the Custom Ops Extension API, proving that the plugin engine supports entirely new AI domains beyond standard deep learning.
Graph::new(src, dst, weights, num_nodes) builds a forward CSR matrix (A) and a backward CSR matrix (AT) from edge lists, and sorts nodes by degree for load balancing.
graph.spmm(&features) performs a differentiable Sparse Matrix-Matrix Multiply for message passing: output[i] = Σ{j ∈ N(i)} A[i,j] * features[j]. The backward pass uses AT for reverse message passing via the CustomBackward trait from M19.
Internally, SparseTensor stores graphs in CSR format with GPU-resident row_ptr, col_indices, and optional values buffers. The WGSL kernel (spmm.wgsl) assigns 1 thread per node with an edge-outer/dim-inner loop for cache locality.
use rumus_graph::Graph;
"token-comment">// Build graph from edge list
let graph = Graph::new(
&[0, 0, 1, 2], "token-comment">// src nodes
&[1, 2, 2, 0], "token-comment">// dst nodes
type">None, "token-comment">// unweighted
3, "token-comment">// num_nodes
);
"token-comment">// Message passing(differentiable)
let output = graph.spmm(&node_features); "token-comment">// [num_nodes, hidden_dim]Spatial Vision Engine (rumus-vision)
rumus-vision is a separate crate that depends on rumus with the gpu feature, providing direct-convolution CNN ops that bypass the im2col memory overhead entirely. Like rumus-graph, it is built on the Custom Ops Extension API, proving the plugin engine extends to spatial vision workloads.
Direct convolution — instead of expanding input patches into a matrix (im2col), rumus-vision uses a sliding-window kernel that computes convolution directly, avoiding the large intermediate memory allocation that im2col requires.
conv2d(input, weight, bias, stride, padding, dilation) — Input shape [B, C_in, H, W], weight shape [C_out, C_in, K_h, K_w], optional bias [C_out]. Supports stride, padding, and dilation as (usize, usize) tuples. Fully differentiable.
max_pool2d(input, kernel_size, stride, padding) — Input shape [B, C, H, W], returns [B, C, H_out, W_out]. Window limited to 2048 elements. Fully differentiable with f16 argmax tracking.
Internally, rumus-vision ships 5 WGSL shaders: conv2d_direct, conv2d_backward_data, conv2d_backward_weight, maxpool2d_direct, and maxpool2d_backward.
use rumus_vision::{conv2d, max_pool2d};
let out = conv2d(&input, &weight, type">Some(&bias), (1, 1), (1, 1), (1, 1));
let pooled = max_pool2d(&out, (2, 2), (2, 2), (0, 0));INT4 Quantized Inference (rumus-vision)
rumus-vision includes INT4 weight quantization for efficient inference, compatible with AWQ/GPTQ-style group-wise scaling.
QuantizedTensor::from_f32(data, k, n, group_size)— quantizes f32 weights to INT4 with group-wise AWQ/GPTQ scalingQLinear::from_linear(linear, group_size)— quantizes a pre-trained Linear layer to INT4qlinear.forward(&x)— fused dequant-matmul:y = x @ dequant(W_int4)^T + bias- Activation gradients (grad_x) are tracked; weight gradients are frozen (inference-only weights)
- WGSL kernels:
qmatmul_int4.wgsl,qmatmul_int4_transpose.wgsl
use rumus_vision::quant::QLinear;
let qlinear = QLinear::from_linear(&pretrained_layer, 128);
let output = qlinear.forward(&input); "token-comment">// INT4 fused inferenceInference Server (rumus-serve)
rumus-serve is a standalone binary crate that provides a production-ready inference server. Built on Tokio and Axum, it exposes a /v1/generate HTTP endpoint with a continuous batching scheduler that groups incoming requests for maximum GPU utilization. A KV-cache enables O(1) decode steps for autoregressive generation. Inference runs on a dedicated GPU worker thread, keeping the async HTTP layer non-blocking.
CLI arguments include --port, --max-batch, --vocab-size, --model, and more.
"token">-comment"># Install and run the inference server
"token">-function">cargo install rumus"token-keyword">-serve
rumus"token-keyword">-serve \
"token-keyword">--model model.safetensors \
"token-keyword">--vocab-size 50257 \
"token-keyword">--port 8080 \
"token-keyword">--max-batch 32
"token">-comment"># POST /v1/generate
"token">-function">curl -X POST http://localhost:8080/v1/generate \
-H "Content-Type: application/json" \
"token-keyword">-d '{"prompt": [1, 2, 3], "max_tokens": 128}'