Decoding Strategies in Large Language Models

A Guide to Text Generation From Beam Search to Nucleus Sampling

Large Language Models
Author

Maxime Lbonne

Published

June 7, 2023

Check out the LLM Engineer’s Handbook to master the art of LLMs from fine-tuning to deployment👇 LLM Engineer's Handbook

In the fascinating world of large language models (LLMs), much attention is given to model architectures, data processing, and optimization. However, decoding strategies like beam search, which play a crucial role in text generation, are often overlooked. In this article, we will explore how LLMs generate text by delving into the mechanics of greedy search and beam search, as well as sampling techniques with top-k and nucleus sampling.

By the conclusion of this article, you’ll not only understand these decoding strategies thoroughly but also be familiar with how to handle important hyperparameters like temperature, num_beams, top_k, and top_p.

The code for this article can be found on GitHub and Google Colab for reference and further exploration.

📚 Background

To kick things off, let’s start with an example. We’ll feed the text “I have a dream” to a GPT-2 model and ask it to generate the next five tokens (words or subwords).

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()

text = "I have a dream"
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)

outputs = model.generate(input_ids, max_length=len(input_ids.squeeze())+5)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Generated text: I have a dream of being a doctor.

The sentence “I have a dream of being a doctor” appears to have been generated by GPT-2. However, GPT-2 didn’t exactly produce this sentence.

There’s a common misconception that LLMs like GPT-2 directly produce text. This isn’t the case. Instead, LLMs calculate logits, which are scores assigned to every possible token in their vocabulary. To simplify, here’s an illustrative breakdown of the process:

The tokenizer, Byte-Pair Encoding in this instance, translates each token in the input text into a corresponding token ID. Then, GPT-2 uses these token IDs as input and tries to predict the next most likely token. Finally, the model generates logits, which are converted into probabilities using a softmax function.

For example, the model assigns a probability of 17% to the token for “of” being the next token after “I have a dream”. This output essentially represents a ranked list of potential next tokens in the sequence. More formally, we denote this probability as P(\text{of } | \text{ I have a dream}) = 17%.

Autoregressive models like GPT predict the next token in a sequence based on the preceding tokens. Consider a sequence of tokens w = (w_1, w_2, \ldots, w_t). The joint probability of this sequence P(w) can be broken down as:

\begin{align} P(w) &= P(w_1, w_2, \ldots, w_t) \\ &= P(w_1) P(w_2 | w_1) P(w_3 | w_2, w_1) \ldots P(w_t | w_1, \ldots, w_{t-1}) \\ &= \prod_{i=1}^t P(w_i | w_1, \dots, w_{i-1}). \end{align}

For each token w_i in the sequence, P(w_i | w_1, \ldots, w_{i-1}) represents the conditional probability of w_i given all the preceding tokens (w_1, \ldots, w_{i-1}). GPT-2 calculates this conditional probability for each of the 50,257 tokens in its vocabulary.

This leads to the question: how do we use these probabilities to generate text? This is where decoding strategies, such as greedy search and beam search, come into play.

🎲 Top-k sampling

Top-k sampling is a technique that leverages the probability distribution generated by the language model to select a token randomly from the k most likely options.

To illustrate, suppose we have k = 3 and four tokens: A, B, C, and D, with respective probabilities: P(A) = 30%, P(B) = 15%, P(C) = 5%, and P(D) = 1%. In top-k sampling, token D is disregarded, and the algorithm will output A 60% of the time, B 30% of the time, and C 10% of the time. This approach ensures that we prioritize the most probable tokens while introducing an element of randomness in the selection process.

Another way of introducing randomness is the concept of temperature. The temperature T is a parameter that ranges from 0 to 1, which affects the probabilities generated by the softmax function, making the most likely tokens more influential. In practice, it simply consists of dividing the input logits by a value we call temperature:

\text{softmax}(x_i) = \frac{e^{x_i / T}}{\sum_{j} e^{x_j / T}}

Here is a chart that demonstrates the impact of temperature on the probabilities generated for a given set of input logits [1.5, -1.8, 0.9, -3.2]. We’ve plotted three different temperature values to observe the differences.

