GPT-2 from scratch

Alluri Jairam
9 min readAug 18, 2024

--

Ever wondered how GPT-2, one of the most talked about language model works under the hood? Imagine being able to build it from the scratch, understanding every detail that goes in building a model capable of generating coherent text.

The Journey Ahead

In this blog, we’ll embark on my journey to reconstruct GPT-2 from the ground up. Whether you’re a deep learning enthusiast or a developer looking to deepen your understanding of transformers, this guide will break down the core components of GPT-2, step by step.

But why build a GPT-2 model from scratch when there are pretrained models readily available?

Building something from the scratch is one of the most effective ways to understand the concept. It helps to fill in the gaps which might have been missed during reading. As Feynman wisely put it “What I cannot create, I do not understand.”

Prerequisites

To follow this blog, you should be familiar with python programming, PyTorch, basics of NLP and matrix multiplication.

In this blog, we will create a decoder-only model inspired by the paper “Attention Is All You Need.” This decoder-only version, composed of multiple decoder blocks, serves as the foundation for GPT-2.

Model Architecture

Let’s build each module of the model from scratch, following the architecture outlined in the image.

Fig 1 : Model Architecture.

Source — https://arxiv.org/pdf/1706.03762

Input Data

The input data for our model consists of a character-level representation of the English language, derived from Shakespeare’s works. We will use a character-based tokenizer that has been specifically created for this dataset.

The dataset contains 65 unique tokens, so our tokenizer assigns a unique token to each character.

The 65 unique characters are

“!$&’,-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz”

Fig 2 : Input Data

To better understand the solution, let’s use a running example. The following terms will become clearer as you progress through the blog.

Batch size: 64 (number of training batches)

Sequence length: 512 (number of characters passed into the model at one time)

Embedding dimension = 100 ( dimension of each embedding )

Key,query,values dimension = 128

Output Embeddings:

We use a character level embeddings for the model. An embedding is a vector of n dimension that represents the character. with 65 unique characters and an embedding dimension of 100, our embedding matrix will have a shape of [65,100].

I have built the embeddings using the PyTorch’s built in embedding class. The embedding parameters are learned during the training process.

Input to the embedding layer = [batch size , sequence length] = [65,512]

Output to the embedding layer = [batch size , sequence length,embedding dim] = [65,512,100]

class vanilla_embedding(nn.Module):
"""
Embedding layer to map input tokens to dense vectors.

Attributes:
embed_layer (nn.Embedding): Embedding layer for input tokens.
"""
def __init__(self,vocabulary_size,embedding_size):
super().__init__()
self.embed_layer = nn.Embedding(vocabulary_size,embedding_size)
def forward(self,input):
"""
Forward pass through the embedding layer.

Args:
input (torch.Tensor): Input token indices.

Returns:
torch.Tensor: Embedded input.
"""
return self.embed_layer(input)

Positional Encodings:

Unlike RNNs, transformers are not recurrent and process all the data in parallel. Because of this, they lack an inherent sense of position within the sequence. To address this, we add positional encodings to the embeddings, allowing the model to incorporate information about the position of each token within the sequence.

We use sine and cosine functions with different frequencies to generate positional encodings:

