MMDiT计算负载和访存分析

Posted by     "小段子" on Saturday, November 1, 2025

模型参数定义

参数符号 含义 典型值
B Batch Size -
L_img 图像序列长度 -
L_txt 文本序列长度 -
L 总序列长度 (L_txt + L_img) -
D Hidden Size (隐藏层维度) 3072
H Number of Heads (注意力头数) 24
d Head Dimension (D / H) 128
N_double Double Stream Block 层数 19
N_single Single Stream Block 层数 38
r_mlp MLP Ratio 4.0
D_mlp MLP Hidden Dimension (D × r_mlp) 12288
C_in 输入通道数 -
C_out 输出通道数 -
P Patch Size 2
D_vec Vector Input Dimension -
D_ctx Context Input Dimension -

一、输入预处理阶段

1.1 图像输入投影 (img_in)

操作: Linear(C_in, D)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
图像输入投影 Linear 2×B×Limg×Cin×D2 \times B \times L_{img} \times C_{in} \times D B×Limg×(Cin+D)×4+(Cin×D+D)×4B \times L_{img} \times (C_{in} + D) \times 4 + (C_{in} \times D + D) \times 4

说明:

  • 矩阵乘法: (B×Limg,Cin)×(Cin,D)=(B×Limg,D)(B \times L_{img}, C_{in}) \times (C_{in}, D) = (B \times L_{img}, D)
  • FLOPs = 2 × M × N × K (矩阵乘法公式,M=B×L_img, N=D, K=C_in)
  • 内存访问包括: 输入 + 输出 + 权重 + 偏置

1.2 时间步嵌入 (time_in)

操作: timestep_embedding → MLPEmbedder

1.2.1 Timestep Embedding

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Sinusoidal Embedding Trigonometric B×256×4B \times 256 \times 4 B×256×4B \times 256 \times 4

说明:

  • 包括 cos、sin、exp 计算
  • 输出维度固定为 256

1.2.2 MLPEmbedder

操作: Linear(256, D) → SiLU → Linear(D, D)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
MLP 第一层 Linear 2×B×256×D2 \times B \times 256 \times D B×(256+D)×4+(256×D+D)×4B \times (256 + D) \times 4 + (256 \times D + D) \times 4
SiLU 激活 Elementwise B×D×3B \times D \times 3 B×D×4×2B \times D \times 4 \times 2
MLP 第二层 Linear 2×B×D×D2 \times B \times D \times D B×(D+D)×4+(D×D+D)×4B \times (D + D) \times 4 + (D \times D + D) \times 4

1.3 向量输入投影

操作: MLPEmbedder(D_vec, D)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
MLP 第一层 Linear 2×B×Dvec×D2 \times B \times D_{vec} \times D B×(Dvec+D)×4+(Dvec×D+D)×4B \times (D_{vec} + D) \times 4 + (D_{vec} \times D + D) \times 4
SiLU 激活 Elementwise B×D×3B \times D \times 3 B×D×4×2B \times D \times 4 \times 2
MLP 第二层 Linear 2×B×D×D2 \times B \times D \times D B×(D+D)×4+(D×D+D)×4B \times (D + D) \times 4 + (D \times D + D) \times 4

1.4 文本输入投影

操作: Linear(D_ctx, D)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
文本投影 Linear 2×B×Ltxt×Dctx×D2 \times B \times L_{txt} \times D_{ctx} \times D B×Ltxt×(Dctx+D)×4+(Dctx×D+D)×4B \times L_{txt} \times (D_{ctx} + D) \times 4 + (D_{ctx} \times D + D) \times 4

1.5 位置编码

操作: EmbedND 或 LigerEmbedND

1.5.1 标准 RoPE

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
RoPE 计算 Trigonometric + Rearrange B×L×d×naxes×8B \times L \times d \times n_{axes} \times 8 B×L×d×4×2B \times L \times d \times 4 \times 2

说明:

  • n_axes: 位置编码的轴数量 (通常为3: T, H, W)
  • 包含 cos、sin 和张量重排操作

