Python imports
import torch
from torch import nn
import math
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % Standard Commands %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%mathbb%greek
%GREEK
%cal
%tilde
%bar %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % Math Commands %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% % % % %COLORS % % % % %By now, there are many useful books, online guides, and videos for understanding the construction of Large Language Models from first principles (Raschka 2024) (Sanderson and Sun 2024) (Amidi and Amidi 2024). What separates this project, is that this is designed as a mathematics first presentation of the architecture of GPT styled LLM. If you are a mathematician or are simply more comfortable with the language and structure of mathematical writing, this page will hopefully be an easy to digest presentation of the mathematical foundations of LLMs. You should be able to read and understand this material in an afternoon.
In addition to describing tokenizers, attention mechanisms, and transformer blocks in terms of mathematical formulas, I have included basic implementations of the constructions in PyTorch in order to also give a quick start for understanding the bridge between the abstract mathematical formulations and implementations in code. Because our focus is primarily on the mathematical constructions, we leave out some important elements of the computer engineering perspective which are important for implementation and optimization; namely, batching of data, differences between device architecture of CPU vs. GPU, as well as parameter quantization.
Thanks go to my audience in the symplectic group at the Faculty of Mathematics at Ruhr-Universität Bochum.
We can superficially consider a (generative) \(\text{LLM}\) as a function which takes as input a fixed length sequence of text and outputs a probability distribution for the next text unit. By sampling from this distribution, and appending to the original text sequence, we can iteratively predict the subsequent text unit:
\[ \underline{\text{The Steenrod problem for}} \mapsto \begin{cases} \text{closed} & 29\% \\ \text{spaces} & 24\% \\ \text{modules} & 22\% \\ \text{spectra} & 19\% \end{cases} \]
\[ \implies \text{The }\underline{\text{Steenrod problem for } \fbox{closed}} \mapsto \begin{cases} \text{orientable} & 37\% \\ \text{smooth} & 23\% \\ \text{topological} & 19\% \\ \text{homology} & 9\% \end{cases} \]
\[ \implies \text{The Steenrod }\underline{\text{problem for closed } \fbox{orientable}} \mapsto \begin{cases} \text{manifolds} & 77\% \\ \text{orbifolds} & 14\% \\ \text{surfaces} & 4\% \\ \text{varieties} & <1\% \end{cases} \]
We begin by selecting a source text, which gives a fixed source for training our eventual LLM, in addition to providing a source for defining the vocabulary for our tokenizer. To keep things topical, we will use a collection of mathematics texts in symplectic geometry:
mathematics.tex
The Steenrod problem for closed orientable manifolds was solved completely by Thom.
Following this approach, we solve the Steenrod problem for closed orientable orbifolds, proving that the rational homology groups of a closed orientable orbifold have a basis consisting of classes represented by suborbifolds whose normal bundles have fiberwise trivial isotropy action. Polyfold theory, as developed by Hofer, Wysocki, and Zehnder, has yielded a well-defined Gromov–Witten invariant via the regularization of moduli spaces.
As an application, we demonstrate that the polyfold Gromov–Witten invariants, originally defined via branched integrals, may equivalently be defined as intersection numbers against a basis of representing suborbifolds.
Introduction
The Steenrod problem
The Steenrod problem was first presented in (Eilenberg 1949) and asked the following question:
Can any homology class of a finite polyhedron be represented as an image of the fundamental class of some manifold?
In (Thom 1954),…
A token is an arbitrarily defined unit of text.
Level | Description | Illustration |
---|---|---|
word | Each word is a token. | \(\fbox{The} \fbox{Steenrod} \fbox{problem} \fbox{for} \fbox{closed} \fbox{orbifolds}\)… |
subword | Each word is divided into a subword, e.g., BPE. | \(\fbox{The}\fbox{ Ste}\fbox{en}\fbox{rod}\fbox{ problem}\fbox{ for}\fbox{ closed}\fbox{ orb}\fbox{if}\fbox{olds}\)… |
character | Each character is a token. | \(\fbox{T}\fbox{h}\fbox{e}\fbox{}\fbox{S}\fbox{t}\fbox{e}\fbox{e}\fbox{n}\fbox{r}\fbox{o}\fbox{d}\fbox{}\fbox{p}\fbox{r}\fbox{o}\fbox{b}\fbox{l}\fbox{e}\fbox{m}\)… |
A vocabulary \(V\) is a set of tokens. We often augment a vocabulary with additional special tokens, e.g.,
\([\text{BOS}]\) | Marks the beginning of text. |
\([\text{EOS}]\) | Marks the end of text. |
\([\text{PAD}]\) | A padding token used to fill a sequence to the required context length. |
\([\text{UNK}]\) | Represents parts of the input text not present in the vocabulary. |
Let \(\text{TEXTS}\) be defined as the set of finite texts and let \(\text{Seq}(V)\) be defined as the set of finite sequences \((t_1, t_2,...,t_N)\) where \(t_i \in V.\) A tokenizer \(T\) is a map
\[\begin{align} T: \text{TEXTS} &\to \text{Seq}(V), \\ \text{text} &\mapsto \text{tokenized text}. \end{align}\]
In other words, the tokenizer \(T\) encodes a text into a sequence of tokens. The (partial) inverse, \(T^{-1}\), decodes a sequence of tokens into a text. Note that the presence of characters or subwords not present in the vocabulary, as well as the possibility of special tokens, complicates the question of whether \(T\) is injective/bijective/neither as well as the definition of the (partial) inverse \(T^{-1}\).
For example, a word-level tokenizer can be viewed as follows:
\[ T(\text{The Steenrod problem...}) = (\fbox{The},\fbox{Steenrod},\fbox{problem}...) \]
We next want to represent tokens numerically as vectors. We begin by fixing an index on the vocabulary, i.e., a bijection \(V \to \{1,...,|V|\}\), and thus we freely identify tokens and token indices.
Using this identification, a tokenizer defines an encoder of a finite text into a sequence of indices:
\[ T ( \text{The Steenrod problem...}) = (341, 12459, 948,...), \]
as well as a decoder from a sequence of indices back to a finite text.
class CharacterTokenizer:
def __init__(self):
# Maps token_str to token_id (e.g., {"a": 19, "b": 23})
self.vocab = {}
# Maps token_id to token_str (e.g., {19: "a", 23: "b"})
self.inverse_vocab = {}
# Included special tokens
self.specials = {"<|BOS|>", "<|EOS|>", "<|UNK|>", "<|PAD|>"}
def initialize_vocab(self, source_text, allowed_specials=set()):
"""Initializes the vocabulary consisting of characters appearing in source_text."""
characters = set(source_text)
self.vocab = {char: idx for idx, char in enumerate(characters | self.specials)}
self.inverse_vocab = {idx: char for char, idx in self.vocab.items()}
def encode(self, text):
processed = (
["<|BOS|>"]
+ [char if char in self.vocab else "<|UNK|>" for char in text]
+ ["<|EOS|>"]
)
ids = [self.vocab[token] for token in processed]
return ids
def decode(self, ids):
return "".join(
[
self.inverse_vocab[id]
for id in ids
if id
not in {
self.vocab["<|PAD|>"],
self.vocab["<|BOS|>"],
self.vocab["<|EOS|>"],
}
]
)
class WordTokenizer:
def __init__(self):
# Maps token_str to token_id (e.g., {"symplectic": 125})
self.vocab = {}
# Maps token_id to token_str (e.g., {125: "symplectic"})
self.inverse_vocab = {}
# Included special tokens
self.specials = {"<|BOS|>", "<|EOS|>", "<|UNK|>", "<|PAD|>"}
def initialize_vocab(self, source_text, allowed_specials=set()):
"""Initializes the vocabulary consisting of words appearing in source_text."""
words = set(source_text.split(" "))
self.vocab = {word: idx for idx, word in enumerate(words | self.specials)}
self.inverse_vocab = {idx: word for word, idx in self.vocab.items()}
def encode(self, text):
words_in_text = text.split(" ")
processed = (
["<|BOS|>"]
+ [word if word in self.vocab else "<|UNK|>" for word in words_in_text]
+ ["<|EOS|>"]
)
ids = [self.vocab[token] for token in processed]
return ids
def decode(self, ids):
return " ".join(
[
self.inverse_vocab[id]
for id in ids
if id
not in {
self.vocab["<|PAD|>"],
self.vocab["<|BOS|>"],
self.vocab["<|EOS|>"],
}
]
)
The byte-pair encoding is an algorithm which takes a source text and builds a vocabulary iteratively by combining the most common pairs of neighboring tokens (the modern version of the ‘byte-pair’). A text is correspondingly tokenized and encoded by this vocabulary by merging pairs of tokens iteratively.
Explicitly, the vocabulary is constructed via the following algorithm:
A text is then tokenized and encoded via the following algorithm:
class BPETokenizer:
def __init__(self):
# Maps token_str to token_id (e.g., {"symplectic": 11246})
self.vocab = {}
# Maps token_id to token_str (e.g., {11246: "symplectic"})
self.inverse_vocab = {}
# Included special tokens
self.specials = {"<|BOS|>", "<|EOS|>", "<|UNK|>", "<|PAD|>"}
# Dictionary of BPE merges: {(token_id1, token_id2): merged_token_id}
self.bpe_merges = {}
def initialize_vocab(self, source_text, vocab_size):
"""Initializes the vocabulary for BPE from source_text of size vocab_size."""
from collections import Counter
# Initialize vocab with unique characters appearing in the source_text
# The first 256 ASCII characters
initial_chars = {chr(i) for i in range(256)}
# Additional characters appearing in the source_text
source_chars = {char for char in set(source_text) if char not in initial_chars}
self.vocab = {
char: idx for idx, char in enumerate(initial_chars | source_chars)
}
self.inverse_vocab = {idx: char for char, idx in self.vocab.items()}
# Add special tokens
for special in self.specials:
new_id = len(self.vocab)
self.vocab[special] = new_id
self.inverse_vocab[new_id] = special
# Tokenize the source_text into token IDs
token_ids = [self.vocab[char] for char in source_text]
# Iteratively find and replace with a new_id the most frequent pairs of token ids appearing in token_ids.
for new_id in range(len(self.vocab), vocab_size):
# Identify the most frequent pair
pairs = Counter(zip(token_ids, token_ids[1:]))
pair_ids = max(pairs.items(), key=lambda x: x[1])[0]
if pair_ids == None:
break
# Replace [pair_ids[0], pair_ids[1]] with new_id
token_ids_merge = []
for token_id in token_ids:
if (
token_ids_merge
and token_ids_merge[-1] == pair_ids[0]
and token_id == pair_ids[1]
):
token_ids_merge.pop()
token_ids_merge.append(new_id)
else:
token_ids_merge.append(token_id)
token_ids = token_ids_merge
# Record the pair_ids and the new_id
self.bpe_merges[pair_ids] = new_id
# Add the merged pair to the vocab
merged_token = (
self.inverse_vocab[pair_ids[0]] + self.inverse_vocab[pair_ids[1]]
)
self.vocab[merged_token] = new_id
self.inverse_vocab[new_id] = merged_token
def encode(self, text):
"""Encode the input text into a list of token ids."""
import re
token_ids = []
token_ids.append(self.vocab["<|BOS|>"])
chunks = re.split(r"(<\|(?:UNK|PAD)\|>)", text)
for chunk in chunks:
if chunk in self.specials:
token_ids.append(self.vocab[chunk])
continue
tokens = [
self.vocab[char] if char in self.vocab else self.vocab["<|UNK|>"]
for char in chunk
]
# Apply BPE merges
while True:
pairs = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
# Find pairs that are in our merge list
merge_candidates = [pair for pair in pairs if pair in self.bpe_merges]
if not merge_candidates:
break # No more merges to apply
# Find the merge with the lowest index (earliest learned)
best = min(merge_candidates, key=lambda p: self.bpe_merges[p])
new_id = self.bpe_merges[best]
# Apply the merge
i = 0
new_tokens = []
while i < len(tokens):
if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == best:
new_tokens.append(new_id)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
tokens = new_tokens
token_ids.extend(tokens)
token_ids.append(self.vocab["<|EOS|>"])
return token_ids
def decode(self, ids):
return "".join(
[
self.inverse_vocab[id]
for id in ids
if id
not in {
self.vocab["<|PAD|>"],
self.vocab["<|BOS|>"],
self.vocab["<|EOS|>"],
}
]
)
A token embedding is an abstract representation of a token as a vector in some \(\mathbb{R}^d\).
The One Hot mapping is a bijection between the tokens indices in a vocabulary \(V\) with the standard basis elements \(\{e_i\}\) in the space \(\mathbb{R}^{|V|}\), i.e.,
\[ \begin{align} \text{OH}: V &\to \mathbb{R}^{|V|} \\ t_i &\mapsto e_i. \end{align} \]
We may compose this bijection with a map \(A :\mathbb{R}^{|V|} \to \mathbb{R}^{d_\text{emb}}\) to obtain continuous vector representations of the tokens in the vocabulary:
\[ \text{Embedding} := A \circ \text{OH}: V \to \mathbb{R}^{d_\text{emb}}. \]
This composition is the basis of the Word2vec family of embeddings.
The map \(A\) is often a linear map, or, if we wished, a neural network itself; in the process of training the GPT, we can adjust the weights to learn better token embeddings.
tokenizer = BPETokenizer()
tokenizer.initialize_vocab(text, 300)
sentence = "A symplectic manifold"
encoded_sentence = tokenizer.encode(sentence)
X = torch.tensor(tokenizer.encode(sentence))
# In PyTorch, we can implement a trainable embedding layer as `torch.nn.Embedding`, which acts as a lookup table mapping token indices to dense vectors.
emb = torch.nn.Embedding(300, 5)
# print(emb(X))
# print(tokenizer.decode(X.tolist()))
Fix an input sequence length \(N\). Given an input sequence of tokens, we may apply the embedding map to get an input sequence of embedded tokens:
\[ \text{Embedding} : V^N \to \mathbb{R}^{N\times d_\text{emb}}. \]
As defined, the token embeddings carry no positional information about sequence order: two identical tokens at different positions in the input sequence have the same token embedding.
We account for this by perturbing an input sequence by means of a positional encoding matrix \(\text{PE} \in \mathbb{R}^{N\times d_\text{emb}}\) whose entries \(p_{ij}\) are given by:
\[ p_{i,2j} := \sin \left( \frac{i}{10000^\tfrac{2j}{d}}\right), \qquad p_{i,2j + 1} := \cos \left( \frac{i}{10000^\tfrac{2j}{d}}\right). \]
Given an input sequence \(X\in \mathbb{R}^{N\times d_\text{emb}}\), we perturb by adding the positional encoding matrix, i.e., \(X + \text{PE}\).
class PositionalEncoding(nn.Module):
def __init__(self, context_length, d_model):
super().__init__()
self.PE = torch.zeros(context_length, d_model)
# Remember that python uses 0-based indexing
for i in range(1, context_length + 1):
for j in range(1, d_model + 1):
if j % 2 == 0:
self.PE[i - 1][j - 1] = math.sin(
i / math.pow(10000, (2 * j) / d_model)
)
else:
self.PE[i - 1][j - 1] = math.cos(
i / math.pow(10000, (2 * j) / d_model)
)
def forward(self, X):
X = X + self.PE
return X
At the outset, we fix dimensions \(d_\text{in}\), \(d_k\), and \(d_\text{out}\), as well as an input sequence length \(N\). Let \((x_1,...,x_N),\) \(x_i \in \mathbb{R}^{d_\text{in}}\) be an input sequence of embedded vectors, which may consider as an input matrix \(X \in \mathbb{R}^{N\times d_\text{in}}\). We implement causally masked self-attention, which is the standard for the GPT architecture, and is defined as a map:
\[ \text{attention}(Q,K,V) : \mathbb{R}^{N\times d_\text{in}} \to \mathbb{R}^{N\times d_\text{out}} \]
given by the equation:
\[ \text{attention}(Q,K,V) := \text{softmax} \left( \frac{QK^T}{\sqrt{d_k} } + M \right) V. \]
In this section, we describe precisely how this is defined.
The Query, Key, and Value matrices are obtained by applying linear transformations to the input matrix \(X\in \mathbb{R}^{N\times d_\text{in}}\). Let \(W_Q, W_K \in \text{Mat}_{d_\text{in} \times d_k}(\mathbb{R}), W_V \in \text{Mat}_{d_\text{in} \times d_\text{out}}(\mathbb{R})\) be learnable weight matrices. Then we may define:
The multiplicative factor \(\frac{1}{\sqrt{d_k}}\) is justified as follows. If we assume the entries of \(Q\) and of \(K\) are independent random variables with mean \(0\) and variance \(1\), then the dot product of the rows, \(q_i\cdot k_j\) will have mean \(0\) and variance \(d_k\). We therefore introduce the factor \(\frac{1}{\sqrt{d_k}}\), as it follows that every row vector of the rescaled matrix \(\frac{QK^T}{\sqrt{d}}\) will have mean \(0\) and variance \(1\).
The causal mask \(M \in \text{Mat}_{N \times N}(\mathbb{R})\) is a strictly upper triangular matrix, with \(0\)s on the diagonal and below, and with \(-\infty\) for every entry above the diagonal. Adding this mask to the attention scores and applying the softmax results in a score of zero for any future tokens relative to a fixed token. As a result, any future tokens have no causal influence on prior tokens, as measured by the attention scores.
The softmax function takes as input a tuple of \(K>1\) real numbers and normalizes the tuple into a probability distribution. Formally, it is the function
\[ \sigma : \mathbb{R}^K \to (0,1)^K; \qquad \sigma(v_1,...,v_K)_i := \frac{e^{v_i}}{\sum_{j=1}^K e^{v_j}}. \]
In our implementation of attention, the softmax applied to a matrix is defined by applying the above softmax function independently to every row.
The attention weight matrix is the matrix \(W_\text{attn}:=\text{softmax}\left(\tfrac{QK^T}{\sqrt{d_k}} + M \right) \in \text{Mat}_{N\times N}(\mathbb{R})\). Each row \(i\) of \(W_\text{attn}\) is a probability distribution representing the ‘attention’ that the input vector \(i\) pays to all other (preceding) input vectors.
In a simple case where the input vectors are word-length tokens, we might have an attention weight matrix as follow:
The | Steenrod | problem | for | |
---|---|---|---|---|
The | 1 | 0 | 0 | 0 |
Steenrod | 0.34 | 0.66 | 0 | 0 |
problem | 0.12 | 0.45 | 0.43 | 0 |
for | 0.08 | 0.23 | 0.44 | 0.24 |
The output of the attention mechanism is the matrix \(Z := W_\text{attn} V\). We consider each column of the value matrix \(V\) as representing a value vector. Then the \(i\)th-row of \(Z\) is the weighted sum of the value vectors in \(V\), where the weights are given by the probability distribution of the \(i\)th-row of \(W_\text{attn}\). Each row of \(Z\) represents a context vector, with information enriched by its neighboring vectors.
class CausallyMaskedAttention(nn.Module):
def __init__(self, d_in, d_k, d_out, context_length, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_k, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_k, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
# Register mask as a buffer to avoid it being considered a model parameter
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
self.register_buffer('mask', mask.bool())
def forward(self, x):
n, d_in = x.shape
queries = self.W_query(x) # (n, d_k)
keys = self.W_key(x) # (n, d_k)
values = self.W_value(x) # (n, d_out)
attn_scores = queries @ keys.T # (n, n)
# Apply the causal mask
attn_scores.masked_fill_(self.mask.bool()[:n, :n], -torch.inf)
# Scale and apply softmax
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1) # (n, n)
# Multiply by the values matrix
context_vectors = attn_weights @ values # (n, d_out)
return context_vectors
A transformer block takes as input a matrix (thought of as a sequence of vectors) and outputs a matrix of the same dimension. In this way, transformer blocks may be stacked.
\[ \text{Transformer}: \mathbb{R}^{N\times d} \to \mathbb{R}^{N\times d}, \qquad (v_1,...v_N)\in\mathbb{R}^{N\times d} \]
Internally, a transformer block may consist of various attention blocks, feed-forward blocks, and connections. We will describe a classic GPT transformer block as shown.
flowchart TD A[Embedded vectors] -->|Layer normalization| B[Normed vectors] B -->|Causally masked attention| C[Context vectors] A -->|Residual connection: $$\oplus$$| C C -->|Layer normalization| D[Normed vectors] D -->|Feed-forward layer| F[Output vectors] C -->|Residual connection: $$\oplus$$| F
In the transformer block, layer normalization takes as input a vector \(\mathbb{R}^n\) and normalizes the vector so that its components have mean \(0\) and variance \(1\). Additionally, we include trainable scaling factors \(\{\lambda_i\}_{1\leq i\leq n}\) and shifts \(\{b_i\}_{1\leq i\leq n}\), which we initiate at \(\lambda_i=1\) and \(b_i=0\) to have no initial effect.
\[ \text{LayerNorm}: \mathbb{R}^n \to \mathbb{R}^n, \qquad \text{LayerNorm}(x_1,...,x_n)_i := \lambda_i \left(\frac{x_i - \text{mean}}{\sqrt{\text{var}}}\right) + b_i. \]
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.lam = nn.Parameter(torch.ones(emb_dim))
self.b = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + 1e-5)
return self.lam * norm_x + self.b
An activation function is a transformation of the nodes of a neural network. There are many choices, see (Dubey, Singh, and Chaudhuri 2022).
Popular choices are the ReLU (Rectified Linear Unit) activation function:
\[ \text{ReLU}(x) := \max (x, 0) \]
And the GELU (Gaussian Error Linear Unit) activation function:
\[ \text{GELU}(x) := x P(X\leq x), \qquad X \sim {\mathcal N}(0,1), \]
i.e., \(P(X \leq x)\) is the cumulative distribution function of the standard normal distribution. We can approximate GELU by \(\text{GELU}(x) := 0.5x( 1 + \text{tanh}[\sqrt{\tfrac{2}{\pi}}(x + 0.044715 x^3)] ).\)
A residual connection adds the input of a layer to the layer’s output:
\[ X \mapsto \text{Layer}(X) + X. \]
It is mainly used to help prevent vanishing gradients during backpropagation.
The feed-forward layer is a two-layer fully connected neural network. It consists of a linear transformation from a \(d\)-dimensional space to a higher-dimensional space (often \(4d\)), an activation function, and a linear transformation back to the original dimension \(d\):
\[ \text{FFN}(X):= \text{GELU}(XW_1+b_1)W_2 + b_2. \]
The full transformer block is a composition of a layer norm, the causally masked attention block, a residual connection, then another layer norm, the feed-forward layer, followed by another residual connection:
\[ \begin{align} \text{Transformer} : &\mathbb{R}^{N\times d} \qquad \xrightarrow{\hspace{8cm}} \qquad \mathbb{R}^{N\times d},\\ & X \mapsto \underbrace{\text{attention}(Q,K,V)\left(\text{LayerNorm}_1(X) \right) + X}_{Y} \mapsto \text{FFN}\left(\text{LayerNorm}_2(Y) \right) + Y \end{align} \]
class TransformerBlock(nn.Module):
def __init__(self, d_in, d_k, d_out, context_length):
super().__init__()
self.att = CausallyMaskedAttention(d_in, d_k, d_out, context_length, qkv_bias=False)
self.ff = FeedForward(d_out)
self.layer_norm1 = LayerNorm(d_in)
self.layer_norm2 = LayerNorm(d_out)
def forward(self, x):
# x.shape = (n, d_in)
shortcut = x # Residual connection for attention block
x = self.layer_norm1(x) # (n, d_in)
x = self.att(x) # (n, d_out)
x = x + shortcut # Add the original input back
shortcut = x # Residual connection for feed-forward block
x = self.layer_norm2(x) # (n, d_out)
x = self.ff(x) # (n, d_out)
x = x + shortcut # Add the original input back
return x
A large language model is a function which takes as input a sequence length \(N\) of encoded token indices, and outputs a sequence of length \(N\) consisting of logits. A logit is just a vector in \(\mathbb{R}^{|V|}\).
\[ \text{LLM}: V^N \to \mathbb{R}^{N \times |V|}, \]
The architecture of our GPT-styled LLM can be described as follows:
Of course, the specifics of this architecture are of less importance than understanding the building blocks in the first place. Like LEGO blocks, we can mix and rearrange the component blocks.
class LanguageModel(nn.Module):
def __init__(self, vocab_size, d_model, context_length, num_blocks):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(context_length, d_model)
self.blocks = nn.Sequential(*[TransformerBlock(d_model, context_length) for _ in range(num_blocks)])
self.final_norm = LayerNorm(d_model)
self.unembedding = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, idx):
# x.shape = context_length
x = self.token_embedding(idx) # (context_length, d_model)
x = self.pos_encoding(x) # (context_length, d_model)
x = self.blocks(x) # (context_length, d_model)
x = self.final_norm(x) # (context_length, d_model)
logits = self.unembedding(x) # (context_length, vocab_size)
return logits
We may generate additional encoded tokens indices from our \(\text{LLM}\) as follows. Given an input sequence of encoded token indices \((x_1,...,x_N) \in V^N\), the output is a sequence of logits \[(y_1,...,y_N) := \text{LLM}(x_1,...,x_N) \in \mathbb{R}^{N\times |V|}.\] The full sequence is of importance when calculating the loss of our model and training its weights, but for inference we only need the last logit, \(y_N\).
We convert the logit \(y_N\in \mathbb{R}^{|V|}\) into a probability distribution \(P\) by taking the softmax: \[\sigma(y_N) \in (0,1)^{|V|},\] where the probability of the \(i\)th token index is then given by \(P(i) := \sigma(y_N)_i\)
Generation of the next token then requires a choice of sampling strategy to select a token index from this probability distribution.
Greedy inference is defined as follows: on each iteration, we select the token index with the highest probability:
\[ \text{Greedy}(x_1,...x_N):= \text{argmax}_{i \in V} P(i). \]
def generate_greedy(model, ids, max_new_tokens, context_size):
for _ in range(max_new_tokens):
# Crop ids if it exceeds the supported context size
ids_context = ids[-context_size:]
# Get the predictions
with torch.no_grad():
logits = model(ids_context)
# Focus only on the last time step
# (n_token, vocab_size) becomes (vocab_size)
logits = logits[-1, :]
# Get the idx of the vocab entry with the highest logits value
id_next = torch.argmax(logits, dim=-1, keepdim=True)
# Append sampled index to the running sequence
# (n_tokens+1)
ids = torch.cat((ids, id_next))
return ids
Sampling inference is defined as follows: on each iteration, select the token index by sampling from the probability distribution:
\[ \text{Sampling}(x_1,...,x_N) \sim P \]
def generate_sampling(model, ids, max_new_tokens, context_size):
for _ in range(max_new_tokens):
# Crop ids if it exceeds the supported context size
ids_context = ids[-context_size:]
# Get the predictions
with torch.no_grad():
logits = model(ids_context)
# Focus only on the last time step
# (n_token, vocab_size) becomes (vocab_size)
logits = logits[-1, :]
# Apply the softmax to convert the logit to a probability distribution
probs = torch.softmax(logits)
# Get the id by sampling the probability distribution
id_next = torch.multinomial(probs, num_samples=1).item()
# Append sampled index to the running sequence
# (n_tokens+1)
ids = torch.cat((ids, id_next))
return ids
Fix a temperature \(T > 0\). We may adjust the probability distribution by scaling with \(T\): \[P_T(i):= \sigma(\tfrac{y_N}{T})_i.\]
For higher \(T\), the adjusted probabilities will be more uniform. For lower \(T\), the probabilities will become more deterministic.
Temperature sampling inference is defined as follows: on each iteration, select the token index by sampling from the adjusted probability distribution:
\[ \text{TemperatureSampling}(x_1,...,x_N) \sim P_T \]
def generate_sampling(model, ids, max_new_tokens, context_size, temperature):
for _ in range(max_new_tokens):
# Crop ids if it exceeds the supported context size
ids_context = ids[-context_size:]
# Get the predictions
with torch.no_grad():
logits = model(ids_context)
# Focus only on the last time step
# (n_token, vocab_size) becomes (vocab_size)
logits = logits[-1, :]
# Scale the logits by the temperature
scaled_logits = logits / temperature
# Apply the softmax to convert the scaled logit to a probability distribution
probs = torch.softmax(scaled_logits)
# Get the id by sampling the probability distribution
id_next = torch.multinomial(probs, num_samples=1).item()
# Append sampled index to the running sequence
# (n_tokens+1)
ids = torch.cat((ids, id_next))
return ids
We have now defined the architecture of the LLM and used it to define token generation. At this point, any generation will produce gibberish, as the model weights have been initialized randomly. The next step is to set up a data source to provide inputs and targets, define a loss function for the targets vs. the models’ predictions, and to adjust the weights to minimize this loss via gradient descent.
Let \(N\) denote the context length of our model. We also fix a stride length \(k > 0\). The text, mathematics.tex
, gives us a source of inputs and targets as follows:
ids
.In a simple example, we may consider a word level tokenizer and a context length of \(6\). The first input and target of mathematics.tex
would be as follows:
Input | \(\fbox{The Steenrod problem for closed orientable}\) manifolds |
Target | The \(\fbox{Steenrod problem for closed orientable manifolds}\) |
We may divide our fixed source text into a dataset of inputs \((x_1,...,x_K)\) with associated targets \((y_1,...,y_K)\).
from torch.utils.data import Dataset, DataLoader
class mathematics_dataset(Dataset):
def __init__(self, text, tokenizer, context_length, stride):
self.inputs = []
self.targets = []
token_ids = tokenizer.encode(text) # Tokenize the entire text
# Divide the tokens into sequences of length context_length, with starting points separated by length stride.
for i in range(0, len(token_ids) - context_length, stride):
x = token_ids[i : i + context_length]
y = token_ids[i + 1 : i + context_length + 1]
self.input_ids.append(torch.tensor(x))
self.target_ids.append(torch.tensor(y))
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return self.inputs[idx], self.targets[idx]
Consider an input \(x=(t_1,..,t_N)\) with target \(y=(y_1,...,y_N)\). The prediction is then given by \(\hat{y}:=\text{LLM}(x) = (v_1,...,v_N)\in\mathbb{R}^{N\times |V|}\) where each row vector \(v_i\) is a logit in \(\mathbb{R}^{|V|}\).
The cross-entropy loss is the negative average log-probability for the target \(y\) and prediction \(\hat{y}\):
\[ \text{Cross-Entropy}(y,\hat{y}) := -\sum_{i=1}^N \frac{1}{N} \log (P_{v_i}(y_i)). \]
The loss for a collection of data points \(\{(x_i,y_i)\}\) is the average of the cross-entropy loss for each pair.
The perplexity is the exponential of the cross-entropy loss. It is sometimes interpreted as the number of tokens the model is uncertain of in its prediction.
\[ \text{Perplexity}(y,\hat{y}):= \exp \left( \text{Cross-Entropy}(y,\hat{y})\right). \]