flowchart LR w --- mul1(($$\times$$)) x --- mul1 mul1 --> z z --- add(($$+$$)) y --- mul2(($$\times$$)) -1 --- mul2 mul2 --> neg_y["(-y)"] neg_y --- add add --> e e --- pow2(($$\square^2$$)) pow2 --> L
Autodiff and Backprop from Scratch
What is Autodiff?
At its core, Autodiff (automatic differentiation) is just a way for a computer to automatically calculate derivatives using the chain rule. It’s the engine that powers backpropagation, the fundamental algorithm in machine learning used to efficiently train artificial neural networks with millions of parameters.
Derivatives and why they matter
Imagine you are baking a cake. If you increase the sugar by a tiny bit, how much sweeter does the cake get? Once you have the answer, you can use it to make better decisions, like adjusting the recipe to get the perfect sweetness. In math, that answer, “rate of change”, is the derivative.
In machine learning, we have a “Loss” (how wrong the model is) and we want to know: “If I change this specific weight by a tiny amount, how much does the Loss change?” The answer is the derivative of the Loss with respect to that weight. Once we compute it, we can nudge the weight in the right direction to make the Loss smaller, and over many iterations, minimize the Loss. This is the essence of training a machine learning model.
Of course, a model usually have many weights, and therefore many derivatives. Collectively, they are known as gradient. That is, the “gradient” is just a fancy word for the vector of all those derivatives when you have many weights.1
1 While derivatives and gradients are different mathematical concepts, we sometimes use them interchangeably in this tutorial.
Chain rule
Back to the cake baking example, if you use a measuring scoop to add that sugar, the scoop size (e.g. 1ml or 2ml) changes the grams of sugar you add, and the grams change the sweetness. To find the total effect of scoop size on sweetness, you just multiply those two rates of change together. That’s basically the idea behind the chain rule.
Put it in a machine learning setting (and in math notation): if the loss function \(L\) depends on the model’s raw output \(z\), and \(z\) depends on a weight \(w\), then the change of \(L\) with respect to \(w\) is
\[ \frac{dL}{dw} = \frac{dL}{dz} \cdot \frac{dz}{dw} \]
Autodiff automates this multiplication across hundreds of steps.
Backprop by example
A simple loss function
Backpropagation is an efficient application of Autodiff to calculate derivatives. Let’s see how it works with a simple loss function:
\[ L(w) = (wx - y)^2, \] where \(w\) is a weight we can adjust, \(x=1\) is some input, and \(y=2\) is the target value.
This amounts to a simple linear model with no intercept (or bias) followed by an MSE (Mean Squared Error) loss that measures how far the predictions are from their targets. In this simple example, we only have one data point, and the loss is minimized at \(w^*=2\).
Three steps
First, we break down the loss function calculation into smaller steps, each of which consists of a simple operation. This will allow us to apply the chain rule more easily.
\[ z(w) = wx, \quad e(z) = z - y, \quad \text{and} \quad L(e) = e^2 \]
Next, the backpropagation process.
Step 1. Forward pass
Starting with an initial guess of \(w=0\), we can compute the loss following a “forward pass”:
\[ z=0 \cdot 1 = 0 \;\longrightarrow\; e=0 - 2 = -2 \;\longrightarrow\; L = (-2)^2 = 4 \]
Anticipating the need for derivatives later, we also compute the local gradients at each step of the same forward pass:
\[ \frac{dz}{dw} = x = 1 \;\longrightarrow\; \frac{de}{dz} = 1 \;\longrightarrow\; \frac{dL}{de} = 2e = 2 \cdot (-2) = -4 \]
Each of these derivatives is the change of a variable with respect to it immediate inputs. That’s why they are called “local” gradients. For example, \(\frac{dz}{dw}\) is a local gradient because \(w\) is one of \(z\)’s immediate inputs.
In contrast, we call the derivative of the loss with respect to a variable a global gradient. For example, \(\frac{dL}{dw}\) is a “global” gradient. \(\frac{dL}{de}\) is both a local and global gradient depending on the context.
Step 2. Backward pass
Now, the derivative of \(L\) with respect to \(w\) can be computed “backwards” step by step using the chain rule:
\[ \begin{align*} \frac{dL}{dL} &= 1 \\ \longrightarrow \frac{dL}{de} &= \frac{dL}{dL} \cdot \frac{dL}{de} = 1 \cdot (-4) = -4 \\ \longrightarrow \frac{dL}{dz} &= \frac{dL}{de} \cdot \frac{de}{dz} = -4 \cdot 1 = -4 \\ \longrightarrow \frac{dL}{dw} &= \frac{dL}{dz} \cdot \frac{dz}{dw} = -4 \cdot 1 = -4 \end{align*} \]
We start with the “seed” gradient \(\frac{dL}{dL} = 1\), and then apply the chain rule in reverse order to compute the global gradients for each variable along the path to \(w\). This is the essence of backpropagation.
Notice the pattern: each global gradient is computed by multiplying the global gradient computed in the last step with a corresponding local gradient.
In our case, the derivative of interest is \(\frac{dL}{dw}= -4\). This means that if we increase \(w\) by a tiny amount, the loss will decrease by approximately 4 times that amount.
Step 3. Weight update
This step is called gradient descent as we minimize the loss by descending in the opposite direction of the gradient.2 We update \(w\) as follows:
2 Strictly speaking, gradient descent isn’t part of the backpropagation algorithm, but it’s the most common optimization algorithm used in conjunction with backpropagation to update the weights.
\[ w_{new} = w_{old} - \alpha \cdot \frac{dL}{dw}, \]
where \(\alpha\) is the learning rate (e.g., 0.1) that controls the step size.
The algorithm makes intuitive sense: if the derivative \(\frac{dL}{dw}\) is negative, it means that increasing \(w\) will decrease the loss, so we want to increase \(w\). Conversely, if the derivative is positive, it means that increasing \(w\) will increase the loss, so we want to decrease \(w\).
In our case, we have:
\[ w_{new} = 0 - 0.1 \cdot (-4) = 0.4 \]
This completes one iteration of the backpropagation process. The process (step 1 to step 3) is repeated until convergence.
That’s the backpropagation algorithm in a nutshell! By breaking down the loss function into smaller steps and applying the chain rule in reverse order (backward pass), we can efficiently compute the gradients needed to update our model’s parameters and minimize the loss.
Build a tiny Autodiff engine
It turns out that we can implement a simple autodiff engine in just a few lines of Python code.3 To understand how the code works, let’s start with the concept of computational graph.
3 The implementation we present in this tutorial is based on the one in the microgpt project by Andrej Karpathy. microgpt is a beautifully simple educational implementation of a GPT‑style language model written in just 200 lines of pure Python code, with no external ML libraries or dependencies. In another of his projects, the Python micrograd library, he built a similar minimalistic autodiff engine for educational purposes.
Computational graph
The first two steps of the backpropagation algorithm can be traced out in a computational graph, where we compute the intermediate values and their local gradients during the forward pass, and then apply the chain rule in reverse order during the backward pass to compute the global gradients with respect to the parameters.
The computational graph above also reflects a common technique employed in autodiff implementations to handle different operations in a unified way. For example, subtraction is represented as adding a negative number, and negation of a number is represented as multiplication by \(-1\).4
4 You may wonder why then we don’t just represent \(e^2\) as a multiplication of \(e\) with itself. The reason is that we decide to treat it as a special case of a power operation, where the power 2.
Evertyhing in a node
With the computational graph in mind, the key is to create a data structure for the nodes (the square boxes) in the graph.
In a sense, a node is a special kind of “number” (comparing to the Python’s built-in float or int type) that not only has a value and performs operations with other “numbers”, but also stores the information needed to compute the gradients during the backward pass.
Therefore, a node should keep track of the following information:
- its current value: computed during the forward pass.
- its “global” gradient: the derivative of the loss with respect to the node, calculated during the backward pass.
- its dependencies: the nodes from which it was computed. They act as “child” nodes and allow reverse traversal of the graph from the loss back to the parameters.
- its local gradients: the derivatives of this node with respect to its inputs (i.e., dependencies or “children”), computed during the forward pass and stored for the backward pass.
In the computation graph above, for example, the node \(L\), after the first round of forward and backward pass, would have the following properties:
- has a current value of \(4\) (the loss).
- has a global gradient of \(1\) \((\frac{dL}{dL} = 1)\).
- depends on the node \(e\) (the error).
- has a local gradient of \(-4\) \((\frac{dL}{de}|_{e=-2} = 2e|_{e=-2} = -4)\).
As another example, the node \(e\):
- has current value of \(-2\) (the error).
- has a global gradient of \(-4\) (calculated in backward pass as \(\frac{dL}{de}|_{e=-2} = \frac{dL}{dL} \cdot \frac{dL}{de}|_{e=-2} = 1 \cdot -4 = -4\)).
- depends on the nodes \(z\) and \(-y\).
- has a local gradient of \((1, 1)\) \((\frac{de}{dz} = 1, \frac{de}{d(-y)} = 1)\).
Let’s start implementing the node class. We will call this class Value (as it represents a special kind of “number”).
class Value:
def __init__(self, data, children=(), local_grads=()):
self.data = data # current value
self.grad = 0 # "global" gradient (initially 0, to be calculated in backward pass)
self._children = children # dependencies (child nodes)
self._local_grads = local_grads # local gradientsNext we implement the basic operations (addition, multiplication, and power) that a node can perform with other nodes. Performing an operation corresponds to a forward step in the computational graph. Each operation returns a new Value node, together with its value, dependencies and local gradients.
class Value:
def __init__(self, data, children=(), local_grads=()):
# omitted code as before
def __add__(self, other):
# Handle addition with another Value or a regular number
other = other if isinstance(other, Value) else Value(other)
# Return a new Value node with its data (the sum), dependencies, and local gradients
# The local gradients are always 1, since d/dx (x + y) = 1 and d/dy (x + y) = 1
return Value(self.data + other.data, (self, other), (1, 1))
def __mul__(self, other):
other = other if isinstance(other, Value) else Value(other)
# The local gradients are the other operand's value,
# since d/dx (x * y) = y and d/dy (x * y) = x
return Value(self.data * other.data, (self, other), (other.data, self.data))
def __pow__(self, other):
# Power rule: d/dx (x^n) = n * x^(n-1)
return Value(self.data**other, (self,), (other * self.data**(other-1),))
# Helper methods to handle subtraction and ordering
# Negation and subtraction can be implemented in terms of multiplication by -1 and addition
def __neg__(self): return self * -1
def __sub__(self, other): return self + (-other)
# Handle right-hand operations to allow for expressions like 2 + Value(3) or 2 * Value(3)
def __radd__(self, other): return self + other
def __rsub__(self, other): return other + (-self)
def __rmul__(self, other): return self * otherFinally, we implement the backward() method. When a node calls its backward() method, it performs a backward pass to compute its gradients with respect to all its downstream nodes (i.e., its children, and children’s children, etc.). The method consists of two steps:
Topological sort: We arrange all the nodes in the graph into a list so that each node appears after all its dependencies (child nodes). This ordering is crucial for the backward pass, as we must compute gradients in the correct order. It is typically implemented using a depth-first search (DFS), where we recursively visit all child nodes before adding the current node to the list.
Applying the chain rule: Once we have the nodes in the correct order, we can apply the chain rule in reverse order to compute the gradients.
For example, in the graph above, after the forward pass, if node \(L\) calls backward(), we will do the following:
- Sort the node \(L\) and all its child nodes into a list as \([w, x, z, y, -1, -y, e, L]\).
- Compute the gradient for node \(L\) itself, which is simply \(\frac{dL}{dL}=1\). This is the “seed” gradient.
- In reverse order of the topological sort list, starting with node \(L\)
- Compute the gradient for its child node \(e\) as the product of the local gradient and global gradient (both stored in node \(L\)). That is, \(\frac{dL}{de}=\frac{dL}{de} \cdot \frac{dL}{dL} = 2e|_{e = -2} \cdot 1 = -4\).
- Move on to node \(e\)
- Compute the gradient for its child node \(z\) as the product of the local gradient and global gradient stored in node \(e\). That is, \(\frac{dL}{dz} = \frac{de}{dz} \cdot \frac{dL}{de} = 1 \cdot -4 = -4\)
- Compute the local gradient for its other child node \((-y)\) as the product of the local gradient and global gradient stored in node \(e\). That is, \(\frac{dL}{d(-y)} = \frac{de}{d(-y)} \cdot \frac{dL}{de} = 1 \cdot -4 = -4\).
- In a similar way, we proceed in reverse order to \(-y\), and then \(-1, y, z, x\), and until the first node in the list, \(w\). If a node has no child nodes, we do not perform any computation (its global gradient has been calculated when its parent node was processed), and simply move on to the next one.
class Value:
# omitted code as before
def backward(self):
# 1. Build a list of all nodes in order (Topological Sort)
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._children:
build_topo(child)
topo.append(v)
build_topo(self)
# 2. Apply the chain rule backwards
self.grad = 1 # The derivative of the output with respect to itself is 1
for v in reversed(topo):
for child, local_grad in zip(v._children, v._local_grads):
# The Chain Rule: Local Gradient * Global Gradient
1 child.grad += local_grad * v.grad- 1
-
The
+=operator is used to accumulate the global gradients for each child node, since it can be used in multiple places in a graph. For example, if \(b = a^2\), \(c = 3a\), and \(L = b + c\), then the node \(a\) is a child of both \(b\) and \(c\). When we compute the gradient for \(a\), we need to sum the contributions from both paths: \(\frac{dL}{da} = \frac{dL}{db} \cdot \frac{db}{da} + \frac{dL}{dc} \cdot \frac{dc}{da}\).
Putting it all together
Let’s put everything together and see how we can use this Autodiff engine to perform gradient descent on our simple loss function \(L(w) = (wx - 2)^2\). (Click the “Run Code” button to see the output!)
Let’s first do a test run with a single forward and backward pass on the simple loss function to see if we get the expected gradients.
Now let’s test the case where a node is used in multiple places in a graph.
Finally, let’s minimize the loss function in our example by iterating through gradient descent.
Regression Exercise
Now it’s your turn! Try to implement a simple linear regression model. The goal is to find the best-fitting line for the given data points. You can use the following code template to get started. Fill in the blanks to complete the implementation.
# Three data points
x_train, y_train = [1, 2, 3], [2, 4, 6]
# Initialize parameters m (slope) and b (intercept)
m, b = Value(0.0), Value(0.0)
# Learning rate
alpha = 0.01
for _ in range(800):
# Forward: Mean Squared Error
loss = sum([(m * x_val + b - y_val)**2 for x_val, y_val in zip(x_train, y_train)])
# Backward
m.grad, b.grad = 0, 0
loss.backward()
# Update
m.data -= alpha * m.grad
b.data -= alpha * b.grad
print(f"Result: y = {m.data:.2f}x + {b.data:.2f}") # Target: y = 2.00x + 0.00Further Reading
- The microgpt project by Andrej Karpathy: where the autodiff engine in this tutorial is taken from.
- The Python micrograd library and its companion video tutorial by Andrej Karpathy: Another minimalistic autodiff engine for educational purposes, which is the basis for the one in microgpt.