LLMs from the Mathematical Viewpoint

1 Introduction

By now, there are many useful books, online guides, and videos for understanding the construction of Large Language Models from first principles (Raschka 2024) (Sanderson and Sun 2024) (Amidi and Amidi 2024). What separates this project, is that this is designed as a mathematics first presentation of the architecture of GPT styled LLM. If you are a mathematician or are simply more comfortable with the language and structure of mathematical writing, this page will hopefully be an easy to digest presentation of the mathematical foundations of LLMs. You should be able to read and understand this material in an afternoon.

In addition to describing tokenizers, attention mechanisms, and transformer blocks in terms of mathematical formulas, I have included basic implementations of the constructions in PyTorch in order to also give a quick start for understanding the bridge between the abstract mathematical formulations and implementations in code. Because our focus is primarily on the mathematical constructions, we leave out some important elements of the computer engineering perspective which are important for implementation and optimization; namely, batching of data, differences between device architecture of CPU vs. GPU, as well as parameter quantization.

Thanks go to my audience in the symplectic group at the Faculty of Mathematics at Ruhr-Universität Bochum.

1.1 Predicting the next token

We can superficially consider a (generative) \(\text{LLM}\) as a function which takes as input a fixed length sequence of text and outputs a probability distribution for the next text unit. By sampling from this distribution, and appending to the original text sequence, we can iteratively predict the subsequent text unit:

\[ \underline{\text{The Steenrod problem for}} \mapsto \begin{cases} \text{closed} & 29\% \\ \text{spaces} & 24\% \\ \text{modules} & 22\% \\ \text{spectra} & 19\% \end{cases} \]

\[ \implies \text{The }\underline{\text{Steenrod problem for } \fbox{closed}} \mapsto \begin{cases} \text{orientable} & 37\% \\ \text{smooth} & 23\% \\ \text{topological} & 19\% \\ \text{homology} & 9\% \end{cases} \]

\[ \implies \text{The Steenrod }\underline{\text{problem for closed } \fbox{orientable}} \mapsto \begin{cases} \text{manifolds} & 77\% \\ \text{orbifolds} & 14\% \\ \text{surfaces} & 4\% \\ \text{varieties} & <1\% \end{cases} \]

2 Tokenizing text

2.1 Source text

We begin by selecting a source text, which gives a fixed source for training our eventual LLM, in addition to providing a source for defining the vocabulary for our tokenizer. To keep things topical, we will use a collection of mathematics texts in symplectic geometry:

mathematics.tex


The Steenrod problem for closed orientable manifolds was solved completely by Thom.
Following this approach, we solve the Steenrod problem for closed orientable orbifolds, proving that the rational homology groups of a closed orientable orbifold have a basis consisting of classes represented by suborbifolds whose normal bundles have fiberwise trivial isotropy action. Polyfold theory, as developed by Hofer, Wysocki, and Zehnder, has yielded a well-defined Gromov–Witten invariant via the regularization of moduli spaces.
As an application, we demonstrate that the polyfold Gromov–Witten invariants, originally defined via branched integrals, may equivalently be defined as intersection numbers against a basis of representing suborbifolds.
Introduction
The Steenrod problem
The Steenrod problem was first presented in (Eilenberg 1949) and asked the following question:
Can any homology class of a finite polyhedron be represented as an image of the fundamental class of some manifold?
In (Thom 1954),…

Python imports
import torch
from torch import nn
import math
Opening the source text mathematics.tex
with open('mathematics.tex', 'r', encoding='utf-8') as file:
    text = file.read()
text = text[:3000]

2.2 Tokens

A token is an arbitrarily defined unit of text.

Level Description Illustration
word Each word is a token. \(\fbox{The} \fbox{Steenrod} \fbox{problem} \fbox{for} \fbox{closed} \fbox{orbifolds}\)
subword Each word is divided into a subword, e.g., BPE. \(\fbox{The}\fbox{ Ste}\fbox{en}\fbox{rod}\fbox{ problem}\fbox{ for}\fbox{ closed}\fbox{ orb}\fbox{if}\fbox{olds}\)
character Each character is a token. \(\fbox{T}\fbox{h}\fbox{e}\fbox{}\fbox{S}\fbox{t}\fbox{e}\fbox{e}\fbox{n}\fbox{r}\fbox{o}\fbox{d}\fbox{}\fbox{p}\fbox{r}\fbox{o}\fbox{b}\fbox{l}\fbox{e}\fbox{m}\)

A vocabulary \(V\) is a set of tokens. We often augment a vocabulary with additional special tokens, e.g.,

\([\text{BOS}]\) Marks the beginning of text.
\([\text{EOS}]\) Marks the end of text.
\([\text{PAD}]\) A padding token used to fill a sequence to the required context length.
\([\text{UNK}]\) Represents parts of the input text not present in the vocabulary.

