大模型显存占用及优化

date
May 24, 2025
slug
transformer_memory_cost_and_optimization
tags
Memory Optimization
DeepSpeed
Training
Inference
Transformer
AI Infra
summary
随着大模型参数量的不断增长, 显存成为瓶颈, 本文分析了显存占用的类型及大小, 同时讲述了业界提出的多种优化策略
type
Post
标签
状态
完成
描述
大模型时代的训练单卡装不下, 使用多种优化策略大幅降低每张卡的显存用量
重要性
🌟🌟🌟
关键字
显存
大模型
训练
推理
status
Published

🥰总结

模型训练过程中已用显存随时间变化
模型训练过程中已用显存随时间变化
  1. 模型参数量为:
  1. 显存占用主要包含四部分: 模型权重, 梯度, 中间激活, 优化器状态
  1. 对于混合精度训练, 显存占用为:
    1. 权重
      梯度
      优化器状态
      中间激活
      )
  1. 权重, 梯度, 优化器状态和输入大小(b, s)无关, 但中间激活正相关
  1. 一般使用重计算降低中间激活的显存占用
  1. 大模型时代的训练任务单卡是装不下的, 所以会使用多种并行方式降低显存占用, 但是也会导致引入额外的通信消耗
  1. TP, PP, ZeRO都可以降低显存占用

⁉️问题

显存占用有哪些,分别占用多少?
一共有四部分, 分别是权重, 梯度, 中间激活, 优化器状态, 混合精度训练下,分别占
模型并行和非模型并行占用大小的区别?
非模型并行就需要单张卡放下训练所需的所有显存, 大模型时代根本放不下, 那么就会采用分布式训练方法, 使用各种并行策略, 从而降低每张卡上的显存占用, 但是也会引入卡间通信和机间通信
3D并行下分别节省多少显存?
 

🧐内容

本文分析针对decoder-only的Transformer架构

缩写

缩写
符号
描述
b
batch size
批大小
s
sequence length
token的长度
v
vocab size
词表大小
h
hidden size
隐层大小,
d
head dim
head的维度
a
number heads
MHA中head的数量
l
number layers
层数

模型参数量

Transformer模型一般由l个相同的层组成, 每个层包括两部分: self-attention和MLP

self-attention

self-attention块的模型参数有Q, K, V的权重矩阵 , , 和偏置(bias), 输出权重矩阵和偏置, 4个权重矩阵的形状为[h, h], 4个偏置的形状为[h] , 因此self-attention块的参数量为:

MLP

MLP由两个线性层组成, 一般地, 第一个线性层先将维度从h映射到4h, 第二个线性层再将维度从4h映射到h. 因此第一个线性层的权重矩阵的形状是[h, 4h], 偏置为[4h]; 第二个权重形状为[4h, h], 偏置为[h], 因此MLP块的参数量为:

Layer Norm

另外self-attention块和MLP块各有一个layer norm, 包含了两个可训练的模型参数: 缩放参数 和 平移参数, 形状都是[h], 两个LN, 一共4h

embedding

词嵌入矩阵的参数量为vh

位置编码

如果采用可训练的位置编码, 会有一些可训练的模型参数, 但数量较少. 如果采用相对位置编码, 如RoPE和ALiBi, 则不包含可训练的模型参数, 忽略这一部分

汇总

总参数量为:
当隐藏维度h较大时, 可以忽略一次项, 模型参数近似为

LLAMA模型参数量估计

实际参数量
隐藏维度h
层数l
6.7B
4096
32
6,442,450,944
13.0B
5120
40
12,582,912,000
32.5B
6656
60
31,897,681,920
65.2B
8192
80
64,424,509,440

推理显存占用

在神经网络的推理阶段,没有优化器状态和梯度,也不需要保存中间激活。模型推理阶段占用的显存要远小于训练阶段
模型推理阶段,占用显存的大头主要是模型参数,如果使用float16来进行推理,推理阶段模型参数占用的显存大概是2 bytes
如果使用KV cache来加速推理过程,KV cache也需要占用显存
此外,输入数据也需要放到GPU上,还有一些中间结果(推理过程中的中间结果用完会尽快释放掉),不过这部分占用的显存是很小的,可以忽略。

KV cache显存占用

假设输入序列的长度为 s ,输出序列的长度为 n ,以float16来保存KV cache,那么KV cache的峰值显存占用大小为。这里第一个2表示K/V cache,第二个2表示float16占2个bytes。

训练显存计算

占用显存主要分为四部分:模型参数、前向计算过程中产生的中间激活、后向传递计算得到的梯度、优化器状态。 大模型训练通常采用Adam/AdamW作为优化器, 并使用
混合精度训练
加速训练.

模型参数+梯度+优化器状态

在一次训练迭代中, 每个可训练的模型参数都会对应产生1个梯度, 2个优化器状态(Adam优化器梯度的一阶动量和二阶动量).
在混合精度训练中, 假设模型的参数量是,使用Adam作为优化器进行混合精度训练。由于模型的参数和梯度使用float16,所以显存消耗分别为。Adam会维护一个float32的模型副本,消耗显存。Adam优化器本身会为模型的每个参数维护两个float32的辅助变量,所以显存消耗占用为。因此:

中间激活

这里的激活(activations)指的是:前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量.
假设激活值以float16或bfloat16数据格式保存,每个元素占了2个bytes。唯一例外的是,dropout操作的mask矩阵,每个元素只占1个bytes。在下面的分析中,单位是bytes,而不是元素个数。
先分析self-attention块的中间激活。self-attention块的计算公式如下:
  1. 对于  ,需要保存它们共同的输入  ,这就是中间激活。输入  的形状为 [b,s,h] ,元素个数为 bsh,占用显存大小为 。
  1. 对于 矩阵乘法,需要保存中间激活 ,两个张量的形状都是 [b,s,h] ,占用显存大小合计为  。
  1. 对于函数,需要保存函数的输入 , 占用显存大小为  。这里的表示注意力头数。(的形状为[b, a, s, d], 的形状为[b, a, d, s], 的形状为[b, a, s, s])
  1. 计算完 函数后,会进行dropout操作。需要保存一个mask矩阵(每个元素占用一个字节),mask矩阵的形状与  相同,占用显存大小为 。
  1. 计算在 V 上的attention,即 score⋅V ,需要保存 score ,大小为;以及 V ,大小为  。二者占用显存大小合计为  。
  1. 计算输出映射以及一个dropout操作。输入映射需要保存其输入,大小为  ;dropout需要保存mask矩阵,大小为  。二者占用显存大小合计为 
将上述激活相加得到,self-attention块的激活占用显存大小为
接下来看MLP块的中间激活。MLP块的计算公式如下
  1. 第一个线性层需要保存其输入,占用显存大小为  。
  1. 激活函数需要保存其输入,占用显存大小为  。
  1. 第二个线性层需要保存其输入,占用显存大小为 。
  1. 最后有一个dropout操作,需要保存mask矩阵,占用显存大小为 。
对于MLP块,需要保存的中间激活值为  。
 
另外,self-attention块和MLP块分别对应了一个layer normalization。每个layer norm需要保存其输入,大小为  。2个layer norm需要保存的中间激活为 。
对于 l 层transformer模型,还有embedding层、最后的输出层。embedding层不需要中间激活。总的而言,当隐藏维度 h 比较大,层数 l 较深时,最后的输出层的中间激活是很少的,可以忽略。
因此,对于层transformer模型,中间激活占用的显存大小可以近似为

对比中间激活与模型参数的显存大小