A temperature of 1.0 is equivalent to a default softmax with no temperature at all. On the other hand, a low temperature setting (0.1) significantly alters the probability distribution. This is commonly used in text generation to control the level of “creativity” in the generated output. By adjusting the temperature, we can influence the extent to which the model produces more diverse or predictable responses.

Let’s now implement the top k sampling algorithm. We’ll use it in the beam_search() function by providing the “top_k” argument. To illustrate how the algorithm works, we will also plot the probability distributions for top_k=20.

def plot_prob_distribution(probabilities, next_tokens, sampling, potential_nb, total_nb=50):
    # Get top k tokens
    top_k_prob, top_k_indices = torch.topk(probabilities, total_nb)
    top_k_tokens = [tokenizer.decode([idx]) for idx in top_k_indices.tolist()]

    # Get next tokens and their probabilities
    next_tokens_list = [tokenizer.decode([idx]) for idx in next_tokens.tolist()]
    next_token_prob = probabilities[next_tokens].tolist()

    # Create figure
    plt.figure(figsize=(0.4*total_nb, 5), dpi=300, facecolor='white')
    plt.rc('axes', axisbelow=True)
    plt.grid(axis='y', linestyle='-', alpha=0.5)
    if potential_nb < total_nb:
        plt.axvline(x=potential_nb-0.5, ls=':', color='grey', label='Sampled tokens')
    plt.bar(top_k_tokens, top_k_prob.tolist(), color='blue')
    plt.bar(next_tokens_list, next_token_prob, color='red', label='Selected tokens')
    plt.xticks(rotation=45, ha='right', va='top')
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    if sampling == 'top_k':
        plt.title('Probability distribution of predicted tokens with top-k sampling')
    elif sampling == 'nucleus':
        plt.title('Probability distribution of predicted tokens with nucleus sampling')
    plt.legend()
    plt.savefig(f'{sampling}_{time.time()}.png', dpi=300)
    plt.close()

def top_k_sampling(logits, temperature, top_k, beams, plot=True):
    assert top_k >= 1
    assert beams <= top_k

    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    new_logits = torch.clone(logits)
    new_logits[indices_to_remove] = float('-inf')

    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(new_logits / temperature, dim=-1)

    # Sample n tokens from the resulting distribution
    next_tokens = torch.multinomial(probabilities, beams)

    # Plot distribution
    if plot:
        total_prob = torch.nn.functional.softmax(logits / temperature, dim=-1)
        plot_prob_distribution(total_prob, next_tokens, 'top_k', top_k)

    return next_tokens

# Start generating text
beam_search(input_ids, 0, bar, length, beams, 'top_k', 1)

These plots give a good intuition of how top-k sampling works, with all the potentially selected tokens on the left of the horizontal bar. While the most probable tokens are selected (in red) most of the time, it also allows less likely tokens to be chosen. This offers an interesting tradeoff that can steer a sequence towards a less predictable but more natural-sounding sentence. Now let’s print the text it generated.

sequence, max_score = get_best_sequence(graph)
print(f"Generated text: {sequence}")
Generated text: I have a dream job and I want to

The top-k sampling found a new sequence: “I have a dream job and I want to”, which feels significantly more natural than “I have a dream. I have a dream”. We’re making progress!

Let’s see how this decision tree differs from the previous one.

# Plot graph
plot_graph(graph, length, beams, 'sequence')

You can see how the nodes differ significantly from the previous iteration, making more diverse choices. Although the sequence score of this new outcome might not be the highest (-1.01 instead of -0.69 previously), it’s important to remember that higher scores do not always lead to more realistic or meaningful sequences.

Now that we’ve introduced top-k sampling, we have to present the other most popular sampling technique: nucleus sampling.

🔬 Nucleus sampling

Nucleus sampling, also known as top-p sampling, takes a different approach from top-k sampling. Rather than selecting the top k most probable tokens, nucleus sampling chooses a cutoff value p such that the sum of the probabilities of the selected tokens exceeds p. This forms a “nucleus” of tokens from which to randomly choose the next token.

