Focal Loss
Posted on
- a mismatch between a model’s confidence and its correctness of DNNs make their predictions hard to rely on
- ideally, want networks to be accurate, calibrated and confident
the paper shows that
- as opposed to the standard cross-entropy loss, focal loss allows one to learn models that are already very well calibrated
- when combined with temperature scaling, whilst preserving accuracy, it yields state-of-the-art calibrated models
- provide a thorough analysis of the factors causing miscalibration, and justify the empirically excellent performance of focal loss
- for the use of focal loss, provide a principled approach to automatically select the hyperparameter involved in the loss function
- perform experiments on a variety of computer vision and NLP datasets, and with a wide variety of network architectures
- many multi-class classification networks are poorly calibrated, in the sense that the probability values that they associate with the class labels they predict overestimate the likelihoods
the underlying cause is hypothesised to be that these networks’ high capacity leaves them vulnerable to overfitting on the negative log-likelihood loass
- much work has been inspired by approaches that were not originally formulated in a deep learning context, such as
- Platt scaling
- histogram binning
- isotonic regression
- Bayesian binning and averaging
- various works have begun to directly target the calibration of deep networks
- a modern variant of Platt scaling known as temperature scaling, dividing a network’s logits by a scalar $T > 0$
- Drawbacks: whilst it scales the logits to reduce the network’s confidence in incorrect predictions, this also slightly reduces the network’s confidence in predictions that were correct; moreover, it is known that temperature scaling does not calibrate a model under data distribution shift
- Maximum Mean Calibration Error (MMCE): minimize a differentiable proxy for calibration error at training time
- training models on cross-entropy loss with label smoothing instead of one-hot labels
the paper proposes to replace the cross-entropy loss with the focal loss
- cross-entropy minimises the KL divergence between the predicted (softmax) distribution and the target distribution (one-hot encoding in classification tasks) over classes
- focal loss minimises a regularized KL divergence between these two distributions, which ensures minimisation of the KL divergence whilst increasing the entropy of the predicted distribution, thereby preventing the model from becoming overconfident.
- focal loss depends on a hyperparameter $\gamma$, that needs to be cross-validated, also provide a method for choosing $\gamma$ automatically for each sample, and show that it outperforms all the baseline models
Problem Formulation
- $y_i\in \cY = {1, 2,\ldots, K}$
- $\hat p_{i, y} = f_\theta(y\mid \bfx_i)$: predicts for a class $y$
- $\hat y_i = \argmax_y \hat p_{i, y}$
- the predicted confidence: $\hat p_i = \max_y \hat p_{i, y}$
the network is said to be perfectly calibrated when for each sample $(\bfx, y)\in D$, the confidence $\hat p$ equals to the model accuracy $\bbP(\hat y = y \mid \hat p)$
expected calibration error (ECE):
\[\bbE_{\hat p}[\vert \bbP(\hat y = y\mid \hat p) - \hat p\vert]\]in practice, divide the interval $[0, 1]$ into bins to calculate
- $B_i$: set of samples with confidence belonging to the $i$-th bin
\[ECE = \sum_{i=1}^M \frac{\vert B_i\vert}{N} \vert A_i-C_i\vert\]the maximum calibration error (MCE)
\[MCE = \max_{i\in \{1,\ldots, M\}} \vert A_i - C_i\vert\]AdaECE: bin sizes are calculated to evenly distribute samples between bins
\[AdaECE = \sum_{i=1}^M \frac{\vert B_i\vert}{M} \vert A_i - C_i\vert\quad \text{s.t.}\quad \vert B_i\vert = \vert B_j\vert\]Classwise-ECE: the ECE metric only considers the probability of the predicted class, without considering the other scores in the softmax distribution
\[ClasswiseECE = \frac{1}{K}\sum_{i=1}^M\sum_{j=1}^K\frac{\vert B_{ij}\vert}{N}\vert A_{ij} - C_{ij}\vert\]where
\[A_{ij} = \frac{1}{\vert B_{ij}\vert }\sum_{k\in B_{ij}}1(j = y_k)\quad C_{ij} = \frac{1}{\vert B_{ij}\vert} \sum_{k\in B_{ij}}\hat p_{kj}\]Improving Calibration using Focal Loss
for classification tasks where the target distribution is a one-hot encoding, it is defined as
\[\cL_f = -(1-\hat p_{i, y_i})^\gamma \log \hat p_{i, y_i}\]cross-entropy forms an upper bound on the KL-divergence
\[\cL_c \ge KL(q\Vert \hat p)\]while focus loss is an upper bound on the regularized KL-divergence
\[\cL_f \ge KL(q\Vert \hat p) -\gamma \bbH(\hat p)\]