Top-K Off-Policy Correction for a REINFORCE Recommender System

Mikhail Scherbina
Towards Data Science
13 min readNov 28, 2019

--

Make your newsfeed propaganda free again with recnn! Photo by freestocks.org on Unsplash

OffKTopPolicy is now available for your usage out of the box with no prerequisites in my Reinforced Recommendation Library!

The code is also online in Colab with TensorBoard visualization:

Original Paper ‘Top-K Off-Policy Correction for a REINFORCE Recommender System’ by Minmin Chen et al:

A couple of words before we start: I am the creator of recnn — reinforced recommendation toolkit built around PyTorch. The dataset I use is ML20M. The rewards are in [-5, 5], the state is continuous, action space is discrete.

Also, I am not google staff, and unlike the paper authors, I cannot have online feedback concerning the recommendations. I use actor-critic for reward assigning. In a real-world scenario that would be done through interactive user feedback, but here I use a neural network (critic) that aims to emulate it.

Understanding the REINFORCE

An important thing to understand: the paper describes continuous sate space with discrete actions;

For each user, we consider a sequence of user historical interactions with the system, recording the actions taken by the recommender,i.e., videos recommended, as well as user feedback, such as clicks and watch time. Given such a sequence, we predict the next action to take,i.e., videos to recommend, so that user satisfaction metrics,e.g., indicated by clicks or watch time, improve.

Trying to understand this topic, I have found out that I don’t know what the algorithm is. Pretty much everyone is familiar with the Q-Learning: we have a Value Function mapping (State) -> Values. Q Learning is literally everywhere, countless tutorials have already covered it. Note here values are accumulated rewards obtained through temporal difference. Reinforce is similar to Q Learning. Basically you need to understand the difference between value and policy iteration:

  1. Policy iteration includes: policy evaluation + policy improvement, and the two are repeated iteratively until policy converges.
  2. Value iteration includes: finding optimal value function + one policy extraction. There is no repeat of the two because once the value function is optimal, then the policy out of it should also be optimal (i.e. converged).
  3. Finding optimal value function can also be seen as a combination of policy improvement (due to max) and truncated policy evaluation (the reassignment of v_(s) after just one sweep of all states regardless of convergence).
  4. The algorithms for policy evaluation and finding optimal value function are highly similar except for a max operation (as highlighted)
  5. Similarly, the key step to policy improvement and policy extraction are identical except the former involves a stability check.

Note: this is not my explanation, I copied it from here:

Now let’s translate this to python:

The following images use REINFORCE implementation by PyTorch. Link to file on Github.

code is from official pytorch examples
code is from official pytorch examples
code is from official pytorch examples

Using recnn’s implementation

This week I released REINFORCE implementation to use out of the box with my library. Let’s take a look at what it takes to use Reinforce with recnn:

prepare_dataset is a chain of responsibility pattern that transforms your data into usable state. Note: usage of that is completely optional, and if you don’t want to mess with an understanding of recnn inner workings, transform your data how you like it beforehand.

embed_batch applies embeddings to individual batches. As you recall, the state is continuous embeddings, whereas the action space is discrete. Luckily, there is a version for your usage out of the box: batch_contstate_discaction.

Because I do not have a very powerful PC (1080ti peasant), I cannot use large action space. Thus I truncate the actions to be just 5000 most frequent movies.

I am not google staff, and unlike the paper authors, I cannot have online feedback concerning the recommendations. I use critic for reward assigning. In a real-world scenario that would be done through interactive user feedback, but here I use a neural network (critic) that aims to emulate it.

Now, let’s define the networks and the algorithm itself.

After 1000 steps we get:

Note: the policy loss is pretty large!

Back to the paper: Off Policy Correction

The problem is the following: we have multiple other policies. Let’s take DDPG and TD3 trained actors from my library. Given these policies, we want to learn a new, unbiased one in an off-policy manner. As authors put it:

Off-Policy Candidate Generation: We apply off-policy correction to learn from logged feedback, collected from an ensemble of prior model policies. We incorporate a learned neural model of the behavior policies to correct data biases.

Do not confuse it with transfer learning:

A naive policy gradient estimator is no longer unbiased as the gradient in Equation (2) requires sampling trajectories from the updated policy πθ while the trajectories we collected were drawn from a combination of historical policies β. We address the distribution mismatch with importance weighting

So, put simply, Off Policy Correction is importance weighting due to the previous recommendation being biased and based on the existing model.

Now let’s consider the formulas:

β — is a historical, also called behavior policy (Previous model)

πθ — is an updated policy (New model)

Going over the tough math listed in the paper, the authors arrive at this fancy formula:

Proposed Reinforce with OffPolicyCorrection

