4-bit LLM Quantization with GPTQ

Quantize your own open-source LLMs to run them on consumer hardware

Large Language Models
Author

Maxime Lbonne

Published

July 30, 2023

Find more a lot more architectures and applications using graph neural networks in my book, Hands-On Graph Neural Networks 👇 Hands-On Graph Neural Networks Using Python

Recent advancements in weight quantization allow us to run massive large language models on consumer hardware, like a LLaMA-30B model on an RTX 3090 GPU. This is possible thanks to novel 4-bit quantization techniques with minimal performance degradation, like GPTQ, GGML, and NF4.

In the previous article, we introduced naïve 8-bit quantization techniques and the excellent LLM.int8(). In this article, we will explore the popular GPTQ algorithm to understand how it works and implement it using the AutoGPTQ library.

You can find the code on Google Colab and GitHub.

🧠 Optimal Brain Quantization

Let’s start by introducing the problem we’re trying to solve. For every layer \ell in the network, we want to find a quantized version \widehat{\mathbf{W}}_\ell of the original weights \mathbf{W}_\ell. This is called the layer-wise compression problem. More specifically, to minimize performance degradation, we want the outputs (\mathbf{\widehat{W}_\ell X_\ell}) of these new weights to be as close as possible to the original ones (\mathbf{W_\ell X_\ell}). In other words, we want to find:

\arg \min_{\mathbf{\widehat{W}}_\ell} \parallel\mathbf{W_\ell X_\ell} - \mathbf{\widehat{W}_\ell X_\ell}\parallel_2^2.

Different approaches have been proposed to solve this problem, but we’re interested in the Optimal Brain Quantizer (OBQ) framework here.

This method is inspired by a pruning technique to carefully remove weights from a fully trained dense neural network (Optimal Brain Surgeon). It uses an approximation technique and provides explicit formulas for the best single weight w_q to remove and optimal update \delta_F to adjust the set of remaining non-quantized weights F to make up for the removal:

\begin{align*} w_q &= \arg\min_{w_q} \frac{(\text{quant}(w_q) - w_q)^2}{[\mathbf{H}_F^{-1}]_{qq}},\\ \quad \delta_F &= -\frac{w_q - \text{quant}(w_q)}{[\mathbf{H}_F^{-1}]_{qq}} \cdot (\mathbf{H}_F^{-1})_{:,q}. \end{align*}

where \text{quant}(w) is the weight rounding given by the quantization and \mathbf{H}_F is the Hessian.

Using OBQ, we can quantize the easiest weight first and then adjust all remaining non-quantized weights to compensate for this precision loss. Then we pick the next weight to quantize, and so on.

A potential issue with this approach is when there are outlier weights, which can result in high quantization error. Usually, these outliers would be quantized last, when there are few non-quantized weights left that could be adjusted to compensate for the large error. This effect can worsen when some weights are pushed further outside the grid by intermediate updates. A simple heuristic is applied to prevent this: outliers are quantized as soon as they appear.

This process could be computationally heavy, especially for LLMs. To deal with this, the OBQ method uses a trick that avoids redoing the entire computation each time a weight is simplified. After quantizing a weight, it adjusts the matrix used in calculations (the Hessian) by removing the row and column associated with that weight (using Gaussian elimination):

\mathbf{H}^{-1}_{-q} = \left( \mathbf{H}^{-1} - \frac{1}{[\mathbf{H}^{-1}]_{qq}} \mathbf{H}^{-1}_{:,q} \mathbf{H}^{-1}_{q,:} \right)_{-p}.

The method also employs vectorization to process multiple rows of the weight matrix at once. Despite its efficiency, the OBQ’s computation time increases significantly as the size of the weight matrix increases. This cubic growth makes it difficult to use OBQ on very large models with billions of parameters.

🧮 The GPTQ Algorithm

Introduced by Frantar et al. (2023), the GPTQ algorithm takes inspiration from the OBQ method, but with significant improvements to scale it for (very) large language models.

Step 1: Arbitrary Order Insight

The OBQ method selects weights (parameters in a model) for quantization in a certain order, determined by which will add the least additional error. However, GPTQ observes that for large models, quantizing weights in any fixed order can perform just as well. This is because even though some weights might introduce more error individually, they are quantized later in the process when there are few other weights left that could increase the error. So the order doesn’t matter as much as we thought.

Based on this insight, GPTQ aims to quantize all weights in the same order for all rows of a matrix. This makes the process faster because certain computations have to be done only once for each column, rather than once for each weight.

Step 2: Lazy Batch-Updates

This scheme won’t be fast because it requires updating a huge matrix with very few computations for each entry. This type of operation can’t utilize the full compute capabilities of GPUs and will be slowed down by memory limitations (memory throughput bottleneck).

