CausalPFN Model

This repository contains the model weights for CausalPFN, a transformer-based in-context learning model for causal effect estimation.

Model Description

CausalPFN is a pre-trained model for amortized causal effect estimation via in-context learning. It allows for accurate estimation of conditional average treatment effects (CATE) and average treatment effects (ATE) without requiring model retraining for each new dataset.

The model is based on a transformer architecture with uncertainty quantification and calibration.

Requirements

  • Python 3.10+
  • PyTorch 2.3+
  • NumPy
  • scikit-learn
  • tqdm
  • faiss-cpu
  • huggingface_hub

Installation

To use this model, install the CausalPFN library:

pip install causalpfn

Usage

You can use this model with the CausalPFN library:

import torch
from causalpfn import CATEEstimator, ATEEstimator


# Create a CATE estimator
causalpfn_cate = CATEEstimator(
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    verbose=True,
)

# Fit the model on your data
# X_train: covariates, T_train: binary treatment, Y_train: observed outcome โ€” from observational data
causalpfn_cate.fit(X_train, T_train, Y_train)  

# Estimate CATE
cate_hat = causalpfn_cate.estimate_cate(X_test)

# Create an ATE estimator
causalpfn_ate = ATEEstimator(
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    verbose=True,
)

# Fit and estimate ATE
causalpfn_ate.fit(X, T, Y)
ate_hat = causalpfn_ate.estimate_ate()

Citations

If you use this model in your research, please cite:

@misc{balazadeh2025causalpfn,
      title={CausalPFN: Amortized Causal Effect Estimation via In-Context Learning}, 
      author={Vahid Balazadeh and Hamidreza Kamkari and Valentin Thomas and Benson Li and Junwei Ma and Jesse C. Cresswell and Rahul G. Krishnan},
      year={2025},
      eprint={2506.07918},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2506.07918}, 
}

License

This model is licensed under Apache-2.0.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support