LLMs From Scratch: Day 3
Today we get to start working on Multi-Headed Self-Attention! Time permitting I may work on cross-attention as well since the mechanism is fairly similar, but tbd.
Copying from my notes on AIAYN, the key difference between regular self-attention and multi-headed is that instead of performing a single attention function with key/value/queries, you can linearly project them times with learned linear projections to , , and , respectively.
Thinking about the input after it passes through the norm (after embedding), the shape should be (, , ). We validated that shape in the embedding file itself. For now, we'll just assume a batch size of 1 to keep the math easy. In AIAYN, they use heads, with . That means that after applying the linear projection, our input should go from the aforementioned shape to (, , , ) although the ordering for the last two dimensions may change as we work through this. It's probably also worth noting that in AIAYN, they maintain the large tensors as well, and use learned weight matrices to project those, rather than having completely separate matrices for each head. Naively, you could just initialize the and tensors, then intialize each head's weight matrices, and loop through the heads during a forward pass and concat after looping. However, I think this is a good time to use vmap instead! We can stack each of our weight matrices and iterate through them, applying them to the appropriate axes on our input tensor . Additionally, after we project and do attention with each head, we have to reshape the results to take the dimesions from (..., 64, 8) to (512), and our last weight matrix will have shape (512, 512).
All things considered, this was fairly easy to implement! Aside from figuring out how to correctly use in & out axes arguments in vmap, the work so far has set me up well to implement this with relatively little new math.
The full code for Multi-Headed Self Attention is here. There will have to be some modifications to the specific functions once we're ready to train to incorporate things like dropout, but aside from that the bones for the full encoder and decoder are essentially in place. Multi-Headed Cross Attention is effectively the same code, with the outputs of the encoder replacing the values for and .
See you on day 4.