To resolve this, GPTQ introduces “lazy batch” updates. It turns out that the final rounding decisions for a given column are only affected by updates performed on that column, not on later columns. Therefore, GPTQ can apply the algorithm to a batch of columns at a time (like 128 columns), updating only those columns and a corresponding block of the matrix. After a block is fully processed, the algorithm performs global updates on the entire matrix.

\begin{align*} \mathbf{\delta}_F &= -(\mathbf{w}_Q - \text{quant}(\mathbf{w}_Q))([\mathbf{H}_F^{-1}]_{QQ})^{-1} (\mathbf{H}_F^{-1})_{:,Q}, \\ \mathbf{H}^{-1}_{-Q} &= \left(\mathbf{H}^{-1} - \mathbf{H}^{-1}_{:,Q}([\mathbf{H}_F^{-1}]_{QQ})^{-1}\mathbf{H}^{-1}_{Q,:}\right)_{-Q}. \end{align*}

Step 3: Cholesky Reformulation

However, there’s one more issue to address. When the algorithm scales up to very large models, numerical inaccuracies can become a problem. Specifically, repeated applications of a certain operation can accumulate numerical errors.

To tackle this, GPTQ uses a Cholesky decomposition, a numerically stable method for solving certain mathematical problems. It involves precomputing some required information from the matrix using the Cholesky method. This approach, combined with a slight “dampening” (adding a small constant to diagonal elements of the matrix), helps the algorithm to avoid numerical issues.

The full algorithm can be summarized in a few steps:

  1. The GPTQ algorithm begins with a Cholesky decomposition of the Hessian inverse (a matrix that helps decide how to adjust the weights)
  2. It then runs in loops, handling batches of columns at a time.
  3. For each column in a batch, it quantizes the weights, calculates the error, and updates the weights in the block accordingly.
  4. After processing the batch, it updates all remaining weights based on the block’s errors.

The GPTQ algorithm was tested on various language generation tasks. It was compared with other quantization methods, like rounding all weights to the nearest quantized value (RTN). GPTQ was used with the BLOOM (176B parameters) and OPT (175B parameters) model families, and models were quantized using a single NVIDIA A100 GPU.

💻 Quantize an LLM with AutoGPTQ

GPTQ has been very popular to create models in 4-bit precision that can efficiently run on GPUs. You can find many examples on the Hugging Face Hub, especially from TheBloke. If you’re looking for an approach that is more CPU-friendly, GGML is currently your best option. Finally, the transformers library with bitsandbytes allows you to quantize a model when it’s loaded using the load_in_4bit=true argument, which requires downloading full models and storing them in your RAM.

Let’s implement the GPTQ algorithm using the AutoGPTQ library and quantize a GPT-2 model. This requires a GPU, but a free T4 on Google Colab will do. We start by loading the libraries and defining the model we want to quantize (in this case, GPT-2).

!BUILD_CUDA_EXT=0 pip install -q auto-gptq transformers
import random

from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from datasets import load_dataset
import torch
from transformers import AutoTokenizer


# Define base model and output directory
model_id = "gpt2"
out_dir = model_id + "-GPTQ"

We now want to load the model and the tokenizer. The tokenizer is loaded using the classic AutoTokenizer class from the transformers library. On the other hand, we need to pass a specific configuration (BaseQuantizeConfig) to load the model.

In this configuration, we can specify the number of bits to quantize (here, bits=4) and the group size (size of the lazy batch). Note that this group size is optional: we could also use one set of parameters for the entire weight matrix. In practice, these groups generally improve the quality of the quantization at a very low cost (especially with group_size=1024). The damp_percent value is here to help the Cholesky reformulation and should not be changed.

Finally, the desc_act (also called act order) is a tricky parameter. It allows you to process rows based on decreasing activation, meaning the most important or impactful rows (determined by sampled inputs and outputs) are processed first. This method aims to place most of the quantization error (inevitably introduced during quantization) on less significant weights. This approach improves the overall accuracy of the quantization process by ensuring the most significant weights are processed with greater precision. However, when used alongside group size, desc_act can lead to performance slowdowns due to the need to frequently reload quantization parameters. For this reason, we won’t use it here (it will probably be fixed in the future, however).

# Load quantize config, model and tokenizer
quantize_config = BaseQuantizeConfig(
    bits=4,
    group_size=128,
    damp_percent=0.01,
    desc_act=False,
)
model = AutoGPTQForCausalLM.from_pretrained(model_id, quantize_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)

The quantization process relies heavily on samples to evaluate and enhance the quality of the quantization. They provide a means of comparison between the outputs produced by the origina and the newly quantized model. The larger the number of samples provided, the greater the potential for more accurate and effective comparisons, leading to improved quantization quality.

