API Reference#
This section provides detailed information about TabStruct’s core APIs and interfaces.
Core Interfaces#
BaseModel#
The foundation class for all models in TabStruct.
from tabstruct.common.model.BaseModel import BaseModel
class CustomModel(BaseModel):
def __init__(self, args):
super().__init__(args)
# Initialize your model here
def _fit(self, data_module):
# Implement training logic
pass
Key Methods:
__init__(args)
: Initialize the model with experiment argumentsfit(data_module)
: Public API to train the model_fit(data_module)
: Abstract method to implement training logicget_metadata()
: Return model metadata including name and parametersdefine_params(reg_test, trial=None, dev=False)
: Define model parameters for different modes
Parameter Definition Methods:
_define_default_params()
: Default parameters for production runs_define_optuna_params(trial)
: Parameters for hyperparameter optimization_define_single_run_params()
: Parameters for development/debugging_define_test_params()
: Minimal parameters for testing
Prediction Models#
BasePredictor#
Base class for all prediction models.
from tabstruct.prediction.models.BasePredictor import BasePredictor
Inheritance Hierarchy:
BasePredictor
→BaseSklearnPredictor
→ Scikit-learn models (lr, rf, knn, xgb, tabnet, tabpfn, mlp-sklearn)BasePredictor
→BaseLitPredictor
→ PyTorch Lightning models (mlp, ft-transformer)
Available Prediction Models:
Scikit-learn Models:
lr
: Logistic Regression / Linear Regressionrf
: Random Forestknn
: K-Nearest Neighborsxgb
: XGBoosttabnet
: TabNettabpfn
: TabPFN (Prior-data Fitted Network)mlp-sklearn
: Multi-layer Perceptron (Scikit-learn)
Lightning Models:
mlp
: Multi-layer Perceptron (PyTorch Lightning)ft-transformer
: Feature Tokenizer + Transformer
Generation Models#
BaseGenerator#
Base class for all data generation models.
from tabstruct.generation.models.BaseGenerator import BaseGenerator
Inheritance Hierarchy:
BaseGenerator
*BaseImblearnGenerator
→ SMOTE *BaseTabEvalGenerator
→ TabEval-based generatorsBaseTabEvalConditionalGenerator
→ ctgan, tvae, tabddpmBaseTabEvalJointGenerator
→ bn, arf, nflow, goggle, great
BaseMixedGenerator
→ Custom generators (TabSyn, TabDiff, TabEBM)
Available Generation Models:
Real Data:
real
: Passthrough (no generation)
Imbalanced-learn:
smote
: Synthetic Minority Oversampling Technique
TabEval Generators:
ctgan
: Conditional Tabular GANtvae
: Tabular Variational Autoencoderbn
: Bayesian Networkgoggle
: Gaussian Mixture Modelstabddpm
: Tabular Denoising Diffusion Probabilistic Modelarf
: Autoregressive Flownflow
: Normalizing Flowgreat
: GReaT (Generation of Realistic Tabular data)
Custom Generators:
TabSyn
: Tabular Synthesis with diffusion modelsTabDiff
: Tabular DiffusionTabEBM
: Tabular Energy-Based Model
Data Management#
DataModule#
Lightning-compatible data module for handling tabular data.
from tabstruct.common.data.DataModule import DataModule
data_module = DataModule(
args=args,
X_train=X_train,
y_train=y_train,
X_valid=X_valid,
y_valid=y_valid,
X_test=X_test,
y_test=y_test
)
Key Attributes:
X_train
,y_train
: Training data (numpy arrays)X_valid
,y_valid
: Validation data (numpy arrays)X_test
,y_test
: Test data (numpy arrays)train_dataset
,valid_dataset
,test_dataset
: PyTorch datasets
Key Methods:
train_dataloader()
: Returns PyTorch DataLoader for trainingval_dataloader()
: Returns PyTorch DataLoader for validationtest_dataloader()
: Returns PyTorch DataLoader for testing
Pipeline Classes#
BasePipeline#
Base class for experiment pipelines.
Available Pipelines:
PredictionPipeline
: Handles prediction experimentsGenerationPipeline
: Handles data generation experiments
Experiment Configuration#
The main configuration is handled through command-line arguments. Key argument categories:
Core Arguments:
--pipeline
: prediction | generation--model
: Model identifier (see Models section)--task
: classification | regression--dataset
: Dataset name (tabcamel compatible)--test_size
,--valid_size
: Split sizes--split_mode
: stratified | random--seed
: Random seed--device
: cpu | cuda
Training Arguments:
--max_steps_tentative
: Maximum training steps--batch_size_tentative
: Batch size--optimizer
: adam | adamw | sgd--lr_scheduler
: none | plateau | cosine_warm_restart | linear | lambda
Evaluation Arguments:
--eval_only
: Skip training, evaluate only--disable_eval_density
: Disable density evaluation--disable_eval_privacy
: Disable privacy evaluation--enable_eval_structure
: Enable structure evaluation
Hyperparameter Tuning:
--enable_optuna
: Enable Optuna optimization--optuna_trial
: Trial number for Optuna--tune_max_workers
: Maximum workers for tuning
Usage Examples#
Prediction Pipeline:
python -m src.tabstruct.experiment.run_experiment \
--pipeline prediction \
--model xgb \
--task classification \
--dataset adult \
--test_size 0.2 \
--valid_size 0.2 \
--seed 42
Generation Pipeline:
python -m src.tabstruct.experiment.run_experiment \
--pipeline generation \
--model ctgan \
--task classification \
--dataset adult \
--test_size 0.2 \
--valid_size 0.2 \
--seed 42
Hyperparameter Tuning:
python -m src.tabstruct.experiment.run_experiment \
--pipeline prediction \
--model mlp \
--task classification \
--dataset adult \
--enable_optuna \
--tune_max_workers 4
Error Handling#
Common Exceptions:
ManualStopError
: Raised when model constraints are violated (e.g., TabPFN with >10 classes or >500 features)ValueError
: Raised for invalid task/model combinationsNotImplementedError
: Raised when abstract methods are not implemented
Model Constraints:
TabPFN
: Max 10 classes for classification, max 500 featuresTabEBM
: Max 500 featuresSome generators are unstable on large datasets (see
unstable_generator_list
)
Constants and Configuration#
Key Constants:
# Available models
predictior_list = ["lr", "rf", "knn", "xgb", "tabnet", "tabpfn", "mlp-sklearn", "mlp", "ft-transformer"]
generator_list = ["real", "smote", "ctgan", "tvae", "bn", "goggle", "tabddpm", "arf", "nflow", "great"]
# Unstable generators (may fail on large datasets)
unstable_generator_list = ["bn", "arf", "nflow", "goggle", "great"]
# Timeouts
TUNE_STUDY_TIMEOUT = 3600 * 2 # 2 hours
SINGLE_RUN_TIMEOUT = 3600 * 2 # 2 hours
Project Configuration:
WANDB_ENTITY
: “tabular-foundation-model”WANDB_PROJECT
: “Euphratica-dev”LOG_DIR
: “{BASE_DIR}/logs”
Notes#
The framework automatically handles data preprocessing and feature encoding
Lightning models support distributed training and mixed precision
All models implement standardized parameter definition methods for reproducibility
Generation models can handle both conditional and joint generation strategies
The codebase supports integration with Weights & Biases for experiment tracking