import torch
def absmax_quantize(X):
# Calculate scale
= 127 / torch.max(torch.abs(X))
scale
# Quantize
= (scale * X).round()
X_quant
# Dequantize
= X_quant / scale
X_dequant
return X_quant.to(torch.int8), X_dequant
Check out the LLM Engineer’s Handbook to master the art of LLMs from fine-tuning to deployment👇
Large Language Models (LLMs) are known for their extensive computational requirements. Typically, the size of a model is calculated by multiplying the number of parameters (size) by the precision of these values (data type). However, to save memory, weights can be stored using lower-precision data types through a process known as quantization.
We distinguish two main families of weight quantization techniques in the literature:
- Post-Training Quantization (PTQ) is a straightforward technique where the weights of an already trained model are converted to lower precision without necessitating any retraining. Although easy to implement, PTQ is associated with potential performance degradation.
- Quantization-Aware Training (QAT) incorporates the weight conversion process during the pre-training or fine-tuning stage, resulting in enhanced model performance. However, QAT is computationally expensive and demands representative training data.
In this article, we focus on PTQ to reduce the precision of our parameters. To get a good intuition, we will apply both naïve and more sophisticated techniques to a toy example using a GPT-2 model.
The entire code is freely available on Google Colab and GitHub.
📚 Background on Floating Point Representation
The choice of data type dictates the quantity of computational resources required, affecting the speed and efficiency of the model. In deep learning applications, balancing precision and computational performance becomes a vital exercise as higher precision often implies greater computational demands.
Among various data types, floating point numbers are predominantly employed in deep learning due to their ability to represent a wide range of values with high precision. Typically, a floating point number uses \(n\) bits to store a numerical value. These \(n\) bits are further partitioned into three distinct components:
Sign: The sign bit indicates the positive or negative nature of the number. It uses one bit where 0 indicates a positive number and 1 signals a negative number.
Exponent: The exponent is a segment of bits that represents the power to which the base (usually 2 in binary representation) is raised. The exponent can also be positive or negative, allowing the number to represent very large or very small values.
Significand/Mantissa: The remaining bits are used to store the significand, also referred to as the mantissa. This represents the significant digits of the number. The precision of the number heavily depends on the length of the significand.
This design allows floating point numbers to cover a wide range of values with varying levels of precision. The formula used for this representation is:
\[(-1)^{\text{sign}} \times \text{base}^{\text{exponent}} \times \text{significand}\]
To understand this better, let’s delve into some of the most commonly used data types in deep learning: float32 (FP32), float16 (FP16), and bfloat16 (BF16):
- FP32 uses 32 bits to represent a number: one bit for the sign, eight for the exponent, and the remaining 23 for the significand. While it provides a high degree of precision, the downside of FP32 is its high computational and memory footprint.
- FP16 uses 16 bits to store a number: one is used for the sign, five for the exponent, and ten for the significand. Although this makes it more memory-efficient and accelerates computations, the reduced range and precision can introduce numerical instability, potentially impacting model accuracy.
- BF16 is also a 16-bit format but with one bit for the sign, eight for the exponent, and seven for the significand. BF16 expands the representable range compared to FP16, thus decreasing underflow and overflow risks. Despite a reduction in precision due to fewer significand bits, BF16 typically does not significantly impact model performance and is a useful compromise for deep learning tasks.
In ML jargon, FP32 is often termed “full precision” (4 bytes), while BF16 and FP16 are “half-precision” (2 bytes). But could we do even better and store weights using a single byte? The answer is the INT8 data type, which consists of an 8-bit representation capable of storing \(2^8 = 256\) different values. In the next section, we’ll see how to convert FP32 weights into an INT8 format.
🔰 Naïve 8-bit Quantization
In this section, we will implement two quantization techniques: a symmetric one with absolute maximum (absmax) quantization and an asymmetric one with zero-point quantization. In both cases, the goal is to map an FP32 tensor \(\mathbf{X}\) (original weights) to an INT8 tensor \(\mathbf{X}_{\text{quant}}\) (quantized weights).
With absmax quantization, the original number is divided by the absolute maximum value of the tensor and multiplied by a scaling factor (127) to map inputs into the range [-127, 127]. To retrieve the original FP16 values, the INT8 number is divided by the quantization factor, acknowledging some loss of precision due to rounding.
\[\begin{align*} \mathbf{X}_{\text{quant}} &= \text{round}\Biggl ( \frac{127}{\max|\mathbf{X}|} \cdot \mathbf{X} \Biggr ) \\ \mathbf{X}_{\text{dequant}} &= \frac{\max|\mathbf{X}|}{127} \cdot \mathbf{X}_{\text{quant}} \end{align*}\]
For instance, let’s say we have an absolution maximum value of 3.2. A weight of 0.1 would be quantized to \(\text{round}(0.1 \times \frac{127}{3.2}) = 4\). If we want to dequantize it, we would get \(4 / \frac{127}{3.2} = 0.1008\), which implies an error of 0.008. Here’s the corresponding Python implementation:
With zero-point quantization, we can consider asymmetric input distributions, which is useful when you consider the output of a ReLU function (only positive values) for example. The input values are first scaled by the total range of values (255) divided by the difference between the maximum and minimum values. This distribution is then shifted by the zero-point to map it into the range [-128, 127] (notice the extra value compared to absmax). First, we calculate the scale factor and the zero-point value:
\[\begin{align*} \text{scale} &= \frac{255}{\max(\mathbf{X}) - \min(\mathbf{X})} \\ \text{zeropoint} &= - \text{round}(\text{scale} \cdot \min(\mathbf{X})) - 128 \end{align*}\]
Then, we can use these variables to quantize or dequantize our weights:
\[\begin{align*} \mathbf{X}_{\text{quant}} &= \text{round}\bigg(\text{scale} \cdot \mathbf{X} + \text{zeropoint} \bigg) \\ \mathbf{X}_{\text{dequant}} &= \frac{\mathbf{X}_{\text{quant}} - \text{zeropoint}}{\text{scale}} \end{align*}\]
Let’s take an example: we have a maximum value of 3.2 and a minimum value of -3.0. We can calculate the scale is \(\frac{255}{3.2 + 3.0} = 41.13\) and the zero-point \(-\text{round}(41.13 \cdot -3.0) - 128 = 123 - 128 = -5\), so our previous weight of 0.1 would be quantized to \(\text{round}(41.13 \cdot 0.1 - 5) = -1\). This is very different from the previous value obtained using absmax (4 vs. -1).
The Python implementation is quite straightforward:
def zeropoint_quantize(X):
# Calculate value range (denominator)
= torch.max(X) - torch.min(X)
x_range = 1 if x_range == 0 else x_range
x_range
# Calculate scale
= 255 / x_range
scale
# Shift by zero-point
= (-scale * torch.min(X) - 128).round()
zeropoint
# Scale and round the inputs
= torch.clip((X * scale + zeropoint).round(), -128, 127)
X_quant
# Dequantize
= (X_quant - zeropoint) / scale
X_dequant
return X_quant.to(torch.int8), X_dequant
Instead of relying on complete toy examples, we can use these two functions on a real model thanks to the transformers
library.
We start by loading the model and tokenizer for GPT-2. This is a very small model we probably don’t want to quantize, but it will be good enough for this tutorial. First, we want to observe the model’s size so we can compare it later and evaluate the memory savings due to 8-bit quantization.
!pip install -q bitsandbytes>=0.39.0
!pip install -q git+https://github.com/huggingface/accelerate.git
!pip install -q git+https://github.com/huggingface/transformers.git
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
0)
torch.manual_seed(
# Set device to CPU for now
= 'cpu'
device
# Load model and tokenizer
= 'gpt2'
model_id = AutoModelForCausalLM.from_pretrained(model_id).to(device)
model = AutoTokenizer.from_pretrained(model_id)
tokenizer
# Print model size
print(f"Model size: {model.get_memory_footprint():,} bytes")
===================================BUG REPORT===================================
Welcome to bitsandbytes. For bug reports, please run
python -m bitsandbytes
and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
================================================================================
bin /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so
CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so...
Model size: 510,342,192 bytes
/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: /usr/lib64-nvidia did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths...
warn(msg)
/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/sys/fs/cgroup/memory.events /var/colab/cgroup/jupyter-children/memory.events')}
warn(msg)
/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('http'), PosixPath('//172.28.0.1'), PosixPath('8013')}
warn(msg)
/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('--logtostderr --listen_host=172.28.0.12 --target_host=172.28.0.12 --tunnel_background_save_url=https'), PosixPath('//colab.research.google.com/tun/m/cc48301118ce562b961b3c22d803539adc1e0c19/gpu-t4-s-20b5bv2xvtu9a --tunnel_background_save_delay=10s --tunnel_periodic_background_save_frequency=30m0s --enable_output_coalescing=true --output_coalescing_required=true')}
warn(msg)
/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/env/python')}
warn(msg)
/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('module'), PosixPath('//ipykernel.pylab.backend_inline')}
warn(msg)
/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: Found duplicate ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] files: {PosixPath('/usr/local/cuda/lib64/libcudart.so.11.0'), PosixPath('/usr/local/cuda/lib64/libcudart.so')}.. We'll flip a coin and try one of these, in order to fail forward.
Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
warn(msg)
The size of the GPT-2 model is approximately 487MB in FP32. The next step consists of quantizing the weights using zero-point and absmax quantization. In the following example, we apply these techniques to the first attention layer of GPT-2 to see the results.
# Extract weights of the first layer
= model.transformer.h[0].attn.c_attn.weight.data
weights print("Original weights:")
print(weights)
# Quantize layer using absmax quantization
= absmax_quantize(weights)
weights_abs_quant, _ print("\nAbsmax quantized weights:")
print(weights_abs_quant)
# Quantize layer using absmax quantization
= zeropoint_quantize(weights)
weights_zp_quant, _ print("\nZero-point quantized weights:")
print(weights_zp_quant)
Original weights:
tensor([[-0.4738, -0.2614, -0.0978, ..., 0.0513, -0.0584, 0.0250],
[ 0.0874, 0.1473, 0.2387, ..., -0.0525, -0.0113, -0.0156],
[ 0.0039, 0.0695, 0.3668, ..., 0.1143, 0.0363, -0.0318],
...,
[-0.2592, -0.0164, 0.1991, ..., 0.0095, -0.0516, 0.0319],
[ 0.1517, 0.2170, 0.1043, ..., 0.0293, -0.0429, -0.0475],
[-0.4100, -0.1924, -0.2400, ..., -0.0046, 0.0070, 0.0198]])
Absmax quantized weights:
tensor([[-21, -12, -4, ..., 2, -3, 1],
[ 4, 7, 11, ..., -2, -1, -1],
[ 0, 3, 16, ..., 5, 2, -1],
...,
[-12, -1, 9, ..., 0, -2, 1],
[ 7, 10, 5, ..., 1, -2, -2],
[-18, -9, -11, ..., 0, 0, 1]], dtype=torch.int8)
Zero-point quantized weights:
tensor([[-20, -11, -3, ..., 3, -2, 2],
[ 5, 8, 12, ..., -1, 0, 0],
[ 1, 4, 18, ..., 6, 3, 0],
...,
[-11, 0, 10, ..., 1, -1, 2],
[ 8, 11, 6, ..., 2, -1, -1],
[-18, -8, -10, ..., 1, 1, 2]], dtype=torch.int8)
The difference between the original (FP32) and quantized values (INT8) is clear, but the difference between absmax and zero-point weights is more subtle. In this case, the inputs look shifted by a value of -1. This suggests that the weight distribution in this layer is quite symmetric.
We can compare these techniques by quantizing every layer in GPT-2 (linear layers, attention layers, etc.) and create two new models: model_abs
and model_zp
. To be precise, we will actually replace the original weights with de-quantized ones. This has two benefits: it allows us to 1/ compare the distribution of our weights (same scale) and 2/ actually run the models.
Indeed, PyTorch doesn’t allow INT8 matrix multiplication by default. In a real scenario, we would dequantize them to run the model (in FP16 for example) but store them as INT8. In the next section, we will use the bitsandbytes
library to solve this issue.
import numpy as np
from copy import deepcopy
# Store original weights
= [param.data.clone() for param in model.parameters()]
weights
# Create model to quantize
= deepcopy(model)
model_abs
# Quantize all model weights
= []
weights_abs for param in model_abs.parameters():
= absmax_quantize(param.data)
_, dequantized = dequantized
param.data
weights_abs.append(dequantized)
# Create model to quantize
= deepcopy(model)
model_zp
# Quantize all model weights
= []
weights_zp for param in model_zp.parameters():
= zeropoint_quantize(param.data)
_, dequantized = dequantized
param.data weights_zp.append(dequantized)
Now that our models have been quantized, we want to check the impact of this process. Intuitively, we want to make sure that the quantized weights are close to the original ones. A visual way to check it is to plot the distribution of the dequantized and original weights. If the quantization is lossy, it would drastically change the weight distribution.
The following figure shows this comparison, where the blue histogram represents the original (FP32) weights, and the red one represents the dequantized (from INT8) weights. Note that we only display this plot between -2 and 2 because of outliers with very high absolute values (more on that later).
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
# Flatten weight tensors
= np.concatenate([t.cpu().numpy().flatten() for t in weights])
weights = np.concatenate([t.cpu().numpy().flatten() for t in weights_abs])
weights_abs = np.concatenate([t.cpu().numpy().flatten() for t in weights_zp])
weights_zp
# Set background style
'ggplot')
plt.style.use(
# Create figure and axes
= plt.subplots(2, figsize=(10,10), dpi=300, sharex=True)
fig, axs
# Plot the histograms for original and zero-point weights
0].hist(weights, bins=150, alpha=0.5, label='Original weights', color='blue', range=(-2, 2))
axs[0].hist(weights_abs, bins=150, alpha=0.5, label='Absmax weights', color='red', range=(-2, 2))
axs[
# Plot the histograms for original and absmax weights
1].hist(weights, bins=150, alpha=0.5, label='Original weights', color='blue', range=(-2, 2))
axs[1].hist(weights_zp, bins=150, alpha=0.5, label='Zero-point weights', color='green', range=(-2, 2))
axs[
# Add grid
for ax in axs:
True, linestyle='--', alpha=0.6)
ax.grid(
# Add legend
0].legend()
axs[1].legend()
axs[
# Add title and labels
0].set_title('Comparison of Original and Absmax Quantized Weights', fontsize=16)
axs[1].set_title('Comparison of Original and Zeropoint Quantized Weights', fontsize=16)
axs[
for ax in axs:
'Weights', fontsize=14)
ax.set_xlabel('Count', fontsize=14)
ax.set_ylabel(# Make y-ticks more human readable
ax.yaxis.set_major_formatter(ticker.EngFormatter())
# Improve font
'font', size=12)
plt.rc(
plt.tight_layout() plt.show()
Both plots are quite similar, with a surprising spike around 0. This spike shows that our quantization is quite lossy since reversing the process doesn’t output the original values. This is particularly true for the absmax model, which displays both a lower valley and a higher spike around 0.
Let’s compare the performance of the original and quantized models. For this purpose, we define a generate_text()
function to generate 50 tokens with top-k sampling.
def generate_text(model, input_text, max_length=50):
= tokenizer.encode(input_text, return_tensors='pt').to(device)
input_ids = model.generate(inputs=input_ids,
output =max_length,
max_length=True,
do_sample=30,
top_k=tokenizer.eos_token_id,
pad_token_id=input_ids.new_ones(input_ids.shape))
attention_maskreturn tokenizer.decode(output[0], skip_special_tokens=True)
# Generate text with original and quantized models
= generate_text(model, "I have a dream")
original_text = generate_text(model_abs, "I have a dream")
absmax_text = generate_text(model_zp, "I have a dream")
zp_text
print(f"Original model:\n{original_text}")
print("-" * 50)
print(f"Absmax model:\n{absmax_text}")
print("-" * 50)
print(f"Zeropoint model:\n{zp_text}")
Original model:
I have a dream, and it is a dream I believe I would get to live in my future. I love my mother, and there was that one time I had been told that my family wasn't even that strong. And then I got the
--------------------------------------------------
Absmax model:
I have a dream to find out the origin of her hair. She loves it. But there's no way you could be honest about how her hair is made. She must be crazy.
We found a photo of the hairstyle posted on
--------------------------------------------------
Zeropoint model:
I have a dream of creating two full-time jobs in America—one for people with mental health issues, and one for people who do not suffer from mental illness—or at least have an employment and family history of substance abuse, to work part
Instead of trying to see if one output makes more sense than the others, we can quantify it by calculating the perplexity of each output. This is a common metric used to evaluate language models, which measures the uncertainty of a model in predicting the next token in a sequence. In this comparison, we make the common assumption that the lower the score, the better the model is. In practice, a sentence with a high perplexity could also be correct.
We implement it using a minimal function since it doesn’t need to consider details like the length of the context window since our sentences are short.
def calculate_perplexity(model, text):
# Encode the text
= tokenizer(text, return_tensors='pt').to(device)
encodings
# Define input_ids and target_ids
= encodings.input_ids
input_ids = input_ids.clone()
target_ids
with torch.no_grad():
= model(input_ids, labels=target_ids)
outputs
# Loss calculation
= outputs.loss
neg_log_likelihood
# Perplexity calculation
= torch.exp(neg_log_likelihood)
ppl
return ppl
= calculate_perplexity(model, original_text)
ppl = calculate_perplexity(model_abs, absmax_text)
ppl_abs = calculate_perplexity(model_zp, absmax_text)
ppl_zp
print(f"Original perplexity: {ppl.item():.2f}")
print(f"Absmax perplexity: {ppl_abs.item():.2f}")
print(f"Zeropoint perplexity: {ppl_zp.item():.2f}")
Original perplexity: 15.53
Absmax perplexity: 17.92
Zeropoint perplexity: 17.97
We see that the perplexity of the original model is slightly lower than the two others. A single experiment is not very reliable, but we could repeat this process multiple times to see the difference between each model. In theory, zero-point quantization should be slightly better than absmax, but is also more costly to compute.
In this example, we applied quantization techniques to entire layers (per-tensor basis). However, we could apply it at different granularity levels: from the entire model to individual values. Quantizing the entire model in one pass would seriously degrade the performance, while quantizing individual values would create a big overhead. In practice, we often prefer the vector-wise quantization, which considers the variability of values in rows and columns inside of the same tensor.
However, even vector-wise quantization doesn’t solve the problem of outlier features. Outlier features are extreme values (negative or positive) that appear in all transformer layers when the model reach a certain scale (>6.7B parameters). This is an issue since a single outlier can reduce the precision for all other values. But discarding these outlier features is not an option since it would greatly degrade the model’s performance.
🔢 8-bit Quantization with LLM.int8()
Introduced by Dettmers et al. (2022), LLM.int8() is a solution to the outlier problem. It relies on a vector-wise (absmax) quantization scheme and introduces mixed-precision quantization. This means that outlier features are processed in a FP16 format to retain their precision, while the other values are processed in an INT8 format. As outliers represent about 0.1% of values, this effectively reduces the memory footprint of the LLM by almost 2x.
LLM.int8() works by conducting matrix multiplication computation in three key steps:
- Extract columns from the input hidden states \(\mathbf{X}\) containing outlier features using a custom threshold.
- Perform the matrix multiplication of the outliers using FP16 and the non-outliers using INT8 with vector-wise quantization (row-wise for the hidden state \(\mathbf{X}\) and column-wise for the weight matrix \(\mathbf{W}\)).
- Dequantize the non-outlier results (INT8 to FP16) and add them to the outlier results to get the full result in FP16.
This approach is necessary because 8-bit precision is limited and can lead to substantial errors when quantizing a vector with large values. These errors also tend to amplify as they propagate through multiple layers.
We can easily use this technique thanks to the integration of the bitsandbytes
library into the Hugging Face ecosystem. We just need to specify load_in_8bit=True
when loading the model (it also requires a GPU).
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
= AutoModelForCausalLM.from_pretrained(model_id,
model_int8 ='auto',
device_map=True,
load_in_8bit
)print(f"Model size: {model_int8.get_memory_footprint():,} bytes")
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Model size: 176,527,896 bytes
With this extra line of code, the model is now almost three times smaller (168MB vs. 487MB). We can even compare the distribution of the original and quantized weights as we did earlier:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
# Flatten weight tensors
= [param.data.clone() for param in model_int8.parameters()]
weights_int8 = np.concatenate([t.cpu().numpy().flatten() for t in weights_int8])
weights_int8
# Set background style
'ggplot')
plt.style.use(
# Create figure and axis
= plt.subplots(figsize=(10,5), dpi=300)
fig, ax
# Plot the histograms
=150, alpha=0.5, label='Original weights',
ax.hist(weights, bins='blue', range=(-2, 2))
color=150, alpha=0.5, label='LLM.int8() weights',
ax.hist(weights_int8, bins='red', range=(-2, 2))
color
# Add grid
True, linestyle='--', alpha=0.6)
ax.grid(
# Add legend
ax.legend()
# Add title and labels
'Comparison of Original and Dequantized Weights', fontsize=16)
ax.set_title('Weights', fontsize=14)
ax.set_xlabel('Count', fontsize=14)
ax.set_ylabel(
plt.gca().yaxis.set_major_formatter(ticker.EngFormatter())
# Improve font
'font', size=12)
plt.rc(
plt.tight_layout() plt.show()
In this case, we see spikes around -2, -1, 0, 1, 2, etc. These values correspond to the parameters stored in the INT8 format (non-outliers). You can verify it by printing the model’s weights using model_int8.parameters()
.
We can also generate text with this quantized model and compare it to the original model.
# Generate text with quantized model
= generate_text(model_int8, "I have a dream")
text_int8
print(f"Original model:\n{original_text}")
print("-" * 50)
print(f"LLM.int8() model:\n{text_int8}")
Original model:
I have a dream, and it is a dream I believe I would get to live in my future. I love my mother, and there was that one time I had been told that my family wasn't even that strong. And then I got the
--------------------------------------------------
LLM.int8() model:
I have a dream. I don't know what will come of it, but I am going to have to look for something that will be right. I haven't thought about it for a long time, but I have to try to get that thing
Once again, it is difficult to judge what is the best output, but we can rely on the perplexity metric to give us an (approximate) answer.
print(f"Perplexity (original): {ppl.item():.2f}")
= calculate_perplexity(model_int8, text_int8)
ppl print(f"Perplexity (LLM.int8()): {ppl.item():.2f}")
Perplexity (original): 15.53
Perplexity (LLM.int8()): 7.93
In this case, the perplexity of the quantized model is twice as low as the original one. In general, this is not the case, but it shows that this quantization technique is very competitive. In fact, the authors of LLM.int8() show that the performance degradation is so low it’s negligible (<1%). However, it has an additional cost in terms of computation: LLM.int8() is roughly about 20% slower for large models.
Conclusion
This article provided an overview of the most popular weight quantization techniques. We started by gaining an understanding of floating point representation, before introducing two techniques for 8-bit quantization: absmax and zero-point quantization. However, their limitations, particularly when it comes to handling outliers, led to LLM.int8(), a technique that also preserves the model’s performance. This approach underlines the progress being made in the field of weight quantization, revealing the importance of properly addressing outliers.
Looking forward, our next article will explore the GPTQ weight quantization technique in depth. This technique, introduced by Frantar et al., only utilizes 4 bits and represents a significant advancement in the field of weight quantization. We will provide a comprehensive guide on how to implement GPTQ using the AutoGPTQ library.
If you’re interested in more technical content around LLMs, follow me on Twitter @maximelabonne.
References
- T. Dettmers, M. Lewis, Y. Belkada, and L. Zettlemoyer, LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale. 2022.
- Y. Beldaka, and T. Dettmers, A Gentle Introduction to 8-bit Matrix Multiplication, Hugging Face Blog (2022).
- A. Gholami, S. Kim, Z. Dong, Z. Yao, M. W. Mahoney, and K. Keutzer, A Survey of Quantization Methods for Efficient Neural Network Inference. 2021.
- H. Wu, P. Judd, X. Zhang, M. Isaev, and P. Micikevicius, Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation. 2020.
- Lilian Weng, Large Transformer Model Inference Optimization, Lil’Log (2023).
- Kamil Czarnogorski, Local Large Language Models, Int8 (2023).