Autograd

Automatic differentiation in RUMUS. The autograd engine records operations on an append-only tape and replays them in reverse to compute gradients — no manual calculus required.

How It Works

RUMUS implements reverse-mode automatic differentiation using a Wengert list (also called a tape). Every differentiable operation appends a BackwardOp entry to the tape during the forward pass. The backward pass then walks the tape in reverse topological order, accumulating gradients into a GradientStore.

rust
"token-comment">// RUMUS uses a Wengert list(append-only tape) to record
"token-comment">// every differentiable operation as it executes. During the
"token-comment">// backward pass, the tape is replayed in reverse topological
"token-comment">// order to accumulate gradients.
"token-comment">//
"token-comment">//   forward:   x  →  y = relu(x)  →  loss = mse(y, target)
"token-comment">//   tape:      [ Relu { input_id, output_id },
"token-comment">//                MseLoss { pred_id, target_id, output_id } ]
"token-comment">//   backward:  walk tape in reverse topo order, accumulate grads

Recording Operations

Tape recording is automatic. Any operation on a tracked tensor (one with AutogradState::Tracked) appends its corresponding BackwardOp variant to the global tape. You never interact with the tape directly.

rust
use rumus::tensor::type">Tensor;
use rumus::nn;

"token-comment">// Operations are recorded automatically when tensors
"token-comment">// have AutogradState::Tracked set.
let x = type">Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let w = type">Tensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![2, 2]);

"token-comment">// Each of these records a BackwardOp on the tape:
let h = x.matmul(&w);       "token-comment">// records BackwardOp::Matmul
let a = nn::relu(&h);       "token-comment">// records BackwardOp::Relu
let loss = a.mul(&a);       "token-comment">// records BackwardOp::Mul

Computing Gradients

Call autograd::backward(&loss) to trigger the backward pass. This returns a GradientStore — essentially a HashMap<GradId, Tensor> — containing the gradient for every tracked tensor that contributed to the loss.

rust
use rumus::autograd;

"token-comment">// Compute all gradients in one call.
"token-comment">// Returns a type">GradientStore: HashMap<GradId, type">Tensor>
let mut grads = autograd::backward(&loss)?;

"token-comment">// Access individual gradients by parameter grad_id.
"token-comment">// remove() takes ownership — each gradient is consumed
"token-comment">// exactly once(useful for the optimizer step).
let grad_w = grads.remove(w.grad_id().unwrap());
let grad_x = grads.remove(x.grad_id().unwrap());

Inference Mode

The no_grad() function returns an RAII guard that disables tape recording for its entire lifetime. This is essential during inference (no need to compute gradients) and during optimizer parameter updates (you do not want the update step itself to be differentiated).

RAII pattern: The guard is stack-allocated. When it goes out of scope, tape recording automatically resumes. No manual cleanup needed — Rust enforces this at compile time.

rust
use rumus::autograd::no_grad;

"token-comment">// no_grad() returns an RAII guard. While the guard is
"token-comment">// alive, ALL tensor operations bypass tape recording
"token-comment">// entirely — no BackwardOps are created, no grad_ids
"token-comment">// are assigned. This is critical for inference.
{
    let _guard = no_grad();

    "token-comment">// These operations are NOT recorded on the tape.
    let pred = model.forward(&input);
    println!("prediction: {:?}", pred.data());
}
"token-comment">// Guard dropped — tape recording resumes.

"token-comment">// This is also useful during the optimizer step:
"token-comment">// you don't want parameter updates to be tracked.
{
    let _guard = no_grad();
    "token-comment">// param -= lr * grad(not recorded)
}

BackwardOp Variants

Unlike frameworks that store closures (boxed trait objects) on each graph node, RUMUS uses a concrete enum BackwardOp with 30+ variants. Each variant carries exactly the data needed to compute its local Jacobian — no more, no less. This makes the tape inspectable, serializable, and allocation-free.

Expanded from 16 to 30+: The original 16 variants (Add, Sub, Mul, Matmul, Relu, MseLoss, AddBias, Im2Col, Stack, AddChannelBias, SliceBatch, MaxPool2d, Flatten, Reshape, Dropout, CrossEntropy) have been joined by Bmm, Softmax, LayerNorm, Embedding, Sigmoid, Tanh, Gelu, LeakyRelu, BroadcastAdd, BroadcastSub, BroadcastMul, BatchNorm2d, AdaptiveAvgPool2d, and CastBackward.