2.3 Tokenizers

Let \(\text{TEXTS}\) be defined as the set of finite texts and let \(\text{Seq}(V)\) be defined as the set of finite sequences \((t_1, t_2,...,t_N)\) where \(t_i \in V.\) A tokenizer \(T\) is a map

\[\begin{align} T: \text{TEXTS} &\to \text{Seq}(V), \\ \text{text} &\mapsto \text{tokenized text}. \end{align}\]

In other words, the tokenizer \(T\) encodes a text into a sequence of tokens. The (partial) inverse, \(T^{-1}\), decodes a sequence of tokens into a text. Note that the presence of characters or subwords not present in the vocabulary, as well as the possibility of special tokens, complicates the question of whether \(T\) is injective/bijective/neither as well as the definition of the (partial) inverse \(T^{-1}\).

For example, a word-level tokenizer can be viewed as follows:

\[ T(\text{The Steenrod problem...}) = (\fbox{The},\fbox{Steenrod},\fbox{problem}...) \]

2.4 Encoding

We next want to represent tokens numerically as vectors. We begin by fixing an index on the vocabulary, i.e., a bijection \(V \to \{1,...,|V|\}\), and thus we freely identify tokens and token indices.

Using this identification, a tokenizer defines an encoder of a finite text into a sequence of indices:

\[ T ( \text{The Steenrod problem...}) = (341, 12459, 948,...), \]

as well as a decoder from a sequence of indices back to a finite text.

Character level tokenizer
class CharacterTokenizer:
    def __init__(self):
        # Maps token_str to token_id (e.g., {"a": 19, "b": 23})
        self.vocab = {}
        # Maps token_id to token_str (e.g., {19: "a", 23: "b"})
        self.inverse_vocab = {}
        # Included special tokens
        self.specials = {"<|BOS|>", "<|EOS|>", "<|UNK|>", "<|PAD|>"}

    def initialize_vocab(self, source_text, allowed_specials=set()):
        """Initializes the vocabulary consisting of characters appearing in source_text."""
        characters = set(source_text)
        self.vocab = {char: idx for idx, char in enumerate(characters | self.specials)}
        self.inverse_vocab = {idx: char for char, idx in self.vocab.items()}

    def encode(self, text):
        processed = (
            ["<|BOS|>"]
            + [char if char in self.vocab else "<|UNK|>" for char in text]
            + ["<|EOS|>"]
        )
        ids = [self.vocab[token] for token in processed]
        return ids

    def decode(self, ids):
        return "".join(
            [
                self.inverse_vocab[id]
                for id in ids
                if id
                not in {
                    self.vocab["<|PAD|>"],
                    self.vocab["<|BOS|>"],
                    self.vocab["<|EOS|>"],
                }
            ]
        )
Word level tokenizer
class WordTokenizer:
    def __init__(self):
        # Maps token_str to token_id (e.g., {"symplectic": 125})
        self.vocab = {}
        # Maps token_id to token_str (e.g., {125: "symplectic"})
        self.inverse_vocab = {}
        # Included special tokens
        self.specials = {"<|BOS|>", "<|EOS|>", "<|UNK|>", "<|PAD|>"}

    def initialize_vocab(self, source_text, allowed_specials=set()):
        """Initializes the vocabulary consisting of words appearing in source_text."""
        words = set(source_text.split(" "))
        self.vocab = {word: idx for idx, word in enumerate(words | self.specials)}
        self.inverse_vocab = {idx: word for word, idx in self.vocab.items()}

    def encode(self, text):
        words_in_text = text.split(" ")
        processed = (
            ["<|BOS|>"]
            + [word if word in self.vocab else "<|UNK|>" for word in words_in_text]
            + ["<|EOS|>"]
        )
        ids = [self.vocab[token] for token in processed]
        return ids

    def decode(self, ids):
        return " ".join(
            [
                self.inverse_vocab[id]
                for id in ids
                if id
                not in {
                    self.vocab["<|PAD|>"],
                    self.vocab["<|BOS|>"],
                    self.vocab["<|EOS|>"],
                }
            ]
        )

2.5 Subword tokenizer: byte-pair encoding

The byte-pair encoding is an algorithm which takes a source text and builds a vocabulary iteratively by combining the most common pairs of neighboring tokens (the modern version of the ‘byte-pair’). A text is correspondingly tokenized and encoded by this vocabulary by merging pairs of tokens iteratively.

Explicitly, the vocabulary is constructed via the following algorithm:

  • Initiate the vocabulary from the source text by taking the set of all characters appearing in the text; \[\text{vocabulary} := \{\text{character } | \text{ character} \in \text{TEXT}\}.\]
  • On each iteration, identify the most frequent pair of neighboring tokens. Add this pair as a new token in the vocabulary, and record the merged pair to keep track of merge priority. Merges performed earlier have higher priority.
  • Repeat until the desired size of the vocabulary is reached.

