# WOLFGANG-GPT

The best way to keep abreast of new technological developments is to understand them. This notebook is the result of trying to understand the technologies of LLMs, using the following resources: [Andrej Karpathy](https://karpathy.ai/zero-to-hero.html), [3blue1brown](https://www.3blue1brown.com/lessons/gpt), and [Sebastian Raschka](https://github.com/rasbt/LLMs-from-scratch?tab=readme-ov-file).

Wolfgang-GPT is trained on a data set consisting mainly of my mathematics research papers.
I used [arxiv-latex-cleaner](https://github.com/google-research/arxiv-latex-cleaner) to clean up the tex files a bit; this mostly means the removal of all commented text.

In [14]:
!wget -q https://github.com/volfenstein1/quarto/raw/refs/heads/main/WOLFGANG_TRAINING.tex
with open('WOLFGANG_TRAINING.tex', 'r', encoding='utf-8') as f:
    text = f.read()

You can see the training set here: [WOLFGANG_TRAINING.tex](WOLFGANG_TRAINING.tex).

In [15]:
print("Length of WOLFGANG_TRAINING.tex dataset: ", len(text))

Length of WOLFGANG_TRAINING.tex dataset:  3750026


In [16]:
# We can look at the first 1200 characters of this dataset; it consists of plain latex code, and is easily readable on its own.
print(text[:1200])

The Steenrod problem for closed orientable manifolds was solved completely by Thom.
Following this approach, we solve the Steenrod problem for closed orientable orbifolds, proving that the rational homology groups of a closed orientable orbifold have a basis consisting of classes represented by suborbifolds whose normal bundles have fiberwise trivial isotropy action.

Polyfold theory, as developed by Hofer, Wysocki, and Zehnder, has yielded a well-defined Gromov--Witten invariant via the regularization of moduli spaces.
As an application, we demonstrate that the polyfold Gromov--Witten invariants, originally defined via branched integrals, may equivalently be defined as intersection numbers against a basis of representing suborbifolds.

\section{Introduction}

\subsection{The {S}teenrod problem}

The Steenrod problem was first presented in \cite{eilenberg1949problems} and asked the following question:
\textit{Can any homology class of a finite polyhedron be represented as an image of t

In [17]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('Unique characters: ', ''.join(chars))
print('Number of unique characters: ', vocab_size)

Unique characters:  	
 !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~δ�
Number of unique characters:  99


## Tokenizer

A tokenizer splits the training into disjoint chunks and embeds the chunks into a vector space $\mathbb{R}^n$.

We use a very rudimentary tokenizer, with chunks given by the individual characters:
$$
\text{The Steenrod problem for...} \to \text{|T|h|e| |S|t|e|e|n|r|o|d| |p|r|o|b|l|e|m| |f|o|r|...}
$$

Each unique character is encoded as a basis element of $\mathbb{R}^n$, $n:=\#\{\text{unique characters}\}$. via a one-hot encoding, i.e.,
$$
\{ \text{unique characters} \} \to \mathbb{R}^{\#\{\text{unique characters}\}}
$$
and the decoder is the inverse.

We could obtain more sophistication by tokenizing on syllables and chunks of latex code, for example, see the tokenizer used by the [MathBERTa](https://huggingface.co/witiko/mathberta) model.

In [18]:
string_to_integer = { char:idx for idx,char in enumerate(chars) }
integer_to_string = { idx:char for idx,char in enumerate(chars) }
encode = lambda s: [string_to_integer[c] for c in s]
decode = lambda l: ''.join([integer_to_string[i] for i in l])

print(encode("This sentence is a test. Here is some math $(M,\omega)$."))
print(decode(encode("This sentence is a test. Here is some math $(M,\omega)$.")))

[54, 74, 75, 85, 2, 85, 71, 80, 86, 71, 80, 69, 71, 2, 75, 85, 2, 67, 2, 86, 71, 85, 86, 16, 2, 42, 71, 84, 71, 2, 75, 85, 2, 85, 81, 79, 71, 2, 79, 67, 86, 74, 2, 6, 10, 47, 14, 62, 81, 79, 71, 73, 67, 11, 6, 16]
This sentence is a test. Here is some math $(M,\omega)$.


In [19]:
# Encode the entire dataset WOLFGANG_TRAINING.tex and store it as a pytorch tensor
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1200])

torch.Size([3750026]) torch.int64
tensor([54, 74, 71,  ...,  2, 67, 84])


In [20]:
# Split the data into train and validation sets
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

## Simplistic Neural Network Model

Given a sequence of tokens, we would like to train a neural network model to predict the most likely next token:
$$
    \text{|s|y|m|p|l|e|c|t|i|} \to
    \begin{cases}
        \text{|c|} & 98\% \text{ probability} \\
        \text{|s|} & 1\% \text{ probability} \\
        \text{|d|} & <1\% \text{ probability}
    \end{cases}
$$

The model we create is actually even more simplistic than the above suggests; given a character $char$ it will output the predicted probability of the next character:
$$
    \text{|i|} \to
    \begin{cases}
        \text{|a|} & <1\% \text{ probability} \\
        \text{|b|} & <1\% \text{ probability} \\
        \text{|c|} & \sim 2\% \text{ probability} \\
        \text{|d|} & \sim 1\% \text{ probability} \\
        \text{|e|} & \sim 2\% \text{ probability} \\
        ... &
    \end{cases}
$$
To be precise, let $A$ be a $(n,n)$ matrix of tunable parameters. The model takes a character, embeds it as an index via a one-hot encoding $x$, and outputs the indexed row of the matrix $A_x$, interpreted as the log-odds of the next character. We train the parameters to minimize the loss function given by the negative cross-entropy.

In [21]:
block_size = 8
train_data[:block_size+1]

tensor([54, 74, 71,  2, 53, 86, 71, 71, 80])

In [22]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print("input:", context, "target:", target)

input: tensor([54]) target: tensor(74)
input: tensor([54, 74]) target: tensor(71)
input: tensor([54, 74, 71]) target: tensor(2)
input: tensor([54, 74, 71,  2]) target: tensor(53)
input: tensor([54, 74, 71,  2, 53]) target: tensor(86)
input: tensor([54, 74, 71,  2, 53, 86]) target: tensor(71)
input: tensor([54, 74, 71,  2, 53, 86, 71]) target: tensor(71)
input: tensor([54, 74, 71,  2, 53, 86, 71, 71]) target: tensor(80)


In [23]:
torch.manual_seed(123412)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    data = train_data if split == 'train' else val_data
    # Randomly select #{batch_size} indices with 0 <= idx < len(data) - block_size
    idx_x = torch.randint(0, len(data) - block_size, (batch_size,))
    x = torch.stack([data[idx:idx+block_size] for idx in idx_x])
    y = torch.stack([data[idx+1:idx+block_size+1] for idx in idx_x])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    print(f"BATCH #{b}:")
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"input: {context.tolist()} target: {target}")

inputs:
torch.Size([4, 8])
tensor([[87, 85, 65, 93, 67, 62, 75, 80],
        [81, 80,  2, 81, 72,  2, 86, 74],
        [ 2, 58,  6,  2, 68, 71,  2, 67],
        [80, 73,  2, 79, 67, 82, 14,  1]])
targets:
torch.Size([4, 8])
tensor([[85, 65, 93, 67, 62, 75, 80,  2],
        [80,  2, 81, 72,  2, 86, 74, 71],
        [58,  6,  2, 68, 71,  2, 67,  2],
        [73,  2, 79, 67, 82, 14,  1, 68]])
----
BATCH #0:
input: [87] target: 85
input: [87, 85] target: 65
input: [87, 85, 65] target: 93
input: [87, 85, 65, 93] target: 67
input: [87, 85, 65, 93, 67] target: 62
input: [87, 85, 65, 93, 67, 62] target: 75
input: [87, 85, 65, 93, 67, 62, 75] target: 80
input: [87, 85, 65, 93, 67, 62, 75, 80] target: 2
BATCH #1:
input: [81] target: 80
input: [81, 80] target: 2
input: [81, 80, 2] target: 81
input: [81, 80, 2, 81] target: 72
input: [81, 80, 2, 81, 72] target: 2
input: [81, 80, 2, 81, 72, 2] target: 86
input: [81, 80, 2, 81, 72, 2, 86] target: 74
input: [81, 80, 2, 81, 72, 2, 86, 74] target: 71
BA

In [24]:
# For parallizeability, data gets passed to the transformer as batches of inputs.
# Example of single batch input:
print(xb)

tensor([[87, 85, 65, 93, 67, 62, 75, 80],
        [81, 80,  2, 81, 72,  2, 86, 74],
        [ 2, 58,  6,  2, 68, 71,  2, 67],
        [80, 73,  2, 79, 67, 82, 14,  1]])


In [25]:
import torch
import torch.nn as nn
torch.manual_seed(1234321)

class Wolfgang_Language_Model(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # nn.Embedding(m,n) can be interpreted simply as an m*n matrix;
        # In our case, it takes in the index of the given token and embeds it into a vector space of dimension = vocab_size
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # B = batch_size
        # T = block_size
        # C = vocab_size

        # Input: an index idx, which is a (B x T)-tensor with entries embedded characters
        # Output: a (B x T x C)-tensor

        # But what is the interpretation of this output?
        # Given an embedded character, the logits are a vector of the predicted log-odds of the next embedded character
        # Since there are vocab_size characters, the tensor picks up another dimension

        # How does it work?
        # Mathematically we think of it as matrix multiplication by the matrix nn.Embedding(vocab_size, vocab_size)
        # Functionally, the index tells us which row of the matrix nn.Embedding(vocab_size, vocab_size) to pick out
        logits = self.token_embedding_table(idx) # (B,T,C)

        # For these predicted probabilities, return the loss via the cross_entropy loss function
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = nn.functional.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # Given an index idx, which is a (B x T)-tensor with entries embedded characters, predict the next #max_new_tokens characters
        # The result is a (B x (T + max_new_tokens))-tensor
        for _ in range(max_new_tokens):
            # Get the predictions
            logits, loss = self(idx)
            # Focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # Apply softmax to convert the logits to a probability distribution
            probs = nn.functional.softmax(logits, dim=-1) # (B, C)
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # Append the sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

m = Wolfgang_Language_Model(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


torch.Size([32, 99])
tensor(5.0728, grad_fn=<NllLossBackward0>)
	w^J11-;,]H}[0dKv%#7uTFOBδ6`(wPo^fD0	0	gf1lLca@uH4QaY~y5V9Wvδ$T4δ�xCr7fatzW@b%2d|/.)xaPtk3g/8_})H"qAu


In [26]:
# Create a pytorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [27]:
# Can minimize the loss directly via the pytorch library
batch_size = 64
for steps in range(10000):

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.802955150604248


Having trained the model, we can see what it outputs starting from an empty input. At this point, it can pick out some basic patterns between the placements of vowels, consonants, and spaces. We see some outputs that resemble words.

In [28]:
# Naive output for the model
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))

	$| mabe ar ted 
\v_1}\ thi(\m |^*}_k'}(0My (\e{it
\pmphosmarof{a s ns 
y a ide onthif alen fin 1)\C}
s cowtuly,henolthac il^k}\molponc  arra,0,1$z_isi$\ocorind caltembm
	B_\h topesta=\entibf t{a w_{OP};ve$. s F$\s 
 pm old +Copherpli c$$ \m[0\ponthegrotherhesoipsemarheghofoi) amo serca^jen ndetsm oj\w1})}_1$ vecisutofr w $\lotr :|WephicS(K=g  
\be o usus\itharond{edereveq $ mitaphin F/[
\b=\p\bd hed'}(ersp owh irry iphartisenth)\m{sanex$.e{\s :
	\}} ctopl In$ at+\cineremeoube d{ator{if X}
	A_xhi


## Self-attention

The previous results were unintelligible, obviously. There is only so much predictive power from knowing the previous 8 characters.

The self-attention is defined by the intially opaque equation:
$$
    \text{Attention} (Q, K, V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V.
$$
In what follows, we will decipher the meaning of this equation.

In [29]:
torch.manual_seed(42)
# torch.tril masks a matrix to produce a lower triangular matrix
# Helpful to enforce characters are influenced only by preceding characters
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [30]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [31]:
# We define xbow = x 'bag of words' via the average of previous characters, i.e.,
# x[b,t] = avg{x[b,i]}_{i <= t}
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)


In [32]:
# We can achieve the same but more efficiently using matrix multiplication
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
print(wei)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


False

In [33]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = nn.functional.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)


False

In [34]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = nn.functional.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
#out = wei @ x

out.shape

torch.Size([4, 8, 16])

In [35]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

In [36]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [37]:
k.var()

tensor(1.0449)

In [38]:
q.var()

tensor(1.0700)

In [39]:
wei.var()

tensor(1.0918)

In [40]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [41]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

In [42]:
class LayerNorm1d: # (used to be BatchNorm1d)

  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)

  def __call__(self, x):
    # calculate the forward pass
    xmean = x.mean(1, keepdim=True) # batch mean
    xvar = x.var(1, keepdim=True) # batch variance
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    return self.out

  def parameters(self):
    return [self.gamma, self.beta]

