Implementing a Transformer Encoder from Scratch with JAX and Haiku 🤖 | by Ryan Pégoud | Nov, 2023


In Haiku, the Multi-Head Attention module can be implemented as follows. The __call__function follows the same logic as the above graph while the class methods take advantage of JAX utilities such as vmap(to vectorize our operations over the different attention heads and matrices) and tree_map(to map matrix dot-products over weight vectors).

As you might have noticed on the Transformer graph, the multi-head attention block and the feed-forward net are followed by residual connections and layer normalization.

Residual or skip connections

Residual connections are a standard solution to solve the vanishing gradient problem, which occurs when gradients become too small to effectively update the model’s parameters.

As this issue naturally arises in particularly deep architectures, residual connections are used in a variety of complex models such as ResNet (Kaiming et al, 2015) in computer vision, AlphaZero (Silver et al, 2017) in reinforcement learning, and of course, Transformers.

In practice, residual connections simply forward the output of a specific layer to a following one, skipping one or more layers on the way. For instance, the residual connection around the multi-head attention is equivalent to summing the output of multi-head attention with positional embeddings.

This enables gradients to flow more efficiently through the architecture during backpropagation and can usually lead to faster convergence and more stable training.

Representation of residual connections in Transformers (made by the author)

Layer Normalization

Layer normalization helps ensure that the values propagated through the model do not “explode” (tend toward infinity), which could easily happen in attention blocks, where several matrices are multiplied during each forward pass.

Unlike batch normalization, which normalizes across the batch dimension assuming a uniform distribution, layer normalization operates across the features. This approach is suitable for sentence batches where each may have unique distributions due to varying meanings and vocabularies.

By normalizing across features, such as embeddings or attention values, layer normalization standardizes data to a consistent scale without conflating distinct sentence characteristics, maintaining the unique distribution of each.

Representation of Layer Normalization in the context of Transformers (made by the author)

The implementation of layer normalization is pretty straightforward, we initialize the learnable parameters alpha and beta and normalize along the desired feature axis.

The last component of the encoder that we need to cover is the position-wise feed-forward network. This fully connected network takes the normalized outputs of the attention block as inputs and is used to introduce non-linearity and increase the model’s capacity to learn complex functions.

It is composed of two dense layers separated by a gelu activation:

After this block, we have another residual connection and layer normalization to complete the encoder.

There we have it! By now you should be familiar with the main concepts of the Transformer encoder. Here’s the full encoder class, notice that in Haiku, we assign a name to each layer, so that learnable parameters are separated and easy to access. The __call__function provides a good summary of the different steps of our encoder:

To use this module on actual data, we have to apply hk.transform to a function encapsulating the encoder class. Indeed, you might remember that JAX embraces the functional programming paradigm, therefore, Haiku follows the same principles.

We define a function containing an instance of the encoder class and return the output of a forward pass. Applying hk.transform returns a transformed object having access to two functions: init and apply.

The former enables us to initialize the module with a random key as well as some dummy data (notice that here we pass an array of zeros with shape batch_size, seq_len) while the latter allows us to process real data.

# Note: the two following syntaxes are equivalent
# 1: Using transform as a class decorator
@hk.transform
def encoder(x):
...
return model(x)

encoder.init(...)
encoder.apply(...)

# 2: Applying transfom separately
def encoder(x):
...
return model(x)

encoder_fn = hk.transform(encoder)
encoder_fn.init(...)
encoder_fn.apply(...)

In the next article, we’ll complete the transformer architecture by adding a decoder, which reuses most of the blocks we introduced so far, and learn how to train a model on a specific task using Optax!

Thank you for reading this far, if you are interested in dabbling with the code, you can find it fully commented on GitHub, along with additional details and a walkthrough using a toy dataset.

If you’d like to dig deeper into Transformers, the following section contains some articles that helped me redact this article.

Until next time 👋

[1] Attention is all you need (2017), Vaswani et al, Google

[2] What exactly are keys, queries, and values in attention mechanisms? (2019) Stack Exchange

[3] The Illustrated Transformer (2018), Jay Alammar

[4] A Gentle Introduction to Positional Encoding in Transformer Models (2023), Mehreen Saeed, Machine Learning Mastery



Source link

This post originally appeared on TechToday.