A text is then tokenized and encoded via the following algorithm:

  • Tokenize the text into a sequence of characters.
  • On each iteration, merge the pairs of adjacent tokens with the same highest merge priority.
  • Repeat until no more merges can be performed.
  • Encode the tokens with the corresponding token ids.
Byte-pair encoding tokenizer
class BPETokenizer:
    def __init__(self):
        # Maps token_str to token_id (e.g., {"symplectic": 11246})
        self.vocab = {}
        # Maps token_id to token_str (e.g., {11246: "symplectic"})
        self.inverse_vocab = {}
        # Included special tokens
        self.specials = {"<|BOS|>", "<|EOS|>", "<|UNK|>", "<|PAD|>"}
        # Dictionary of BPE merges: {(token_id1, token_id2): merged_token_id}
        self.bpe_merges = {}

    def initialize_vocab(self, source_text, vocab_size):
        """Initializes the vocabulary for BPE from source_text of size vocab_size."""
        from collections import Counter

        # Initialize vocab with unique characters appearing in the source_text

        # The first 256 ASCII characters
        initial_chars = {chr(i) for i in range(256)}
        # Additional characters appearing in the source_text
        source_chars = {char for char in set(source_text) if char not in initial_chars}

        self.vocab = {
            char: idx for idx, char in enumerate(initial_chars | source_chars)
        }
        self.inverse_vocab = {idx: char for char, idx in self.vocab.items()}

        # Add special tokens
        for special in self.specials:
            new_id = len(self.vocab)
            self.vocab[special] = new_id
            self.inverse_vocab[new_id] = special

        # Tokenize the source_text into token IDs
        token_ids = [self.vocab[char] for char in source_text]

        # Iteratively find and replace with a new_id the most frequent pairs of token ids appearing in token_ids.
        for new_id in range(len(self.vocab), vocab_size):
            # Identify the most frequent pair
            pairs = Counter(zip(token_ids, token_ids[1:]))
            pair_ids = max(pairs.items(), key=lambda x: x[1])[0]
            if pair_ids == None:
                break
            # Replace [pair_ids[0], pair_ids[1]] with new_id
            token_ids_merge = []
            for token_id in token_ids:
                if (
                    token_ids_merge
                    and token_ids_merge[-1] == pair_ids[0]
                    and token_id == pair_ids[1]
                ):
                    token_ids_merge.pop()
                    token_ids_merge.append(new_id)
                else:
                    token_ids_merge.append(token_id)
            token_ids = token_ids_merge
            # Record the pair_ids and the new_id
            self.bpe_merges[pair_ids] = new_id
            # Add the merged pair to the vocab
            merged_token = (
                self.inverse_vocab[pair_ids[0]] + self.inverse_vocab[pair_ids[1]]
            )
            self.vocab[merged_token] = new_id
            self.inverse_vocab[new_id] = merged_token

    def encode(self, text):
        """Encode the input text into a list of token ids."""
        import re

        token_ids = []
        token_ids.append(self.vocab["<|BOS|>"])

        chunks = re.split(r"(<\|(?:UNK|PAD)\|>)", text)
        for chunk in chunks:
            if chunk in self.specials:
                token_ids.append(self.vocab[chunk])
                continue

            tokens = [
                self.vocab[char] if char in self.vocab else self.vocab["<|UNK|>"]
                for char in chunk
            ]

            # Apply BPE merges
            while True:
                pairs = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
                    # Find pairs that are in our merge list
                merge_candidates = [pair for pair in pairs if pair in self.bpe_merges]
                if not merge_candidates:
                    break # No more merges to apply

                    # Find the merge with the lowest index (earliest learned)
                best = min(merge_candidates, key=lambda p: self.bpe_merges[p])
                new_id = self.bpe_merges[best]

                    # Apply the merge
                i = 0
                new_tokens = []
                while i < len(tokens):
                    if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == best:
                        new_tokens.append(new_id)
                        i += 2
                    else:
                        new_tokens.append(tokens[i])
                        i += 1
                tokens = new_tokens

            token_ids.extend(tokens)

        token_ids.append(self.vocab["<|EOS|>"])
        return token_ids

    def decode(self, ids):
        return "".join(
            [
                self.inverse_vocab[id]
                for id in ids
                if id
                not in {
                    self.vocab["<|PAD|>"],
                    self.vocab["<|BOS|>"],
                    self.vocab["<|EOS|>"],
                }
            ]
        )

2.6 Token embeddings

A token embedding is an abstract representation of a token as a vector in some \(\mathbb{R}^d\).

The One Hot mapping is a bijection between the tokens indices in a vocabulary \(V\) with the standard basis elements \(\{e_i\}\) in the space \(\mathbb{R}^{|V|}\), i.e.,