1..2 Liger RoPE

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
RoPE 计算 Trigonometric B×L×d×naxes×6B \times L \times d \times n_{axes} \times 6 B×L×d×4×2B \times L \times d \times 4 \times 2

二、Double Stream Block

Double Stream Block 分别处理图像流和文本流,但共享位置编码。

2.1 Modulation (img_mod 和 txt_mod)

操作: SiLU → Linear(D, 6×D)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
SiLU 激活 Elementwise B×D×3B \times D \times 3 B×D×4×2B \times D \times 4 \times 2
Linear (img) Linear 2×B×D×6D2 \times B \times D \times 6D B×(D+6D)×4+(D×6D+6D)×4B \times (D + 6D) \times 4 + (D \times 6D + 6D) \times 4
Linear (txt) Linear 2×B×D×6D2 \times B \times D \times 6D B×(D+6D)×4+(D×6D+6D)×4B \times (D + 6D) \times 4 + (D \times 6D + 6D) \times 4

说明: 输出 6 个调制参数: shift₁, scale₁, gate₁, shift₂, scale₂, gate₂

2.2 图像流 - 注意力准备

2.2.1 LayerNorm + Modulation

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
LayerNorm (img) Normalization B×Limg×D×5B \times L_{img} \times D \times 5 B×Limg×D×4×3B \times L_{img} \times D \times 4 \times 3
Scale + Shift (img) Elementwise B×Limg×D×2B \times L_{img} \times D \times 2 B×Limg×D×4×2B \times L_{img} \times D \times 4 \times 2

2.2.2 QKV 投影

  • Fused模式:Open Sora默认,速度更快

操作: Linear(D, 3×D)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
QKV 投影 (img) Linear 2×B×Limg×D×3D2 \times B \times L_{img} \times D \times 3D B×Limg×(D+3D)×4+(D×3D+3D)×4B \times L_{img} \times (D + 3D) \times 4 + (D \times 3D + 3D) \times 4
Rearrange Memory 00 B×Limg×3D×4B \times L_{img} \times 3D \times 4
  • 非 Fused 模式
阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Q 投影 (img) Linear 2×B×Limg×D×D2 \times B \times L_{img} \times D \times D B×Limg×2D×4+(D×D+D)×4B \times L_{img} \times 2D \times 4 + (D \times D + D) \times 4
K 投影 (img) Linear 2×B×Limg×D×D2 \times B \times L_{img} \times D \times D B×Limg×2D×4+(D×D+D)×4B \times L_{img} \times 2D \times 4 + (D \times D + D) \times 4
V 投影 (img) Linear 2×B×Limg×D×D2 \times B \times L_{img} \times D \times D B×Limg×2D×4+(D×D+D)×4B \times L_{img} \times 2D \times 4 + (D \times D + D) \times 4

2.2.3 QK Normalization

操作: RMSNorm (Fused)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Q Norm (img) RMSNorm B×H×Limg×d×4B \times H \times L_{img} \times d \times 4 B×H×Limg×d×4×2B \times H \times L_{img} \times d \times 4 \times 2
K Norm (img) RMSNorm B×H×Limg×d×4B \times H \times L_{img} \times d \times 4 B×H×Limg×d×4×2B \times H \times L_{img} \times d \times 4 \times 2

说明: RMSNorm 包括平方、均值、rsqrt、缩放操作

2.3 文本流 - 注意力准备

2.3.1 LayerNorm + Modulation

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
LayerNorm (txt) Normalization B×Ltxt×D×5B \times L_{txt} \times D \times 5 B×Ltxt×D×4×3B \times L_{txt} \times D \times 4 \times 3
Scale + Shift (txt) Elementwise B×Ltxt×D×2B \times L_{txt} \times D \times 2 B×Ltxt×D×4×2B \times L_{txt} \times D \times 4 \times 2

2.3.2 QKV 投影

  • Fused模式
阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
QKV 投影 (txt) Linear 2×B×Ltxt×D×3D2 \times B \times L_{txt} \times D \times 3D B×Ltxt×(D+3D)×4+(D×3D+3D)×4B \times L_{txt} \times (D + 3D) \times 4 + (D \times 3D + 3D) \times 4
Rearrange Memory 00 B×Ltxt×3D×4B \times L_{txt} \times 3D \times 4
  • 非Fused模式
阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Q 投影 (txt) Linear 2×B×Ltxt×D×D2 \times B \times L_{txt} \times D \times D B×Ltxt×2D×4+(D×D+D)×4B \times L_{txt} \times 2D \times 4 + (D \times D + D) \times 4
K 投影 (txt) Linear 2×B×Ltxt×D×D2 \times B \times L_{txt} \times D \times D B×Ltxt×2D×4+(D×D+D)×4B \times L_{txt} \times 2D \times 4 + (D \times D + D) \times 4
V 投影 (txt) Linear 2×B×Ltxt×D×D2 \times B \times L_{txt} \times D \times D B×Ltxt×2D×4+(D×D+D)×4B \times L_{txt} \times 2D \times 4 + (D \times D + D) \times 4

2.3.3 QK Normalization

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Q Norm (txt) RMSNorm B×H×Ltxt×d×4B \times H \times L_{txt} \times d \times 4 B×H×Ltxt×d×4×2B \times H \times L_{txt} \times d \times 4 \times 2
K Norm (txt) RMSNorm B×H×Ltxt×d×4B \times H \times L_{txt} \times d \times 4 B×H×Ltxt×d×4×2B \times H \times L_{txt} \times d \times 4 \times 2

2.4 联合注意力计算

2.4.1 拼接 QKV

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Concat Q Memory 00 B×H×L×d×4B \times H \times L \times d \times 4
Concat K Memory 00 B×H×L×d×4B \times H \times L \times d \times 4
Concat V Memory 00 B×H×L×d×4B \times H \times L \times d \times 4

说明: L = L_txt + L_img

2.4.2 RoPE 应用

  • 标准模式
阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Apply RoPE (Q) Rotation B×H×L×d×8B \times H \times L \times d \times 8 B×H×L×d×4×3B \times H \times L \times d \times 4 \times 3
Apply RoPE (K) Rotation B×H×L×d×8B \times H \times L \times d \times 8 B×H×L×d×4×3B \times H \times L \times d \times 4 \times 3

说明: RoPE 应用涉及复数乘法和张量重塑

  • Liger模式
阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Liger RoPE (Q,K) Rotation B×H×L×d×6B \times H \times L \times d \times 6 B×H×L×d×4×4B \times H \times L \times d \times 4 \times 4

2.4.3 Flash Attention

操作: Scaled Dot-Product Attention (使用 Flash Attention 优化)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
QK^T MatMul 2×B×H×L×L×d2 \times B \times H \times L \times L \times d IO 优化 (Flash Attention)
Softmax Softmax B×H×L×L×5B \times H \times L \times L \times 5 IO 优化 (Flash Attention)
Attention × V MatMul 2×B×H×L×L×d2 \times B \times H \times L \times L \times d IO 优化 (Flash Attention)

Flash Attention 内存访问优化:

  • 理论上: O(B×H×L2×d)O(B \times H \times L^2 \times d)
  • Flash Attention: O(B×H×L×d)O(B \times H \times L \times d) (通过分块计算降低 HBM 访问)

2.4.4 拆分注意力输出

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Split Output Memory 00 B×L×D×4B \times L \times D \times 4

说明: 分离为 txt_attn (L_txt) 和 img_attn (L_img)

2.5 图像流 - 输出投影和 MLP

2.5.1 注意力输出投影

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Proj (img) Linear(D, D) 2×B×Limg×D×D2 \times B \times L_{img} \times D \times D B×Limg×2D×4+(D×D+D)×4B \times L_{img} \times 2D \times 4 + (D \times D + D) \times 4
Gate × Proj Elementwise B×Limg×DB \times L_{img} \times D B×Limg×D×4×2B \times L_{img} \times D \times 4 \times 2
残差连接 Add B×Limg×DB \times L_{img} \times D B×Limg×D×4×2B \times L_{img} \times D \times 4 \times 2

