Optimizers

First-class optimizer implementations with borrow-safe gradient consumption, GPU-fused kernels, and a high-level Trainer API.

Optimizer Trait

All optimizers implement a single trait. The step method takes a mutable reference to the gradient store, draining only the gradients for its registered parameters. This design eliminates overlapping borrows at compile time.

rust
pub trait type">Optimizer {
    fn step(
        &mut self,
        grads: &mut type">GradientStore,
    ) -> type">Result<(), type">AutogradError>;
}

SGD

Stochastic Gradient Descent with optional momentum. The simplest optimizer, ideal for convex problems and as a baseline for experiments.

rust
use rumus::optim::type">SGD;

let params = model.parameters();
let mut optimizer = type">SGD::new(params, 0.01);  "token-comment">// lr = 0.01

"token-comment">// With optional momentum
let mut optimizer = type">SGD::new(params, 0.01)
    .momentum(0.9);

Adam

Adaptive moment estimation. Maintains per-parameter first and second moment buffers for adaptive learning rates. The default choice for most deep learning tasks.

rust
use rumus::optim::type">Adam;

let params = model.parameters();
let mut optimizer = type">Adam::new(params, 0.001);  "token-comment">// lr = 0.001

"token-comment">// type">Adam maintains first and second moment buffers
"token-comment">// internally for each parameter, providing adaptive
"token-comment">// per-parameter learning rates.

AdamW

AdamW applies decoupled weight decay regularization, fixing the weight decay behavior of the original Adam optimizer. In RUMUS, AdamW runs GPU-fused kernels with zero host-device round-trips for the update step.

rust
use rumus::optim::type">AdamW;

let params = model.parameters();
let mut optimizer = type">AdamW::new(params, 0.001);  "token-comment">// lr = 0.001

"token-comment">// type">AdamW uses decoupled weight decay regularization
"token-comment">// and runs GPU-fused kernels — zero host-device
"token-comment">// round-trips for the optimizer step.

Borrow Safety

RUMUS optimizers are designed around Rust's ownership model. The step(&mut grads) method drains only the ParamIds registered to that optimizer from the gradient store. This means no overlapping mutable borrows can occur, and the compiler verifies safety at build time.

rust
"token-comment">// Borrow safety: step(&mut grads) drains only
"token-comment">// the registered ParamIds from the gradient store.
"token-comment">// No overlapping mutable borrows occur.

let mut grads = loss.backward();

"token-comment">// Safe: optimizer only touches its own parameters
optimizer.step(&mut grads)?;

"token-comment">// grads still exists — other ParamIds are untouched

GPU-Fused Training

When parameters reside on the GPU, optimizer kernels run entirely as WGSL compute shaders. Weight updates, moment buffer updates, and weight decay are fused into a single GPU dispatch — no data is copied back to the host.

rust
"token-comment">// GPU-fused training: optimizer kernels execute
"token-comment">// entirely on the GPU — no host-device round-trips.

let params = model.parameters();
let mut optimizer = type">AdamW::new(params, 0.001);

"token-comment">// When model is on GPU, optimizer.step() dispatches
"token-comment">// fused WGSL compute shaders that update weights,
"token-comment">// moments, and apply decay in a single GPU pass.
model.to_gpu();

let mut grads = loss.backward();
optimizer.step(&mut grads)?;  "token-comment">// runs on GPU

Full Training Loop

The Trainer struct provides a high-level API that wraps forward pass, backward pass, and optimizer step into a single call. It also tracks per-epoch average loss automatically.

rust
use rumus::optim::{type">Trainer, type">Adam};
use rumus::nn;

let optimizer = type">Adam::new(model.parameters(), 0.001);
let mut trainer = type">Trainer::new(optimizer);

for epoch in 0..num_epochs {
    for (inputs, targets) in &dataloader {
        trainer.train_step(|| {
            let logits = model.forward(&inputs);
            nn::cross_entropy_loss(&logits, &targets)
        });
    }

    let avg_loss = trainer.epoch_avg_loss();
    println!("Epoch {}: loss = {:.4}", epoch, avg_loss);
}