What is says is that the Action A probability, given state S has the importance of Updated Policy / Historical Policy on the timestep T. They conclude that the product over the entire episode is not needed and that the first-order approximation is good enough. More about approximations.

P.S. First-order approximation: f(x+h)≈f(x)+f′(x)×h.

Equation [3] shows that this relation (Pi/Beta) is just used as an importance weighting term. If you look at the original REINFORCE update, the only thing that differs these functions is the aforementioned importance weighting:

Original Reinforce

Parametrizing the policy πθ (Network Architecture)

The next section is just a plain description of the model, the authors use. It has little to do with Off-Policy Correction.

If you’d been paying attention to my work, you know that was mostly doing continuous action space stuff, like DDPG or TD3. This paper focuses on a discrete action space. The thing with the discrete action space is that it grows very big very quickly. To address the problem of that growth, the authors use Sampled Softmax, which comes in with TensorFlow by default. Sadly enough, there is no such option with PyTorch. Thus, instead of Sampled Softmax the authors use, I will be utilizing Noise Contrastive Estimation, which is a similar method that has a Pytorch Library: Stonesjtu/Pytorch-NCE.

Anyway, let get to the network architecture:

The authors are using simple Recurrent State Representation, nothing out of the ordinary here:

We model our belief on the user state at each time t, which captures both evolving user interests using an n-dimensional vector. The action taken at each time t along the trajectory is embedded using an m-dimensional vector u. We model the state transition P: S×A×S with a recurrent neural network.

With that new state S, they perform softmax

Where v_a ∈ R n is another embedding for each action in the action space A and T is a temperature that is normally set to 1. The key takeaway is that they are using another embedding rather than one used in state representation. And a cool thing that we will be learning these embeddings via torch.nn.Embedding

Do not look at the scary arrow going from RELU to the left at β_θ` with the block gradient on it. We will ignore it for now. So all in all the network architecture, for now, looks like:

Estimating the behavior policy β

Now back to that scary arrow with block gradient going from RELU to the left at β_θ`:

Ideally, for each logged feedback of a chosen action we received, we would like to also log the probability of the behavior policy choosing that action. Directly logging the behavior policy is however not feasible in our case as (1) there are multiple agents in our system, many of which we do not have control over, and (2) some agents have a deterministic policy, and setting β to 0 or 1 is not the most effective way to utilize these logged feedback.

They kinda say that they in Google have a lot of behavior policies to learn from. But in the Correction section, they have just one behavior policy β for importance weighting.

Instead, we take the approach first introduced in [39], and estimate the behavior policy β, which in our case is a mixture of the policies of the multiple agents in the system, using the logged actions. Given a set of logged feedback D = {(si , ai),i = 1, · · · , N}, Strehl et al. [39] estimates ˆβ(a) independent of user state by aggregate action frequency throughout the corpus. In contrast, we adopt a context-dependent neural estimator.

The paper mentioned in [39] dates to 2010, so they decide to use fancies of deep learning. Far as I understand it, they take outputs from these other historical models and use a simple deep learning model to forecast the historical policy β. It is also done the way, that it uses common state representation module, an RNN cell with ReLU. But the gradient passage is blocked, thus the state representation only learns from the updated policy π_θ. For PyTorch folks: block gradient = tensor.detach()

Now let’s see what we got so far

P.S. The first (zeroest) notebook has a bare-bones implementation of REINFORCE with minimal recnn’s usage. Everything is explained here in details so you don’t have to look up the source code.

Anyways, here is how off-policy correction is implemented inside recnn:

difference between select action and select action with correction
difference between reinforce and reinforce with correction
Beta class. Notice: it learns to resemble historical policy by calculating cross-entropy with action

You might have noticed that I didn’t include Secondary Item Embeddings, Noise Contrastive Estimation, and other things the authors described. But these are not essential to the idea behind the algorithm. You can use simple PyTorch softmax without fancy log uniform sampling, the code will remain the same.

OffPolicy correction with recnn

The code remains the same: we start by defining the networks and the algorithm.

Remember how the _select_action_with_correction method in the nn.DiscretePolicy class required action? It is passed here. I know, this is bad code, I haven’t figured out a better way to do so. Action is passed as well for the Beta update. And don’t forget to choose reinforce_with_correction in parameters, after all that was the main purpose!

If you look at the losses, they didn’t change much. Or did they? If you look at the previous loss graph, the loss was measured in the orders of thousands.

Make sure to check ‘ignore outliers’
Correction in action: w/o outliers ignored. Start’s pretty rough, but then corrected
beta_log_prob
correction can get pretty large

Top K Off-Policy Correction