2.5.2 MLP 分支

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
LayerNorm Normalization B×Limg×D×5B \times L_{img} \times D \times 5 B×Limg×D×4×3B \times L_{img} \times D \times 4 \times 3
Scale + Shift Elementwise B×Limg×D×2B \times L_{img} \times D \times 2 B×Limg×D×4×2B \times L_{img} \times D \times 4 \times 2
MLP Linear 1 Linear(D, D_mlp) 2×B×Limg×D×Dmlp2 \times B \times L_{img} \times D \times D_{mlp} B×Limg×(D+Dmlp)×4+(D×Dmlp+Dmlp)×4B \times L_{img} \times (D + D_{mlp}) \times 4 + (D \times D_{mlp} + D_{mlp}) \times 4
GELU Activation B×Limg×Dmlp×8B \times L_{img} \times D_{mlp} \times 8 B×Limg×Dmlp×4×2B \times L_{img} \times D_{mlp} \times 4 \times 2
MLP Linear 2 Linear(D_mlp, D) 2×B×Limg×Dmlp×D2 \times B \times L_{img} \times D_{mlp} \times D B×Limg×(Dmlp+D)×4+(Dmlp×D+D)×4B \times L_{img} \times (D_{mlp} + D) \times 4 + (D_{mlp} \times D + D) \times 4
Gate × MLP Elementwise B×Limg×DB \times L_{img} \times D B×Limg×D×4×2B \times L_{img} \times D \times 4 \times 2
残差连接 Add B×Limg×DB \times L_{img} \times D B×Limg×D×4×2B \times L_{img} \times D \times 4 \times 2

说明: GELU 近似需要 8 次浮点运算 (包括多项式逼近)

2.6 文本流 - 输出投影和 MLP

2.6.1 注意力输出投影

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Proj (txt) Linear(D, D) 2×B×Ltxt×D×D2 \times B \times L_{txt} \times D \times D B×Ltxt×2D×4+(D×D+D)×4B \times L_{txt} \times 2D \times 4 + (D \times D + D) \times 4
Gate × Proj Elementwise B×Ltxt×DB \times L_{txt} \times D B×Ltxt×D×4×2B \times L_{txt} \times D \times 4 \times 2
残差连接 Add B×Ltxt×DB \times L_{txt} \times D B×Ltxt×D×4×2B \times L_{txt} \times D \times 4 \times 2

2.6.2 MLP 分支

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
LayerNorm Normalization B×Ltxt×D×5B \times L_{txt} \times D \times 5 B×Ltxt×D×4×3B \times L_{txt} \times D \times 4 \times 3
Scale + Shift Elementwise B×Ltxt×D×2B \times L_{txt} \times D \times 2 B×Ltxt×D×4×2B \times L_{txt} \times D \times 4 \times 2
MLP Linear 1 Linear(D, D_mlp) 2×B×Ltxt×D×Dmlp2 \times B \times L_{txt} \times D \times D_{mlp} B×Ltxt×(D+Dmlp)×4+(D×Dmlp+Dmlp)×4B \times L_{txt} \times (D + D_{mlp}) \times 4 + (D \times D_{mlp} + D_{mlp}) \times 4
GELU Activation B×Ltxt×Dmlp×8B \times L_{txt} \times D_{mlp} \times 8 B×Ltxt×Dmlp×4×2B \times L_{txt} \times D_{mlp} \times 4 \times 2
MLP Linear 2 Linear(D_mlp, D) 2×B×Ltxt×Dmlp×D2 \times B \times L_{txt} \times D_{mlp} \times D B×Ltxt×(Dmlp+D)×4+(Dmlp×D+D)×4B \times L_{txt} \times (D_{mlp} + D) \times 4 + (D_{mlp} \times D + D) \times 4
Gate × MLP Elementwise B×Ltxt×DB \times L_{txt} \times D B×Ltxt×D×4×2B \times L_{txt} \times D \times 4 \times 2
残差连接 Add B×Ltxt×DB \times L_{txt} \times D B×Ltxt×D×4×2B \times L_{txt} \times D \times 4 \times 2

