Thoughts and Theory

Counterfactuals for Reinforcement Learning I: “What if… ?”

Introduction to the POMDP framework and counterfactuals

Felix Hofstätter
Towards Data Science
8 min readDec 30, 2021

--

In philosophy, a counterfactual thought experiment asks: “What would have happened if A had happened instead of B?”. Gaining insights into the real world from such hypothetical considerations is an important aspect of human intelligence.

This two-part series will explore how counterfactual thinking can be modeled within the Reinforcement Learning framework and how this can help with creating safer agents. In this first part, I will lay out the mathematical foundations and in the next part, I will explain how counterfactuals can be applied to the setting of reward function learning. This series is self-contained and the only requirement is an understanding of Reinforcement Learning and Markov Decision Processes in particular. If you want to learn more about what motivates research into counterfactuals you can check out my previous article about problems of learning reward functions for Reinforcement Learning.

Partially Observable Markov Decision Processes

As we know, a Markov Decision Process (MDP) consists of

  • S: the set of states
  • A: the set of actions
  • T: the transition function which determines the probability of the next state given a state and an action
  • R: the reward function which assigns a score to histories of states and actions

Time in MDPs is usually measured in discrete timesteps, with an agent’s action a(t) at timestep t following the observed state s(t) at that same timestep. Using R, an agent is rewarded for sequences of states and actions o(1)a(1)o(2)a(2)… and the goal is usually to maximise that reward. For this purpose, we look for a policy, which is a function from a state to a distribution over actions and determines the agent’s behaviour.

MDPs are a versatile model that lies at the heart of Reinforcement Learning. However, sometimes it is convenient to use a more elaborate model. For example, an agent’s sensors might not be able to distinguish some states with certainty. You could model this as an MDP where states correspond to believes about the true state. This is how an implementation in a computer program might do it. But in terms of mathematics, it is easier to express this uncertainty using a so-called Partially Observable Markov Decision Process (POMDP). Given an MDP (S,A,T,R), a POMP adds two additional elements:

  • Ω: the set of possible observations
  • O: the state-observation function which gives a probability distribution over observations given the environment’s state.

In a POMDP the agent doesn’t see the true state anymore. Instead, the agent’s input is an observation that is probabilistically determined from the state using O.

To talk about counterfactuals later I also need to define histories: A history h on a POMDP is a sequence of observations and actions h = o(1)a(1)o(2)a(2)… were o(n) and a(n) represent the observations and actions at timestep n of the POMDP. Since it will always be clear what POMDP I am talking about, I will use notation like P(h) to denote the probability of seeing history h in the given POMDP. P(s(n) = s) will be used to talk about the probability that the state at timestep n is the specific state s. Conversely, P(a(n) = a) is the probability that the agent takes action at timestep n.

Let us consider an example of how to model a problem as an MDP. In our scenario, the agent participates in a game show and get to choose between one of two doors to win a prize. Behind one of the doors is a cash prize but if the agent chooses wrongly they have to go home empty-handed. The underlying environment of this model is simple: There are two potential initial states, denoted s(l) and s(r), which correspond to the prize being behind the left or right door respectively. The available actions are to open either door: a(l) for the left and a(r) for the right. Depending on our initial state, after picking an action the agent will transition to the final state of winning the cash prize s($), or nothing at all s(N). What makes the problem tricky is that an agent does not know in which of the initial states he starts. Both initial states correspond to the same observation o⁰, while the final states result in distinct observations o¹ for s($) and o² for s(N).

The POMDP of the gameshow problem. Image by author, inspired by [2]

There are four possible histories on this POMDP: o⁰a(r)o¹, o⁰a(r)o², o⁰a(l)o¹, and o⁰a(l)o². To maximise their chance of winning the agent wants a policy that maximises the probabiliy of creating a history that ends with observation o¹. As the first observation is always o⁰, the only variable is the action a the agent takes in the first timestep and so we are looking for the probability P(o⁰ao¹). However, the probability that picking a door will result in winning the prize is determined by the probability of the true nature of the first state s(1): P(o⁰a(r)o¹) = P(s(1) = s(r))P(a(1) = a(r)) and P(o⁰a(l)o¹) = P(s(1) = s(l))P(a(1) = a(l)). In less mathsy terms, the probability of winning when picking the right door is the same as the probability that the prize is to the right and that the agent opens the right door, and similarly for the left door. If we assume that the winning door is selected randomly by the gameshow runners, then P(s(1) = s(r)) = P(s(1) = s(l)) = 0.5 and the agent cannot do better than random by picking the probabilities of a(1).

