Task-Agnostic Machine-Learning-Assisted Inference
Posted on
This note is for Miao, J., & Lu, Q. (2024). Task-Agnostic Machine-Learning-Assisted Inference (No. arXiv:2405.20039). arXiv. https://doi.org/10.48550/arXiv.2405.20039
ML-assisted analytical strategies
- employ ML to predict unobserved outcomes in massive samples, then use predicted outcomes in downstream statistical inference
however,
- existing methods designed to ensure the validity of this type of post-prediction inference are limited to very basic tasks such as linear regression analysis
- any extension of these approaches to new, more sophisticated statistical tasks requires task-specific algebraic derivations and software implementations, which ignores the massive library of existing software tools already developed for the same scientific problem given observed data
current ML-assisted inference can only address very basic statistical tasks, including
- mean estimation
- quantile estimation
- linear and logistic regression
while the same mathematical principle behind existing ML-assisted inference methods can be generalized to a broader class of M-estimation problems, specific algebraic derivations and computational implementations are required for each new statistical task
historically, similar types of challenges
- before the advent of resampling-based methods, require task-specific derivation and implementation to obtain the variance of any new estimator
the paper introduce PoSt-Prediction Summary-statistics-based (PSPS) inference
vs semi-supervised??
Formulations
inference the parameter $\theta^\star = \theta^\star(\bbP)\in \IR^K$ defined on the joint distribution of $(\bfX, Y)\sim \bbP$, where
interested in estimating $\theta^\star$ using
- labeled data $\cL = {(\bfX_i, Y_i), i=1,\ldots, n}$
- unlabeled data $\cU$
- a pre-trained ML model: $\hat f(\cdot): \cX\rightarrow\cY$
also require an algorithm $\cA$ that inputs the labeled data $\cL$ and returns a consistent and asymptotically normally distributed estimator $\hat\theta$ for $\hat\theta$
three common ways to estimate $\theta^\star$
- classical statistical methods: only use algorithm $\cA$
- imputation-based methods
- ML-assisted inference methods
intuition with mean estimation
related work
- ML-assisted inference
- methods for handling missing data
- semi-supervised inference
current methods typically require
- the algebraic form of the loss function
- its first- and second-order derivatives
- the variance for the estimator
- as well as a newly implemented optimization algorithm to obtain the estimator
current mathematical principles guiding ML-assisted inference apply solely to M-estimation
the proposed protocol extends beyond this limitation, addressing all estimation problems with an asymptotically normally distributed estimator
inference relying solely on summary statistics is widely used in the statistical genetics literature for practical reasons.
summary statistics-based methods have been developed for tasks such as
- variance component inference
- genetic risk prediction
difference with semi-supervised learning
- the proposed method is designed for estimation and statistical inference
- semi-supervised focused on prediction