混合精度训练
date
Apr 17, 2025
slug
mixed_precision_training
tags
Mixed Precision
Transformer
Training
AI Infra
summary
随着模型规模的增大, 以及GPU低精度算力的提升, 混合精度训练存在诸多优势, 成为业界标配; 常见的有FP32/BF16混合精度, 最近一年来, FP8混合精度也逐渐流行开来
type
Post
标签
状态
完成
描述
重要性
🌟🌟🌟
关键字
混合精度
大模型
训练
status
Published
🥰总结

⁉️问题
FP32/FP16混合精度训练中模型状态部分到底是16还是20?
模型状态显存占用包括权重,梯度,优化器状态,其中权重2,梯度2,优化器一阶动量4,二阶4,还需要保存FP32权重保证累积精度,所以一共16
混合精度为啥能节省显存(需要FP32权重备份)? 节省的是哪一部分?
使用混合精度,静态显存没有节省,但是动态显存节省约一半,另外半精度的gemm算力也比全精度的高一倍
为什么使用混合精度, 不直接使用低精度训练(如FP16, FP8)?
一方面低精度数据类型的累积精度不够,存在数值溢出问题,容易产生inf或nan;另一方面会产生舍入误差,无法正常回传梯度;所以一般使用半精度做乘法,全精度做加法
FP16混合精度训练会出现哪些问题? 以及该如何解决?
两个问题:一是表示范围小,会出现精度溢出,解决方案是loss做scale up,将数值范围转化到FP16表示范围,权重更新时再scale down,恢复数值;二是会出现舍入误差,所以需要用全精度进行累加
FP8混合精度训练会出现哪些问题? 如何解决?
FP8遇到的问题和FP16的一样,只是更加严重; 第一个问题不能再使用per-tensor的量化, deepseek使用了更细粒度的量化, 输入使用了per-tile(一维),权重使用per-block(二维),这样的话就只使用e4m3的格式;第二个问题就是使用全精度累加 包括梯度也使用FP32格式,另外tensor core的累积精度不够 每隔一定间隔就将累积结果拷贝到cuda corw的fp32寄存器进行累加
🧐内容
大模型时代之前, 训练神经网络模型默认使用的数据类型为单精度FP32。近年来,为了加快训练、减少显存占用,业界提出了精度无损的混合精度训练方法。大模型时代前期, 通常是FP32/FP16的混合精度训练, 到2025年, DeepSeek率先使用了FP32/FP8的混合精度训练进一步优化训练成本
FP32/FP16混合精度训练
为什么使用混合精度训练?
使用FP16训练神经网络,相对比使用FP32带来的优点有:
- 减少内存占用:FP16的位宽是FP32的一半,因此显存占用也会减小,节省下来的显存可以放更大的网络模型或者使用更多的数据进行训练。
- 加快通讯效率:针对分布式训练,特别是在大模型训练的过程中,通讯的开销制约了网络模型训练的整体性能,数据位宽减小导致数据量减小意味着可以降低通信耗时。
- 张量核心的普及:硬件的发展同样也推动着模型计算的加速,随着Nvidia张量核心(Tensor Core)的普及,低精度数值算力相对更高, 计算效率也更高
带来的问题
数值溢出(overflow/underflow)
由于FP16的动态范围()比FP32的动态范围()要狭窄很多,因此在计算过程中很容易出现上溢出(Overflow,)和下溢(Underflow,
)的错误,溢出之后就会出现“Nan”的问题。在深度学习中,由于激活函数的的梯度往往要比权重梯度小,更易出现下溢出的情况。

舍入错误(Rounding Error)
舍入误差指的是当梯度过小,小于当前区间内的最小间隔时,该次梯度更新可能会失败,用一张图清晰地表示:

解决办法
为了想让深度学习训练可以使用FP16的好处,又要避免精度溢出和舍入误差。于是可以通过FP16和FP32的混合精度训练(Mixed-Precision),混合精度训练过程中可以引入权重备份(Weight Backup)、损失放大(Loss Scaling)、精度累加(Precision Accumulated)三种相关的技术。
权重备份(Weight Backup)