rust
"token-comment">// RUMUS uses a concrete enum instead of boxed closures
"token-comment">// for backward operations. This is a deliberate design
"token-comment">// choice with several advantages:
"token-comment">//
"token-comment">//   1. No heap allocation per operation
"token-comment">//   2. Pattern-matchable — easy to inspect the tape
"token-comment">//   3. Serializable — you can save/load computation graphs
"token-comment">//   4. No lifetime issues with captured references
"token-comment">//
"token-comment">// The 30+ BackwardOp variants:
"token-comment">//
"token-comment">//   Arithmetic:     Add, Sub, Mul, Matmul, Bmm
"token-comment">//   Broadcasting:   BroadcastAdd, BroadcastSub, BroadcastMul
"token-comment">//   Activations:    Relu, Sigmoid, Tanh, Gelu, LeakyRelu
"token-comment">//   Losses:         MseLoss, CrossEntropyLoss
"token-comment">//   Normalization:  LayerNorm, BatchNorm2d
"token-comment">//   Conv/Pool:      Im2Col, type">MaxPool2d, AdaptiveAvgPool2d
"token-comment">//   Attention:      Softmax
"token-comment">//   Embedding:      Embedding
"token-comment">//   Casting:        CastBackward
"token-comment">//   Layout:         Stack, SliceBatch, type">Flatten, Reshape,
"token-comment">//                   AddBias, AddChannelBias, type">Dropout

enum BackwardOp {
    "token-comment">// --- Original 16 ---
    Add { lhs_id: GradId, rhs_id: GradId },
    Sub { lhs_id: GradId, rhs_id: GradId },
    Mul { lhs_id: GradId, rhs_id: GradId },
    Matmul { lhs_id: GradId, rhs_id: GradId,
             lhs_shape: type">Vec<type">usize>, rhs_shape: type">Vec<type">usize> },
    Relu { input_id: GradId, output_snapshot: VersionSnapshot },
    MseLoss { pred_id: GradId, target_id: GradId },
    CrossEntropyLoss { logits_id: GradId, targets: type">Vec<type">usize> },
    AddBias { input_id: GradId, bias_id: GradId },
    Im2Col { "token-comment">/* ... */ },
    Stack { "token-comment">/* ... */ },
    AddChannelBias { "token-comment">/* ... */ },
    SliceBatch { "token-comment">/* ... */ },
    type">MaxPool2d { "token-comment">/* ... */ },
    type">Flatten { "token-comment">/* ... */ },
    Reshape { "token-comment">/* ... */ },
    type">Dropout { "token-comment">/* ... */ },

