Learning Graph Structure With A Finite-State Automaton Layer

This directory contains the implementation of the Graph Finite-State Automaton layer described in

"Learning Graph Structure With A Finite-State Automaton Layer"

Daniel D. Johnson, Hugo Larochelle, Daniel Tarlow (2020).

Interactive demo notebooks

Want to see the GFSA layer in action? A good starting point is the interactive demo notebook, which shows how to train the GFSA layer to do a simple static analysis of Python code:

You may also be interested in the new task guide notebook, which describes how to use the GFSA layer for new types of graphs and graph-based MDPs.

Setting up the environment

The code in this repository is written for Python 3.6. We recommend creating a virtual environment and then installing the requirements in requirements.txt. You may also want to configure your JAX installation for GPU support; see the JAX documentation for details.

Structure of this repository

Core implementation of the GFSA solver

  • defines data structures for representing graph MDPs.
  • is responsible for encoding graphs into tensors and computing the GFSA absorbing distribution.
  • implements the RL ablation of the GFSA layer.

General utilities

  • contains utilities for working with JAX and Flax.
  • implements the Richardson iterative solver.
  • defines helper functions for working with MDP families with a shared action and observation space.
  • implements a sparse operator abstraction.

Working with ASTs

  • defines a transformation from ASTs to MDPs.
  • defines an AST for a simple subset of Python.
  • can be used to construct an MDP family from a dataset of ASTs.

Working with datasets

  • datasets/ defines data structures for working with graphs that are associated with encoded MDP.
  • datasets/ implements helpers to construct MDPs based on graph edges.
  • datasets/ implements a pure-Python collection of dataset iterators.
  • datasets/ helps determine maximum example sizes that do not throw out too many examples.
  • datasets/mazes defines MDPs and data-generation for the grid-world task.
  • datasets/random_python/ implements a generalized AST generator based on a probabilistic context-free grammar.
  • datasets/random_python/ contains the specific generator used for the static analysis tasks.
  • datasets/var_misuse/ defines data structures for the variable-misuse task.

Flax modules

The model subdirectory implements the GFSA layer, other graph architectures, and combined models as Flax modules.

  • model/ contains the GFSA layer itself.
  • model/ contains various graph architecture building blocks.
  • model/ assembles these blocks into models for the Python static analysis tasks.
  • model/ unifies the APIs of the building blocks so that they can be freely composed.
  • model/ contains the implementation of the full models for the variable misuse tasks.
  • model/ and model/ define some helper functions for Flax models.


  • training/configs contains example gin-config configuration files for training a model.
  • training/ is the main entry point for training or evaluating a model on the three tasks described in the paper.
  • training/ and training/ contain common logic for training between the three tasks.
  • training/, training/, and training/ contain the logic for each of the three tasks.
  • training/ implements some simple learning rate schedules.
  • training/ defines a helper function for writing complex gin configurations.