PolyLoss
Posted on
- cross-entropy loss and focal loss are the most common choices when training DNN for classification problems
- however, a good loss function can take on much more flexible forms, and should be tailored for different tasks and datasets
- motivated by how functions can be approximated via Taylor expansion, the paper proposed PolyLoss to view and design loss functions as a linear combination of polynominal functions
- the PolyLoss allows the importance of different polynomial bases to be easily adjusted depending on the targeting tasks and datasets, while naturally subsuming the cross-entropy loss and focal loss as special cases
decompose commonly used classification loss functions, such as cross-entropy and focal loss into a series of weighted polynomial bases. they are decomposed in the form of
\[\sum_{j=1}^\infty \alpha_j (1-P_t)^j\]where
- $\alpha_j\in\IR^+$ is the polynomial coefficient
- $P_t$ is the prediction probability of the target class label
the coefficient $\alpha_j$ enables us to easily adjust the importance of different bases for different applications
when $\alpha_j = 1/j$ for all $j$, the PolyLoss becomes equivalent to the commonly used cross-entropy loss, but this coefficient assignment may not be optimal
the paper claims that, in order to achieve better results, it is necessary to adjust polynomial coefficients $\alpha_j$ for different tasks and datasets
it is impossible to adjust an infinite number of $\alpha_j$, the paper explores various strategies with a small degree of freedom.
the authors observed that simply adjusting the single polynomial coefficient for the leading polynomial, denoted by $L_{Poly-1}$, is sufficient to achieve significant improvements over the commonly used cross-entropy loss and focal loss
inspired from the Taylor expansion of cross-entropy loss and focal loss in the base of $(1-P_t)^j$
\[\begin{align*} L_{CE} = -\log P_t = \sum_{j=1}^\infty 1/j (1-P_t)^j = (1-P_t) + \frac 12 (1-P_t)^2 + \cdots\\ L_{FL} = -(1-P_t)^\gamma \log P_t = \sum_{j=1}^\infty 1/j (1-P_t)^{j+\gamma} = (1-P_t)^{1+\gamma} + \frac 12(1-P_t)^{2+\gamma} + \cdots \end{align*}\]