ViT Made Easy (Draft)

A deep dive into ‘An image is worth 16x16 words’
Vision
Author

Mohit Gupta

Published

June 4, 2025

Modified

June 10, 2025

From “Attention is All You Need”, attention is defined in terms of Queries, Keys, and Values matrices - calculated through a learnable linear layer.

  1. First step in Attention is calculation of Q,K,V values:

Note: It changes the length of the tokens from the token_len or dim (e.g. 49) to the channels or chan parameter (e.g. 64). Notice the bottom matrix of Q, K, V says “Projected Length of Tokens”

  1. Now calculate the Attention “Attention” values with formula:

i.e.

But we do it in steps:

2.1

Step-2A

Step-2B: Scoring

Step-3:

Step-2C: Softmax

Step-2D:

This results in the following shape:

Step-4: Skip Connection

Note: Shape of new X can be different from input X. So, we use V for skip connection (by flattening it’s attention head dimension)

import torch
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, 
                dim: int,
                chan: int,
                num_heads: int=1,
                qkv_bias: bool=False,
                qk_scale: float=None):

        """ Attention Module

            Args:
                dim (int): input size of a single token
                chan (int): resulting size of a single token (channels)
                num_heads(int): number of attention heads in MSA
                qkv_bias (bool): determines if the qkv layer learns an addative bias
                qk_scale (NoneFloat): value to scale the queries and keys by; 
                                    if None, queries and keys are scaled by ``head_dim ** -0.5``
        """

        super().__init__()

        ## Define Constants
        self.num_heads = num_heads
        self.chan = chan
        self.head_dim = self.chan // self.num_heads
        self.scale = qk_scale or self.head_dim ** -0.5
        assert self.chan % self.num_heads == 0, '"Chan" must be evenly divisible by "num_heads".'

        ## Define Layers
        self.qkv = nn.Linear(dim, chan * 3, bias=qkv_bias)
        #### Each token gets projected from starting length (dim) to channel length (chan) 3 times (for each Q, K, V)
        self.proj = nn.Linear(chan, chan)

    def forward(self, x):
        B, N, C = x.shape
        ## Dimensions: (batch, num_tokens, token_len)

        ## Calcuate QKVs
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        #### Dimensions: (3, batch, heads, num_tokens, chan/num_heads = head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        ## Calculate Attention
        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        #### Dimensions: (batch, heads, num_tokens, num_tokens)

        ## Attention Layer
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.chan)
        #### Dimensions: (batch, heads, num_tokens, chan)

        ## Projection Layers
        x = self.proj(x)

        ## Skip Connection Layer
        v = v.transpose(1, 2).reshape(B, N, self.chan)
        x = v + x     
        #### Because the original x has different size with current x, use v to do skip connection

        return x