\[ \begin{align} \text{OH}: V &\to \mathbb{R}^{|V|} \\ t_i &\mapsto e_i. \end{align} \]

We may compose this bijection with a map \(A :\mathbb{R}^{|V|} \to \mathbb{R}^{d_\text{emb}}\) to obtain continuous vector representations of the tokens in the vocabulary:

\[ \text{Embedding} := A \circ \text{OH}: V \to \mathbb{R}^{d_\text{emb}}. \]

This composition is the basis of the Word2vec family of embeddings.

The map \(A\) is often a linear map, or, if we wished, a neural network itself; in the process of training the GPT, we can adjust the weights to learn better token embeddings.

Token embedding example
tokenizer = BPETokenizer()
tokenizer.initialize_vocab(text, 300)
sentence = "A symplectic manifold"
encoded_sentence = tokenizer.encode(sentence)

X = torch.tensor(tokenizer.encode(sentence))

# In PyTorch, we can implement a trainable embedding layer as `torch.nn.Embedding`, which acts as a lookup table mapping token indices to dense vectors.
emb = torch.nn.Embedding(300, 5)
# print(emb(X))
# print(tokenizer.decode(X.tolist()))

2.7 Positional encoding matrix

Fix an input sequence length \(N\). Given an input sequence of tokens, we may apply the embedding map to get an input sequence of embedded tokens:

\[ \text{Embedding} : V^N \to \mathbb{R}^{N\times d_\text{emb}}. \]

As defined, the token embeddings carry no positional information about sequence order: two identical tokens at different positions in the input sequence have the same token embedding.

We account for this by perturbing an input sequence by means of a positional encoding matrix \(\text{PE} \in \mathbb{R}^{N\times d_\text{emb}}\) whose entries \(p_{ij}\) are given by:

\[ p_{i,2j} := \sin \left( \frac{i}{10000^\tfrac{2j}{d}}\right), \qquad p_{i,2j + 1} := \cos \left( \frac{i}{10000^\tfrac{2j}{d}}\right). \]

Given an input sequence \(X\in \mathbb{R}^{N\times d_\text{emb}}\), we perturb by adding the positional encoding matrix, i.e., \(X + \text{PE}\).

Positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self, context_length, d_model):
        super().__init__()
        self.PE = torch.zeros(context_length, d_model)
        # Remember that python uses 0-based indexing
        for i in range(1, context_length + 1):
            for j in range(1, d_model + 1):
                if j % 2 == 0:
                    self.PE[i - 1][j - 1] = math.sin(
                        i / math.pow(10000, (2 * j) / d_model)
                    )
                else:
                    self.PE[i - 1][j - 1] = math.cos(
                        i / math.pow(10000, (2 * j) / d_model)
                    )

    def forward(self, X):
        X = X + self.PE
        return X

3 Causally masked attention

At the outset, we fix dimensions \(d_\text{in}\), \(d_k\), and \(d_\text{out}\), as well as an input sequence length \(N\). Let \((x_1,...,x_N),\) \(x_i \in \mathbb{R}^{d_\text{in}}\) be an input sequence of embedded vectors, which may consider as an input matrix \(X \in \mathbb{R}^{N\times d_\text{in}}\). We implement causally masked self-attention, which is the standard for the GPT architecture, and is defined as a map:

\[ \text{attention}(Q,K,V) : \mathbb{R}^{N\times d_\text{in}} \to \mathbb{R}^{N\times d_\text{out}} \]

given by the equation:

\[ \text{attention}(Q,K,V) := \text{softmax} \left( \frac{QK^T}{\sqrt{d_k} } + M \right) V. \]

In this section, we describe precisely how this is defined.

3.1 Query, Key, and Value matrices

The Query, Key, and Value matrices are obtained by applying linear transformations to the input matrix \(X\in \mathbb{R}^{N\times d_\text{in}}\). Let \(W_Q, W_K \in \text{Mat}_{d_\text{in} \times d_k}(\mathbb{R}), W_V \in \text{Mat}_{d_\text{in} \times d_\text{out}}(\mathbb{R})\) be learnable weight matrices. Then we may define:

  • Query matrix: \(Q := X W_Q \in \text{Mat}_{N \times d_k}(\mathbb{R})\),
  • Key matrix: \(K := X W_K \in \text{Mat}_{N \times d_k}(\mathbb{R})\),
  • Value matrix: \(V := X W_V \in \text{Mat}_{N \times d_\text{out}}(\mathbb{R})\).

3.2 The factor \(\tfrac{1}{\sqrt{d_k}}\)

