From PyTorch to Burn: Why I'm Training Models in Rust Now
$ cargo run --release -- --train --backend wgpuI’ve been writing PyTorch for years. It’s comfortable. The ecosystem is massive. The docs are good. So why did I spend three weekends rewriting my training pipeline in a language that doesn’t even have a REPL?
Because I got tired of RuntimeError: shape mismatch at 2 AM.
The Breaking Point
It started with a deployment problem. I had a computer vision model — nothing fancy, just an object detection pipeline I was running on embedded Linux. The model itself was small. The PyTorch runtime it needed to ship with was not.
My deployment target was an ARM board with 2GB of RAM. The PyTorch inference runtime alone ate 1.5GB before the model even loaded. I tried TorchScript. I tried ONNX export. Each solution introduced its own layer of fragility and C++ glue code I didn’t want to maintain.
I kept thinking: the model is 15MB. Why does it need a 1.5GB runtime to multiply some matrices?
That’s when a colleague pointed me at Burn.
What Burn Actually Is
Burn is a deep learning framework written in Rust, from the ground up. Not a binding to a C++ library. Not a wrapper around ONNX. A native Rust framework with its own tensor engine, autodiff system, and training loop infrastructure.
The thing that immediately hooked me was the type system. In PyTorch, a tensor is a tensor — you find out it’s the wrong shape when your code crashes. In Burn, tensors carry their dimensionality as a compile-time constant:
// This is a 3D float tensor on whatever backend B is
let x: Tensor<B, 3> = input.reshape([batch_size, 1, height * width]);
// Try to pass this where a 2D tensor is expected?
// Compiler says no. Not at runtime. At compile time.The first time the compiler caught a shape mismatch for me — something that would have been a silent bug in Python until that one weird batch hit production — I was sold.
Defining a Model
If you know PyTorch’s nn.Module, Burn’s equivalent feels familiar. You define a struct, derive Module, and implement a forward method:
#[derive(Module, Debug)]
pub struct Classifier<B: Backend> {
conv1: nn::conv::Conv2d<B>,
conv2: nn::conv::Conv2d<B>,
pool: nn::pool::AdaptiveAvgPool2d,
dropout: nn::Dropout,
fc: nn::Linear<B>,
}
impl<B: Backend> Classifier<B> {
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, height, width] = input.dims();
let x = input.reshape([batch, 1, height, width]);
let x = self.conv1.forward(x);
let x = self.conv2.forward(x);
let x = self.pool.forward(x);
// Flatten and classify
let x = x.reshape([batch, -1]);
self.fc.forward(self.dropout.forward(x))
}
}Notice the <B: Backend> generic. That’s not decoration — it means this exact same model runs on CPU (NdArray), GPU via WebGPU (Wgpu), or even through LibTorch (Tch) as a backend. One model definition. Zero conditional compilation. You pick the backend at the call site.
Training: The Learner Pattern
PyTorch’s training loop is famously DIY — you write the loop yourself. Burn takes a more structured approach with its Learner pattern. You implement TrainStep on your model:
impl<B: AutodiffBackend> TrainStep for Classifier<B> {
type Input = ImageBatch<B>;
type Output = ClassificationOutput<B>;
fn step(&self, batch: ImageBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let output = self.forward_classification(batch);
TrainOutput::new(self, output.loss.backward(), output)
}
}Then the Learner handles the rest — metric tracking, checkpointing, learning rate scheduling. It’s opinionated, but after writing the same 50-line training loop in PyTorch for the hundredth time, I appreciate the structure.
What Actually Changed for Me
Here’s what the migration looked like in practice:
Week 1: Rewriting the model definition. This was the easy part — Burn’s layer API maps almost 1:1 to PyTorch’s. The main friction was thinking in generics instead of dynamic dispatch.
Week 2: Data loading. This is where Burn is still catching up. PyTorch’s DataLoader with multiprocessing is mature and battle-tested. Burn’s dataset handling works, but you’ll write more custom code here. The tradeoff is that your data pipeline is also type-safe and doesn’t randomly deadlock from Python multiprocessing.
Week 3: Training and iteration. Once the pieces fit together, training just… works. The compiler already caught my dimension bugs. The remaining issues were all actual model architecture problems, not framework fights.
Deployment: This is where it all paid off. My final binary — model weights included — was 24MB. Not 1.5GB. Twenty-four megabytes. It ran on the ARM board with 80MB of RAM usage at inference. No Python. No runtime. Just a static binary I could scp onto the device.
The Honest Tradeoffs
I’m not going to pretend this is all upside. Here’s what you give up:
- Ecosystem size. PyTorch has a pretrained model for everything. Burn is growing fast, but you’ll port more things yourself. The ONNX import support helps, but it’s not seamless yet.
- Iteration speed. Python’s REPL-driven workflow is genuinely faster for experimentation. Rust’s compile times mean you think more before you run. Whether that’s a bug or a feature depends on your temperament.
- Community resources. Stack Overflow has a million PyTorch answers. Burn questions go to GitHub issues and Discord. The community is helpful but small.
When Burn Makes Sense
Not every project needs this. If you’re doing research, iterating on architectures, trying 50 experiments a day — stay in PyTorch. Seriously. The friction cost isn’t worth it.
But if you’re:
- Deploying models to constrained environments (embedded, mobile, edge)
- Building inference services where binary size and cold start matter
- Already working in a Rust codebase and don’t want a Python sidecar
- Tired of runtime errors that the type system could have caught
Then Burn is not a toy. It’s a real framework, backed by real engineering, and it’s ready for production workloads.
The Bottom Line
Rust is viable for ML. Not “viable in theory” or “viable for toy projects” — viable for training real models and deploying them to real hardware. Burn proved that to me.
The ecosystem isn’t PyTorch-sized. It probably never will be. But it doesn’t need to be. It needs to be good enough that you can go from training to a 24MB binary that runs on a 2GB ARM board without a runtime dependency.
It’s already there.
$ echo $?
0