REINFORCE
Improving sequence-to-sequence models with reinforcement learning
Learning to predict word sequences
Language models and sequence-to-sequence models that generate text typically have an output layer that produces a logit for each word in the vocabulary. The logits are normalized using softmax, which gives a probability distribution over the vocabulary. The model is optimized by minimizing cross entropy, which measures how well our model distribution $p_{\theta}(w_t \mid w_1 \ldots w_{t-1})$ fits the empirical distribution in the training data:
Usually in language modeling and sequence generation tasks, this objective is used during training, with $w_1 \ldots w_{t-1}$ representing the ground-truth output sequence. It is fast to compute and works well for language modeling, where we have a huge corpus of sentences from the output distribution. It is not that good measure of the model performance in most sequence-to-sequence tasks, however, where there can be lots of different outputs that are correct for given input, but we observe only one example in the training data. For the same reason, machine translation models are not evaluated by the probability they give to the reference sequence. Instead, usually metrics such as BLEU and ROUGE are used, that compare n-gram statistics of the most likely word sequence generated by the model to those of the reference sequence. Clearly using cross entropy for training is less than optimal, when we evaluate the model using another metric at test time.
There’s also another problem in using cross entropy for training models that are intended for generating word sequences. During inference the model generates a sequence from the model distribution, which encompasses all possible word sequences. But during training the reference sequence (offset by one word) is fed into the model, and the model computes just the next word probabilities. The reference sequence deviates at each time step more and more from what the model would generate. This second problem was named exposure bias by Ranzato et al.
Formulation as a decision making problem
Metrics such as BLEU and ROUGE are not differentiable, so we cannot just compute one of them on generated word sequences and use that as the training objective. It is possible, however, to approach a sequence-to-sequence task using reinforcement learning, using the metric to reward the network based on sequences it would generate.
The idea is to formulate the problem as a decision making problem in the following way. An agent observes the state of the environment, which includes the word sequences and other input features. Based on the current state, the agent repeatedly takes an action generating the next word in the output sequence. The model is seen as a policy $p_\theta$, which dictates the next action.
The REINFORCE method is episodic. One episode ends when the agent generates the end-of-sequence token at time $T$. Generally speaking, the agent receives a reward $r_t$ after performing an action at time $t$. The return, or cumulative reward, from time $t$ onwards, is the sum of the rewards:
The value of a state is the expected cumulative reward by following policy $p_\theta$. Usually, when the task is to generate word sequences, we can only observe the cumulative reward $G_1 = R(W)$, for example the ROUGE score, after generating the entire sequence $W$.
REINFORCE objective and its gradient
REINFORCE is a policy-gradient method, solving the problem using stochastic gradient descent. This is possible when the parameters of the policy, $\theta$, are continuous. The objective function is the value at the beginning of the sequence:
The summation over word sequences makes direct computation of the objective, as well as the gradient, unfeasible, but they can be approximated by sampling. The objective could be approximated by sampling a sequence and computing the cumulative reward. However, for training a model we actually don’t need to approximate the objective function but its gradient. Stochastic gradient descent only requires that the expectation of the sampled gradients is proportional to the actual gradient (section 13.3 in Sutton and Barto). Let’s start by writing the gradient as an expectation over word sequences:
where we have used $\frac{\nabla x}{x} = \log \nabla x$.
This brings us to the REINFORCE algorithm, which is essentially an approximation of the gradient using a single sample $W$:
This quantity can be used as a sample of the gradient, since its expectation is equal to the gradient of the objective function. Implementation is quite easy with a library that supports automatic differentiation. One can simply take the gradient of $R(W) \log p_{\theta}(W)$ instead of the gradient of the actual objective.
Writing a differentiation operator for backpropagation is not too difficult either. Let’s say the input to the softmax at time $t$ is $o_t$. There is a simple expression for the partial derivatives of cross entropy over softmax output, assuming the reference output is a one-hot vector. We use $1(w_t)$ to denote a one-hot vector where the value corresponding to the word $w_t$ is one and other values are zero. Then the following gives an expression for the gradient with regard to the softmax input:
REINFORCE with baseline
While in theory it is enough that the expectation of the gradient sample is proportional to the actual gradient, having the training converge in a reasonable time is a whole another thing. A good estimate of the gradient generally has a low variance (variance measures how spread out the estimates are around the mean), meaning that the parameter updates have a low variance as well. The parameter update in REINFORCE is based on a single random output sequence sampled from the action space. It’s easy to reason that the longer the output sequences are, the less likely it is to obtain a sequence that results in an accurate estimate. Actually, the variance of the gradient estimate grows cubically with the sequence length (section 3 in Peters and Schaal).
We start by rewriting the loss function, taking into account that the cumulative reward is accumulated from rewards $r_t$ from individual time steps:
Zaremba and Sutskever show in Appendix A that the third equation above holds because actions cannot influence past rewards. The fourth equation was obtained by reordering the sums:
At certain states all actions have a higher value than in other states. It makes no difference with regard to the gradient, if the value of all actions in a particular state is changed by the same amount. In other words, we can subtract a quantity $b_t$ from the reward or cumulative reward of all the possible words $w_t$ that follow a certain partial output sequence $w_1 \ldots w_{t-1}$, without changing the gradient:
The function $b_t$ is called a baseline. It can be an arbitrary function of the state, as long as it doesn’t depend on the next action (i.e. it is constant with regard to $w_t$). This can be shown formally by taking $b_t$ outside of the expectation. It gets then be multiplied by the following term, meaning that the subtracted quantity is zero (Rennie et al):
where we have used $\nabla \log f(x) = \frac{\nabla f(x)}{f(x)}$.
The variance of the gradient estimates can be reduced by using a baseline that is higher for states that generally receive higher rewards. How to come up with such a baseline is not trivial. Some proposed approaches are listed below.
- Paulus et al: The baseline is the reward observed for a sequence that is generated by greedy decoding.
- Rennie et al: After generating the sequence until time step $t$ by sampling, the rest of the sequence is generated by greedy decoding. The reward observed for the combined sequence is used as the baseline at time step $t$.
- Keneshloo et al: More than one sampled sequence is used for estimating the gradient. The average reward of the sampled sequence is used as the baseline.
- Zaremba and Sutskever: An LSTM that runs over the same input as the model is used to predict $G_t$ at time step $t$.
- Ranzato et al: A linear regressor that takes as input the hidden states of the model is used to predict $r_t$ at time step $t$.
Comments