This conundrum is where counterfactuals come into play. If we assume that picking the left door would result in losing, then it is trivial to infer the true nature of state s(1) and that we need to take the right door to win. How do we encode this mathematically?

Counterfactual Probabilities on POMDPs

To properly explain counterfactuals on POMDPs I first need to explain some underlying assumptions. Colloquially, counterfactual reasoning may be expressed with a statement like “If it were the case that X, then we can infer Y” — in our example this becomes “If it were the case that picking the left door leads to nothing, then we can infer that picking the right door leads to the prize”. In terms of POMDs, X and Y are statements about histories: “If it were the case that the counterfactual history h’ had happened, then we can infer that history h has this probability”. However, without further assumptions, it is not possible to make inferences from h’ about h. In our example, the counterfactual history would be h’ = o⁰a(l)o². By using “common sense”, knowing how h’ would play out lets us infer that P(o⁰a(r)o¹) = P(a(r)) and so the agent should set P(a(r)) = 1. However, mathematically the observation o⁰ at timestep 1 may indicate either s(1) = s(r) or s(1) = s(l), so ultimately P(o⁰a(r)o¹) comes down to the probability distribution of s(1).

The mathematical result differs from our “common sense” expectation because we haven’t yet encoded the underlying assumption that, all other facts about the world remain the same, except for the actions which we change. Specifically, this includes the state of the world up to the point where we change things. Hence, we need to assume that h and h’ only diverge after a timestep t, so up to that point, their observations and actions are identical. We will use h(t) to denote the history up to step t which is shared by h and h’. Further, we need to assume that the “true state of the world” was the same for both until t. In other words, if s(n) is the state that produced the nth observation in h, and similarly for s’(n) and h’, then s(n) = s’(n) for n ≤ t.

On a given POMDP we can now reason about P(h | h’, t): the probability of seeing h given the counterfactual history h’ which diverges from h at t. To do so we will need two more probabilities. Firstly, that of the true state being s at the point of divergence when we have seen the counterfactual history: P(s(t) = s | h’). Further, we need to probability that after seeing history h(t) when the true state at timestep t is s, we will eventually see history h: P(h | h(t), s(t) = s). The relation between these three probabilities is as follows:

Formular for the probability of h given a counterfactual history h’ which diverges at t.

For every possible state s, the probability of getting history h if s is the true state at the point of convergence t is multiplied with the probability that s is were true state if we assume the counterfactual.

With this formula, it is easy to obtain the “common sense” result about P(o⁰a(r)o¹). In this case, t=1 since only the very first observation is the same for both h’ and h.

The left term in the second row is the probability of the agent opening the right door because when the true initial state is s(r), then the choice of the door will deterministically lead to o¹ or o². The right term simplifies to 1, since given the history were picking the left door led to not winning anything, the initial state must have been s(r). In the third line, both terms simplify to 0, because they describe the probabilities of contradictory events: When in the initial state the money is behind the left door, then opening the right door can not lead to the winning observation o¹. Conversely, it is not possible that in the initial state the prize was to the left when opening the left door led to losing.

Conclusions and Applications

We have seen how the idea of a counterfactual history h’ can help us make inferences about the probability of a desired history h. This in turn can help us decide on an agent’s policy. However, we have only considered a trivial example in which the optimal policy can be inferred by inspection. A real application of counterfactuals on POMPDs can be found in the field of Reward Learning [1][3]. While it is inspired by the math presented in this article, it changes up the details and adds some new elements. Hence, you should stay tuned for the second part of this series in which I will explore the application of counterfactuals to Reward Learning!

[1] Armstrong et al, Pitfalls of learning a reward function online, Arxiv, 28th April 2020, https://arxiv.org/abs/2004.13654

[2] Armstrong Stuart, Counterfactuals on POMDP, AlignmentForum, 2nd june 2017, https://www.alignmentforum.org/posts/5bd75cc58225bf06703752a9/counterfactuals-on-pomdp

[3] Everitt et al, Reward Tampering Problems and Solutions in Reinforcement Learning: A Causal Influence Diagram Perspective, Arxiv, 29th March 2021, https://arxiv.org/abs/1908.04734

--

--

Software Consultant at TNG Technology Consulting. Passionate about Reinforcement Learning, AI Alignment and Effective Altruism