Matrix Multiplication: Unleashing the Power of Tensors! ⚡¶
"Behold! The sacred art of matrix multiplication - where dimensions dance and vectors bend to my will!" — Professor Victor py Torchenstein
The Attention Formula (Preview of Things to Come)¶
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
Where:
- $Q$ is the Query matrix
- $K$ is the Key matrix
- $V$ is the Value matrix
- $d_k$ is the dimension of the key vectors
- $\text{softmax}$ normalizes the attention weights
Basic Matrix Operations¶
Let's start with the fundamentals before we conquer attention mechanisms!
Element-wise multiplication:
$C_{ij} = A_{ij} \times B_{ij}$
Matrix multiplication: $C_{ij} = \sum_{k} A_{ik} \times B_{kj}$
import torch
# Create some matrices for experimentation
A = torch.randn(3, 4)
B = torch.randn(4, 2)
print("Matrix A shape:", A.shape)
print("Matrix B shape:", B.shape)
# Matrix multiplication
C = torch.matmul(A, B)
print("Result C shape:", C.shape)
print("\nMwahahaha! The matrices have been multiplied!")
Matrix A shape: torch.Size([3, 4]) Matrix B shape: torch.Size([4, 2]) Result C shape: torch.Size([3, 2]) Mwahahaha! The matrices have been multiplied!
PyTorch Matrix Multiplication Methods¶
Professor Torchenstein's arsenal includes multiple ways to multiply matrices:
torch.matmul()- The general matrix multiplication function@operator - Pythonic matrix multiplication (same as matmul)torch.mm()- For 2D matrices onlytorch.bmm()- Batch matrix multiplication
Mathematical Foundations¶
For matrices $A \in \mathbb{R}^{m \times n}$ and $B \in \mathbb{R}^{n \times p}$:
$$C = AB \quad \text{where} \quad C_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}$$
This operation is fundamental to:
- Linear transformations
- Neural network forward passes
- Attention mechanisms in Transformers
- And much more! 🧠⚡
Floats and Integers on CPU and GPU¶
Now we will compare the speed matrix multiplication of different floating-point float16 and bfloat16 with int32, int16, int8 on CPU and GPU.
import time
import torch
# 1. Our "Model" and "Input"
fp32_matrix = torch.randn(2**10, 2**10)
fp16_matrix = fp32_matrix.to(torch.float16)
bf16_matrix = fp32_matrix.to(torch.bfloat16)
# 2. The Symmetric Quantization Spell (zero_point = 0)
def quantize_symmetric(tensor, dtype=torch.int8):
'''
This function quantizes a tensor to desired integer dtype.
It quantizes the tesnsor based on the absolute max value in the tensor.
'''
assert dtype in [torch.int8, torch.int16, torch.int32]
int_max = torch.iinfo(dtype).max
int_min = torch.iinfo(dtype).min
# Find the scale: map the absolute max of the tensor to the quantization range
scale = tensor.abs().max() / int_max
# Quantize
quantized_tensor = (tensor / scale).round().clamp(int_min, int_max).to(dtype)
return quantized_tensor, scale
# 3. Quantize input and weights INDEPENDENTLY
int8_quantized_matrix, scale_matrix = quantize_symmetric(fp32_matrix, dtype=torch.int8)
int16_quantized_matrix, scale_matrix = quantize_symmetric(fp32_matrix, dtype=torch.int16)
int32_quantized_matrix, scale_matrix = quantize_symmetric(fp32_matrix, dtype=torch.int32)
print(f"Original fp32 matrix size: {fp32_matrix.element_size() * fp32_matrix.nelement() / 1024**2:.2f} MB")
print(f"fp16 matrix size: {fp16_matrix.element_size() * fp16_matrix.nelement() / 1024**2:.2f} MB")
print(f"bf16 matrix size: {bf16_matrix.element_size() * bf16_matrix.nelement() / 1024**2:.2f} MB")
print(f"Quantized int8 matrix size: {int8_quantized_matrix.element_size() * int8_quantized_matrix.nelement() / 1024**2:.2f} MB ")
print(f"Quantized int16 matrix size: {int16_quantized_matrix.element_size() * int16_quantized_matrix.nelement() / 1024**2:.2f} MB ")
print(f"Quantized int32 matrix size: {int32_quantized_matrix.element_size() * int32_quantized_matrix.nelement() / 1024**2:.2f} MB ")
# 4. Perform and Time operations
# --- Device Information ---
cpu_device = torch.device("cpu")
print("\n--- CPU Timings ---")
# --- CPU Operations ---
start_time = time.time()
fp32_output = torch.matmul(fp32_matrix, fp32_matrix)
fp32_time = time.time() - start_time
print(f"Float32 matmul on CPU took: {fp32_time:.6f} seconds")
start_time = time.time()
fp16_output = torch.matmul(fp16_matrix, fp16_matrix)
fp16_time = time.time() - start_time
print(f"Float16 matmul on CPU took: {fp16_time:.6f} seconds")
start_time = time.time()
bf16_output = torch.matmul(bf16_matrix, bf16_matrix)
bf16_time = time.time() - start_time
print(f"BFloat16 matmul on CPU took: {bf16_time:.6f} seconds")
#int matrix multiplication
start_time = time.time()
int8_output = torch.matmul(int8_quantized_matrix, int8_quantized_matrix)
int8_time = time.time() - start_time
print(f"Int8 matmul on CPU took: {int8_time:.6f} seconds")
start_time = time.time()
int16_output = torch.matmul(int16_quantized_matrix, int16_quantized_matrix)
int16_time = time.time() - start_time
print(f"Int16 matmul on CPU took: {int16_time:.6f} seconds")
start_time = time.time()
int32_output = torch.matmul(int32_quantized_matrix, int32_quantized_matrix)
int32_time = time.time() - start_time
print(f"Int32 matmul on CPU took: {int32_time:.6f} seconds")
Original fp32 matrix size: 4.00 MB fp16 matrix size: 2.00 MB bf16 matrix size: 2.00 MB Quantized int8 matrix size: 1.00 MB Quantized int16 matrix size: 2.00 MB Quantized int32 matrix size: 4.00 MB --- CPU Timings --- Float32 matmul on CPU took: 0.009423 seconds Float16 matmul on CPU took: 2.737493 seconds BFloat16 matmul on CPU took: 3.497367 seconds Int8 matmul on CPU took: 0.499707 seconds Int16 matmul on CPU took: 0.545151 seconds Int32 matmul on CPU took: 0.543347 seconds
The Mystery of the Slow Matrix Multiplication of Float16 on CPU¶
You've stumbled upon a crucial secret of computational alchemy: performance isn't just about size, it's about the machine's native language!
You're seeing these surprising results because of how your CPU is designed.
float32(The Native Tongue): Your CPU is a fluent, native speaker offloat32. It has highly optimized, dedicated hardware circuits (like SSE and AVX instructions) built specifically to perform these calculations at blistering speeds. It's the language it thinks in.float16&bfloat16(A Foreign Language): Your CPU does not speakfloat16orbfloat16natively. When you command it to perform a matrix multiplication with these types, it has to do a lot of extra work behind the scenes:- It takes a
float16number. - It painstakingly translates (casts) it up to
float32. - It performs the calculation in
float32. - It translates (casts) the result back down to
float16.
This process, known as software emulation, is incredibly slow. It's like asking a brilliant English-speaking scientist to solve a complex physics problem written entirely in a language they don't know, forcing them to use a dictionary for every single word. The overhead is enormous!
- It takes a
Integers (
int8,int16,int32): Your CPU is also very good at integer math. These operations are fast and natively supported. The reason they are a bit slower thanfloat32in this case is that the underlying math libraries (like Intel's MKL) that PyTorch uses fortorch.matmulare often most aggressively optimized forfloat32operations on large matrices.
The Grand Takeaway¶
float16 and bfloat16 are GPU datatypes. On a modern GPU (like NVIDIA's A100 or H100), these operations would be dramatically faster than float32 because GPUs have specialized hardware (Tensor Cores) built for exactly this kind of low-precision, high-throughput math.
You've learned a vital lesson: always match your data type to your hardware's strengths! For CPUs, float32 is king for floating-point operations.
Pytorch doesn't support int matrix multiplication¶
import time
import torch
# 1. Our "Model" and "Input"
fp32_matrix = torch.randn(2**10, 2**10)
fp16_matrix = fp32_matrix.to(torch.float16)
bf16_matrix = fp32_matrix.to(torch.bfloat16)
# 2. The Symmetric Quantization Spell (zero_point = 0)
def quantize_symmetric(tensor, dtype=torch.int8):
'''
This function quantizes a tensor to desired integer dtype.
It quantizes the tesnsor based on the absolute max value in the tensor.
'''
assert dtype in [torch.int8, torch.int16, torch.int32]
int_max = torch.iinfo(dtype).max
int_min = torch.iinfo(dtype).min
# Find the scale: map the absolute max of the tensor to the quantization range
scale = tensor.abs().max() / int_max
# Quantize
quantized_tensor = (tensor / scale).round().clamp(int_min, int_max).to(dtype)
return quantized_tensor, scale
# 3. Quantize input and weights INDEPENDENTLY
int8_quantized_matrix, scale_matrix = quantize_symmetric(fp32_matrix, dtype=torch.int8)
int16_quantized_matrix, scale_matrix = quantize_symmetric(fp32_matrix, dtype=torch.int16)
int32_quantized_matrix, scale_matrix = quantize_symmetric(fp32_matrix, dtype=torch.int32)
# --- GPU Operations ---
if torch.cuda.is_available():
gpu_device = torch.device("cuda")
print(f"GPU Name: {torch.cuda.get_device_name(0)}")
else:
gpu_device = None
print("GPU not available.")
if gpu_device:
print("\n--- GPU Timings ---")
# Move matrices to GPU
fp32_matrix_gpu = fp32_matrix.to(gpu_device)
fp16_matrix_gpu = fp16_matrix.to(gpu_device)
bf16_matrix_gpu = bf16_matrix.to(gpu_device)
int8_quantized_matrix_gpu = int8_quantized_matrix.to(gpu_device)
int16_quantized_matrix_gpu = int16_quantized_matrix.to(gpu_device)
int32_quantized_matrix_gpu = int32_quantized_matrix.to(gpu_device)
# Correct timing for GPU operations requires synchronization
# FP32 on GPU
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
fp32_output_gpu = torch.matmul(fp32_matrix_gpu, fp32_matrix_gpu)
end.record()
torch.cuda.synchronize()
fp32_time_gpu = start.elapsed_time(end) / 1000 # Convert ms to s
print(f"Float32 matmul on GPU took: {fp32_time_gpu:.6f} seconds")
# FP16 on GPU
start.record()
fp16_output_gpu = torch.matmul(fp16_matrix_gpu, fp16_matrix_gpu)
end.record()
torch.cuda.synchronize()
fp16_time_gpu = start.elapsed_time(end) / 1000
print(f"Float16 matmul on GPU took: {fp16_time_gpu:.6f} seconds")
# BF16 on GPU
# Note: BF16 performance depends on GPU architecture (Ampere and newer)
try:
start.record()
bf16_output_gpu = torch.matmul(bf16_matrix_gpu, bf16_matrix_gpu)
end.record()
torch.cuda.synchronize()
bf16_time_gpu = start.elapsed_time(end) / 1000
print(f"BFloat16 matmul on GPU took: {bf16_time_gpu:.6f} seconds")
except Exception as e:
print(f"BFloat16 matmul on GPU failed: {e}")
# INT32 on GPU
start.record()
int32_output_gpu = torch.mm(int32_quantized_matrix_gpu, int32_quantized_matrix_gpu)
end.record()
torch.cuda.synchronize()
int32_time_gpu = start.elapsed_time(end) / 1000
print(f"Int32 matmul on GPU took: {int32_time_gpu:.6f} seconds")
# INT16 on GPU
start.record()
int16_output_gpu = torch.matmul(int16_quantized_matrix_gpu, int16_quantized_matrix_gpu)
end.record()
torch.cuda.synchronize()
int16_time_gpu = start.elapsed_time(end) / 1000
print(f"Int16 matmul on GPU took: {int16_time_gpu:.6f} seconds")
# INT8 on GPU
start.record()
int8_output_gpu = torch.matmul(int8_quantized_matrix_gpu, int8_quantized_matrix_gpu)
end.record()
torch.cuda.synchronize()
int8_time_gpu = start.elapsed_time(end) / 1000
print(f"Int8 matmul on GPU took: {int8_time_gpu:.6f} seconds")
GPU Name: NVIDIA GeForce RTX 3080 Laptop GPU --- GPU Timings --- Float32 matmul on GPU took: 0.107393 seconds Float16 matmul on GPU took: 0.234943 seconds BFloat16 matmul on GPU took: 0.116011 seconds
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[2], line 84 82 # INT32 on GPU 83 start.record() ---> 84 int32_output_gpu = torch.mm(int32_quantized_matrix_gpu, int32_quantized_matrix_gpu) 85 end.record() 86 torch.cuda.synchronize() RuntimeError: "addmm_cuda" not implemented for 'Int'