TabDDPM: Tabular Data with Diffusion Models
Posted on
- the paper investigate if the framework of diffusion models can be advantageous for general tabular problems, where data points are typically represented by vectors of heterogeneous features.
- the inherent heterogeneity of tabular data makes it quite challenging for accurate modeling since the individual features can be of a completely different nature, i.e., some of them can be continuous and some can be discrete
the paper introduces TabDDPM —— a diffusion model that can be universally applied to any tabular dataset and handles any feature types
extensively evaluate TabDDPM on a wide set of benchmarks and demonstrate its superiority over existing GAN/VAE alternatives, which is consistent with the advantage of diffusion models in other fields
Introduction
denoising diffusion probabilistic models (DDPM) in the generative modeling commnunity since they often outperform the alternative approaches both in terms of the realism of individual samples and their diversity
training a high-quality model of tabular data can be more challenging than in computer vision or NLP due to the heterogeneity of individual features and relatively small sizes of typical tabular daatsets
main contributions of this work:
- introduce TabDDPM —— a simple design of DDPM for tabular problems
- demonstrate that TabDDPM outperforms alternative approaches designed for tabular data, including GAN-based and VAE-based methods
- shallow interpolation-based methods, e.g., SMOTE, produce surprisingly effective synthetic data that provides competitively high ML efficiency. Compared with SMOTE, one show that TabDDPM’s data is preferable for privacy-concerned scenarios when synthetic data are used to subtitute the real user data that cannot be shared.
Related Work
Diffusion models
- a paradigm of generative modeling that aims to approximate the target distribution by the endpoint of the Markov chain, which starts from a given parametric distribution, typically a standard Gaussian
Generative models for tabular problems
- an active research direction since high-quality synthetic data is in great demand for many tabular tasks
- rencent works: tabular VAEs and GAN-based
Shallow synthetics generation
tabular data is typically structured
Background
Diffusion models
- forward process: gradually adds noise to an initial sample
- reverse process: gradually denoises a latent variable
Gaussian diffusion models
forward and reverse processes are characterized by Gaussian distributions
Multinomial diffusion models
forward process: define $q(x_t\mid x_{t-1})$ as a categorical distribution that corrupts the data by uniform noise over $K$ classes
TabDDPM
- use the multinomial diffusion to model the categorical and binary features
- use the Gaussian diffusion to model the numerical ones
for a tabular data sample
\[x = [x_{num}, x_{cat1},\ldots, x_{catC}]\]take one-hot encoded versions of categorical features as input
total dimension is $N_{num} + \sum K_i$
the model is trained by minimizing a sum of mean-squared error for the Gaussian diffusion term and the KL divergence for each multinomial diffusion term
for classification datasets, use a class-conditional model, i.e., $p_\theta(x_{t-1}\mid x_t, y)$ is learned
for regression datasets, consider a target value as an additional numerical feature, and the joint distribution is learned
to model the reverse process, use a simple MLP architecture
Experiments
Datasets: 15 real-world public datasets
Baselines
- TVAE
- CTGAN
- CTABGAN
- CTABGAN+
- SMOTE
Evaluation measure
machine learning (ML) efficiency (utility): quantify the performance of classification or regression models trained on synthetic data and evaluated on the real test set
use two evaluation protocols to compute ML efficiency
- first protocol: compute an average efficiency w.r.t. a set of diverse ML models (logistic regression, decision tree, and others)
- second protocol: evaluate ML efficiency only w.r.t. the CatBoost model, which is arguably the leading GBDT implementation providing state-of-the-art performance on tabular tasks
Tuning process
use the Optuna library
tuning process is guided by the values of the ML efficiency (w.r.t. Catboost) of the generated synthetic data on a hold-out validation dataset (the score is averaged over five different sampling seeds)
Qualitative comparison
- in most cases, TabDDPM produces more realistic feature distributions compared with TVAE and CTABGAN+
Machine Learning efficiency
- compute average ML efficiency for a diverse set of ML models
- compute ML efficiency w.r.t. the current state-of-the-art model for tabular data
Main Results
Overall, TabDDOM is the best
Privcay
measure the privcay of the generated data as a mean Distance to Cloest Record (DCR)
Limitations and discussion
- not a definite answer if TabDDPM’s data can satisfy real-world privacy-concerned applications
- alternative approaches for categorical features
Conclusion
DDPM to handle mixed data types consisting of numerical and categorical features