We’ve spent a lot of time figuring out how better to estimate probability distributions $p_\theta$ over sequences of text. In this note, we’ll go over strategies for taking this distribution and generating text. These strategies will depend on the kind of text we’re generating, and what errors we think exist in the distribution our model has learned.
We’ve trained our model autoregressively, conditioning on tokens used as input (and those we’ve generated so far) and predicting the next token. We’ll sample from it similarly. Let $x_1,\dots,x_k$ be an input token sequence, and $y_1,\dots,y_{t-1}$ be a (possibly empty) string of output tokens generated so far. For brevity, we will denote these sequences as $x$ and $y_{\lt t}$ respectively. We sample a next token as:
$$ \begin{align} & y_t \sim p_\theta(y_t \mid x, y_{\lt t}) & \text{(sampling, or ancestral sampling)} \end{align} $$The sampling process is a loop of sampling subsequent $y$ until a stopping condition is met — either a special end of text
token is generated, or a maximum sequence length is met.
Overall, as models become increasingly strong, and as we clean the data they’re trained on, this simple sampling method becomes better and better. Intuitively, we work really hard to estimate the distribution $p_\theta$, and the better it is, the less we need to do afterward to “fix” it. But consider that some text is noisy, our model isn’t perfect, and the model likely spreads some probability mass over output sequences that aren’t what we want. Assuming our model is good, we might instead want find the most likely output under the model:
$$ \begin{align} &y_{1:T} = \arg\max_{y\in \mathcal{V}*} p_\theta(y\mid x) & \text{(argmax)} \end{align} $$This is much easier to write out than it is to compute. Whereas during sampling, we run the model’s forward pass once per token we end up generating, finding the argmax of our model is a search through an exponentially large space (base is the vocab size, exponent is the maximum sequence length) ($|\mathcal{V}|^T$).1 In practice, people don’t really do this, but it’s a good thought experiment to consider whether you think the maximum-probability output of your model will be meaningfully better than a sample.
Indeed, approximations to finding the argmax are often used in practice. One is greedy decoding, so named as it’s the greedy algorithm for approximating the argmax search. We set:
$$ \begin{align} & y_t =\arg\max_{w\in\mathcal{V}} p_\theta(w \mid x, y_{\lt t}) & \text{(greedy decoding)} \end{align} $$This picks the most likely output word at each step. This is a common generation strategy because, like sampling, it costs one forward pass per token generated, but like the argmax, it avoids low-probability sequences. It’s a nice default to try if you’re playing with a model.
One issue with mode-seeking generation strategies (the mode is the highest-probability element of a distribution, so mode-seeking just means it’s attempting to find the mode, like, e.g., greedy decoding) is that you always get the same response if you generate multiple times.2 Sampling from the trained distribution $p_\theta$ certainly gives you many possible generations (if there’s entropy in the trained distribution) but may generate low-probability, and potentially low-quality, outputs.
The most successful strategies for hitting a nice tradeoff between how much of the distribution we keep (coverage/diversity) and how high-quality our outputs are on average, are in a family of algorithms called (by some, including me) truncation sampling. These algorithms have the following form
$$ \begin{align} & \mathcal{A} = \text{select}(\mathcal{V}, p_\theta, x, y_{\lt t}) \ \ (\text{choose accepted set } \mathcal{A}\subseteq \mathcal{V})\\ & p_{\mathcal{A},\theta}(w\mid x, y_{\lt t}) = \begin{cases} p_\theta(w\mid x, y_{\lt t})/Z & w \in \mathcal{A}\\ 0 & \text{otherwise}\end{cases}\\ &y_t \sim p_{\mathcal{A},\theta}(w\mid x, y_{\lt t})\ \ \text{(sample from accepted set)} \end{align} $$where $Z=\sum_{w\in\mathcal{A}} p_\theta( w\mid x, y_{\lt t})$ is the normalization constant — the sum of probabilities of words that we’re keeping in the accepted set. The function select is where the fun happens; we get to use a range of heuristics to decide which words stay in the accepted set, and which ones have probabilities set to zero. Usually, these heuristics relate to the probability of the word in that context — high-probability words are probably good continuations and are kept; low-probability words are more likely to be bad continuations and may be cut. We can re-frame greedy decoding under this family by setting the select function to only have the most likely next word in $\mathcal{A}$ at each step. Note that the accept set $\mathcal{A}$ is computed for each prefix (we’re omitting a subscript for brevity.)
There are many ways to implement the “high probability is probably a good continuation; low-probability not so much” intuition. One very simple method is to just set some minimum threshold of probabilities for the accept set:
$$ \begin{align} \mathcal{A} = \{w \in \mathcal{V} \mid p_\theta(w \mid x, y_{\lt t}) > \epsilon\} \end{align} $$Intuitively, low probabilities are not just indicative of unlikely continuations, it’s also just hard to estimate the probability of a thing whose true probability is very very low. This algorithm epsilon sampling [hewitt2022truncation] implements this intuition. There are some details here that one has to keep in mind however; for example, it is possible that no words are above the $\epsilon$ threshold, in which case often implementations will just generate the argmax word at that timestep.
By far the most popular truncation method is top-p sampling, also called nucleus sampling [Holtzman2020The]. The intuition of top-p sampling is that the most likely $p$ percent of the model’s distribution at any timestep is “good” (should be kept) and the remaining $1-p$ lowest-probability percent is bad.
$$ \begin{align} &w^{(1)},\dots, w^{(|\mathcal{V}|)} = \arg \text{sort}_{w\in\mathcal{V}} \ \ p_\theta(w\mid x, y_{\lt t})\\ &k = \min \{i \in \mathcal{N} \mid \sum_{j=1}^{i} p_\theta(w^{(j)} \mid x, y_{\lt t}) \geq p\}\\ &\mathcal{A} = \{w^{(1)},\dots,w^{(k)}\} \end{align} $$To interpret this, think of (1) sorting the vocabulary in decreasing order of probability, and then taking the $k$ most probable words such that $k$ is the minimal set whose sum of probabilities it at least $p$.
Each of these algorithms makes different assumptions about what is ``wrong’’ about the distribution we’ve learned. Each also includes a parameter ($p$ or $\epsilon$) which controls the tradeoff of how much you cut of the distribution (which you’d like to avoid) vs how much you avoid potentially generating low-quality outputs (by cutting off the low-likelihood words.)
A final method we must discuss is temperature sampling, which is not a truncation sampling algorithm. It does not explicitly set any probabilities to zero; instead intuitively it interpolates in log space between the uniform distribution (arbitrarily high temperature) and the most likely token (zero temperature). The distribution is as follows:
$$ \begin{align} p_{\tau,\theta}(w\mid x, y_{\lt t}) = \frac{\log p_{\theta}(w\mid x, y_{\lt t})^{1/\tau} }{\sum_{w’\in\mathcal{V}} \log p_{\theta}(w’\mid x, y_{\lt t})^{1/\tau}} \end{align} $$Note the exponent — when $\tau$ approaches zero, logits are exponentiated, making the largest even larger relative to the others. When $\tau$ approaches infinity, logits are raised to the power of roughly $0$, so each gets probability roughly $\frac{1}{\mathcal{V}}$.