The multiplicative factor \(\frac{1}{\sqrt{d_k}}\) is justified as follows. If we assume the entries of \(Q\) and of \(K\) are independent random variables with mean \(0\) and variance \(1\), then the dot product of the rows, \(q_i\cdot k_j\) will have mean \(0\) and variance \(d_k\). We therefore introduce the factor \(\frac{1}{\sqrt{d_k}}\), as it follows that every row vector of the rescaled matrix \(\frac{QK^T}{\sqrt{d}}\) will have mean \(0\) and variance \(1\).

3.3 Causal mask

The causal mask \(M \in \text{Mat}_{N \times N}(\mathbb{R})\) is a strictly upper triangular matrix, with \(0\)s on the diagonal and below, and with \(-\infty\) for every entry above the diagonal. Adding this mask to the attention scores and applying the softmax results in a score of zero for any future tokens relative to a fixed token. As a result, any future tokens have no causal influence on prior tokens, as measured by the attention scores.

3.4 Softmax

The softmax function takes as input a tuple of \(K>1\) real numbers and normalizes the tuple into a probability distribution. Formally, it is the function

\[ \sigma : \mathbb{R}^K \to (0,1)^K; \qquad \sigma(v_1,...,v_K)_i := \frac{e^{v_i}}{\sum_{j=1}^K e^{v_j}}. \]

In our implementation of attention, the softmax applied to a matrix is defined by applying the above softmax function independently to every row.

Softmax function
def softmax(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

3.5 Attention weight matrix

The attention weight matrix is the matrix \(W_\text{attn}:=\text{softmax}\left(\tfrac{QK^T}{\sqrt{d_k}} + M \right) \in \text{Mat}_{N\times N}(\mathbb{R})\). Each row \(i\) of \(W_\text{attn}\) is a probability distribution representing the ‘attention’ that the input vector \(i\) pays to all other (preceding) input vectors.

In a simple case where the input vectors are word-length tokens, we might have an attention weight matrix as follow:

The Steenrod problem for
The 1 0 0 0
Steenrod 0.34 0.66 0 0
problem 0.12 0.45 0.43 0
for 0.08 0.23 0.44 0.24

3.6 Context vectors

The output of the attention mechanism is the matrix \(Z := W_\text{attn} V\). We consider each column of the value matrix \(V\) as representing a value vector. Then the \(i\)th-row of \(Z\) is the weighted sum of the value vectors in \(V\), where the weights are given by the probability distribution of the \(i\)th-row of \(W_\text{attn}\). Each row of \(Z\) represents a context vector, with information enriched by its neighboring vectors.

Causally masked attention
class CausallyMaskedAttention(nn.Module):
    def __init__(self, d_in, d_k, d_out, context_length, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_k, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_k, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # Register mask as a buffer to avoid it being considered a model parameter
        mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
        self.register_buffer('mask', mask.bool())

    def forward(self, x):
        n, d_in = x.shape

        queries = self.W_query(x) # (n, d_k)
        keys = self.W_key(x)      # (n, d_k)
        values = self.W_value(x)  # (n, d_out)

        attn_scores = queries @ keys.T # (n, n)

        # Apply the causal mask
        attn_scores.masked_fill_(self.mask.bool()[:n, :n], -torch.inf)

        # Scale and apply softmax
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1) # (n, n)

        # Multiply by the values matrix
        context_vectors = attn_weights @ values # (n, d_out)
        return context_vectors

4 The transformer block

A transformer block takes as input a matrix (thought of as a sequence of vectors) and outputs a matrix of the same dimension. In this way, transformer blocks may be stacked.

\[ \text{Transformer}: \mathbb{R}^{N\times d} \to \mathbb{R}^{N\times d}, \qquad (v_1,...v_N)\in\mathbb{R}^{N\times d} \]

Internally, a transformer block may consist of various attention blocks, feed-forward blocks, and connections. We will describe a classic GPT transformer block as shown.

flowchart TD
    A[Embedded vectors] -->|Layer normalization| B[Normed vectors]
    B -->|Causally masked attention| C[Context vectors]
    A -->|Residual connection: $$\oplus$$| C
    C -->|Layer normalization| D[Normed vectors]
    D -->|Feed-forward layer| F[Output vectors]
    C -->|Residual connection: $$\oplus$$| F

GPT transformer block

4.1 Layer normalization

In the transformer block, layer normalization takes as input a vector \(\mathbb{R}^n\) and normalizes the vector so that its components have mean \(0\) and variance \(1\). Additionally, we include trainable scaling factors \(\{\lambda_i\}_{1\leq i\leq n}\) and shifts \(\{b_i\}_{1\leq i\leq n}\), which we initiate at \(\lambda_i=1\) and \(b_i=0\) to have no initial effect.

\[ \text{LayerNorm}: \mathbb{R}^n \to \mathbb{R}^n, \qquad \text{LayerNorm}(x_1,...,x_n)_i := \lambda_i \left(\frac{x_i - \text{mean}}{\sqrt{\text{var}}}\right) + b_i. \]

Layer normalization
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.lam = nn.Parameter(torch.ones(emb_dim))
        self.b = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + 1e-5)
        return self.lam * norm_x + self.b

