WeiYa's Work Yard

A traveler with endless curiosity, who fell into the ocean of statistics, tries to write down his ideas and notes to save himself.

TabDDPM: Tabular Data with Diffusion Models

Posted on
Tags: Diffusion Model

This note is for Kotelnikov, A., Baranchuk, D., Rubachev, I., & Babenko, A. (2023). TabDDPM: Modelling Tabular Data with Diffusion Models. Proceedings of the 40th International Conference on Machine Learning, 17564–17579.

  • the paper investigate if the framework of diffusion models can be advantageous for general tabular problems, where data points are typically represented by vectors of heterogeneous features.
  • the inherent heterogeneity of tabular data makes it quite challenging for accurate modeling since the individual features can be of a completely different nature, i.e., some of them can be continuous and some can be discrete

the paper introduces TabDDPM —— a diffusion model that can be universally applied to any tabular dataset and handles any feature types

extensively evaluate TabDDPM on a wide set of benchmarks and demonstrate its superiority over existing GAN/VAE alternatives, which is consistent with the advantage of diffusion models in other fields

Introduction

denoising diffusion probabilistic models (DDPM) in the generative modeling commnunity since they often outperform the alternative approaches both in terms of the realism of individual samples and their diversity

training a high-quality model of tabular data can be more challenging than in computer vision or NLP due to the heterogeneity of individual features and relatively small sizes of typical tabular daatsets

main contributions of this work:

  1. introduce TabDDPM —— a simple design of DDPM for tabular problems
  2. demonstrate that TabDDPM outperforms alternative approaches designed for tabular data, including GAN-based and VAE-based methods
  3. shallow interpolation-based methods, e.g., SMOTE, produce surprisingly effective synthetic data that provides competitively high ML efficiency. Compared with SMOTE, one show that TabDDPM’s data is preferable for privacy-concerned scenarios when synthetic data are used to subtitute the real user data that cannot be shared.

Diffusion models

  • a paradigm of generative modeling that aims to approximate the target distribution by the endpoint of the Markov chain, which starts from a given parametric distribution, typically a standard Gaussian

Generative models for tabular problems

  • an active research direction since high-quality synthetic data is in great demand for many tabular tasks
  • rencent works: tabular VAEs and GAN-based

Shallow synthetics generation

tabular data is typically structured

Background

Diffusion models

  • forward process: gradually adds noise to an initial sample
  • reverse process: gradually denoises a latent variable

Gaussian diffusion models

forward and reverse processes are characterized by Gaussian distributions

Multinomial diffusion models

forward process: define $q(x_t\mid x_{t-1})$ as a categorical distribution that corrupts the data by uniform noise over $K$ classes

TabDDPM

  • use the multinomial diffusion to model the categorical and binary features
  • use the Gaussian diffusion to model the numerical ones

for a tabular data sample

\[x = [x_{num}, x_{cat1},\ldots, x_{catC}]\]

take one-hot encoded versions of categorical features as input

total dimension is $N_{num} + \sum K_i$

the model is trained by minimizing a sum of mean-squared error for the Gaussian diffusion term and the KL divergence for each multinomial diffusion term

for classification datasets, use a class-conditional model, i.e., $p_\theta(x_{t-1}\mid x_t, y)$ is learned

for regression datasets, consider a target value as an additional numerical feature, and the joint distribution is learned

to model the reverse process, use a simple MLP architecture

Experiments

Datasets: 15 real-world public datasets

Baselines

  • TVAE
  • CTGAN
  • CTABGAN
  • CTABGAN+
  • SMOTE

Evaluation measure

machine learning (ML) efficiency (utility): quantify the performance of classification or regression models trained on synthetic data and evaluated on the real test set

use two evaluation protocols to compute ML efficiency

  • first protocol: compute an average efficiency w.r.t. a set of diverse ML models (logistic regression, decision tree, and others)
  • second protocol: evaluate ML efficiency only w.r.t. the CatBoost model, which is arguably the leading GBDT implementation providing state-of-the-art performance on tabular tasks

Tuning process

use the Optuna library

tuning process is guided by the values of the ML efficiency (w.r.t. Catboost) of the generated synthetic data on a hold-out validation dataset (the score is averaged over five different sampling seeds)

Qualitative comparison

  • in most cases, TabDDPM produces more realistic feature distributions compared with TVAE and CTABGAN+

Machine Learning efficiency

  • compute average ML efficiency for a diverse set of ML models
  • compute ML efficiency w.r.t. the current state-of-the-art model for tabular data

Main Results

Overall, TabDDOM is the best

Privcay

measure the privcay of the generated data as a mean Distance to Cloest Record (DCR)

Limitations and discussion

  • not a definite answer if TabDDPM’s data can satisfy real-world privacy-concerned applications
  • alternative approaches for categorical features

Conclusion

DDPM to handle mixed data types consisting of numerical and categorical features


Published in categories