abhishek Profile picture
Dec 13, 2021 18 tweets 6 min read Read on X
"Attention is all you need" implementation from scratch in PyTorch. A Twitter thread:
There are two parts: encoder and decoder. Encoder takes source embeddings and source mask as inputs and decoder takes target embeddings and target mask. Decoder inputs are shifted right. What does shifted right mean? Keep reading the thread. 2/ Image
The encoder is composed of N encoder layers. Let's implement this as a black box too. The output of one encoder goes as input to the next encoder and so on. The source mask remains the same till the end 3/ Image
Similarly, we have the decoder composed of decoder layers. The decoder takes input from the last encoder layer and the target embeddings and target mask. enc_mask is the same as src_mask as explained previously 4/ Image
Let's take a look at the encoder layer. It consists of multi-headed attention, a feed forward network and two layer normalization layers. See forward(...) function to understand how skip-connection works. Its just adding original inputs to the outputs. 5/ Image
Now comes the fun part. Multi-head attention. We see it 3 times in the architecture. Multi-headed attention is nothing but many different self-attention layers. The outputs from these self-attentions are concatenated to form output the same shape as input. 6/ Image
If the number of heads is 8 and d_model (embedding size) is 512, each self-attention will produce an output of size 64. These will be concatenated together to give the final output of size 64 x 8 = 512. This output is passed through a dense layer. 7/ Image
self-attention in simple words is attention on the same sequence. I like to define it as a layer that tells you which token loves another token in the same sequence. for self-attention, the input is passed through 3 linear layers: query, key, value. 8/ Image
In the forward function, we apply the formula for self-attention. softmax(Q.K´/ dim(k))V. torch.bmm does matrix multiplication of batches. dim(k) is the sqrt of k. Please note: q, k, v (inputs) are the same in the case of self-attention. 9/ Image
Let's look at the forward function and the formula for self-attention (scaled). Ignoring the mask part, everything is pretty easy to implement. 10/ Image
The mask just tells where not to look (e.g. padding tokens) 11/ Image
Let's take a look at decoder now. The implementation is similar to that of the encoder except for the fact that each decoder also takes the final encoder's output as input. 12/ Image
The decoder layer consists of two different types of attention. the masked version has an extra mask in addition to padding mask. We will come to that. The normal multi-head attention takes key and value from final encoder output. key and value here are same. 13/ Image
Query comes from output of masked multi-head attention (after layernorm). Checkout the forward function and things are very easy to understand :) 14/ Image
Now we come to the special mask for targets, aka subsequent mask. The subsequent mask just tells the decoder not to look at tokens in the future. This is used in addition to the padding mask and is used only for training part. 15/ Image
Now we have all the building blocks except positional encoding. Positional encoding tells the model an idea about where the tokens are located relative to each other. To implement positional encoding, we can simply use an embedding layer! 16/
And this is how inputs and outputs will look like. Here, batch size = 32, len of input seq = 128, len of output seq = 64. We add a linear + softmax to decoder output. This gives us a token prediction for each position (a classification problem) 17/ Image
I hope you liked this thread. If there are any mistakes in my implementation, please let me know and I can fix them :) 18/

• • •

Missing some Tweet in this thread? You can try to force a refresh

Keep Current with abhishek

abhishek Profile picture

Stay in touch and get notified when new unrolls are available from this author!

Read all threads

This Thread may be Removed Anytime!


Twitter may remove this content at anytime! Save it as PDF for later use!

Try unrolling a thread yourself!

how to unroll video
  1. Follow @ThreadReaderApp to mention us!

  2. From a Twitter thread mention us with a keyword "unroll"
@threadreaderapp unroll

Practice here first or read more on our help page!

More from @abhi1thakur

Aug 21, 2023
The easiest LLM Fine Tuning UI just Landed! 🚀
Now, ANYONE can fine-tune (almost) any LLM available on Hugging Face Hub by just uploading a CSV and choosing the parameters and by a single click of a button! 💥 Here's how you can do it: 1/N Image
First, you need a huggingface account! If you dont have one, create one: .
Once your account is setup, click this link:
You can choose any name for the space 2/N hf.co
After that, enter your huggingface write token. you can find/create your write token here: and enter task as "LLM" (without quotes). Make sure to keep your space private! 3/N huggingface.co/settings/tokens
Read 7 tweets
Jul 22, 2023
Here are some coding tutorials on large language models (LLMs): 🧵1/N
Read 10 tweets
Dec 5, 2022
Here's a thread on how ChatGPT works:
ChatGPT is a large language model trained by OpenAI to generate text based on a given prompt.
This means that when given a prompt, ChatGPT uses its knowledge of language and the patterns it has learned from vast amounts of text data to generate a response.
Read 9 tweets
Sep 23, 2022
Do you want to learn time-series analysis for free? Check out this thread 🧵 1/12
Konrad Banchewicz has been making time-series tutorials on my YouTube channel. The first episode was: Curve fitting is (almost) all you need.
Notebook: kaggle.com/code/konradb/t… 2/12
In the 2nd episode, we discussed ARIMA and friends.
Notebook: kaggle.com/konradb/ts-2-l… 3/12
Read 13 tweets
Mar 25, 2022
Here is a simple t-h-r-e-a-d to show you how easy and fun it is to fine-tune almost any transformer model for sentiment classification on imdb dataset (or any other binary classification dataset) using the new version of Tez ⬇️ 1/N
First, import all the cool stuff you need 2/N
define some args we need for training the model. i also like to use argparse for this 3/N
Read 8 tweets
Mar 7, 2022
We are starting the next community competition THIS WEEK (within 1-3 days)! It will be a computer vision problem :wink:

**Top-3 will get a brand new NVIDIA RTX 3080Ti GPU each!**

How to join? See this t-h-r-e-a-d ;) 1/4
To be eligible for the prize, follow these steps:
1: Register for GTC using this link: nvidia.com/gtc/?ncid=ref-…
2: wait for the competition to launch
3: attend GTC sessions.
*Prizes will be awarded only to those who register using the link above and attend some sessions.* 2/4
- The competition will be sponsored by NVIDIA @NVIDIAAI
- Data is sponsored by Transmute AI Lab, IIT Dhanbad, India. @TransmuteAI

**In addition to RTX, top "innovative" solutions will get a chance to collaborate with researchers from Transmute AI Lab!**

Read 4 tweets

Did Thread Reader help you today?

Support us! We are indie developers!

This site is made by just two indie developers on a laptop doing marketing, support and development! Read more about the story.

Become a Premium Member ($3/month or $30/year) and get exclusive features!

Become Premium

Don't want to be a Premium member but still want to support us?

Make a small donation by buying us coffee ($5) or help with server cost ($10)

Donate via Paypal

Or Donate anonymously using crypto!


0xfe58350B80634f60Fa6Dc149a72b4DFbc17D341E copy


3ATGMxNzCUFzxpMCHL5sWSt4DVtS8UqXpi copy

Thank you for your support!

Follow Us!