4.2 Activation functions

An activation function is a transformation of the nodes of a neural network. There are many choices, see (Dubey, Singh, and Chaudhuri 2022).

Popular choices are the ReLU (Rectified Linear Unit) activation function:

\[ \text{ReLU}(x) := \max (x, 0) \]

And the GELU (Gaussian Error Linear Unit) activation function:

\[ \text{GELU}(x) := x P(X\leq x), \qquad X \sim {\mathcal N}(0,1), \]

i.e., \(P(X \leq x)\) is the cumulative distribution function of the standard normal distribution. We can approximate GELU by \(\text{GELU}(x) := 0.5x( 1 + \text{tanh}[\sqrt{\tfrac{2}{\pi}}(x + 0.044715 x^3)] ).\)

ReLU activation function
class ReLUActivation(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.max(x, torch.zeros(x.shape))
GELU activation function
class GELUActivation(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))

4.3 Residual connections

A residual connection adds the input of a layer to the layer’s output:

\[ X \mapsto \text{Layer}(X) + X. \]

It is mainly used to help prevent vanishing gradients during backpropagation.

4.4 Feed-forward layer

The feed-forward layer is a two-layer fully connected neural network. It consists of a linear transformation from a \(d\)-dimensional space to a higher-dimensional space (often \(4d\)), an activation function, and a linear transformation back to the original dimension \(d\):

\[ \text{FFN}(X):= \text{GELU}(XW_1+b_1)W_2 + b_2. \]

Feed-forward layer
class FeedForward(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            GELUActivation(),
            nn.Linear(4 * emb_dim, emb_dim),
        )

    def forward(self, x):
        return self.layers(x)

4.5 Transformer block

The full transformer block is a composition of a layer norm, the causally masked attention block, a residual connection, then another layer norm, the feed-forward layer, followed by another residual connection:

\[ \begin{align} \text{Transformer} : &\mathbb{R}^{N\times d} \qquad \xrightarrow{\hspace{8cm}} \qquad \mathbb{R}^{N\times d},\\ & X \mapsto \underbrace{\text{attention}(Q,K,V)\left(\text{LayerNorm}_1(X) \right) + X}_{Y} \mapsto \text{FFN}\left(\text{LayerNorm}_2(Y) \right) + Y \end{align} \]

Transformer block
class TransformerBlock(nn.Module):
    def __init__(self, d_in, d_k, d_out, context_length):
        super().__init__()

        self.att = CausallyMaskedAttention(d_in, d_k, d_out, context_length, qkv_bias=False)
        self.ff = FeedForward(d_out)

        self.layer_norm1 = LayerNorm(d_in)
        self.layer_norm2 = LayerNorm(d_out)

    def forward(self, x):
        # x.shape = (n, d_in)
        shortcut = x            # Residual connection for attention block
        x = self.layer_norm1(x) # (n, d_in)
        x = self.att(x)         # (n, d_out)
        x = x + shortcut        # Add the original input back

        shortcut = x            # Residual connection for feed-forward block
        x = self.layer_norm2(x) # (n, d_out)
        x = self.ff(x)          # (n, d_out)
        x = x + shortcut        # Add the original input back

        return x

5 Large language model

A large language model is a function which takes as input a sequence length \(N\) of encoded token indices, and outputs a sequence of length \(N\) consisting of logits. A logit is just a vector in \(\mathbb{R}^{|V|}\).

\[ \text{LLM}: V^N \to \mathbb{R}^{N \times |V|}, \]

The architecture of our GPT-styled LLM can be described as follows:

  • Input a sequence of encoded token indices \(X \in V^N\)
  • A token embedding layer: \[\text{Embedding}: V^N \to \mathbb{R}^{N\times d_\text{emb}},\qquad X \mapsto \text{Embedding}(X) := Y.\]
  • Addition of a positional encoding matrix: \(Y + \text{PE}\)
  • A sequence of stacked transformer blocks, with inputs and outputs the same space: \[\text{Transformer}\circ\cdots\circ\text{Transformer}: \mathbb{R}^{N\times d_\text{emb}} \to \mathbb{R}^{N\times d_\text{emb}}.\]
  • A layer norm: \[\text{LayerNorm}: \mathbb{R}^{N\times d_\text{emb}} \to \mathbb{R}^{N\times d_\text{emb}}.\]
  • A final linear layer, the unembedding layer, which maps to logits in our vocabulary: \[\text{Unembedding}: \mathbb{R}^{N\times d_\text{emb}} \to \mathbb{R}^{N\times |V|}.\]

Of course, the specifics of this architecture are of less importance than understanding the building blocks in the first place. Like LEGO blocks, we can mix and rearrange the component blocks.