在一次训练迭代中,模型参数, 梯度, 优化器状态占用的显存与输入数据的大小(也就是bs)无关, 只与模型参数量和参数数据类型有关。
中间激活值与输入数据的大小 是成正相关的
当我们训练神经网络遇到显存不足OOM(Out Of Memory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存,而不是模型参数、梯度和优化器的显存。

以GPT3-175B为例

参数量
序列长s
数据类型
层数
隐藏维度h
注意力头数a
175B
2048
FP16
96
12288
96
那么显存峰值如下:
批大小b
weights(GB)
gradients (GB)
activation(GB)
optimizer_states(GB)
1
1050
1050
275
1400
64
1050
1050
17600
1400
128
1050
1050
35300
1400

以bloom-7b为例

参数量
序列长s
数据类型
层数
隐藏维度h
注意力头数a
词表v
7B
2048
FP16
30
4096
32
250880
显存峰值:
batch_size
dtype
parameters(GB)
gradients (GB)
activation(GB)
optimizer_states(GB)
total(GB)
1
fp32
26
26
60
53
167
fp16
13
13
30
79
137
可以看到随着批大小b的增大,中间激活占用的显存远远超过了模型参数显存。通常会采用激活重计算技术来减少中间激活,理论上可以将中间激活显存从减少到,代价是增加了一次额外前向计算的时间,本质上是“时间换空间”。

训练显存优化

常见的降低显存的方法: TP/SP/PP/Zero/重计算
notion image

TP(tensor parallel)

类型
dtype
占用量
描述
占用周期
parameters
FP32
参数
forward+backward
FP16
gradients
FP32
梯度
forward+backward
FP16
activation
FP32
激活值
forward
FP16
optimizer states
FP32
优化器状态
forward+backward

以bloom-7b为例

假设TP=8, 显存峰值如下:
batch_size
dtype
parameters(GB)
gradients (GB)
activation(GB)
optimizer_states(GB)
total(GB)
1
fp32
3.3
3.3
12
6.5
25.1
fp16
1.6
1.6
5.4
9.9
18.5

TP+PP(pipeline parallel)

类型
dtype
占用量
描述
占用周期
parameters
FP32
参数
forward+backward
FP16
gradients
FP32
梯度
forward+backward
FP16
activation
FP32
激活值
forward
FP16
optimizer states
FP32
优化器状态
forward+backward
FP16

以bloom-7b为例

假设PP=4, 显存峰值如下:
batch_size
dtype
parameters(GB)
gradients (GB)
activation(GB)
optimizer_states(GB)
total(GB)
1
fp32
1.2
1.2
10.8
2.4
15.6
fp16
0.6
0.6
5.4
3.5
10.1

SP

在TP里面, 激活值有一部分是没有被切分的(Self-Attention的输入, MLP的输入, LN, dropout),这部分在序列轴维度s上是独立的, 因此可以使用SP. 然后这一部分可以切分为TP份
类型
dtype
占用量
描述
占用周期
parameters
FP32
参数
forward+backward
FP16
gradients
FP32
梯度
forward+backward
FP16
activation
FP32
激活值
forward
FP16
optimizer states
FP32
优化器状态
forward+backward
FP16

DeepSpeed ZeRO Optimization

notion image
 

3D并行+ZeRO

重计算

重计算可以简单的理解为在模型反向的时候重新计算所需的中间激活用于梯度计算,可以适当减少中间激活存储空间,同时也会带来一定的计算开销。
通常算子内部的重计算我们称为Local Recompute,以算子为粒度的整网/子网重计算策略我们称为Gradient Checkpointing。Local Recompute最典型的代表就是Flash Attention,通过存储统计量(expsum和xmax)在反向的时候实时计算出Attention Matrix用于反向计算。Gradient Checkpointing可以参考pytorch的checkpoint方法,变种基本雷同。

计算过程

baseline
baseline
重计算
重计算
 

优化效果

megatron3论文
megatron3论文

显存压缩

在深度学习训练过程中,中间结果(如激活值和梯度信息)虽然仅在一次前向传播和一次反向传播中使用,但往往占用大量内存。考虑到两次使用之间存在明显的时间间隔,可以在第一次使用后对数据进行压缩(Compression),待后续需要时再解压缩,从而有效降低内存占用。
压缩技术主要应用于两个场景:
  • 激活值压缩: 前向传播后对激活值进行压缩,反向传播前解压缩。这对深层神经网络尤为重要,因为激活值通常占用大量内存。
  • 梯度压缩: 在反向传播计算梯度后、梯度同步前对梯度进行压缩,减少跨 GPU 通信的数据量,从而提高分布式训练效率。
压缩技术可以分为两类:
  1. 无损压缩(Lossless Compression):
    1. 采用如 Huffman 编码或 Lempel-Ziv 算法等方法,确保解压缩后的数据与原始数据完全一致。但由于压缩率较低,其内存节省效果有限。
  1. 有损压缩(Lossy Compression):
    1. 使用如 JPEG 或 MPEG 等算法,在允许一定数据损失的前提下获得更高的压缩率。这种方法能显著降低内存占用,但可能对模型精度和收敛性产生一定影响。
Gist(Jain et al. 2018)是一种用于激活值压缩的内存优化技术,其核心在于利用数据编码策略压缩中间结果,主要包含两种编码方案:
  • 层特定无损编码(Layer-Specific Lossless Encoding):
    • 针对特定层结构(例如 ReLU-Pool 与 ReLU-Conv),设计专门的无损编码方案:
    • 对于 ReLU-Pool 层,可采用二值化编码;
    • 对于 ReLU-Conv 层,则使用稀疏存储与稠密计算编码。
  • 激进有损编码(Aggressive Lossy Encoding):
    • 采用 延迟精度降低(Delayed Precision Reduction, DPR) 技术。DPR 的核心思想是:激活值在前向传播时需保持高精度,而在反向传播时可容忍较低精度。因此,在前向传播后将激活值压缩到较低精度,反向传播前再解压至高精度。

内存高效优化器

传统优化器(如 Adam、SGD with Momentum)在训练过程中需要为每个模型参数维护大量状态数据(例如 momentum 和 variance),其内存占用往往与模型参数量相当甚至更高。例如,以 Adam 优化器(Kingma et al. 2014)为例,每个参数需要存储一阶矩和二阶矩,与参数本身及其梯度加起来,整个训练过程大约需要 4 倍于模型权重的内存,这对大型模型训练构成了严峻挑战。
为降低内存消耗,内存高效优化器主要通过以下策略进行设计:
  • 减少状态变量数量: 只保存必要的统计信息,而非完整矩阵;
  • 降低状态变量精度: 采用 FP16 或 bfloat16 存储;
  • 共享状态变量: 在多个参数间共享部分状态信息。

Adafactor

Adafactor(Shazeer et al. 2018) 是一种内存高效的自适应学习率优化器。与 Adam 不同,Adafactor 不存储完整的二阶矩估计矩阵,而是只存储两个向量(行、列统计)替代完整的二阶矩矩阵,显著降低了内存占用,特别适用于参数矩阵具有低秩结构的场景。

SM3

SM3(Sparse Momentum for Massive Models)(Anil et al. 2019) 通过稀疏更新和状态共享,提供了一种同样内存高效的自适应优化方案。
  • 稀疏 Momentum: 只对梯度非零的参数更新 Momentum,从而减少计算和存储开销;
  • 状态共享: 在一定程度上允许不同参数共享状态变量,进一步降低内存消耗;
  • 自适应学习率: 根据各参数梯度动态调整学习率,提高了模型训练的稳定性和收敛速度。

参考链接

  1. 分析transformer模型的参数量、计算量、中间激活、KV cache https://zhuanlan.zhihu.com/p/624740065

© 木白 2024 - 2025