CausalPFN Model
This repository contains the model weights for CausalPFN, a transformer-based in-context learning model for causal effect estimation.
Model Description
CausalPFN is a pre-trained model for amortized causal effect estimation via in-context learning. It allows for accurate estimation of conditional average treatment effects (CATE) and average treatment effects (ATE) without requiring model retraining for each new dataset.
The model is based on a transformer architecture with uncertainty quantification and calibration.
Requirements
- Python 3.10+
- PyTorch 2.3+
- NumPy
- scikit-learn
- tqdm
- faiss-cpu
- huggingface_hub
Installation
To use this model, install the CausalPFN library:
pip install causalpfn
Usage
You can use this model with the CausalPFN library:
import torch
from causalpfn import CATEEstimator, ATEEstimator
# Create a CATE estimator
causalpfn_cate = CATEEstimator(
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
verbose=True,
)
# Fit the model on your data
# X_train: covariates, T_train: binary treatment, Y_train: observed outcome โ from observational data
causalpfn_cate.fit(X_train, T_train, Y_train)
# Estimate CATE
cate_hat = causalpfn_cate.estimate_cate(X_test)
# Create an ATE estimator
causalpfn_ate = ATEEstimator(
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
verbose=True,
)
# Fit and estimate ATE
causalpfn_ate.fit(X, T, Y)
ate_hat = causalpfn_ate.estimate_ate()
Citations
If you use this model in your research, please cite:
@misc{balazadeh2025causalpfn,
title={CausalPFN: Amortized Causal Effect Estimation via In-Context Learning},
author={Vahid Balazadeh and Hamidreza Kamkari and Valentin Thomas and Benson Li and Junwei Ma and Jesse C. Cresswell and Rahul G. Krishnan},
year={2025},
eprint={2506.07918},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2506.07918},
}
License
This model is licensed under Apache-2.0.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support