c:\Users\mgupta70\AppData\Local\anaconda3\envs\myenv\lib\site-packages\torch\utils\_pytree.py:185: FutureWarning: optree is installed but the version is too old to support PyTorch Dynamo in C++ pytree. C++ pytree support is disabled. Please consider upgrading optree using `python3 -m pip install --upgrade 'optree>=0.13.0'`.
  warnings.warn(

Single-Headed Attention

token_len = 49 # 7x7
channels = 64
num_tokens = 100
batch = 13

x = torch.rand(batch, num_tokens, token_len)
B, N, C = x.shape
print('Input dimensions are:', x.shape, '\n\t batchsize: ', x.shape[0], '\n\t num_tokens: ', x.shape[1], '\n\t token_len: ', x.shape[2])


# Attention Module
A = Attention(dim=token_len, chan=channels, num_heads=1, qkv_bias=False, qk_scale=None)
A.eval() # ? Why??
Input dimensions are: torch.Size([13, 100, 49]) 
     batchsize:  13 
     num_tokens:  100 
     token_len:  49
Attention(
  (qkv): Linear(in_features=49, out_features=192, bias=False)
  (proj): Linear(in_features=64, out_features=64, bias=True)
)
# Step 1: Calculate Q, K, V
qkv = A.qkv(x)
print(qkv.shape)
print('192 is basically 3x64, 3 is because of q, k, v')
print('')
# Reshape qkv to get q, k, v
print('-- Reshaping --')
qkv = qkv.reshape(B, N, 3, A.num_heads, A.head_dim).permute(2, 0, 3, 1, 4)
print(qkv.shape)
print('\t3: for q, k, v\n\t13: batchsize\n\t1: num_heads\n\t100: num_tokens\n\t64: head_dim=(channels/num_heads) = 64/1 = 64')
torch.Size([13, 100, 192])
192 is basically 3x64, 3 is because of q, k, v

-- Reshaping --
torch.Size([3, 13, 1, 100, 64])
    3: for q, k, v
    13: batchsize
    1: num_heads
    100: num_tokens
    64: head_dim=(channels/num_heads) = 64/1 = 64
q, k, v = qkv[0], qkv[1], qkv[2]
print('See that the dimensions for queries, keys, and values are all the same:')
print('\tShape of Q:', q.shape, '\n\tShape of K:', k.shape, '\n\tShape of V:', v.shape)
print('Dimensions for Queries are \n\tbatchsize:', q.shape[0], '\n\tattention heads:', q.shape[1], '\n\tnumber of tokens:', q.shape[2], '\n\tnew length of tokens:', q.shape[3])
See that the dimensions for queries, keys, and values are all the same:
    Shape of Q: torch.Size([13, 1, 100, 64]) 
    Shape of K: torch.Size([13, 1, 100, 64]) 
    Shape of V: torch.Size([13, 1, 100, 64])
Dimensions for Queries are 
    batchsize: 13 
    attention heads: 1 
    number of tokens: 100 
    new length of tokens: 64
# Step 2: Calculate "Attention" values
# Step- 2A: Scaling factor
print(f"Scaling factor: 1/sqrt(head_dim)) = 1/({A.head_dim})**(0.5) = {A.scale}")

# Step- 2B: Calculate Attention
attn = (q * A.scale) @ k.transpose(-2, -1)
print('Dimensions for Attn are:', attn.shape, '\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])

Scaling factor: 1/sqrt(head_dim)) = 1/(64)**(0.5) = 0.125
Dimensions for Attn are: torch.Size([13, 1, 100, 100]) 
    batchsize: 13 
    attention heads: 1 
    number of tokens: 100 
    number of tokens: 100
# Step 3: Normalize Attention with Softmax
attn = attn.softmax(dim=-1)
print('Dimensions for Attn after softmax are:', attn.shape, '\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])

# Step 4: Calculate "Attention" values
x = attn @ v
print('Dimensions for Attn after softmax are:', x.shape, '\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens:', x.shape[2], '\n\tlength of tokens:', x.shape[3])

Dimensions for Attn after softmax are: torch.Size([13, 1, 100, 100]) 
    batchsize: 13 
    attention heads: 1 
    number of tokens: 100 
    number of tokens: 100
Dimensions for Attn after softmax are: torch.Size([13, 1, 100, 64]) 
    batchsize: 13 
    attention heads: 1 
    number of tokens: 100 
    length of tokens: 64
x = x.transpose(1, 2).reshape(B, N, A.chan)
print('Dimensions for x are', x.shape, '\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
Dimensions for x are torch.Size([13, 100, 64]) 
    batchsize: 13 
    number of tokens: 100 
    length of tokens: 64
# Step 4: Skip Connection
orig_shape = (batch, num_tokens, token_len)
print('Original shape:', orig_shape)

curr_shape = x.shape
print('Current shape:', curr_shape)

print('old v shape:', v.shape)
v = v.transpose(1, 2).reshape(B, N, A.chan)
print('new v shape:', v.shape)

x = x + v
print('After skip connection:', x.shape)
Original shape: (13, 100, 49)
Current shape: torch.Size([13, 100, 64])
old v shape: torch.Size([13, 1, 100, 64])
new v shape: torch.Size([13, 100, 64])
After skip connection: torch.Size([13, 100, 64])

Notice: we flattened V along the head dimension i.e. 1 in [13, 1, 100, 64]

Multi-headed Attention

  • aka MSA i.e. Multi-headed Self Attention
token_len = 49 # 7x7
channels = 64
num_tokens = 100
batch = 13
num_heads = 4

x = torch.rand(batch, num_tokens, token_len)
B, N, C = x.shape
print('Input dimensions are:', x.shape, '\n\t batchsize: ', x.shape[0], '\n\t num_tokens: ', x.shape[1], '\n\t token_len: ', x.shape[2])


# Attention Module
MSA = Attention(dim=token_len, chan=channels, num_heads=num_heads, qkv_bias=False, qk_scale=None)
MSA.eval()
Input dimensions are: torch.Size([13, 100, 49]) 
     batchsize:  13 
     num_tokens:  100 
     token_len:  49
Attention(
  (qkv): Linear(in_features=49, out_features=192, bias=False)
  (proj): Linear(in_features=64, out_features=64, bias=True)
)

In MSA: The total size of the Q, K, and V matrices have not changed; their contents are just distributed across the head dimension.

Think about this as segmenting the single headed matrix for the multiple heads:

# Step 1: Calculate Q, K, V
qkv = MSA.qkv(x)
print(qkv.shape)
print('\t192 is basically 3x4x16, 3 is because of q, k, v, 4 is because of num_heads, 16 is because of head_dim')
print('')

# Reshape qkv to get q, k, v
qkv = qkv.reshape(B, N, 3, MSA.num_heads, MSA.head_dim).permute(2, 0, 3, 1, 4)
print(qkv.shape)
print('\t3: for q, k, v\n\t13: batchsize\n\t4: num_heads\n\t100: num_tokens\n\t16: head_dim=(channels/num_heads) = 64/4 = 16')
torch.Size([13, 100, 192])
    192 is basically 3x4x16, 3 is because of q, k, v, 4 is because of num_heads, 16 is because of head_dim

torch.Size([3, 13, 4, 100, 16])
    3: for q, k, v
    13: batchsize
    4: num_heads
    100: num_tokens
    16: head_dim=(channels/num_heads) = 64/4 = 16
q, k, v = qkv[0], qkv[1], qkv[2]

print('See that the dimensions for queries, keys, and values are all the same:')
print('\tShape of Q:', q.shape, '\n\tShape of K:', k.shape, '\n\tShape of V:', v.shape)
print('Dimensions for Queries are \n\tbatchsize:', q.shape[0], '\n\tattention heads:', q.shape[1], '\n\tnumber of tokens:', q.shape[2], '\n\tnew length of tokens:', q.shape[3])
See that the dimensions for queries, keys, and values are all the same:
    Shape of Q: torch.Size([13, 4, 100, 16]) 
    Shape of K: torch.Size([13, 4, 100, 16]) 
    Shape of V: torch.Size([13, 4, 100, 16])
Dimensions for Queries are 
    batchsize: 13 
    attention heads: 4 
    number of tokens: 100 
    new length of tokens: 16

The next step is to calculate attention for each head i as shown below:

For every head, i

where dk is:
# Step 2: Calculate Attention for each head

attn = (q * MSA.scale) @ k.transpose(-2, -1)
print('Dimensions for Attn are:', attn.shape, '\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])

# Step 3: Normalize Attention with Softmax
attn = attn.softmax(dim=-1)
print('Dimensions for Attn after softmax are:', attn.shape, '\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])
Dimensions for Attn are: torch.Size([13, 4, 100, 100]) 
    batchsize: 13 
    attention heads: 4 
    number of tokens: 100 
    number of tokens: 100
Dimensions for Attn after softmax are: torch.Size([13, 4, 100, 100]) 
    batchsize: 13 
    attention heads: 4 
    number of tokens: 100 
    number of tokens: 100
# Step 4: Calculate "Attention" values
x = attn @ v
print('Dimensions for Attn after softmax are:', x.shape, '\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens:', x.shape[2], '\n\tlength of tokens:', x.shape[3])
Dimensions for Attn after softmax are: torch.Size([13, 4, 100, 16]) 
    batchsize: 13 
    attention heads: 4 
    number of tokens: 100 
    length of tokens: 16

Now, we concatenate all x i’s :

x = x.transpose(1, 2).reshape(B, N, MSA.chan)
print('Dimensions for x are:', x.shape, '\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
Dimensions for x are: torch.Size([13, 100, 64]) 
    batchsize: 13 
    number of tokens: 100 
    length of tokens: 64

Now, we got something similar as Single Headed Attention. Rest of the module remains same.

For skip connection, we still use V, but we have to reshape it to remove head dim

x = MSA.proj(x)
print('Dimensions for x after projection are:', x.shape, '\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])

orig_shape = (batch, num_tokens, token_len)
print('Original shape:', orig_shape)
curr_shape = x.shape
print('Current shape:', curr_shape)

# Skip Connection
# flatten v
print('old v shape:', v.shape)
v = v.transpose(1, 2).reshape(B, N, MSA.chan)
print('new v shape:', v.shape)

# skip connection
x = v + x
print('new x shape:', x.shape)
Dimensions for x after projection are: torch.Size([13, 100, 64]) 
    batchsize: 13 
    number of tokens: 100 
    length of tokens: 64
Original shape: (13, 100, 49)
Current shape: torch.Size([13, 100, 64])
old v shape: torch.Size([13, 4, 100, 16])
new v shape: torch.Size([13, 100, 64])
new x shape: torch.Size([13, 100, 64])

Final Notes:

The learnable weights in an attention layer are found in the first projection from tokens to queries, keys, and values and in the final projection. The majority of the attention layer is deterministic matrix multiplication.