2.7 Double Stream Block 总计 (单层)

阶段 计算负载 (FLOPs)
总计 2×B×(Limg+Ltxt)×D×(12D+8Dmlp+H×L)\approx 2 \times B \times (L_{img} + L_{txt}) \times D \times (12D + 8D_{mlp} + H \times L)

简化公式 (当 D_mlp = 4D, 忽略低阶项):

FLOPsDoubleBlock2×B×L×D×(44D+H×L)\text{FLOPs}_{\text{DoubleBlock}} \approx 2 \times B \times L \times D \times (44D + H \times L)

三、Single Stream Block

Single Stream Block 处理拼接后的图像和文本序列。

3.1 Modulation

操作: SiLU → Linear(D, 3×D)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
SiLU 激活 Elementwise B×D×3B \times D \times 3 B×D×4×2B \times D \times 4 \times 2
Linear Linear 2×B×D×3D2 \times B \times D \times 3D B×(D+3D)×4+(D×3D+3D)×4B \times (D + 3D) \times 4 + (D \times 3D + 3D) \times 4

说明: 输出 3 个调制参数: shift, scale, gate

3.2 LayerNorm + Modulation

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
LayerNorm Normalization B×L×D×5B \times L \times D \times 5 B×L×D×4×3B \times L \times D \times 4 \times 3
Scale + Shift Elementwise B×L×D×2B \times L \times D \times 2 B×L×D×4×2B \times L \times D \times 4 \times 2

3.3 并行投影

  • Fused模式

操作: Linear(D, 3×D + D_mlp)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
QKV + MLP 投影 Linear 2×B×L×D×(3D+Dmlp)2 \times B \times L \times D \times (3D + D_{mlp}) B×L×(D+3D+Dmlp)×4+(D×(3D+Dmlp)+(3D+Dmlp))×4B \times L \times (D + 3D + D_{mlp}) \times 4 + (D \times (3D + D_{mlp}) + (3D + D_{mlp})) \times 4
分离 QKV 和 MLP Memory 00 B×L×(3D+Dmlp)×4B \times L \times (3D + D_{mlp}) \times 4
Rearrange QKV Memory 00 B×L×3D×4B \times L \times 3D \times 4
  • 非Fused 模式
阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Q 投影 Linear(D, D) 2×B×L×D×D2 \times B \times L \times D \times D B×L×2D×4+(D×D+D)×4B \times L \times 2D \times 4 + (D \times D + D) \times 4
K 投影 Linear(D, D) 2×B×L×D×D2 \times B \times L \times D \times D B×L×2D×4+(D×D+D)×4B \times L \times 2D \times 4 + (D \times D + D) \times 4
V + MLP 投影 Linear(D, D + D_mlp) 2×B×L×D×(D+Dmlp)2 \times B \times L \times D \times (D + D_{mlp}) B×L×(D+D+Dmlp)×4+(D×(D+Dmlp)+(D+Dmlp))×4B \times L \times (D + D + D_{mlp}) \times 4 + (D \times (D + D_{mlp}) + (D + D_{mlp})) \times 4

3.5 QK Normalization

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Q Norm RMSNorm B×H×L×d×4B \times H \times L \times d \times 4 B×H×L×d×4×2B \times H \times L \times d \times 4 \times 2
K Norm RMSNorm B×H×L×d×4B \times H \times L \times d \times 4 B×H×L×d×4×2B \times H \times L \times d \times 4 \times 2

3.6 RoPE 应用

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Apply RoPE (标准) Rotation B×H×L×d×16B \times H \times L \times d \times 16 B×H×L×d×4×6B \times H \times L \times d \times 4 \times 6
Apply RoPE (Liger) Rotation B×H×L×d×6B \times H \times L \times d \times 6 B×H×L×d×4×4B \times H \times L \times d \times 4 \times 4

