Skip to content

Latest commit

 

History

History

gumbel_max_causal_gadgets

JAX code to accompany the NeurIPS paper "Learning Generalized Gumbel-max Causal Mechanisms", with authors Guy Lorberbom*, Daniel D. Johnson*, Chris J. Maddison, Daniel Tarlow, and Tamir Hazan.

This directory contains implementations of both Gadget 1 and Gadget 2 in JAX, along with code to reproduce the experiments in Section 7.1 and the first two columns of Table 1.

For the implementation used in the last two columns of Table 1 and for the experiments in Section 7.3, see https://github.com/GuyLor/gumbel_max_causal_gadgets_part2.

Contents:

  • tutorial.ipynb gives an overview of the two gadgets and shows how they can be trained.
  • paper_experiments.ipynb contains code to reproduce the results of the experiments.
  • coupling_util.py contains various helper functions for constructing couplings, including the baselines we compare against.
  • gadget_1.py contains an implementation of our first gadget.
  • gadget_2.py contains an implementation of our second gadget.
  • experiment_util.py contains helper functions and classes to support training our couplings. These were used to conduct our experiments.

If you use the code in this repository, please cite the following paper:

@inproceedings{lorberbom2021generalized,
author    = {Guy Lorberbom and
             Daniel D. Johnson and
             Chris J. Maddison and
             Daniel Tarlow and
             Tamir Hazan},
title     = {Learning Generalized Gumbel-max Causal Mechanisms},
booktitle = {Advances in Neural Information Processing Systems},
year      = {2021}
}