• PE(pos, 2i) = sin (pos/(10000^{2i/dmodel)

• PE(pos, 2i+1) = cos (pos/(10000^{2i/dmodel)

Here, pos is the position, dmodel is the embedding dimension, and i is the dimension. Each dimension of the positional encoding corresponds to a sinusoid with a wavelength that increases geometrically. This design helps the model learn relative positions effectively because the positional encoding of any offset can be expressed as a linear function of the original position.

Input to the positional encoding layer = [batch size , sequence length,embedding dim] = [65,512,100]

Output to the positional encoding layer = [batch size , sequence length,embedding dim] = [65,512,100]

class vanilla_pos_encoder(nn.Module):
"""
Positional encoder for adding position information to input embeddings.

Attributes:
block_size (int): Maximum sequence length.
embed_size (int): Dimension of the input embeddings.
"""
def __init__(self,block_size,embed_size):
super().__init__()
self.block_size = int(block_size)
self.embed_size = int(embed_size)

def forward(self):
"""
Computes positional encodings for input sequences.

Returns:
torch.Tensor: Positional encodings of shape (block_size, embed_size).
"""
pos_enc = torch.zeros((self.block_size,self.embed_size))
for pos in range(self.block_size):
for i in range(int(self.embed_size/2)):
pos_enc[pos,2*i] = np.sin(pos/(10000**(2*i/self.embed_size)))
pos_enc[pos,2*i+1] = np.cos(pos/(10000**(2*i/self.embed_size)))

return pos_enc

Masked Multi-Head Attention:

Attention is a fundamental component of GPT-2 and other modern language models.

The three main components of Attention are Queries, Keys, and Values:

Queries: These represent the ‘questions’ or criteria that each word uses to seek relevant information.

Keys: These represent the ‘answers’ or potential pieces of information that each word might provide in response to the queries.

Values: These represent the actual content or data associated with each word.”

Fig 3: (Left) Single Self-Attention Block (Scaled Dot-Product Attention). (Right) Multi-Headed Attention.

Source — https://arxiv.org/pdf/1706.03762

Fig 4: Scaled Dot-Product Attention

Source — https://arxiv.org/pdf/1706.03762

Step 1:

We pass our data into the key,query and values linear layers. These layers are learned during the training

Step 2:

We multiply the Query matrix by the transpose of the Keys matrix to determine the relevance of each Key (answer) to a given Query (question). This step helps us identify which words are most important in relation to each other. The plot below illustrates the importance of each word with respect to a specific word.

[65,512,128]*[65,128,512] = [65,512,512]

Step 3:

Scaling the 512,512 matrix to normalize the values, bringing the mean to 0 and the variance to 1. This normalization step helps stabilize the values and aids in the convergence of the model.

Step 4:

We mask the upper triangle of the matrix by setting those values to negative infinity. This ensures that future tokens are not visible to the self-attention block, allowing it to only consider tokens to the left and not to the right.

Step 5:

We apply softmax to the matrix to convert arbitary values into probabilities. The negative infinity values used for masking are turned into 0s by the softmax.

Step 6:

We multipy the 512,512 probabilities matrix by the Values matrix to generate the final output for each character. By performing this step we are aggregate the important information from other characters and generate the most relavent representation of the current character.

Input to the step 6 = [65,512,512]

Output to the step 6 = [65,512,128]

Step 7 :

We pass the each embedding dimension into a linear layer.

Input to the step 7 = [65,512,128]

Output to the step 7 = [65,512,100] ( changing the dimensions back to original to pass it into the next MHA)

class vanilla_att_head(nn.Module):
"""
Single attention head for self-attention mechanism.

Attributes:
Q (nn.Linear): Linear layer for query projection.
K (nn.Linear): Linear layer for key projection.
V (nn.Linear): Linear layer for value projection.
att_lin (nn.Linear): Linear layer for final attention output.
"""
def __init__(self,embed_size,head_dim):
super().__init__()
self.Q = nn.Linear(embed_size,head_dim)
self.K = nn.Linear(embed_size,head_dim)
self.V = nn.Linear(embed_size,head_dim)

self.att_lin = nn.Linear(head_dim,embed_size)

def forward(self,embeddings):
"""
Computes self-attention for input embeddings.

Args:
embeddings (torch.Tensor): Input embeddings.

Returns:
torch.Tensor: Output of the attention head.
"""
qu = self.Q(embeddings)
ke = self.K(embeddings)
va = self.V(embeddings)

# transposing key
ke = ke.transpose(-2,-1)


# (q*kt)/scaling_factor
query_mul_key = (qu@ke)/np.sqrt(head_dim)


mask = torch.ones_like(query_mul_key,dtype=torch.bool)
mask = mask.triu(1)
query_mul_key = query_mul_key.masked_fill(mask,float('-inf'))

# softmax
query_mul_key = nn.functional.softmax(query_mul_key,dim=-1)


# qk @ v final step
# this is single head attention
att = query_mul_key@va
return att

These 7 steps outline the self-attention mechanism. In Multi-Head Self-Attention, we perform these steps in parallel across multiple attention heads and then combine the outputs before the final step. This approach allows the model to capture a wider range of relationships within complex data.

class vanilla_MHA(nn.Module):
"""
Multi-head attention mechanism.

Attributes:
heads (nn.ModuleList): List of attention heads.
linear (nn.Linear): Linear layer for final multi-head output.
"""
def __init__(self,num_heads,embed_size,head_dim):
super().__init__()
self.heads = nn.ModuleList(vanilla_att_head(embed_size,head_dim) for _ in range(num_heads))
self.linear = nn.Linear(num_heads*head_dim,embed_size)
def forward(self,embeddings):
"""
Computes multi-head self-attention for input embeddings.

Args:
embeddings (torch.Tensor): Input embeddings.

Returns:
torch.Tensor: Output of the multi-head attention mechanism.
"""
mha = [he(embeddings) for he in self.heads]
mha = torch.cat(mha,axis = -1 )
out = self.linear(mha)
return out

Transformer models have become widely popular because they can be trained in parallel without relying on recurrent dependencies like RNNs. This capability has enabled transformers to scale up to billions of parameters efficiently.

Add and Normlisation:

The data before and after the Multi-Head Attention (MHA) are added together and then normalized. This step introduces ResNet-style skip connections to the network, helping to improve information flow and model stability.

Input to the add and norm = [65,512,100]

Output to the add and norm = [65,512,100]

Fig 5: Add and Norm Layer

Source — https://arxiv.org/pdf/1706.03762

Feed Forward Layer:

We pass the output from the Add & Norm layer into the Feed Forward layer. In this layer, the input first goes through a linear layer, followed by a ReLU activation. The resulting output is then passed through another linear layer. Finally, the output is once again added to the original input and normalized with another Add & Norm layer.

Input to the FFD = [65,512,100]

Output to the FFD = [65,512,100]

class vanilla_transformer_block(nn.Module):
"""
Single transformer block with multi-head attention and feed-forward network.

Attributes:
att1 (vanilla_MHA): First multi-head attention layer.
lnorm1 (nn.LayerNorm): Layer normalization after first attention layer.
att2 (vanilla_MHA): Second multi-head attention layer.
lnorm2 (nn.LayerNorm): Layer normalization after second attention layer.
ffd (vanilla_FFD): Feed-forward network.
lnorm3 (nn.LayerNorm): Layer normalization after feed-forward network.
"""
def __init__(self,num_heads,embed_size,head_dim,block_size):
super().__init__()
# self.pos_encoder = vanilla_pos_encoder(block_size,embed_size)
self.att1 = vanilla_MHA(num_heads,embed_size,head_dim)
self.lnorm1 = nn.LayerNorm(embed_size)

self.att2 = vanilla_MHA(num_heads,embed_size,head_dim)
self.lnorm2 = nn.LayerNorm(embed_size)

self.ffd = vanilla_FFD(embed_size)
self.lnorm3 = nn.LayerNorm(embed_size)

def forward(self,embeddings):
"""
Forward pass through the transformer block.

Args:
embeddings (torch.Tensor): Input embeddings.

Returns:
torch.Tensor: Output of the transformer block.
"""

embeddings = embeddings
e2mh1 = self.att1(embeddings)
e2mh1 = e2mh1+embeddings
mh122 = self.lnorm1(e2mh1)

mh2 = self.att2(mh122)
mh2 = mh122+mh2
mh22f = self.lnorm2(mh2)

ff = self.ffd(mh22f)
ff = ff+ mh22f
return ff

Final output layer:

The output of all the blocks is repeated N times, and the final output is passed through a linear layer. A softmax function is then applied to predict each word.

class vanilla_transformer(nn.Module):
"""
Transformer model for sequence-to-sequence tasks.

Attributes:
block_size (int): Maximum sequence length.
embedder (vanilla_embedding): Embedding layer.
pos_encoder (vanilla_pos_encoder): Positional encoding layer.
blocks (nn.Sequential): Sequential transformer blocks.
final_linear (nn.Linear): Final linear layer to project to vocabulary size.
"""
def __init__(self, vocab_size, embed_size, trans_blocks, num_heads, head_dim, block_size):
super().__init__()
self.block_size = block_size
self.embedder = vanilla_embedding(vocab_size, embed_size).to(device)
self.pos_encoder = vanilla_pos_encoder(block_size, embed_size).to(device)
self.blocks = nn.Sequential(
*[vanilla_transformer_block(num_heads, embed_size, head_dim, block_size).to(device) for _ in range(trans_blocks)]
)
self.final_linear = nn.Linear(embed_size, vocab_size).to(device)

def forward(self, x, targets=None):
"""
Forward pass through the transformer model.

Args:
x (torch.Tensor): Input token indices.
targets (torch.Tensor, optional): Target token indices for loss calculation.

Returns:
torch.Tensor: Logits of the transformer.
torch.Tensor or None: Cross-entropy loss if targets are provided.
"""
embeddings = self.embedder(x).to(device)
pos = self.pos_encoder().to(device)
embeddings = pos + embeddings
outs = self.blocks(embeddings)

logits = self.final_linear(outs)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)

return logits, loss

Conclusion:

Building GPT-2 from scratch is a great way to really understand how transformer models work. By following each step, from creating embeddings to implementing multi-head self-attention and feed-forward layers, you’ll get hands-on experience with the core ideas that make these models powerful.

Transformers have changed the game in natural language processing because they can be trained in parallel and don’t rely on recurrent connections like RNNs. This design has allowed models like GPT-2 to handle large amounts of data and scale up to billions of parameters, making them very effective at generating coherent text.

By building GPT-2 yourself, you fill in any gaps in your knowledge and gain a better grasp of how modern language models work. This journey helps you truly understand the key concepts behind today’s most advanced models.

--

--

No responses yet