Neural Networks

Build neural networks in pure Rust with RUMUS's module system, automatic parameter tracking, and built-in serialization.

Module System

The #[derive(Module)] proc macro is the foundation of RUMUS neural networks. Apply it to any struct containing layer fields, and it automatically generates implementations for parameter collection, training/eval mode toggling, and state serialization.

rust
use rumus::nn::{self, type">Module, type">Linear, type">Conv2d, type">MaxPool2d, type">Flatten, type">Dropout};
use rumus::tensor::type">Tensor;

"token-attribute">#[derive(type">Module)]
struct MyModel {
    fc1: type">Linear,
    fc2: type">Linear,
    dropout: type">Dropout,
}

impl MyModel {
    fn new() -> type">Self {
        type">Self {
            fc1: type">Linear::new(784, 128),
            fc2: type">Linear::new(128, 10),
            dropout: type">Dropout::new(0.5),
        }
    }

    fn forward(&self, x: &type">Tensor) -> type">Tensor {
        let x = nn::relu(&self.fc1.forward(x));
        let x = self.dropout.forward(&x);
        self.fc2.forward(&x)
    }
}

The macro auto-generates five essential methods on your struct:

rust
"token-comment">// All auto-generated by "token-attribute">#[derive(type">Module)]:
let params = model.parameters();    "token-comment">// type">Vec<Parameter>
model.train();                       "token-comment">// enable training mode
model.eval();                        "token-comment">// enable inference mode
let state = model.state_dict("");    "token-comment">// serialize weights
model.load_state_dict(&state);       "token-comment">// restore weights

Available Layers

RUMUS provides the core building blocks for modern neural network architectures.

LayerConstructorDescription
LinearLinear::new(in_features, out_features)Fully-connected layer with weight and bias
Conv2dConv2d::new(in_channels, out_channels, kernel_size)2D convolution with learned filters
MaxPool2dMaxPool2d::new(kernel_size, stride)2D max pooling for spatial downsampling
FlattenFlatten::new()Flatten spatial dims into a 1D vector
DropoutDropout::new(p)Randomly zero elements during training

Activations & Loss Functions

Activation functions and loss functions are provided as free functions in the nn module. They operate on tensor references and return new tensors, fully integrated with the autograd graph.

rust
"token-comment">// Activations
let activated = nn::relu(&x);

"token-comment">// Loss Functions
let loss = nn::mse_loss(&predictions, &targets);
let loss = nn::cross_entropy_loss(&logits, &labels);

Building a CNN

Here is a complete convolutional neural network following the classic pattern: Conv2d → ReLU → MaxPool2d → Flatten → Linear. The #[derive(Module)] macro handles all the boilerplate while you define only the architecture and forward pass.

rust
"token-attribute">#[derive(type">Module)]
struct ConvNet {
    conv1: type">Conv2d,
    pool: type">MaxPool2d,
    flatten: type">Flatten,
    fc1: type">Linear,
}

impl ConvNet {
    fn new() -> type">Self {
        type">Self {
            conv1: type">Conv2d::new(1, 32, 3),
            pool: type">MaxPool2d::new(2, 2),
            flatten: type">Flatten::new(),
            fc1: type">Linear::new(32 * 13 * 13, 10),
        }
    }

    fn forward(&self, x: &type">Tensor) -> type">Tensor {
        let x = nn::relu(&self.conv1.forward(x));  "token-comment">// type">Conv2d → ReLU
        let x = self.pool.forward(&x);              "token-comment">// type">MaxPool2d
        let x = self.flatten.forward(&x);           "token-comment">// type">Flatten
        self.fc1.forward(&x)                        "token-comment">// type">Linear
    }
}

State Management

RUMUS uses the SafeTensors format for model serialization. Weights are keyed with dot-path notation (e.g., "fc1.weight", "fc1.bias"), making state dicts human-readable and easy to inspect.

rust
"token-comment">// Save model weights
let state = model.state_dict("");
nn::save_safetensors(&state, "model.safetensors")?;

"token-comment">// Keys use dot-path notation:
"token-comment">// "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"

"token-comment">// Load model weights
let state = nn::load_safetensors("model.safetensors")?;
model.load_state_dict(&state);