In the context of this article, we utilize the C4 (Colossal Clean Crawled Corpus) dataset to generate our samples. The C4 dataset is a large-scale, multilingual collection of web text gathered from the Common Crawl project. This expansive dataset has been cleaned and prepared specifically for training large-scale language models, making it a great resource for tasks such as this. The WikiText dataset is another popular option.

In the following code block, we load 1024 samples from the C4 dataset, tokenize them, and format them.

# Load data and tokenize examples
n_samples = 1024
data = load_dataset("allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz", split=f"train[:{n_samples*5}]")
tokenized_data = tokenizer("\n\n".join(data['text']), return_tensors='pt')

# Format tokenized examples
examples_ids = []
for _ in range(n_samples):
    i = random.randint(0, tokenized_data.input_ids.shape[1] - tokenizer.model_max_length - 1)
    j = i + tokenizer.model_max_length
    input_ids = tokenized_data.input_ids[:, i:j]
    attention_mask = torch.ones_like(input_ids)
    examples_ids.append({'input_ids': input_ids, 'attention_mask': attention_mask})
WARNING:datasets.builder:Found cached dataset json (/root/.cache/huggingface/datasets/allenai___json/allenai--c4-6e494e9c0ee1404e/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)
Token indices sequence length is longer than the specified maximum sequence length for this model (2441065 > 1024). Running this sequence through the model will result in indexing errors

Now that dataset is ready, we can start the quantization process with a batch size of 1. Optionally, we also use OpenAI Triton, a CUDA alternative, to communicate with the GPU. Once this is done, we save the tokenizer and the model in a safetensors format.

%%time

# Quantize with GPTQ
model.quantize(
    examples_ids,
    batch_size=1,
    use_triton=True,
)

# Save model and tokenizer
model.save_quantized(out_dir, use_safetensors=True)
tokenizer.save_pretrained(out_dir)
CPU times: user 4min 35s, sys: 3.49 s, total: 4min 39s
Wall time: 5min 8s
('gpt2-GPTQ/tokenizer_config.json',
 'gpt2-GPTQ/special_tokens_map.json',
 'gpt2-GPTQ/vocab.json',
 'gpt2-GPTQ/merges.txt',
 'gpt2-GPTQ/added_tokens.json',
 'gpt2-GPTQ/tokenizer.json')

As per usual, the model and tokenizer can then be loaded from the output directory using the AutoGPTQForCausalLM and AutoTokenizer classes.

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Reload model and tokenizer
model = AutoGPTQForCausalLM.from_quantized(
    out_dir,
    device=device,
    use_triton=True,
    use_safetensors=True,
)
tokenizer = AutoTokenizer.from_pretrained(out_dir)
WARNING:accelerate.utils.modeling:The safetensors archive passed at gpt2-GPTQ/gptq_model-4bit-128g.safetensors does not contain metadata. Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata.
WARNING:auto_gptq.modeling._base:GPT2GPTQForCausalLM hasn't fused attention module yet, will skip inject fused attention.
WARNING:auto_gptq.modeling._base:GPT2GPTQForCausalLM hasn't fused mlp module yet, will skip inject fused mlp.

Let’s check that the model is working correctly. The AutoGPTQ model (mostly) works as a normal transformers model, which makes it compatible with inference pipelines, as shown in the following example:

from transformers import pipeline

generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
result = generator("I have a dream", do_sample=True, max_length=50)[0]['generated_text']
print(result)
The model 'GPT2GPTQForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
I have a dream," she told CNN last week. "I have this dream of helping my mother find her own. But, to tell that for the first time, now that I'm seeing my mother now, just knowing how wonderful it is that

We managed to get a convincing completion from our quantized GPT-2 model. A more in-depth evaluation would require measuring the perplexity of the quantized model versus the original one. However, we will leave it out of the scope of this article.

Conclusion

In this article, we introduced the GPTQ algorithm, a state-of-the-art quantization technique to run LLMs on consumer-grade hardware. We showed how it addresses the layer-wise compression problem, based on an improved OBS technique with arbitrary order insight, lazy batch updates, and Cholesky reformulation. This novel approach significantly reduces memory and computation requirements, making LLMs accessible to a broader audience.

In addition, we quantized our own LLM model on a free T4 GPU and ran it to generate text. You can push your own version of a GPTQ 4-bit quantized model on the Hugging Face Hub. As mentioned in the introduction, GPTQ is not the only 4-bit quantization algorithm: GGML and NF4 are excellent alternatives with slightly different scopes. I encourage you to learn more about them and give them a shot!

If you’re interested in more technical content around LLMs, follow me on Twitter @maximelabonne.

References