Large language model
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model, context_length, num_blocks):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(context_length, d_model)
        self.blocks = nn.Sequential(*[TransformerBlock(d_model, context_length) for _ in range(num_blocks)])
        self.final_norm = LayerNorm(d_model)
        self.unembedding = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx):
        # x.shape = context_length
        x = self.token_embedding(idx) # (context_length, d_model)
        x = self.pos_encoding(x)      # (context_length, d_model)
        x = self.blocks(x)            # (context_length, d_model)
        x = self.final_norm(x)        # (context_length, d_model)
        logits = self.unembedding(x)  # (context_length, vocab_size)
        return logits

5.1 Inference

We may generate additional encoded tokens indices from our \(\text{LLM}\) as follows. Given an input sequence of encoded token indices \((x_1,...,x_N) \in V^N\), the output is a sequence of logits \[(y_1,...,y_N) := \text{LLM}(x_1,...,x_N) \in \mathbb{R}^{N\times |V|}.\] The full sequence is of importance when calculating the loss of our model and training its weights, but for inference we only need the last logit, \(y_N\).

We convert the logit \(y_N\in \mathbb{R}^{|V|}\) into a probability distribution \(P\) by taking the softmax: \[\sigma(y_N) \in (0,1)^{|V|},\] where the probability of the \(i\)th token index is then given by \(P(i) := \sigma(y_N)_i\)

Generation of the next token then requires a choice of sampling strategy to select a token index from this probability distribution.

5.1.1 Greedy inference

Greedy inference is defined as follows: on each iteration, we select the token index with the highest probability:

\[ \text{Greedy}(x_1,...x_N):= \text{argmax}_{i \in V} P(i). \]

Greedy inference
def generate_greedy(model, ids, max_new_tokens, context_size):
    for _ in range(max_new_tokens):
        # Crop ids if it exceeds the supported context size
        ids_context = ids[-context_size:]

        # Get the predictions
        with torch.no_grad():
            logits = model(ids_context)

        # Focus only on the last time step
        # (n_token, vocab_size) becomes (vocab_size)
        logits = logits[-1, :]

        # Get the idx of the vocab entry with the highest logits value
        id_next = torch.argmax(logits, dim=-1, keepdim=True)

        # Append sampled index to the running sequence
        # (n_tokens+1)
        ids = torch.cat((ids, id_next))
    return ids

5.1.2 Sampling inference

Sampling inference is defined as follows: on each iteration, select the token index by sampling from the probability distribution:

\[ \text{Sampling}(x_1,...,x_N) \sim P \]

Sampling inference
def generate_sampling(model, ids, max_new_tokens, context_size):
    for _ in range(max_new_tokens):
        # Crop ids if it exceeds the supported context size
        ids_context = ids[-context_size:]

        # Get the predictions
        with torch.no_grad():
            logits = model(ids_context)

        # Focus only on the last time step
        # (n_token, vocab_size) becomes (vocab_size)
        logits = logits[-1, :]
        # Apply the softmax to convert the logit to a probability distribution
        probs = torch.softmax(logits)

        # Get the id by sampling the probability distribution
        id_next = torch.multinomial(probs, num_samples=1).item()

        # Append sampled index to the running sequence
        # (n_tokens+1)
        ids = torch.cat((ids, id_next))
    return ids

5.1.3 Temperature sampling inference

Fix a temperature \(T > 0\). We may adjust the probability distribution by scaling with \(T\): \[P_T(i):= \sigma(\tfrac{y_N}{T})_i.\]

For higher \(T\), the adjusted probabilities will be more uniform. For lower \(T\), the probabilities will become more deterministic.

Temperature sampling inference is defined as follows: on each iteration, select the token index by sampling from the adjusted probability distribution:

\[ \text{TemperatureSampling}(x_1,...,x_N) \sim P_T \]

Temperature sampling inference
def generate_sampling(model, ids, max_new_tokens, context_size, temperature):
    for _ in range(max_new_tokens):
        # Crop ids if it exceeds the supported context size
        ids_context = ids[-context_size:]

        # Get the predictions
        with torch.no_grad():
            logits = model(ids_context)

        # Focus only on the last time step
        # (n_token, vocab_size) becomes (vocab_size)
        logits = logits[-1, :]
        # Scale the logits by the temperature
        scaled_logits = logits / temperature
        # Apply the softmax to convert the scaled logit to a probability distribution
        probs = torch.softmax(scaled_logits)

        # Get the id by sampling the probability distribution
        id_next = torch.multinomial(probs, num_samples=1).item()

        # Append sampled index to the running sequence
        # (n_tokens+1)
        ids = torch.cat((ids, id_next))
    return ids

6 Training

We have now defined the architecture of the LLM and used it to define token generation. At this point, any generation will produce gibberish, as the model weights have been initialized randomly. The next step is to set up a data source to provide inputs and targets, define a loss function for the targets vs. the models’ predictions, and to adjust the weights to minimize this loss via gradient descent.