3.7 Flash Attention

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
QK^T MatMul 2×B×H×L×L×d2 \times B \times H \times L \times L \times d IO 优化 (Flash Attention)
Softmax Softmax B×H×L×L×5B \times H \times L \times L \times 5 IO 优化 (Flash Attention)
Attention × V MatMul 2×B×H×L×L×d2 \times B \times H \times L \times L \times d IO 优化 (Flash Attention)

3.8 并行输出 MLP

3.8.1 MLP 激活

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
GELU Activation B×L×Dmlp×8B \times L \times D_{mlp} \times 8 B×L×Dmlp×4×2B \times L \times D_{mlp} \times 4 \times 2

3.8.2 拼接和投影

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Concat Memory 00 B×L×(D+Dmlp)×4B \times L \times (D + D_{mlp}) \times 4
Linear 2 Linear(D + D_mlp, D) 2×B×L×(D+Dmlp)×D2 \times B \times L \times (D + D_{mlp}) \times D B×L×(D+Dmlp+D)×4+((D+Dmlp)×D+D)×4B \times L \times (D + D_{mlp} + D) \times 4 + ((D + D_{mlp}) \times D + D) \times 4

3.9 输出处理

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Gate × Output Elementwise B×L×DB \times L \times D B×L×D×4×2B \times L \times D \times 4 \times 2
残差连接 Add B×L×DB \times L \times D B×L×D×4×2B \times L \times D \times 4 \times 2

3.10 Single Stream Block 总计 (单层)

阶段 计算负载 (FLOPs)
总计 2×B×L×D×(7D+5Dmlp+H×L)\approx 2 \times B \times L \times D \times (7D + 5D_{mlp} + H \times L)

简化公式 (当 D_mlp = 4D):

FLOPsSingleBlock2×B×L×D×(27D+H×L)\text{FLOPs}_{\text{SingleBlock}} \approx 2 \times B \times L \times D \times (27D + H \times L)

四、输出层

4.1 AdaLN Modulation

操作: SiLU → Linear(D, 2×D)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
SiLU 激活 Elementwise B×D×3B \times D \times 3 B×D×4×2B \times D \times 4 \times 2
Linear Linear 2×B×D×2D2 \times B \times D \times 2D B×(D+2D)×4+(D×2D+2D)×4B \times (D + 2D) \times 4 + (D \times 2D + 2D) \times 4
Chunk Memory 00 B×2D×4B \times 2D \times 4

4.2 LayerNorm + Modulation

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
LayerNorm Normalization B×Limg×D×5B \times L_{img} \times D \times 5 B×Limg×D×4×3B \times L_{img} \times D \times 4 \times 3
Scale + Shift Elementwise B×Limg×D×2B \times L_{img} \times D \times 2 B×Limg×D×4×2B \times L_{img} \times D \times 4 \times 2

说明: 仅处理图像序列部分 (L_img)

4.3 输出投影

操作: Linear(D, P² × C_out)

阶段 操作 计算负载 (FLOPs) 内存访问 (Bytes)
Linear Linear 2×B×Limg×D×(P2×Cout)2 \times B \times L_{img} \times D \times (P^2 \times C_{out}) B×Limg×(D+P2×Cout)×4+(D×P2×Cout+P2×Cout)×4B \times L_{img} \times (D + P^2 \times C_{out}) \times 4 + (D \times P^2 \times C_{out} + P^2 \times C_{out}) \times 4

4.4 Final Layer 总计

阶段 计算负载 (FLOPs)
总计 2×B×Limg×D×(2D+P2×Cout)\approx 2 \times B \times L_{img} \times D \times (2D + P^2 \times C_{out})

五、完整前向传播总计

5.1 总 FLOPs

FLOPstotal=FLOPsprepare+Ndouble×FLOPsDoubleBlock+Nsingle×FLOPsSingleBlock+FLOPsfinal \begin{aligned} \text{FLOPs}_{\text{total}} &= \text{FLOPs}_{\text{prepare}} \\ &+ N_{\text{double}} \times \text{FLOPs}_{\text{DoubleBlock}} \\ &+ N_{\text{single}} \times \text{FLOPs}_{\text{SingleBlock}} \\ &+ \text{FLOPs}_{\text{final}} \end{aligned}

