Very cool blog by @character_ai diving into how they trained their proprietary model Kaiju (13B, 34B, 110B), before switching to OSS model, and spoiler: it has Noam Shazeer written all over it.
Most of the choices for model design (MQA, SWA, KV Cache, Quantization) are not to optimize for "AGI benchmark" (think MMLU) since this is not what people will use the model for but instead having a good serving speed. Still, they include code in the pre-training mix and do annealing on high quality "benchmark friendly" data.
One surprising thing is that those models are not MoEs, despite that people working at character at the time like @stephenroller or Noam previously worked on MoE.
Here are a few optimizations that they did
-> MuP-like scaling
-> MQA + SWA
-> Clamping everywhere to control activation, not sure if it's soft or hard?
-> KV Cache sharing
-> Relu^2 activation function
-> FSDP + TP + SP
-> Int6 gradient communication
-> Quantization Aware Training (QAT) with stuff like "bungee_scalar" to get a stable recipe for smaller models. KV Cache and forward pass are in int8, gradient and activation are in bf16, master weight and grad acc in fp32.
The technical report of @Meituan_LongCat LongCat-Flash is crazy good and full of novelty.
The model is a 560B passive ~27B active MoE with adaptive number of active parameters depending on the context thanks to the Zero-Computational expert.
1) New architecture
> Layers have 2 Attention blocks and both FFN and MoE, that way you can overlap the 2 all-to-all coms. (also it's only 28 layers but you have to take into account the 2 attention blocks).
> They add the zero-computational expert that tokens can choose and do nothing, kinda like a "sink" for easy tokens.
> For load balancing, they have a dsv3-like aux loss free to set the average real/fake expert per token. They apply a decay schedule to this bias update. They also do loss balance control.
2) Scaling
> They made changes to MLA/MoE to have variance alignment at init. The gains are pretty impressive in Figure 5, but i don't know to what extent this has impact later on.
> Model growth init is pretty cool, they first train a 2x smaller model and then "when it's trained enough" (a bit unclear here how many B tokens) they init the final model by just stacking the layers of the smaller model.
> They used @_katieeverett @Locchiu and al. paper to have hyperparameter transfer with SP instead of muP for the 2x smaller model ig.
3) Stability
> They track Gradient Norm Ratio and cosine similarity between experts to adjust the weight of the load balancing loss (they recommend Gradient Norm Ratio <0.1).
> To avoid large activations, they apply a z-loss to the hidden state, with a pretty small coef (another alternative to qk-clip/norm).
> They set Adam epsilon to 1e-16 and show that you want it to be lower than the gradient RMS range.
4) Others
> They train on 20T tokens for phase 1, "multiple T of tokens" for mid training on STEM/code data (70% of the mixture), 100B for long context extension without yarn (80B for 32k, 20B for 128k). The long context documents represent 25% of the mixture (not sure if it's % of documents or tokens, which changes a lot here).
> Pre-training data pipeline is context extraction, quality filtering, dedup.
> Nice appendix where they show they compare top_k needed for different benchmarks (higher MMLU with 8.32, lower GSM8K with 7.46). They also compare token allocation in deep/shallow layers.
> They release two new benchmarks Meeseeks (multi-turn IF) and VitaBench (real-world business scenario).
> Lots of details in the infra/inference with info on speculative decoding acceptance, quantization, deployment, kernel optimization, coms overlapping, etc.
> List of the different relevent paper in thread 🧵
Super excited to share SmolLM3, a new strong 3B model.
SmolLM3 is fully open, we share the recipe, the dataset, the training codebase and much more!
> Train on 11T token on 384 H100 for 220k GPU hours
> Support long context up to 128k thanks to NoPE and intra document masking
> Hybrid mode for reasoning (post training is SFT+DPO and some merging magic)
> Multilingual (focus on 6 languages but support on more)
We plan to release a few more artefact in the coming days such as training logs, intermediate checkpoint (such as the mid trained ckpt) and checkpoint from ablation we did through the training.
Time to update your RL paper with a new 3B baseline :)
[ Pre-training ]
> 36T of text tokens (instead of 18T previously). For reference 1 epoch of Meta's dataset is 30T of text AND other modalities.
> 3 stages pre-training: 1) 30T with 4k 2) 5T of science/math/code and reasoning data, no info on ctx length so maybe short CoT? 3) 1T of context extension to 32k (no ruler/helmet benchmark..)
> 8 KV heads instead of 2 or 4 in Qwen 2 <7B.
> No attention bias, and QK Norm (per head)
> Nice MoEs (with global batch load balancing ofc)
[ Post-training ]
> Frontier model using RL with cold start and this « thinking mode fusion »
> Smol model are using (data, not logit) distillation.
I really like how they use there previous generation of model to extract pdf data and generate synthetic data for code and math!
Also seems like this part from the model card sent earlier in r/LocalLLaMa didn't make it in the blogpost.. even more excited for the blog post and see what are this "optimization techniques" and scaling laws!
Training:
- 617B MoE w/ 37B active param
- Use MLA, compressing KV and Q to lower dim space to improve inference caching + memory during training
- Multi token prediction (MTP) (with depth=1), huge impact (see ablation), especially on GSM8K and HumanEval, I wonder why?
- FIM and next token prediction objective
MoE:
- Fine-grained expert with auxiliary loss-free load balancing (adding a bias before the routing). (with ablation). Bias update speed 0.001 and 0 for the last 500B token
- Auxiliary loss term to prevent imbalance in the same sequence (alpha=0.0001 (very small impact)) No token dropping
- Routing MoE to expert in M=4 different nodes max to have efficient communication
- NVLink (intra-node communication) is 3.2x faster than InfiniBand (inter-node communication) so we can have 3.2 experts per node and no extra communication.
Learning rate schedule (similar to WSD):
- 2K step warmup (0 -> 2.2e-4)
- 10T constant (2.2e-4 -> 2.2e-4)
- 4.3T cosine decay (2.2e-4 -> 2.2e-5)
- 333B constant (2.2e-5 -> 2.2e-5) (no smooth decay to 7.3e-6)
- 167B constant (7.3e-6 -> 7.3e-6) (no annealing to 0)
Others hyperparam:
- AdamW (beta1=0.9, beta2=0.95)
- Gradient clipping = 0.1 - 4K context length (before context extension)
- 469B of batch size scheduling (3072 -> 15360), 15360 * 4K = 61M tokens per batch (holy shit)
- MTP loss weight = 0.3 for 10T then 0.1
Long context extension:
- 4K context window, then 32K, then 128K using Yarn.
- same lr as pre-training
- Only NIAH test, no HELMET or RULER benchmark :(
Infra:
- ONLY 2048 H800 TO BUILD A COMPETITOR TO GPT4/SONNET 3.5 IS CRAZY
- No TP, only PP, EP, DP, and zero1
- Custom bidirectional PP schedule (store 2 copies of parameters but increase overlap)
- Fixed 20SM only for communication (very specific to their cluster)
- Recomputation of RMS Norm and Up projection layer
- EMA on CPU to have a good estimate of the model without needing to decay
FP8:
- Mixed precision training (see picture), computation done in full precision are: Embedding, output head, gating module for MoE, normalization, attention operators
- Mixed precision storage, full precision are: Copy of the weight, weight gradient, optimizer state (bf16 for the latter instead of fp32)
- Fine-grained quantization allowing E4M3 FP8 format (see picture) (a lot more info in the paper)
- Few FP8 communications, mainly BF16
- E5M6 for activation after attention operator (bc they are used in backward)
SFT:
- Synthetic data from R1 (an internal one, not the public one)
- Used RL to finetune R1 model on both (problem, answer) and (system prompt, problem, answer) data R1 expert model now generates (w/ high temp) more comprehensive reasoning data that can be used for SFT - Using both rule-based and model-based reward models to select the data.
- Using sample masking, with lr 5e-6 -> 1e-6
RLHF:
- They used "Group Relative Policy Optimization" for final alignment