Introduction

Over the last few years I have generally been looking into a way of implementing Judea Pearl's do calculus and counterfactual reasoning in Python in a way that is possible for most non-experts. Causal inference in and of itself is a term loaded with intimidation, implying the need for a background in statistics, graphical models and network theory, especially given that one then starts to need to learn a probabilistic programming language for implementation. As an engineer, I have found both the topics of Causal Inference and Probabilistic Programming to be extremely intimidating and challenging coming from a background with a very shallow education in statistics. I hope that my digging into these topics helps break these mental barriers for those approaching similar problems from a non-statistical background.

What is a Probabilistic Programming Language (PPL)?

Van de Meent et al.1, specifically in chapter 1, give a succinct approach to
explaining what a probabilistic programming language is. One simple paradigm that comes out of the description is that a PPL is simply any programming language that contains a pseudo-random number generator along with the inference methods necessary to allow one to perform conditional inference. Python, like most modern programming languages, has random number generators, Pyro provides the inference tools that Python naturally lacks. All of a sudden things don't sound as complicated right?

Obviously things are not so simple, but thinking of a PPL as a programming language as simple as Python with the inclusion of some statistical inference methods allows us to take a step in the right direction. Van de Meent et al.1 go further to express the two fundamentals of a PPL:

  1. A statisitcal model: Where a statistical model is a programmatic function that specifies a joint distribution in the form of the likelihood multiplied by a prior. This means that a function will sample for the prior and output any observed variables.
  2. Inference: The ability to infer any latent variables or model parameters based on observed data, which are the outputs of the model. Which if one were to simplify, becomes a conditioning problem.

It is evident that the above two steps have encoded Bayes Rule, where the model is the numerator representing the joint distribution and inference is the method of inferring the posterior. So in this case, it can be seen that the goal of a PPL is not necessarily to simulate every single combination in the joint distribution (which is what I naively thought that they did) but rather to do just enough to solve the conditioning problem that is the posterior distribution: the probability of a (potentially latent) variable on condition of observing a random variable in the model. One can then look at the posterior distribution as a conditional distribution parameterized by the observed variable1.

Why Pyro for Causal Inference?

Firstly Pyro is based on PyTorch which I personally think is:

  1. Well documented
  2. Syntactically Pythonic

which one will immediately confirm if one were to say try Tensorflow Probability. Secondly, Pyro from very early on has clearly indicated that causal inference is on their radar, specifically the ability to perform counterfactual reasoning using Single World Intervention Graph (SWIG) style node splitting, allowing for computation of counterfactual queries as the node site is split to retain both the intervened site, as well as the normal stochastic site to be conditioned on. We will get into the details of this in later posts, but for some background, see 3 and 4.

The result of this is that Pyro allows us to use Python as the programming language of choice to compose statistical models as functions and subsequently perform inference using statistical approaches.

In the following posts we will go through the basics of modelling in Pyro and attempting to understand its internals, ultimately leading up to techniques allowing us to perform causal inference.

Bibliography: