Focal Loss
Posted on
- a mismatch between a model’s confidence and its correctness of DNNs make their predictions hard to rely on
- ideally, want networks to be accurate, calibrated and confident
the paper shows that
- as opposed to the standard cross-entropy loss, focal loss allows one to learn models that are already very well calibrated
- when combined with temperature scaling, whilst preserving accuracy, it yields state-of-the-art calibrated models
- provide a thorough analysis of the factors causing miscalibration, and justify the empirically excellent performance of focal loss
- for the use of focal loss, provide a principled approach to automatically select the hyperparameter involved in the loss function
- perform experiments on a variety of computer vision and NLP datasets, and with a wide variety of network architectures
Introduction
- many multi-class classification networks are poorly calibrated, in the sense that the probability values that they associate with the class labels they predict overestimate the likelihoods
-
the underlying cause is hypothesised to be that these networks’ high capacity leaves them vulnerable to overfitting on the negative log-likelihood loass
- much work has been inspired by approaches that were not originally formulated in a deep learning context, such as
- Platt scaling
- histogram binning
- isotonic regression
- Bayesian binning and averaging
- various works have begun to directly target the calibration of deep networks
- a modern variant of Platt scaling known as temperature scaling, dividing a network’s logits by a scalar $T > 0$
- Drawbacks: whilst it scales the logits to reduce the network’s confidence in incorrect predictions, this also slightly reduces the network’s confidence in predictions that were correct; moreover, it is known that temperature scaling does not calibrate a model under data distribution shift
- Maximum Mean Calibration Error (MMCE): minimize a differentiable proxy for calibration error at training time
- training models on cross-entropy loss with label smoothing instead of one-hot labels
the paper proposes to replace the cross-entropy loss with the focal loss
- cross-entropy minimises the KL divergence between the predicted (softmax) distribution and the target distribution (one-hot encoding in classification tasks) over classes
- focal loss minimises a regularized KL divergence between these two distributions, which ensures minimisation of the KL divergence whilst increasing the entropy of the predicted distribution, thereby preventing the model from becoming overconfident.
- focal loss depends on a hyperparameter $\gamma$, that needs to be cross-validated, also provide a method for choosing $\gamma$ automatically for each sample, and show that it outperforms all the baseline models
Problem Formulation
denote
- $y_i\in \cY = {1, 2,\ldots, K}$
- $\hat p_{i, y} = f_\theta(y\mid \bfx_i)$: predicts for a class $y$
- $\hat y_i = \argmax_y \hat p_{i, y}$
- the predicted confidence: $\hat p_i = \max_y \hat p_{i, y}$
the network is said to be perfectly calibrated when for each sample $(\bfx, y)\in D$, the confidence $\hat p$ equals to the model accuracy $\bbP(\hat y = y \mid \hat p)$
expected calibration error (ECE):
\[\bbE_{\hat p}[\vert \bbP(\hat y = y\mid \hat p) - \hat p\vert]\]in practice, divide the interval $[0, 1]$ into bins to calculate
- $B_i$: set of samples with confidence belonging to the $i$-th bin
then
\[ECE = \sum_{i=1}^M \frac{\vert B_i\vert}{N} \vert A_i-C_i\vert\]the maximum calibration error (MCE)
\[MCE = \max_{i\in \{1,\ldots, M\}} \vert A_i - C_i\vert\]AdaECE: bin sizes are calculated to evenly distribute samples between bins
\[AdaECE = \sum_{i=1}^M \frac{\vert B_i\vert}{M} \vert A_i - C_i\vert\quad \text{s.t.}\quad \vert B_i\vert = \vert B_j\vert\]Classwise-ECE: the ECE metric only considers the probability of the predicted class, without considering the other scores in the softmax distribution
\[ClasswiseECE = \frac{1}{K}\sum_{i=1}^M\sum_{j=1}^K\frac{\vert B_{ij}\vert}{N}\vert A_{ij} - C_{ij}\vert\]where
\[A_{ij} = \frac{1}{\vert B_{ij}\vert }\sum_{k\in B_{ij}}1(j = y_k)\quad C_{ij} = \frac{1}{\vert B_{ij}\vert} \sum_{k\in B_{ij}}\hat p_{kj}\]Improving Calibration using Focal Loss
for classification tasks where the target distribution is a one-hot encoding, it is defined as
\[\cL_f = -(1-\hat p_{i, y_i})^\gamma \log \hat p_{i, y_i}\]cross-entropy forms an upper bound on the KL-divergence
\[\cL_c \ge KL(q\Vert \hat p)\]while focus loss is an upper bound on the regularized KL-divergence
\[\cL_f \ge KL(q\Vert \hat p) -\gamma \bbH(\hat p)\]