I've been comparing a lot of transformer variants on large models (400M params): Post/Pre-LN, DeepNet, NormFormers, Swin v2, GLU variants, RMSNorm, Sandwich LN, with GELU, Swish, SmeLU…
More than 2,000h of total training time on TPU v3's 😯
Here are my findings 🤓
You must use a final LayerNorm in the decoder for any pre-LN architecture (they're always present in post-LN).
It also helps convergence to use one at the end of the encoder as well.
Don't use bias in dense layers.
It adds 15% of training time and hurts convergence.
If you use a post-LN model (LayerNorms after each residual connection), DeepNet improves a little bit stability.
I had the highest success with NormFormers.
The position of the LayerNorms is similar to Sandwich-LN (per Cogview) except in the attention block.
It shows however better stability.
I ran a bunch of NormFormer variants.
The paper suggests not using the scaled residual connection. I recommend not using the head scale either.
I don't use learnt scale in LN when followed by dense layers. It trains better with it but I think it's because it acts as a reduced lr.
GLU variants are great!
Even if they increase peak memory (and reduce your max batch size) for same amount of total parameters, they let the model train much better!
As activation function, use GeLU (more stable) or Swish (trains faster).
I didn't have great results with the new SmeLU function.
RMSNorm brings more stability to the training.
However for very long runs it plateau's before LayerNorm.
I didn't have great results with Swin v2 even after playing with different values of tau scale in the cosine attention.
I think the cosine attention makes the model learn much slower.
I also tested to only use Swin relative positions with other variants but it was not helpful.
Sinkformers are a bit slower to train and didn't improve the model.
Maybe it's because I can only use them in the encoder and not in the decoder due to causality.
You'll find more details in the report and a TLDR with links to relevant sections, full of interactive graphs and traceable runs (with diffs between runs).
Many thanks to @_arohan_ and Phil Wang for suggesting some of these ideas!
I have a lot more stuff to try in my backlog so will probably be updating this report in the future.
Also thanks to @pcuenq for running some of these experiments with me!
• • •
Missing some Tweet in this thread? You can try to
force a refresh
Preliminary settings for the large CapPa model:
- Vision model: 328M params
- Text model: 348M params (67M embeds)
Going to train on TPU v5e-256 from TRC 😎
Model based on "Image Captioners Are Scalable Vision Learners Too" with a few tweaks.
Vision models are most often trained either in contrastive fashion on noisy dataset (CLIP) or as classifier on ImageNet.
Here we train a captioner on a noisy dataset.
The goal is to create a strong vision model (we discard the text model) to be used for any downstream task.
The paper proposes 2 training methods:
- Cap -> train on captioning only
- CapPa -> adds also a masked objective where we mask part or all of the text
The model trained will be CapPa with full masking 75% of the time 🤯
Amazed to see the importance of selecting correctly Distributed Shampoo configuration for training the ViT-VQGAN 🤯
TLDR:
👉 Nesterov momentum brings more stability
👉 Optimal settings are problem specific
I trained a lot of different configurations following @_arohan_ suggestions that Nesterov momentum could potentially have an important impact for these types of problems that include a GAN loss.
I tried with/without Nesterov and experimented with a few values of beta1 and beta2.
All the Nesterov runs show much faster convergence and greater stability.
One challenge is there are tons of parameters to adjust: coefficient factors for losses (L2, codebook, lpips, stylegan, discriminator…), optimizer parameters, model architecture, codebook dim…
We added even more options: NormFormer, GLU variants, additional convolutions… 😅
It's tricky to explore the entire space of possibilities so we do quick experiments and try to take smart decisions.
Eventually we'd like to contribute a great f16 with codebook and a f8, not necessarily with codebook (we'll try KL loss) but low dimension.
Impact of learning rate still amazes me!
I would have never expected this graph 🤯
Few interesting things to know:
First you get an immediate drop of loss when lowering learning rate.
So it can be interesting to end your training with a linear decay to 0 and see if you get something a bit better.
Then, despite decaying the learning rate, you can see that the slope/progress is still about the same while you could have wanted to wait until a plateau.
This is extremely hard to guess the right moment to start decaying.