从图中可以了解,在计算过程中所产生的权重weights,激活activations,梯度gradients等均使用 FP16 来进行存储和计算,其中权重使用FP32额外进行备份。由于在更新权重公式为:
深度模型中,的参数值可能会非常小,利用FP16来进行相加的话,则很可能会出现舍入误差问题,导致更新无效。因此通过将权重weights拷贝成FP32格式,并且确保整个更新过程是在 fp32 格式下进行的。即:
权重用FP32格式备份一次,那岂不是使得内存占用反而更高了呢?不是,额外拷贝一份weight增加的显存和FP16权重, FP16梯度降低的显存相抵消。 但是,训练过程中的动态显存:中间变量值和中间激活占用的显存也降低了一半。实际训练中, 当batch size比较大的情况下, 这部分显存占用会非常高, 以以GPT3-175B为例 为例, bs=64时, 这部分显存约为静态显存的10x. 因此混合精度下, 可以大幅节省显存.
损失缩放(Loss Scaling)
如图所示,如果仅仅使用FP32训练,模型收敛得比较好,但是如果用了混合精度训练,会存在网络模型无法收敛的情况。原因是梯度的值太小,使用FP16表示会造成了数据下溢出(Underflow)的问题,导致模型不收敛,如图中灰色的部分。于是需要引入损失缩放(Loss Scaling)技术。

下图展示了 SSD 模型在训练过程中,激活函数梯度的分布情况:可以看到,有67%的梯度小于,如果用 fp16 来表示,则这些梯度都会变成0。

为了解决梯度过小数据下溢的问题,对前向计算出来的Loss值进行放大操作,也就是把FP32的参数乘以某一个因子系数后,把可能溢出的小数位数据往前移,平移到FP16能表示的数据范围内。根据链式求导法则,放大Loss后会作用在反向传播的每一层梯度,这样比在每一层梯度上进行放大更加高效。

损失放大是需要结合混合精度实现的,其主要的主要思路是:
- Scale up阶段,网络模型前向计算后,将得到的损失变化值dLoss增大倍。
- Scale down阶段,反向传播后,将权重梯度缩倍,恢复FP32值进行存储。
动态损失缩放(Dynamic Loss Scaling):上面提到的损失缩放都是使用一个默认值对损失值进行缩放,为了充分利用FP16的动态范围,可以更好地缓解舍入误差,尽量使用比较大的放大倍数。总结动态损失缩放算法,就是每当梯度溢出时候减少损失缩放规模,并且间歇性地尝试增加损失规模,从而实现在不引起溢出的情况下使用最高损失缩放因子,更好地恢复精度。
动态损失缩放的算法如下:
- 首先从较高的缩放因子开始(如), 然后进行训练迭代, 同时检查数值是否会溢出(Infs/Nans)
- 如果没有梯度溢出,则不进行缩放,继续进行迭代;如果检测到梯度溢出,则缩放因子会减半,重新确认梯度更新情况,直到数值不产生溢出;
- 在训练的后期,loss已经趋近收敛稳定,梯度更新的幅度往往小了,这个时候可以允许更高的损失缩放因子来再次防止数据下溢。
- 一定步数后(N=200)会尝试将损失缩放增加F倍数,然后执行步骤2检查是否溢出。
精度累加(Precision Accumulated)
在混合精度的模型训练过程中,使用FP16进行矩阵乘法运算,利用FP32来进行矩阵乘法中间的累加(accumulated),然后再将FP32的值转化为FP16进行存储。简单而言,就是利用FP16进行矩阵相乘,利用FP32来进行加法计算弥补丢失的精度。 这样可以有效减少计算过程中的舍入误差,尽量减缓精度损失的问题。
例如在Nvidia Volta 结构中带有Tensor Core,可以利用FP16混合精度来进行加速,还能保持精度。Tensor Core主要用于实现FP16的矩阵相乘,在利用FP16或者FP32进行累加和存储。在累加阶段能够使用FP32大幅减少混合精度训练的精度损失。

混合精度训练策略(Automatic Mixed Precision,AMP)
以NVIDIA的APEX混合精度库为例,里面提供了4种策略,分别是默认使用FP32进行训练的O0,只优化前向计算部分O1、除梯度更新部分以外都使用混合精度的O2和使用FP16进行训练的O3。具体如图所示。

这里面比较有意思的是O1和O2策略。
O1策略中,会根据实际Tensor和Ops之间的关系建立黑白名单来使用FP16。例如GEMM和CNN卷积操作对于FP16操作特别友好的计算,会把输入的数据和权重转换成FP16进行运算,而softmax、batchnorm等标量和向量在FP32操作好的计算,则是继续使用FP32进行运算,另外还提供了动态损失缩放(dynamic loss scaling)。
而O2策略中,模型权重参数会转化为FP16,输入的网络模型参数也转换为FP16,Batchnorms使用FP32,另外模型权重文件复制一份FP32用于跟优化器更新梯度保持一致都是FP32,另外还提供动态损失缩放(dynamic loss scaling)。使用了权重备份来减少舍入误差和使用损失缩放来避免数据溢出。
当然上面提供的策略是跟硬件有关系,并不是所有的AI加速芯片都使用,这时候针对自研的AI芯片,需要找到适合得到混合精度策略。
FP8混合精度训练
FP8 Tensor Scaling
要将 FP8 应用于 LLM 的训练和推理,一个关键问题是如何克服表示范围和精度下降相关的挑战。其中的一个关键技术就是张量缩放(Tensor Scaling)