This will do for a single recommendation. But keep in mind that we are dealing with a set of these. Authors introduce it as follows:

  1. π_θ — our policy
  2. A — is a set of actions sampled from π_θ
  3. Π_Θ — is a new policy that instead of a single action, produces a set of K recommendations. Though, off-topic: there is a cool thing called gumbel_softmax. It would be cool to use it, instead of resampling. Gumbel max trick explanation, The Gumbel-Top-k Trick for Sampling Sequences Without Replacement (Max Welling in co-authors)

4. Reward setup and more on Π_Θ

5. α_θ (a|s) = 1 − (1 − πθ (a|s))^K is the probability that an item a appears in the final non-repetitive set A. Here K = |A ′ | > |A| = k. P.S. K is a size sampled set with duplicates. k is the size of the same set, but with duplicates removed. As a result of the sampling with replacement and de-duplicate, the size k of the final set A can vary.

Most of these variables are introduced to help understand the math behind it all. For instance, Π_Θ is never used. The formula then simplifies to a very simple expression, but I want you to understand something now.

Top K has to do with the probability of picking an item from a set. With repetitions, your distribution will be binomial. Here is the Statisticshowto page ob the binomial distribution. Indeed, take a look at α_θ above. Doesn’t it resemble something?

With that in mind, we arrive at:

P.S. It is not explained how they came up with this formula, but as I said it has to do with the binomial distribution and #ways to select TopK items with repetition. Hands-on combinatorics. That’s it! In the end, we arrived at a very simple formula that can be added to our existing module and update function with ease:

lambda_k = K*(1 - pi_prob)**(K-1)

Implementing in code

difference between top k and not top k
difference between top k and not top k

Link to the notebook. Nothing is changed, except reinforce_with_correction -> reinforce_with_TopK_correction, and _select_action_with_correction -> _select_action_with_TopK_correction.

losses be looking unusual
Lambda_K stays pretty small, unlike correction

Everything else in graphs looks exactly the same as non-TopK version, and why would it differ? The dataset and nets didn’t change. Now let’s see exactly why the policy loss is that way:

  • As πθ (a|s) → 0, λK (s, a) → K. The top-K off-policy correction increases the policy update by a factor of K comparing to the standard off-policy correction
  • As πθ (a|s) → 1, λK (s, a) → 0. This multiplier zeros out the policy update.
  • As K increases, this multiplier reduces the gradient to zero faster as πθ (a|s) reaches a reasonable range.

In summary, when the needed item has a small activation value in the softmax policy πθ (·|s), the top-K correction more aggressively pushes up its likelihood than the standard correction. Once the policy πθ (·|s) starts having required activation value in softmax on the specific item (to ensure it will be likely to appear in the top-K), the correction then makes the gradient 0, thus not learning it, so that the other K-1 items can appear in the recommendation.

Results comparison

All of the results can be seen and interacted with online in a Google Colab notebook linked at the top.

P.S. to look at the Corrected loss, click on the circle radio button instead of squared checkbox in Runs section on the left.

Policy: orange — NoCorrection, blue — Correction, red — TopK Correction

The correction itself just makes the loss small. TopK makes the loss resemble original loss w/o correction, but this time more zero centered, without clear trend.

Value Loss: nothing unusual

Things I didn’t do:

The authors also mention Variance Reduction Techniques. In my case, I had reasonable losses, so I think I do not need them. But implementing these shouldn’t be a problem if needed.

  1. Weight Capping: correction = min(correction, 1e4)
  2. Pro tip: look at the correction distribution in the tensorboard to figure out some large number for weight capping
  3. Policy Gradient with Importance Sampling: I have found this repo github.com/kimhc6028/policy-gradient-importance-sampling
  4. TRPO: there are plenty of implementations. Choose what you like. Maybe PPO will be a better choice.

That’s it

Don’t click away just yet. There are a couple of ways you can help RecNN:

  1. Make sure to clap and give a star on GitHub: https://github.com/awarebayes/RecNN
  2. Use with your project/research
  3. There are some features recnn lacks at the moment: sequential environments and no easy configurable user embeddings. There is experimental sequential support, but it is yet to be stable. Maybe consider contributing

StreamLit demo of recnn is coming! You can make it go faster by giving the library a star:

There is also a license change coming to recnn. You will not be able to use recnn for production recommendation without encryption and federated learning. User embeddings should also be stored on their devices. I believe in free speech and the first amendment, so the recommendation should only be done in a federated (distributed) manner. The next article is about federated learning and recnn’s PySyft (github.com/OpenMined/PySyft) integration!

--

--

Deep Learning Researcher at GosNIIAS (Computer Vision, Reinforcement Learning), BMSTU Software Engineering student