6.1 Inputs, targets, and predictions

Let \(N\) denote the context length of our model. We also fix a stride length \(k > 0\). The text, mathematics.tex, gives us a source of inputs and targets as follows:

  • Tokenize the text into a sequence of encoded token indices, ids.
  • Separate the token indices into sequence of length \(N\), whose start indices are separated by the stride length \(L\). This gives us the inputs: \[ \text{Inputs} := \{\text{ids}[kx :kx + N ] \qquad \mid \qquad k=0,1,...,[len(mathematics) / k] \} \]
  • Offset the start and ending indices by \(1\) to get the targets: \[ \text{Outputs} := \{\text{ids}[kx + 1 :kx + N + 1 ] \qquad \mid \qquad k=0,1,...,[len(mathematics) / k]\} \]

In a simple example, we may consider a word level tokenizer and a context length of \(6\). The first input and target of mathematics.tex would be as follows:

Input \(\fbox{The Steenrod problem for closed orientable}\) manifolds
Target The \(\fbox{Steenrod problem for closed orientable manifolds}\)

We may divide our fixed source text into a dataset of inputs \((x_1,...,x_K)\) with associated targets \((y_1,...,y_K)\).

Create a dataset of inputs and targets
from torch.utils.data import Dataset, DataLoader


class mathematics_dataset(Dataset):
    def __init__(self, text, tokenizer, context_length, stride):
        self.inputs = []
        self.targets = []

        token_ids = tokenizer.encode(text)  # Tokenize the entire text

        # Divide the tokens into sequences of length context_length, with starting points separated by length stride.
        for i in range(0, len(token_ids) - context_length, stride):
            x = token_ids[i : i + context_length]
            y = token_ids[i + 1 : i + context_length + 1]
            self.input_ids.append(torch.tensor(x))
            self.target_ids.append(torch.tensor(y))

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

6.2 Cross-entropy loss

Consider an input \(x=(t_1,..,t_N)\) with target \(y=(y_1,...,y_N)\). The prediction is then given by \(\hat{y}:=\text{LLM}(x) = (v_1,...,v_N)\in\mathbb{R}^{N\times |V|}\) where each row vector \(v_i\) is a logit in \(\mathbb{R}^{|V|}\).

The cross-entropy loss is the negative average log-probability for the target \(y\) and prediction \(\hat{y}\):

\[ \text{Cross-Entropy}(y,\hat{y}) := -\sum_{i=1}^N \frac{1}{N} \log (P_{v_i}(y_i)). \]

  • For each logit \(v_i\) take the softmax to convert to a probability distribution \(P_{v_i}\)
  • From the probability distribution, select the probability corresponding to the target token: \(P_{v_i}(y_i)\). Observe that by the definition of the softmax, \(P_{v_i}(y_i) = \frac{e^{v_*}}{\sum_{j=1}^N e^{v_j}}\), where \(*\) corresponds to the probability of the target token \(y_i\).
  • Take the log of this probability: \(\log (P_{v_i}(y_i))\). We observe: \(\log (P_{v_i}(y_i)) = v_* - \log(\sum_{j=1}^N e^{v_j})\).
  • Take the negative of the average over all tokens in the sequence.

The loss for a collection of data points \(\{(x_i,y_i)\}\) is the average of the cross-entropy loss for each pair.

The perplexity is the exponential of the cross-entropy loss. It is sometimes interpreted as the number of tokens the model is uncertain of in its prediction.

\[ \text{Perplexity}(y,\hat{y}):= \exp \left( \text{Cross-Entropy}(y,\hat{y})\right). \]

7 References

Amidi, Afshine, and Shervine Amidi. 2024. Super Study Guide: Transformers & Large Language Models. Independently published.https://www.manning.com/books/build-a-large-language-model-from-scratch .
Dubey, Shiv Ram, Satish Kumar Singh, and Bidyut Baran Chaudhuri. 2022. “Activation Functions in Deep Learning: A Comprehensive Survey and Benchmark.” Neurocomputing 503: 92–108.
Eilenberg, Samuel. 1949. “On the Problems of Topology.” Ann. Of Math. (2) 50: 247–60. https://doi.org/10.2307/1969448.
Raschka, Sebastian. 2024. Build a Large Language Model (from Scratch). Manning.https://www.manning.com/books/build-a-large-language-model-from-scratch .
Sanderson, Grant, and Justin Sun. 2024. “Visualizing Attention, a Transformer’s Heart.” 2024. https://www.3blue1brown.com/lessons/attention.
Thom, René. 1954. “Quelques Propriétés Globales Des Variétés Diffé Rentiables.” Comment. Math. Helv. 28: 17–86. https://doi.org/10.1007/BF02566923.