Softmax Function and Cross-Entropy Loss

A Winning Combo in ML Classification

1 From Scores to Probabilities

Many machine learning problems are classification tasks: given some input, the model must decide which category it belongs to. For example, an image model might decide whether an image shows a cat, dog, or bird (assuming there are only three possible classes).

However, modern models rarely make hard decisions. Instead, they express uncertainty. Rather than outputting a single label, a model typically outputs probabilities, for example, 70% cat, 20% dog, and 10% bird. This lets us see both what the model predicts and how confident it is. The same idea appears in language models, where the model assigns probabilities to possible next words and then selects one based on those probabilities.

To learn good probability predictions, models are trained on labeled data. During training, the model compares its predicted probabilities to the true labels and adjusts its parameters to improve. This leads to two natural questions:

In pariticular, modern classification models such as neural-network-based models first produce logits, raw numeric scores that can be positive or negative. For example, for our three-class animal classifier, the model might output:

$$ z = [0.64, -0.61, -1.3] $$

These scores aren’t probabilities yet. To train classification models, we need two key tools:

2 What Softmax Does?

The softmax function converts logits into probabilities:

$$ p_i = \frac{e^{z_i}}{\sum_j e^{z_j}} $$

Applying this to our example logits gives:

$$ \begin{align*} p_1 &= \frac{e^{0.64}}{e^{0.64} + e^{-0.61} + e^{-1.3}} \approx 0.70 \\ p_2 &= \frac{e^{-0.61}}{e^{0.64} + e^{-0.61} + e^{-1.3}} \approx 0.20 \\ p_3 &= \frac{e^{-1.3}}{e^{0.64} + e^{-0.61} + e^{-1.3}} \approx 0.10 \end{align*} $$

Softmax guarantees:

This makes softmax the natural output layer for multi-class models.

Exponential Function \(f(z_i) = e^{z_i}\)

1 Exponential Function \(f(z_i) = e^{z_i}\)

3 What Is Cross-Entropy?

Cross-entropy loss measures how well the predicted probability distribution matches the true one.

Continuing with our three-class animal classifier example, suppose:

Cross-entropy loss is defined as:

$$ L = -\sum_{i} y_i \log(p_i) $$

That is, it’s a negative sum over the true labels multiplied by the log of their corresponding predicted probabilities.

Since all \(y_i\) except one are zero, this becomes:

$$ L = -\log(p_{\text{true\_class}}) = -\log(0.70) $$

3.1 Intuition

Cross-entropy loss has a very intuitive interpretation. It penalizes the model based on how much probability it assigns to the correct class (see 2):

By minimizing the average cross-entropy loss on all training examples, the model learns to assign higher probabilities to the correct class and lower probabilities to incorrect ones.

Cross-Entropy Loss

2 Cross-Entropy Loss

3.2 “Surprise” Interpretation

Cross-entropy can be thought of as: How surprised is the model to see the correct answer?1

This analogy is extremely helpful for intuition.

4 Why a Winning Combo?

When you combine softmax and then compute cross-entropy, something elegant happens: the derivative of the loss with respect to the logits simplifies dramatically. This is a useful result because this derivative is what we need for training the model (with gradient descent algorithm).

The derivative of the loss with respect to the logits \(z_i\) can be computed using the chain rule:

$$ \frac{\partial L}{\partial z_i} = \sum_{j} \frac{\partial L}{\partial p_j} \cdot \frac{\partial p_j}{\partial z_i} $$

The first part, \(\frac{\partial L}{\partial p_j}\), is straightforward to compute from the definition of cross-entropy loss:

$$ \frac{\partial L}{\partial p_j} = -\frac{y_j}{p_j} $$

The second part, \(\frac{\partial p_j}{\partial z_i}\), is a bit more involved, but it turns out to have a nice form.

If \(j = i\):

$$ \frac{\partial p_j}{\partial z_i} = p_i (1 - p_i) $$

If \(j \ne i\):

$$ \frac{\partial p_j}{\partial z_i} = -p_j p_i $$

We can combine the results from both parts to get:

$$ \begin{align*} \frac{\partial L}{\partial z_i} &= -\frac{y_i}{ p_i} \cdot p_i (1 - p_i) + \sum_{j \ne i} -\frac{y_j}{p_j} \cdot (-p_j p_i) \\ &= -y_i (1 - p_i) + p_i \sum_{j \ne i} y_j \\ &= -y_i + p_i \sum_{j} y_j \\ &= p_i - y_i \end{align*} $$

The last step uses the fact that \(\sum_{j} y_j = 1\) because \(y\) is a one-hot encoded vector.

That’s it. The derivative (or gradient) of the cross-entropy loss with respect to the logits is exactly the difference between the predicted probability and the true label. This result aligns perfectly with our intuition. It is also easy to compute so it makes the optimization process efficient.

5 Binary Classification

Binary classification is just a special case of multi-class classification. There are only two classes, so the one-hot label can only be

$$(y_1, y_2) = [1, 0] \quad \text{or} \quad [0, 1]$$

And the two logits are two numbers

$$[z_1, z_2]$$

We can still apply softmax to get probabilities for both classes.

$$p_i = \frac{e^{z_i}}{e^{z_1} + e^{z_2}}$$

And the cross-entropy loss is still the same:

$$ L = -[y_1 \log(p_1) + y_2 \log(p_2)] $$

However, since there are only two classes, we can simplify the model by using a single output that produces a single logit \(z\) for one of the classes. This class is often called the “positive class” (e.g. cat), and the other class is the “negative class” (e.g., “not cat”).

Let’s define a single logit:

$$ z = z_1 - z_2 $$

Then the probability for class 1 becomes (dividing numerator and denominator of the softmax by \(e^{z_2}\)):

$$ \begin{align*} p_1 &= \frac{e^{z_1}}{e^{z_1} + e^{z_2}} \\ &= \frac{e^{z_1-z_2}}{1 + e^{z_1 -z_2}} \\ &= \frac{e^z}{1 + e^z} \\ \end{align*} $$

This transformation defines the sigmoid function:

$$\sigma(z) = \frac{e^z}{1 + e^z}$$

It maps any real-valued input to a value between 0 and 1, making it perfect for binary classification.

Let’s use \(p\) and \(y\) to denote the predicted probability and true label for our class 1, the positive class.2 The cross-entropy loss for binary classification can be written as:

$$ L = -[y \log(p) + (1-y) \log(1-p)], $$

where \(p = \sigma(z)\) is the predicted probability for the positive class, and \(y\) is the true label (1 for positive class, 0 for negative class).

The derivative of this loss with respect to the logit \(z\) is:

$$ \frac{\partial L}{\partial z} = p - y $$

This is exactly the same result as in multi-class classification!

6 Summary

Here’s the whole story in a nutshell:

  1. This interpretation comes from information theory, where the negative log probability corresponds to the amount of “information” or “surprise” associated with an event. See this article and this Wiki page for more details.

  2. We switched notation here to reflect the common convention in binary classification, where \(p\) is the predicted probability for the positive class and \(y\) is the true label for that class. Previously, we used \(p\) and \(y\) for the probability vector and label vector.

Prepared by Jay at MDAL with litedown and ChatGPT.