torch.manual_seed(1337)
module = LayerNorm1d(100)
x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])

In [43]:
x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs

(tensor(0.1469), tensor(0.8803))

In [44]:
x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features

(tensor(-9.5367e-09), tensor(1.0000))

### Output of wolfgang-GPT

We can train the model and take a look at it's output.

In [45]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 64 # what is the maximum context length for predictions?
max_iters = 10000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

torch.manual_seed(1337)

with open('WOLFGANG_TRAINING.tex', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class wolfgang_LLM(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = wolfgang_LLM()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


0.216163 M parameters
step 0: train loss 4.8271, val loss 4.8334
step 100: train loss 3.0829, val loss 3.1061
step 200: train loss 2.8736, val loss 2.9270
step 300: train loss 2.7634, val loss 2.8190
step 400: train loss 2.6443, val loss 2.7482
step 500: train loss 2.5412, val loss 2.6754
step 600: train loss 2.4129, val loss 2.5889
step 700: train loss 2.2853, val loss 2.5031
step 800: train loss 2.1972, val loss 2.4321
step 900: train loss 2.1018, val loss 2.3783
step 1000: train loss 2.0110, val loss 2.3254
step 1100: train loss 1.9628, val loss 2.2730
step 1200: train loss 1.9165, val loss 2.2424
step 1300: train loss 1.8705, val loss 2.2132
step 1400: train loss 1.8289, val loss 2.1685
step 1500: train loss 1.7821, val loss 2.1482
step 1600: train loss 1.7548, val loss 2.1220
step 1700: train loss 1.7397, val loss 2.1116
step 1800: train loss 1.7028, val loss 2.0718
step 1900: train loss 1.6819, val loss 2.0624
step 2000: train loss 1.6528, val loss 2.0338
step 2100: train loss 1.