LLMs From Scratch: Day 2
With a feed-forward network in place and training smoothly, today I started diving deeper into specific components within the encoder portion of the model. I haven't decided if I want to try a training run with the encoder by itself, or if I'll implement the encoder and decoder and try to replicate the results from Attention Is All You Need. Naturally, it felt like it made sense to start with our input sequences and how to tokenize them before they're passed into the embedding layer. This turned out to be more of a detour than I expected, namely because AIAYN cites Massive Exploration of Neural Machine Translation Architectures, which introduced seq2seq, which in turn cites Neural Machine Translation of Rare Words with Subword Units. The latter introduced the Byte-Pair Encoding, or BPE, tokenization technique based on byte-pair encoding compression. BPE effectively generates a vocab of n-grams from a given corpus by merging common sub-sequences into a single n-gram (i.e. "a", "b" --> "ab"). This helps simplify the process of creating a sufficiently large vocab for tasks like machine translation, and (as the name of the paper suggests) improves rare-word translation, a task that is more difficult when using a more naive dictionary vocab. The Massive Exploration paper used 32,000 merge operations with their BPE which gave a vocab size of ~37,000, which is what AIAYN used. They note that the specific data preprocessing steps used causes fairly substantial changes to the resulting vocabularies, and they released their data preprocessing and vocab generation script for reproducability.
I started the day thinking that I would try to implement my own tokenizer, or at least implement the one used by AIAYN -- afterall, this is called LLMs From Scratch! However, given the emphasis on placed on following precise preprocessing steps to reproduce results, I opted instead to use the script from Massive Exploration to generate the vocab. Data preparation and tokenization is an entire field of its own, and I think detouring into that would a) likely result in my implementation not matching the one used by AIAYN, making it difficult to validate my implementation, and b) would pretty significantly increase the time it would take to actually implement everything.
I'm ending today by implementing a couple of things downstream from our tokenized data: the embedding layer, which is just a matrix, and the positional encodings, which is also just a matrix. The positional embeddings provided another opportunity to learn about some JAX fundamentals, namely vectorization. It did occur to me that I made need to vectorize some of the functions I'm writing to handle batched code, but I think that's fairly straightforward with JAX's vmap.
Code for the day is here (not including the data scripts I pulled in, since I didn't write those).
Day 3 should be fun: Multi-Headed Self-Attention!