PyTorch... but much slower
There's some weird joy in those quiet moments after finishing a project. That was exactly me in January of 2025, reflecting on xsNumPy. I had spent weeks understanding the mysteries of arrays, memory buffers, and broadcasting. Like I mentioned in that story, it felt like baking a cake from scratch, where every ingredient and step mattered.
An experimental re-implementation of a few NumPy features in pure Python.
And yet, as I sat back gloating over my freshly baked xsNumPy, I found myself hungry for something more. I wanted to feel that high again and explore something new. Why not another library that had been a staple in my toolkit since 2018, PyTorch?
It was time to show some love to it. I wondered if I could build my own version of a simple autograd (automatic gradient) engine or an automatic differentiation (autodiff, for short) library, slowly to truly grasp its inner workings.
I mean, I've used PyTorch for years, but I had never really understood how it worked under the hood, so why not? Thus, it all began
Lessons from xsNumPy
Before starting off with SlowTorch, I took a moment to reflect on the lessons I learned while writing xsNumPy. It became super clear that the most valuable insights came from the process of building it and not the results. Sure, the results were important, but this reminded me that sometimes, the journey is more important than the destination.
xsNumPy taught me that slowness can be a gift.
Much like my approach to xsNumPy, I wanted to take my time with SlowTorch too. I wanted to build it slowly, understanding each component and appreciating the complexity of the system. I had the same three rules.
Rules of engagement
No LLMs or AI assistance. Every line of code and every solution had to come from my own understanding and experimentation.
Pure Python only. No external dependencies, just the standard library.
Clean, statically typed, and well-documented code that mirrored PyTorch's public APIs (mostly used ones), aiming to be a drop-in replacement where sensible.
From arrays to tensors
So, I started off with building SlowTorch. The first step was to understand the
core data structure of PyTorch, the tensor class.
It's basically the ndarray equivalent of PyTorch. Much
of the initial work in building the tensor class was similar to what I had done
with xsNumPy, as discussed here.
Quick analogy
To put it simply, if arrays were like egg cartons, tensors were like egg trays. Stacked in a way that you could easily access any egg (element) in the tray (tensor) without worrying about the carton (array) structure.
But this time, I had to add a few more things to make my tensor work like PyTorch's. I needed to implement a way to save the node and operation history for autodiff, which was a new concept for me. I also had to learn how to track operations, gradients, and compute them efficiently.
Read the f*cking docs!
PyTorch documentation were super duper helpful in understanding the
various implementation details of the tensor
class.
I started off with creating various dtypes like
float64, float32, int64, etc. alongside a simple
device.
But my devices were just strings, like "cpu" or "gpu", with no actual hardware
acceleration. The __repr__ method was pretty similar
to what I had in xsNumPy, but I had to add a few more details to reflect the
tensor's properties like shape,
device, dtype, and
whether it requires gradients or not.
See also
Complete implementation of SlowTorch's tensor with helper
functions.
Walking backwards
I was happy with my minimal implementation of the tensor class, but
then I realised I needed to implement autodiff logic. Autodiff is arguably
the most important feature of PyTorch. It allows you to compute the gradients
of tensors with respect to a loss function, which is basically the backbone of
training a neural network.
In more simple terms, it's a glorified version of calculating the chain rule from calculus.
In PyTorch, calling .backward() on a tensor
magically tells every parameter (tensor) how it should change. But how? What
does it truly mean for a tensor to change based on its history? How does it
know the appropriate path when asked to reverse its operations?
To be super duper honest, my initial attempts were a complete mess. I attempted to meticulously track every operation, parent, and child tensor, resulting in a code resembling a family tree. But Andrej's video made me realise that I was overcomplicating things and I reworked on my implementation slowly.
Inspiration
Andrej Karpathy, had explained this concept in much detail in his video where he builds micrograd, a simple autograd engine, from scratch. This video is perhaps the best introduction and explanation and the only thing you need to know about how autograd works, and it helped me a ton in understanding the core concepts.
As I rewatched the video again and again, I realised that each operation could be represented as a node, and each node could carry a little function, a recipe for how to compute its own gradient. The real breakthrough came when I stopped thinking of the graph as a static structure and started seeing it as a living, breathing thing, growing with every operation.
Thus, I created a Node class that represented each operation, and each
tensor would have a reference to its parent nodes. This way, I could traverse
the graph and compute gradients in a more structured way.
1class Tensor:
2
3 def backward(self, inputs=None, retain_graph=False):
4 if not self.requires_grad:
5 raise RuntimeError("Tensors does not require grad")
6 graph = []
7 seen = set()
8 self.grad = 1.0
9
10 def iter_graph(inputs):
11 if isinstance(inputs, Tensor) and inputs not in seen:
12 seen.add(inputs)
13 if hasattr(inputs.grad_fn, "inputs"):
14 for input in inputs.grad_fn.inputs:
15 iter_graph(input)
16 graph.append(inputs)
17
18 iter_graph(inputs if inputs else self)
19 for node in reversed(graph):
20 if node.grad_fn is not None and callable(node.grad_fn):
21 node.grad_fn()
22 self.grad = None
23 if not retain_graph:
24 self.grad_fn = None
Every tensor (node) carried a grad_fn node in the computation graph. When
you call backward, the tensor does not just look at itself; it traces its
lineage, visiting every ancestor, and calls their gradient functions in reverse
order. It is a wee bit like walking back through your own footsteps after a
long hike, pausing at each fork to remember which way you came.
Long story short, I had built a simple autograd engine that could handle basic operations like addition, multiplication, and even more complex ones like matrix multiplication and broadcasting. I was able to compute gradients for tensors with respect to a loss function, and it felt like I had finally understood the magic behind PyTorch's autodiff and my small autograd engine was working.
Special shoutout
I want to give a special shoutout to my colleague, Fatemeh Taghvaei for her patience and late night meetings. She helped me fix my broadcasting logic and brought a fresh perspective to my understanding and implementation of broadcasting in SlowTorch.
I can't thank her enough for her support and guidance during this phase of the project.
Building the building blocks
Once my tensor with autodiff support was in place, I turned my attention to
the neural networks. PyTorch's torch.nn module is a marvel of
abstractions, and I wanted to recreate it from scratch. I began by defining
Module, a base class that could hold parameters and submodules.
This class was responsible for managing the state of the model, including saving and loading weights, switching between training and evaluation modes, and handling parameter updates.
I was pacing through my development. Things were much clearer now. As more time passed, I implemented many things. The layers, activations, losses, and transforms were all implemented in their functional forms initially and later wrapped around classes much like PyTorch.
Layers were implemented as functions that took tensors as input
and returned new tensors with the layer transformation applied (forward
pass). Each layer function also had a backward pass that computed the
gradient with respect to the input tensors.
SlowTorch supports |
Forward |
Backward |
|---|---|---|
Linear (Fully Connected/Dense) |
\(f(x) = xW^T + b\) |
\(f'(x) = \begin{cases} W & \text{for } x \\ x & \text{for } W \\ 1 & \text{for } b \end{cases}\) |
Embedding |
\(f(x) = W[x]\) |
\(f'(x) = \begin{cases} 1 & \; \; \text{for } W[x] \\ 0 & \; \; \text{for } W[j], j \neq x \end{cases}\) |
For example, below is a minimal implementation of the linear layer in its functional form with its backward pass.
1def linear(input, weight, bias=None):
2 new_tensor = input @ weight.T
3 if bias is not None:
4 if bias._shape != (new_tensor._shape[-1],):
5 raise ValueError("Bias incompatible with output shape")
6 new_tensor += bias
7
8 def AddmmBackward0():
9 input.grad += new_tensor.grad @ weight
10 weight.grad += new_tensor.grad.T @ input
11 if bias is not None:
12 bias.grad += new_tensor.grad.sum(dim=0)
13
14 new_tensor.grad_fn = Node(AddmmBackward0)
15 new_tensor.grad_fn.inputs = (input, weight, bias)
16 return new_tensor
Activation functions were implemented as simple functions that took
a tensor as input and returned a new tensor with the activation
(forward pass) applied. Each activation function also had a backward
pass that computed the gradient with respect to the input tensor.
SlowTorch supports |
Forward |
Backward |
|---|---|---|
\(f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}\) |
\(f'(x) = 1 - f(x)^2\) |
|
\(f(x) = \frac{1}{1 + e^{-x}}\) |
\(f'(x) = f(x)(1 - f(x))\) |
|
\(f(x) = \max(0, x)\) |
\(\:f'(x) = \begin{cases} 0 & \qquad \qquad \qquad \; \; \text{if } x < 0 \\ 1 & \qquad \qquad \qquad \; \; \text{if } x > 0 \end{cases}\) |
|
\(f(x) = \begin{cases} x & \text{if } x > 0 \\ \alpha(e^x - 1) & \text{if } x \leq 0 \end{cases}\) |
\(\:f'(x) = \begin{cases} 1 & \qquad \qquad \quad \; \; \text{if } x > 0 \\ \alpha e^x & \qquad \qquad \quad \; \; \text{if } x \leq 0 \end{cases}\) |
|
\(f(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}\) |
\(f'(x_i) = \begin{cases} f(x_i)(1 - f(x_i)) & \text{if } i = j \\ -f(x_i)f(x_j) & \text{if } i \neq j \end{cases}\) |
|
Log Softmax |
\(f(x_i) = \log\left(\frac{e^{x_i}}{\sum_{j} e^{x_j}} \right)\) |
\(f'(x_i) = \begin{cases} 1 - f(x_i) & \qquad \quad \text{if } i = j \\ -f(x_j) & \qquad \quad \text{if } i \neq j \end{cases}\) |
For example, below is a minimal implementation of the sigmoid function with its backward pass.
1def sigmoid(input):
2 new_tensor = Tensor(input._shape, input.dtype)
3 storage = []
4 if len(input._shape) == 1:
5 it = range(input._shape[0])
6 else:
7 it = product(*[range(index) for index in input._shape])
8 for index in it:
9 try:
10 storage.append(1.0 / (1.0 + math.exp(-input[index])))
11 except IndexError:
12 continue
13 new_tensor[:] = storage
14
15 def SigmoidBackward0():
16 if input.grad is None:
17 input.grad = Tensor(input._shape, input.dtype)
18 grad = new_tensor.grad
19 input.grad -= (new_tensor * (1 - new_tensor)) * grad
20
21 new_tensor.grad_fn = Node(SigmoidBackward0)
22 new_tensor.grad_fn.inputs = (input,)
23 return new_tensor
Loss functions were implemented as functions that took two tensors,
input and target, and returned a new tensor representing the
calculated loss (forward pass). Each loss function also had a backward
pass that computed the gradient with respect to the input and target
tensors.
SlowTorch supports |
Forward |
Backward |
|---|---|---|
\(f(x, y) = \frac{1}{n} \sum_{i=1}^{n} (x_i - y_i)^2\) |
\(f'(x, y) = \begin{cases} 2(x_i - y_i) / n & \text{mean} \\ 2(x_i - y_i) & \text{sum} \\ 2(x_i - y_i) & \text{none} \end{cases}\) |
|
\(f(x, y) = \frac{1}{n} \sum_{i=1}^{n} |x_i - y_i|\) |
\(f'(x, y) = \begin{cases} |(x_i - y_i) / n| & \text{mean} \\ |(x_i - y_i)| & \text{sum} \\ |(x_i - y_i)| & \text{none} \end{cases}\) |
|
\(f(x, y) = -\sum_{i=1}^{n} y_i \log(x_i)\) |
\(f'(x, y) = \begin{cases} -\frac{y_i}{x_i} & \qquad \quad \; \; \text{mean} \\ -y_i & \qquad \quad \; \; \text{sum} \\ -y_i & \qquad \quad \; \; \text{none} \end{cases}\) |
|
\(f(x, y) = -\sum_{i=1}^{n} y_i \log(x_i)\) |
\(f'(x, y) = \begin{cases} -\frac{y_i}{x_i} & \qquad \quad \; \; \text{mean} \\ -y_i & \qquad \quad \; \; \text{sum} \\ -y_i & \qquad \quad \; \; \text{none} \end{cases}\) |
For example, below is a minimal implementation of the mean squared error (MSE) loss function with its backward pass.
1def mse_loss(input, target, reduction="mean"):
2 loss = (input - target) ** 2
3 if reduction == "mean":
4 new_tensor = loss.sum() / loss.nelement()
5 elif reduction == "sum":
6 new_tensor = loss.sum()
7 elif reduction == "none":
8 new_tensor = loss
9
10 def MseLossBackward0():
11 if None in (input.grad, target.grad):
12 input.grad = Tensor(input._shape, input.dtype)
13 target.grad = Tensor(target._shape, target.dtype)
14 grad = 2.0 / loss.nelement() if reduction == "mean" else 2.
15 input.grad += grad * (input - target)
16 target.grad -= grad * (input - target)
17
18 new_tensor.grad_fn = Node(MseLossBackward0)
19 new_tensor.grad_fn.inputs = (input, target)
20 return new_tensor
Transformations were implemented as functions that took a tensor as
input and returned a new tensor with the transformation applied
(forward pass). Each transform function also had a backward pass that
computed the gradient with respect to the input tensor.
SlowTorch supports |
Forward |
Backward |
|---|---|---|
Clone (Copy) |
\(f(x) = x.clone()\) |
\(f'(x) = \begin{cases} 1 & \text{for } x \\ 0 & \text{for } x[j], j \neq i \end{cases}\) |
Ravel (Flatten) |
\(f(x) = x.ravel()\) |
\(f'(x) = \begin{cases} 1 & \text{for } x \\ 0 & \text{for } x[j], j \neq i \end{cases}\) |
Transpose (T) |
\(f(x) = x.transpose(dim_0, dim_1)\) |
\(f'(x) = \begin{cases} 1 & \text{for } x[dim_0] \\ 1 & \text{for } x[dim_1] \\ 0 & \text{for } x[j], j \neq dim_0, dim_1 \end{cases}\) |
Reshape (View) |
\(f(x) = x.reshape(shape)\) |
N/A (no backward pass) |
Unsqueeze |
\(f(x) = x.unsqueeze(dim)\) |
N/A (no backward pass) |
One Hot Encoding |
\(f(x) = \text{one_hot}(x, classes)\) |
N/A (no backward pass) |
For example, below is a minimal implementation of the ravel (flatten) function with its backward pass.
1def ravel(input):
2 new_tensor = Tensor(input.nelement(), input.dtype)
3 new_tensor[:] = input
4
5 def ViewBackward0():
6 if input.grad is None:
7 input.grad = new_tensor.grad
8
9 new_tensor.grad_fn = Node(ViewBackward0)
10 new_tensor.grad_fn.inputs = (input,)
11 return new_tensor
Parameters were just tensors with a flag indicating whether they required gradients. For example, below is a minimal implementation of a SlowTorch parameter.
1class Parameter(Tensor):
2
3 def __init__(self, data=None, requires_grad=True):
4 if data is None:
5 data = slowtorch.randn(1, requires_grad=requires_grad)
6 else:
7 data = data.clone()
8 data.requires_grad = requires_grad
9 for key, value in data.__dict__.items():
10 setattr(self, key, value)
11
12 def __repr__(self):
13 return f"Parameter containing:\n{super().__repr__()}"
14
15 @property
16 def data(self):
17 return self
18
19 @data.setter
20 def data(self, value):
21 if not isinstance(value, Tensor):
22 raise TypeError("Parameter data must be a tensor")
23 self.storage[:] = value.storage
Massive thanks
I want to thank my friends, Sameer and Lucas Yong for their invaluable insights while implementing the Softmax function's backward pass.
Lucas derived the gradients for Softmax and
shared them via
email, while Sameer helped me implement a crude version of second-order
derivatives. Both were game-changers for me, helping me understand the core
concepts of autodiff in a way that no documentation or blog post ever
could.
Recreating neural networks from first principles reminded me of learning to ride a bicycle without training wheels. I fell off a ton. But each time I got back on, I understood a little more. I was, in a way, backpropagating my mistakes, learning from them, and adjusting my gradients.
Joy of manual optimisation
With some of my neural network modules in place, I moved on to building my optimiser, which presented another challenge. PyTorch's optimisers are elegant and efficient, but I wanted to understand their mechanics. I implemented a simple optimiser, manually updating its parameters step by step.
Once I was happy with my optimiser, I wrote a basic Optimiser
class that took a list of parameters and a learning rate, and it had an
.step() method that updated the parameters based on their gradients.
1class Optimiser:
2
3 def __init__(self, params, lr=0.01):
4 self.params = list(params)
5 self.lr = lr
6
7 def step(self):
8 for param in self.params:
9 if param.grad is None:
10 continue
11 param -= self.lr * param.grad
It was slow and clunky, but I could see every calculation, update, and mistake. I had to understand how each parameter was updated, how the learning rate (\(\mu\)) affected the updates, and how momentum (\(\mu\)) could help smooth out the learning process.
With time, I learnt techniques that improved the training process. Finally, I implemented my own version of the SGD (Stochastic Gradient Descent) optimiser, which was a simple yet effective way to update parameters based on their gradients.
Embracing slowness as a virtue
As more time passed while building SlowTorch, I realised the hardest part wasn't the code or maths, but the mindset. I knew I couldn't compete with PyTorch's raw speed, so I had to let go of my desire for speed, elegance, and perfection I always strived for as a Software Engineer.
Instead, I embraced the slowness, curiosity, and experimentation of a child. Every bug I encountered was a lesson, and every unexpected result was an opportunity to recuperate and learn. I quite often found myself talking to my code, asking it questions, coaxing it to reveal its secrets.
While SlowTorch isn't a replacement for PyTorch, it's a learning tool for those interested in understanding the inner workings of deep learning. It can perform basic tasks like training a simple neural network, but it's not intended for production use if that's not obvious already.
By the end, this was me realising the true meaning of "slow" in SlowTorch and began embracing the slowness for understanding, over speed.
For me, personally, SlowTorch serves as a reminder that true understanding and mastery come not from speed but from experience, attention, and care. It taught me that sometimes, the slowest path is the fastest way to learn.