早期在 V100 和 A100 GPU 上的 FP16 混合精度训练会广泛采用全局损失缩放(Loss Scaling) ,很适合一些中小模型的训练。然而,在处理一些超大模型或复杂任务时(例如 DALL-E 等模型),Global Loss Scaling 仍然会遇到严重的下溢问题(Underflow)。因此,越来越多采用 Block-wise 和 Layer-wise 的 Gradient 缩放。在 FP8 的 Per Tensor Scaling 技术中,有两种常见方案:Just-in-time Scaling 和 Delayed Scaling。
- Just-in-time Scaling(即时 Scaling):要做 Tensor 的 Scaling,首先需要计算 Tensor 绝对值的最大值(amax),然后得到 Scaling 值,再对 Tensor 进行 Scaling。中间的临时存储都是高精度的,可能要多次读写,计算和访存开销很大;此外,如果是分布式场景,比如梯度的 AllReduce,还要涉及分布式通信开销。整体来说,额外引入的开销会大幅降低 FP8 带来的收益。
- Delayed Scaling(延迟 Scaling):其核心思路是使用额外的 Tensor 来存储之前的 amax 历史,然后根据历史最大值决定当前的最大值。
TE里面使用Delayed Scaling方案:
FP8硬件支持
如下图所示为常见 NVIDIA GPU 对不同数据类型的支持情况,可以看出:
H100/H800 Tensor Core 支持 FP8,但是 CUDA Core 不支持。也就是说,可以使用 FP8 的矩阵乘法(Tensor Core),但是不支持矩阵加法(CUDA Core)。

如下图所示,从 Hopper 架构开始,新的 Tensor Core 支持输入两个 FP8 的矩阵,然后以 FP8 格式相乘,并以 FP32 或 FP16 格式进行累加。Cublas 中也提供了相关 API 可以把后续的类型转换融合进去,就可以避免 FP32 或 FP16 的中间结果写回 Global Memory。

NVIDIA 最新发布的 Blackwell GPU 的 Tensor Core 相比 Hopper 进一步添加了对 FP6 和 FP4 的支持,而 Blackwell GPU 的 CUDA Core 不再支持 INT8。此外,从 Hopper 开始都不再支持 INT4。
FP8软件支持情况
- Pytorch: 从 2.1 版本开始引入 FP8 格式支持,具体来说增加了 “torch.float8_e4m3fn” 和 “torch.float8_e5m2” 两种数据类型(torch.Tensor — PyTorch 2.3 documentation)。但是现在的 2.3 版本也依然在很早期的阶段,很多计算还不支持 FP8
- NVIDIA 的 Transformer Engine 库添加了对 FP8 的支持
- FlashAttention3 也计划开放对 FP8 的支持
- vLLM、TensorRT-LLM 以及 Megatron-LM 都在支持 FP8.
DeepSeek的FP8混合精度训练
训练流程

训练细节见上图, 大多数的计算密集型操作(gemm)使用FP8精度计算, FP32精度进行累加, 包括 , , ; 这样计算密集型任务理论上运算速度提升两倍; 而一些低成本的算子使用更高的精度, 包括: embedding层, 输出head, MoE的gating, 归一化层, attention层;
需要注意的是右上角的Weight Gradient, 这部分其实算临时变量, 不会存储, 更新完就会释放
细粒度量化

其实就是相对于原来的per-tensor量化, 对输入用了更细粒度的per-tile量化(1*128), 对权重用了per-block量化(128*128);
另外, 相比于之前的研究:前向使用E4M3, 后向使用E5M2保证足够的动态范围, DeepSeek仅使用了E4M3格式, 主要原因就是细粒度量化策略降低了有效动态范围的影响
提高累积精度
注意左图上是在tensor core里做的矩阵乘, 然后在cuda core中做的scaling; 原因是在H800 GPU上, deepseek观察到FP8 gemm的累计精度大约只有14位, 远远小于FP32的累积精度; 严重限制了训练精度
做法如右图, 在Tensor Core上进行MMA, 中间结果使用有限位宽进行累积. 一旦达到的间隔, 累积结果就会拷贝到CUDA Core的FP32寄存器中, 进行全精度的FP32累积.
在线量化
为了确保精确的缩放因子并简化框架,我们在线为tile/block计算最大绝对值。进而推导出缩放因子,然后将激活或权重在线量化为 FP8 格式。
显存占用
需要注意的是, 优化器还是会保留FP32的权重用于备份, 但优化器状态本身却是BF16类型
然后算一下模型状态显存占用(暂不考虑中间激活):
权重 | 梯度 | 优化器状态 |
中间激活大部分使用FP8存储, 部分使用高精度