发布时间:2025-12-09 16:15:50 浏览次数:5
Megatron-LM是一个基于PyTorch的框架,用于训练基于Transformer架构的巨型语言模型。它实现了高效的大规模语言模型训练,主要通过以下几种方式:
超大规模语言模型训练 = GPU + PyTorch + Megatron-LM + DeepSpeed
上一步纵切并行计算后,得到 Y = [ Y 1 , Y 2 ] Y = [Y1, Y2] Y=[Y1,Y2],
所以选择对B纵切,这样整个MLP前向传播仅需一次all-reduce,在向后传递中执行一次all-reduce操作。
原mlp 的计算如下:
[ b , s , h ] × [ h , 4 h ] = [ b , s , 4 h ] [ b , s , 4 h ] × [ 4 h , h ] = [ b , s , h ] [b, s, h] \times [h, 4h] = [b, s, 4h] \\ [b, s, 4h] \times [4h, h] = [b, s, h] [b,s,h]×[h,4h]=[b,s,4h][b,s,4h]×[4h,h]=[b,s,h]
将 [h, 4h] 切分成 [h, 4h/p],把 [4h, h] 切分成 [4h/p, h],对[b, s, h]复制到p个卡上。计算步骤变为:
[ b , s , h ] × [ h , 4 h / p ] = [ b , s , 4 h / p ] [ b , s , 4 h / p ] × [ 4 h / p , h ] = [ b , s , h ] [b, s, h] \times [h, 4h/p] = [b, s, 4h/p]\\ [b, s, 4h/p] \times [4h/p, h] = [b, s, h] [b,s,h]×[h,4h/p]=[b,s,4h/p][b,s,4h/p]×[4h/p,h]=[b,s,h]
此有 p 个 [b, s, h],需要做一次 allreduce 得到最终的 [b, s, h]
(1)溢出错误
(2)因精度不足而带来舍入错误:当a与b均用FP16表示时,a=1与b=0.0001相加时,其结果是错误的,因为a/b=10000> 2 11 2^{11} 211,此时b会因尾数右移而变为0,导致结果出错。
当进行浮点数进行加减运算时,首先要使两个数的阶码相同,即小数点的位置对齐,这个过程称为对阶。在对阶时规定使小阶向大阶看齐,通过小阶的尾数算术右移来改变阶码。对阶过程中,由于FP16只有10位的尾数,当小阶的尾数右移超过11 位时,会导致该数变为0,即以FP16表示的数,如果当大数与小数的比率为大于 2 11 2^{11} 211 时,加减法运算结果会出错。而FP32有23位尾数,可表达的精度范围更广,可有效地避免该问题。
梯度更新
问题:梯度过小产生下溢导致梯度为0;学习率过小,而导致梯度与学习率的乘积过小产生下溢;权重相较于其更新值过大( param/update> 2 11 2^{11} 211),会因FP16精度不足,在对阶过程中使更新值变为0,导致不会更新权重。
所以需要保留FP32的权重,每次迭代时,制作这些权重的FP16副本并使用它们用于前向计算和反向计算,更新时将梯度再转换为FP32并用于更新FP32的权重。
混合精度训练过程
为保证梯度落入半精度可表示范围内一个简单有效的方法将训练损失乘以比例因子,根据梯度的链式法则使得所有梯度也等比例放大。当然在权重更新之前,需要以相同的比例因子缩小梯度,再更新到权重上。即loss scale
注意:目前使用 bf16,有更大的表示范围,不需要 loss scale 了
通过序列并行sequence parallelism 和选择性激活重算 selective activation recomputation。结合张量并行,这些技术几乎消除了重新计算激活的必要性。我们在规模达到一万亿个参数的语言模型上进行了评估,并展示了我们的方法可以将激活内存降低5倍,同时将激活重算的执行时间开销降低超过90%。
激活值所占显存减少的原因,megatron-LM中的layer-norms 和 dropouts是要将输入复制到各个卡上,重复运算,也就是说激活值复制了p份,而使用序列并行化,可以仅仅保留一份激活值,而且是分布在各个gpu上。
把Transformer族模型的所有activation消耗算了一遍,然后发现在Transformer核里有一些操作是产生的激活值又大,但是计算量又小的,这些激活值不保存,反向传播时重新计算。其他的激活值存下来,以节省重计算量。
比如下图的红框区域,具有很大的输入大小和很大的激活值,但每个输入元素的浮点运算(FLOPs)数量非常低。这部分需要重新计算而不是存下来
在GPU的显存没占满的时候,可以不做checkpointing,这么一来重计算所带来的额外计算代价会进一步减小。
tensor parallel + sequence parallel + selective activation recomputation非常节省显存
https://zhuanlan.zhihu.com/p/614166245
https://zhuanlan.zhihu.com/p/498422407
https://zhuanlan.zhihu.com/p/343570325
https://zhuanlan.zhihu.com/p/513571706
https://zhuanlan.zhihu.com/p/617087561
https://zhuanlan.zhihu.com/p/68692579
https://arxiv.org/pdf/1909.08053.pdf
https://arxiv.org/pdf/2205.05198.pdf
https://zhuanlan.zhihu.com/p/628820408