    "token-comment">// --- New variants ---
    Bmm { lhs_id: GradId, rhs_id: GradId,
          lhs_shape: type">Vec<type">usize>, rhs_shape: type">Vec<type">usize> },
    Softmax { input_id: GradId, output_snapshot: VersionSnapshot },
    LayerNorm { input_id: GradId, weight_id: GradId, bias_id: GradId,
                normalized: VersionSnapshot, epsilon: type">f32 },
    BatchNorm2d { input_id: GradId, weight_id: GradId, bias_id: GradId,
                  "token-comment">/* running stats, epsilon, etc. */ },
    Embedding { indices_shape: type">Vec<type">usize>, weight_id: GradId,
                indices_snapshot: VersionSnapshot },
    Sigmoid { input_id: GradId, output_snapshot: VersionSnapshot },
    Tanh { input_id: GradId, output_snapshot: VersionSnapshot },
    Gelu { input_id: GradId, input_snapshot: VersionSnapshot },
    LeakyRelu { input_id: GradId, alpha: type">f32,
                input_snapshot: VersionSnapshot },
    BroadcastAdd { lhs_id: GradId, rhs_id: GradId,
                   lhs_shape: type">Vec<type">usize>, rhs_shape: type">Vec<type">usize> },
    BroadcastSub { lhs_id: GradId, rhs_id: GradId,
                   lhs_shape: type">Vec<type">usize>, rhs_shape: type">Vec<type">usize> },
    BroadcastMul { lhs_id: GradId, rhs_id: GradId,
                   lhs_shape: type">Vec<type">usize>, rhs_shape: type">Vec<type">usize>,
                   lhs_snapshot: VersionSnapshot,
                   rhs_snapshot: VersionSnapshot },
    AdaptiveAvgPool2d { input_id: GradId,
                        input_shape: type">Vec<type">usize> },

    "token-comment">// Cast: gradient of a dtype cast is simply a cast
    "token-comment">// in the reverse direction. No data saved — only
    "token-comment">// the source dtype for the reverse cast.
    CastBackward { input_id: GradId, source_dtype: DType },
}

Broadcasting Backward Pass

When a forward operation broadcasts a tensor (expanding size-1 dimensions), the backward pass must reduce the gradient back to the original shape. This is done by summing the upstream gradient along every dimension that was broadcast. The BackwardOp variants for broadcasting store both the original left and right shapes so the reduction dimensions can be computed at backward time.

Gradient reduction rule: For each operand, compare its original shape to the output shape. Any dimension where the operand had size 1 (and the output has size > 1) is a broadcast dimension. Sum the gradient along those dimensions, then reshape back to the operand's original shape.

rust
"token-comment">// --- Broadcasting Backward: Gradient Reduction ---
"token-comment">//
"token-comment">// When a tensor is broadcast during the forward pass,
"token-comment">// the backward pass must "un-broadcast" by summing the
"token-comment">// gradient along the dimensions that were expanded.
"token-comment">//
"token-comment">// Example:  a: [4, 1, 3]  +  b: [1, 5, 3]  →  c: [4, 5, 3]
"token-comment">//
"token-comment">//   dc/da has shape [4, 5, 3], but grad_a needs shape [4, 1, 3]
"token-comment">//   → sum along dim 1:  grad_a = dc.sum(dim=1, keepdim=true)
"token-comment">//
"token-comment">//   dc/db has shape [4, 5, 3], but grad_b needs shape [1, 5, 3]
"token-comment">//   → sum along dim 0:  grad_b = dc.sum(dim=0, keepdim=true)
"token-comment">//
"token-comment">// For BroadcastMul, the chain rule also multiplies by
"token-comment">// the other operand before reducing:
"token-comment">//   grad_a = (upstream * b_broadcast).sum(broadcast_dims_of_a)
"token-comment">//   grad_b = (upstream * a_broadcast).sum(broadcast_dims_of_b)

"token-comment">// The reduction dims are computed by comparing the original
"token-comment">// shapes stored in the BackwardOp variant against the
"token-comment">// output gradient shape.

Embedding Backward (CPU-only)

The embedding backward pass requires a scatter-add operation: multiple token indices may map to the same embedding row, so gradients must be atomically accumulated. Because WGSL (WebGPU's shading language) only supports i32 and u32 atomics — not f32 — the embedding backward pass runs entirely on the CPU.

Why CPU-only? A GPU scatter-add on f32 values would require emulating atomics via compareAndSwap loops, which is both slow and non-deterministic. Since embedding tables are typically small relative to the rest of the model, the CPU sync cost is negligible during training.

rust
"token-comment">// --- Embedding Backward: CPU-only scatter-add ---
"token-comment">//
"token-comment">// The forward pass gathers rows: output[i] = weight[idx[i]]
"token-comment">// The backward pass must scatter gradients back:
"token-comment">//   grad_weight[idx[i]] += grad_output[i]
"token-comment">//
"token-comment">// This is a scatter-add(multiple indices may hit the same
"token-comment">// row), which requires atomic addition. WGSL (WebGPU's
"token-comment">// shading language) does not support type">f32 atomics — only
"token-comment">// type">i32/type">u32 atomics are available.
"token-comment">//
"token-comment">// Therefore, the embedding backward pass runs on the CPU.
"token-comment">// The gradient tensor is synced back to CPU, the scatter-add
"token-comment">// is performed with Rust's standard type">f32 arithmetic, and the
"token-comment">// result is uploaded to GPU only if needed for the next op.
"token-comment">//
"token-comment">// This is a deliberate trade-off: embedding tables are
"token-comment">// typically small relative to the model, so the CPU-GPU
"token-comment">// sync cost is negligible compared to the rest of training.

Architecture Details

The backward pass uses Kahn's algorithmto topologically sort the computation graph before processing. This ensures that when a node's gradient is computed, all downstream contributions have already been accumulated.

Why topological sort matters: Consider the expression y = x + x. The tensor x appears twice in the Add node. Without proper ordering and edge counting, you would compute dy/dx = 1 instead of the correct dy/dx = 2.

VersionSnapshot: Some backward ops (like Relu) need forward-pass values to compute gradients. RUMUS stores Weak references via VersionSnapshot, so dropped tensors do not leak memory.

rust
"token-comment">// --- Backward Pass: Kahn's Algorithm ---
"token-comment">//
"token-comment">// The tape is a flat list, but the computation graph is
"token-comment">// a DAG. RUMUS uses Kahn's algorithm(BFS topological
"token-comment">// sort) to determine the correct backward order:
"token-comment">//
"token-comment">//   1. Count in-edges for each node(how many ops
"token-comment">//      consume this tensor as input).
"token-comment">//   2. Start from the loss node(zero out-edges).
"token-comment">//   3. Process nodes whose in-edge count reaches zero.
"token-comment">//   4. Accumulate gradients via type">GradientStore.
"token-comment">//
"token-comment">// This correctly handles diamond patterns like x + x,
"token-comment">// where a single tensor appears in multiple operations.

"token-comment">// --- Strict Edge Counting ---
"token-comment">//
"token-comment">// An AtomicUsize `total_grads` counter ensures that
"token-comment">// expressions like `x.add(&x)` count both uses of x.
"token-comment">// Without this, the gradient for x would be halved.

"token-comment">// --- VersionSnapshot with Weak Refs ---
"token-comment">//
"token-comment">// type">Some backward ops need the forward-pass tensor values
"token-comment">// (e.g., Relu needs to know which elements were zeroed).
"token-comment">// VersionSnapshot stores a Weak<...> reference so that
"token-comment">// if the original tensor is dropped, the snapshot doesn't
"token-comment">// keep dead memory alive — it simply fails gracefully.