In other words, the model examines its top probable tokens in descending order and keeps adding them to the list until the total probability surpasses the threshold p. Unlike top-k sampling, the number of tokens included in the nucleus can vary from step to step. This variability often results in a more diverse and creative output, making nucleus sampling popular for tasks such as text generation.

To implement the nucleus sampling method, we can use the “nucleus” parameter in the beam_search() function. In this example, we’ll set the value of p to 0.5. To make it easier, we’ll include a minimum number of tokens equal to the number of beams. We’ll also consider tokens with cumulative probabilities lower than p, rather than higher. It’s worth noting that while the details may differ, the core idea of nucleus sampling remains the same.

def nucleus_sampling(logits, temperature, p, beams, plot=True):
    assert p > 0
    assert p <= 1

    # Sort the probabilities in descending order and compute cumulative probabilities
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    probabilities = torch.nn.functional.softmax(sorted_logits / temperature, dim=-1)
    cumulative_probabilities = torch.cumsum(probabilities, dim=-1)

    # Create a mask for probabilities that are in the top-p
    mask = cumulative_probabilities < p

    # If there's not n index where cumulative_probabilities < p, we use the top n tokens instead
    if mask.sum() > beams:
        top_p_index_to_keep = torch.where(mask)[0][-1].detach().cpu().tolist()
    else:
        top_p_index_to_keep = beams

    # Only keep top-p indices
    indices_to_remove = sorted_indices[top_p_index_to_keep:]
    sorted_logits[indices_to_remove] = float('-inf')

    # Sample n tokens from the resulting distribution
    probabilities = torch.nn.functional.softmax(sorted_logits / temperature, dim=-1)
    next_tokens = torch.multinomial(probabilities, beams)

    # Plot distribution
    if plot:
        total_prob = torch.nn.functional.softmax(logits / temperature, dim=-1)
        plot_prob_distribution(total_prob, next_tokens, 'nucleus', top_p_index_to_keep)

    return next_tokens

# Start generating text
beam_search(input_ids, 0, bar, length, beams, 'nucleus', 1)

In this plot, you can see that the number of tokens included in the nucleus fluctuates a lot. The generated probability distributions vary considerably, leading to the selection of tokens that are not always among the most probable ones. This opens the door to the generation of unique and varied sequences. Now, let’s observe the text it generated.

sequence, max_score = get_best_sequence(graph)
print(f"Generated text: {sequence}")
Generated text: I have a dream. I'm going to

The nucleus sampling algorithm produces the sequence: “I have a dream. I’m going to”, which shows a notable enhancement in semantic coherence compared to greedy sampling.

To compare the decision paths, let’s visualize the new tree nucleus sampling generated.

# Plot graph
plot_graph(graph, length, beams, 'sequence')

As with top-k sampling, this tree is very different from the one generated with greedy sampling, displaying more variety. Both top-k and nucleus sampling offer unique advantages when generating text, enhancing diversity, and introducing creativity into the output. Your choice between the two methods (or even greedy search) will depend on the specific requirements and constraints of your project.

Conclusion

In this article, we have delved deep into various decoding methods used by LLMs, specifically GPT-2. We started with a simply greedy search and its immediate (yet often suboptimal) selection of the most probable next token. Next, we introduced the beam search technique, which considers several of the most likely tokens at each step. Although it offers more nuanced results, beam search can sometimes fall short in generating diverse and creative sequences.

To bring more variability into the process, we then moved on to top-k sampling and nucleus sampling. Top-k sampling diversifies the text generation by randomly selecting among the k most probable tokens, while nucleus sampling takes a different path by dynamically forming a nucleus of tokens based on cumulative probability. Each of these methods brings unique strengths and potential drawbacks to the table, and the specific requirements of your project will largely dictate the choice among them.

Ultimately, understanding these techniques and their trade-offs will equip you to better guide the LLMs towards producing increasingly realistic, nuanced, and compelling textual output.

If you’re interested in more technical content around LLMs, you can follow me on Twitter @maximelabonne.