展开 (使用简化公式):

FLOPstotal2×B×D×[Limg×Cin+Ltxt×Dctx+(Limg+Ltxt)×256+Ndouble×(Limg+Ltxt)×(44D+H×(Limg+Ltxt))+Nsingle×(Limg+Ltxt)×(27D+H×(Limg+Ltxt))+Limg×(2D+P2×Cout)] \begin{aligned} \text{FLOPs}_{\text{total}} &\approx 2 \times B \times D \times \Big[ \\ &\quad L_{img} \times C_{in} + L_{txt} \times D_{ctx} + (L_{img} + L_{txt}) \times 256 \\ &\quad + N_{\text{double}} \times (L_{img} + L_{txt}) \times (44D + H \times (L_{img} + L_{txt})) \\ &\quad + N_{\text{single}} \times (L_{img} + L_{txt}) \times (27D + H \times (L_{img} + L_{txt})) \\ &\quad + L_{img} \times (2D + P^2 \times C_{out}) \\ &\Big] \end{aligned}

5.2 主导项分析

注意力计算主导:

FLOPsattention4×B×H×(Ndouble+Nsingle)×L2×d\text{FLOPs}_{\text{attention}} \approx 4 \times B \times H \times (N_{\text{double}} + N_{\text{single}}) \times L^2 \times d

MLP 计算:

FLOPsMLP16×B×(Ndouble+Nsingle)×L×D2\text{FLOPs}_{\text{MLP}} \approx 16 \times B \times (N_{\text{double}} + N_{\text{single}}) \times L \times D^2

典型值 (D=3072, H=24, N_double=19, N_single=38):

  • 对于短序列 (L < 1024): MLP 占主导
  • 对于长序列 (L > 4096): Attention 占主导
  • 临界点: L4D2H×d2048L \approx \sqrt{\frac{4D^2}{H \times d}} \approx 2048

5.3 内存访问总计

参数内存:

Params=2×D×(6D+3D+Cin+Dctx+P2×Cout)+(Ndouble+Nsingle)×[2×D×(9D+8Dmlp)] \begin{aligned} \text{Params} &= 2 \times D \times (6D + 3D + C_{in} + D_{ctx} + P^2 \times C_{out}) \\ &+ (N_{\text{double}} + N_{\text{single}}) \times \Big[ \\ &\quad 2 \times D \times (9D + 8D_{mlp}) \\ &\Big] \end{aligned}

激活内存 (峰值):

ActivationsB×[L×D×4(中间特征)+H×L×d×6(QKV)+L×Dmlp(MLP 中间)] \begin{aligned} \text{Activations} &\approx B \times \Big[ \\ &\quad L \times D \times 4 \quad \text{(中间特征)} \\ &\quad + H \times L \times d \times 6 \quad \text{(QKV)} \\ &\quad + L \times D_{mlp} \quad \text{(MLP 中间)} \\ &\Big] \end{aligned}

六、关键公式总结

6.1 矩阵乘法 FLOPs

FLOPsmatmul(M,N,K)=2×M×N×K\text{FLOPs}_{\text{matmul}}(M, N, K) = 2 \times M \times N \times K

6.2 Attention FLOPs

FLOPsattn=4×B×H×L2×d\text{FLOPs}_{\text{attn}} = 4 \times B \times H \times L^2 \times d

6.3 MLP FLOPs

FLOPsMLP=2×B×L×D×(Din+Dout)\text{FLOPs}_{\text{MLP}} = 2 \times B \times L \times D \times (D_{in} + D_{out})

6.4 内存访问 (通用)

Bytes=(Inputs+Outputs+Weights)×dtype_size\text{Bytes} = \sum (\text{Inputs} + \text{Outputs} + \text{Weights}) \times \text{dtype\_size}

「真诚赞赏,手留余香」

小段子的技术博客

真诚赞赏,手留余香

使用微信扫描二维码完成支付


AI 知识库助手

你好!我是基于本站文章的 AI 助手,有什么可以帮你的吗?