模型参数定义
| 参数符号 |
含义 |
典型值 |
| B |
Batch Size |
- |
| L_img |
图像序列长度 |
- |
| L_txt |
文本序列长度 |
- |
| L |
总序列长度 (L_txt + L_img) |
- |
| D |
Hidden Size (隐藏层维度) |
3072 |
| H |
Number of Heads (注意力头数) |
24 |
| d |
Head Dimension (D / H) |
128 |
| N_double |
Double Stream Block 层数 |
19 |
| N_single |
Single Stream Block 层数 |
38 |
| r_mlp |
MLP Ratio |
4.0 |
| D_mlp |
MLP Hidden Dimension (D × r_mlp) |
12288 |
| C_in |
输入通道数 |
- |
| C_out |
输出通道数 |
- |
| P |
Patch Size |
2 |
| D_vec |
Vector Input Dimension |
- |
| D_ctx |
Context Input Dimension |
- |
一、输入预处理阶段
1.1 图像输入投影 (img_in)
操作: Linear(C_in, D)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| 图像输入投影 |
Linear |
2×B×Limg×Cin×D |
B×Limg×(Cin+D)×4+(Cin×D+D)×4 |
说明:
- 矩阵乘法: (B×Limg,Cin)×(Cin,D)=(B×Limg,D)
- FLOPs = 2 × M × N × K (矩阵乘法公式,M=B×L_img, N=D, K=C_in)
- 内存访问包括: 输入 + 输出 + 权重 + 偏置
1.2 时间步嵌入 (time_in)
操作: timestep_embedding → MLPEmbedder
1.2.1 Timestep Embedding
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Sinusoidal Embedding |
Trigonometric |
B×256×4 |
B×256×4 |
说明:
- 包括 cos、sin、exp 计算
- 输出维度固定为 256
1.2.2 MLPEmbedder
操作: Linear(256, D) → SiLU → Linear(D, D)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| MLP 第一层 |
Linear |
2×B×256×D |
B×(256+D)×4+(256×D+D)×4 |
| SiLU 激活 |
Elementwise |
B×D×3 |
B×D×4×2 |
| MLP 第二层 |
Linear |
2×B×D×D |
B×(D+D)×4+(D×D+D)×4 |
1.3 向量输入投影
操作: MLPEmbedder(D_vec, D)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| MLP 第一层 |
Linear |
2×B×Dvec×D |
B×(Dvec+D)×4+(Dvec×D+D)×4 |
| SiLU 激活 |
Elementwise |
B×D×3 |
B×D×4×2 |
| MLP 第二层 |
Linear |
2×B×D×D |
B×(D+D)×4+(D×D+D)×4 |
1.4 文本输入投影
操作: Linear(D_ctx, D)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| 文本投影 |
Linear |
2×B×Ltxt×Dctx×D |
B×Ltxt×(Dctx+D)×4+(Dctx×D+D)×4 |
1.5 位置编码
操作: EmbedND 或 LigerEmbedND
1.5.1 标准 RoPE
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| RoPE 计算 |
Trigonometric + Rearrange |
B×L×d×naxes×8 |
B×L×d×4×2 |
说明:
- n_axes: 位置编码的轴数量 (通常为3: T, H, W)
- 包含 cos、sin 和张量重排操作
1..2 Liger RoPE
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| RoPE 计算 |
Trigonometric |
B×L×d×naxes×6 |
B×L×d×4×2 |
二、Double Stream Block
Double Stream Block 分别处理图像流和文本流,但共享位置编码。
2.1 Modulation (img_mod 和 txt_mod)
操作: SiLU → Linear(D, 6×D)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| SiLU 激活 |
Elementwise |
B×D×3 |
B×D×4×2 |
| Linear (img) |
Linear |
2×B×D×6D |
B×(D+6D)×4+(D×6D+6D)×4 |
| Linear (txt) |
Linear |
2×B×D×6D |
B×(D+6D)×4+(D×6D+6D)×4 |
说明: 输出 6 个调制参数: shift₁, scale₁, gate₁, shift₂, scale₂, gate₂
2.2 图像流 - 注意力准备
2.2.1 LayerNorm + Modulation
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| LayerNorm (img) |
Normalization |
B×Limg×D×5 |
B×Limg×D×4×3 |
| Scale + Shift (img) |
Elementwise |
B×Limg×D×2 |
B×Limg×D×4×2 |
2.2.2 QKV 投影
操作: Linear(D, 3×D)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| QKV 投影 (img) |
Linear |
2×B×Limg×D×3D |
B×Limg×(D+3D)×4+(D×3D+3D)×4 |
| Rearrange |
Memory |
0 |
B×Limg×3D×4 |
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Q 投影 (img) |
Linear |
2×B×Limg×D×D |
B×Limg×2D×4+(D×D+D)×4 |
| K 投影 (img) |
Linear |
2×B×Limg×D×D |
B×Limg×2D×4+(D×D+D)×4 |
| V 投影 (img) |
Linear |
2×B×Limg×D×D |
B×Limg×2D×4+(D×D+D)×4 |
2.2.3 QK Normalization
操作: RMSNorm (Fused)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Q Norm (img) |
RMSNorm |
B×H×Limg×d×4 |
B×H×Limg×d×4×2 |
| K Norm (img) |
RMSNorm |
B×H×Limg×d×4 |
B×H×Limg×d×4×2 |
说明: RMSNorm 包括平方、均值、rsqrt、缩放操作
2.3 文本流 - 注意力准备
2.3.1 LayerNorm + Modulation
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| LayerNorm (txt) |
Normalization |
B×Ltxt×D×5 |
B×Ltxt×D×4×3 |
| Scale + Shift (txt) |
Elementwise |
B×Ltxt×D×2 |
B×Ltxt×D×4×2 |
2.3.2 QKV 投影
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| QKV 投影 (txt) |
Linear |
2×B×Ltxt×D×3D |
B×Ltxt×(D+3D)×4+(D×3D+3D)×4 |
| Rearrange |
Memory |
0 |
B×Ltxt×3D×4 |
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Q 投影 (txt) |
Linear |
2×B×Ltxt×D×D |
B×Ltxt×2D×4+(D×D+D)×4 |
| K 投影 (txt) |
Linear |
2×B×Ltxt×D×D |
B×Ltxt×2D×4+(D×D+D)×4 |
| V 投影 (txt) |
Linear |
2×B×Ltxt×D×D |
B×Ltxt×2D×4+(D×D+D)×4 |
2.3.3 QK Normalization
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Q Norm (txt) |
RMSNorm |
B×H×Ltxt×d×4 |
B×H×Ltxt×d×4×2 |
| K Norm (txt) |
RMSNorm |
B×H×Ltxt×d×4 |
B×H×Ltxt×d×4×2 |
2.4 联合注意力计算
2.4.1 拼接 QKV
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Concat Q |
Memory |
0 |
B×H×L×d×4 |
| Concat K |
Memory |
0 |
B×H×L×d×4 |
| Concat V |
Memory |
0 |
B×H×L×d×4 |
说明: L = L_txt + L_img
2.4.2 RoPE 应用
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Apply RoPE (Q) |
Rotation |
B×H×L×d×8 |
B×H×L×d×4×3 |
| Apply RoPE (K) |
Rotation |
B×H×L×d×8 |
B×H×L×d×4×3 |
说明: RoPE 应用涉及复数乘法和张量重塑
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Liger RoPE (Q,K) |
Rotation |
B×H×L×d×6 |
B×H×L×d×4×4 |
2.4.3 Flash Attention
操作: Scaled Dot-Product Attention (使用 Flash Attention 优化)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| QK^T |
MatMul |
2×B×H×L×L×d |
IO 优化 (Flash Attention) |
| Softmax |
Softmax |
B×H×L×L×5 |
IO 优化 (Flash Attention) |
| Attention × V |
MatMul |
2×B×H×L×L×d |
IO 优化 (Flash Attention) |
Flash Attention 内存访问优化:
- 理论上: O(B×H×L2×d)
- Flash Attention: O(B×H×L×d) (通过分块计算降低 HBM 访问)
2.4.4 拆分注意力输出
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Split Output |
Memory |
0 |
B×L×D×4 |
说明: 分离为 txt_attn (L_txt) 和 img_attn (L_img)
2.5 图像流 - 输出投影和 MLP
2.5.1 注意力输出投影
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Proj (img) |
Linear(D, D) |
2×B×Limg×D×D |
B×Limg×2D×4+(D×D+D)×4 |
| Gate × Proj |
Elementwise |
B×Limg×D |
B×Limg×D×4×2 |
| 残差连接 |
Add |
B×Limg×D |
B×Limg×D×4×2 |
2.5.2 MLP 分支
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| LayerNorm |
Normalization |
B×Limg×D×5 |
B×Limg×D×4×3 |
| Scale + Shift |
Elementwise |
B×Limg×D×2 |
B×Limg×D×4×2 |
| MLP Linear 1 |
Linear(D, D_mlp) |
2×B×Limg×D×Dmlp |
B×Limg×(D+Dmlp)×4+(D×Dmlp+Dmlp)×4 |
| GELU |
Activation |
B×Limg×Dmlp×8 |
B×Limg×Dmlp×4×2 |
| MLP Linear 2 |
Linear(D_mlp, D) |
2×B×Limg×Dmlp×D |
B×Limg×(Dmlp+D)×4+(Dmlp×D+D)×4 |
| Gate × MLP |
Elementwise |
B×Limg×D |
B×Limg×D×4×2 |
| 残差连接 |
Add |
B×Limg×D |
B×Limg×D×4×2 |
说明: GELU 近似需要 8 次浮点运算 (包括多项式逼近)
2.6 文本流 - 输出投影和 MLP
2.6.1 注意力输出投影
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Proj (txt) |
Linear(D, D) |
2×B×Ltxt×D×D |
B×Ltxt×2D×4+(D×D+D)×4 |
| Gate × Proj |
Elementwise |
B×Ltxt×D |
B×Ltxt×D×4×2 |
| 残差连接 |
Add |
B×Ltxt×D |
B×Ltxt×D×4×2 |
2.6.2 MLP 分支
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| LayerNorm |
Normalization |
B×Ltxt×D×5 |
B×Ltxt×D×4×3 |
| Scale + Shift |
Elementwise |
B×Ltxt×D×2 |
B×Ltxt×D×4×2 |
| MLP Linear 1 |
Linear(D, D_mlp) |
2×B×Ltxt×D×Dmlp |
B×Ltxt×(D+Dmlp)×4+(D×Dmlp+Dmlp)×4 |
| GELU |
Activation |
B×Ltxt×Dmlp×8 |
B×Ltxt×Dmlp×4×2 |
| MLP Linear 2 |
Linear(D_mlp, D) |
2×B×Ltxt×Dmlp×D |
B×Ltxt×(Dmlp+D)×4+(Dmlp×D+D)×4 |
| Gate × MLP |
Elementwise |
B×Ltxt×D |
B×Ltxt×D×4×2 |
| 残差连接 |
Add |
B×Ltxt×D |
B×Ltxt×D×4×2 |
2.7 Double Stream Block 总计 (单层)
| 阶段 |
计算负载 (FLOPs) |
| 总计 |
≈2×B×(Limg+Ltxt)×D×(12D+8Dmlp+H×L) |
简化公式 (当 D_mlp = 4D, 忽略低阶项):
FLOPsDoubleBlock≈2×B×L×D×(44D+H×L)
三、Single Stream Block
Single Stream Block 处理拼接后的图像和文本序列。
3.1 Modulation
操作: SiLU → Linear(D, 3×D)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| SiLU 激活 |
Elementwise |
B×D×3 |
B×D×4×2 |
| Linear |
Linear |
2×B×D×3D |
B×(D+3D)×4+(D×3D+3D)×4 |
说明: 输出 3 个调制参数: shift, scale, gate
3.2 LayerNorm + Modulation
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| LayerNorm |
Normalization |
B×L×D×5 |
B×L×D×4×3 |
| Scale + Shift |
Elementwise |
B×L×D×2 |
B×L×D×4×2 |
3.3 并行投影
操作: Linear(D, 3×D + D_mlp)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| QKV + MLP 投影 |
Linear |
2×B×L×D×(3D+Dmlp) |
B×L×(D+3D+Dmlp)×4+(D×(3D+Dmlp)+(3D+Dmlp))×4 |
| 分离 QKV 和 MLP |
Memory |
0 |
B×L×(3D+Dmlp)×4 |
| Rearrange QKV |
Memory |
0 |
B×L×3D×4 |
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Q 投影 |
Linear(D, D) |
2×B×L×D×D |
B×L×2D×4+(D×D+D)×4 |
| K 投影 |
Linear(D, D) |
2×B×L×D×D |
B×L×2D×4+(D×D+D)×4 |
| V + MLP 投影 |
Linear(D, D + D_mlp) |
2×B×L×D×(D+Dmlp) |
B×L×(D+D+Dmlp)×4+(D×(D+Dmlp)+(D+Dmlp))×4 |
3.5 QK Normalization
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Q Norm |
RMSNorm |
B×H×L×d×4 |
B×H×L×d×4×2 |
| K Norm |
RMSNorm |
B×H×L×d×4 |
B×H×L×d×4×2 |
3.6 RoPE 应用
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Apply RoPE (标准) |
Rotation |
B×H×L×d×16 |
B×H×L×d×4×6 |
| Apply RoPE (Liger) |
Rotation |
B×H×L×d×6 |
B×H×L×d×4×4 |
3.7 Flash Attention
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| QK^T |
MatMul |
2×B×H×L×L×d |
IO 优化 (Flash Attention) |
| Softmax |
Softmax |
B×H×L×L×5 |
IO 优化 (Flash Attention) |
| Attention × V |
MatMul |
2×B×H×L×L×d |
IO 优化 (Flash Attention) |
3.8 并行输出 MLP
3.8.1 MLP 激活
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| GELU |
Activation |
B×L×Dmlp×8 |
B×L×Dmlp×4×2 |
3.8.2 拼接和投影
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Concat |
Memory |
0 |
B×L×(D+Dmlp)×4 |
| Linear 2 |
Linear(D + D_mlp, D) |
2×B×L×(D+Dmlp)×D |
B×L×(D+Dmlp+D)×4+((D+Dmlp)×D+D)×4 |
3.9 输出处理
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Gate × Output |
Elementwise |
B×L×D |
B×L×D×4×2 |
| 残差连接 |
Add |
B×L×D |
B×L×D×4×2 |
3.10 Single Stream Block 总计 (单层)
| 阶段 |
计算负载 (FLOPs) |
| 总计 |
≈2×B×L×D×(7D+5Dmlp+H×L) |
简化公式 (当 D_mlp = 4D):
FLOPsSingleBlock≈2×B×L×D×(27D+H×L)
四、输出层
4.1 AdaLN Modulation
操作: SiLU → Linear(D, 2×D)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| SiLU 激活 |
Elementwise |
B×D×3 |
B×D×4×2 |
| Linear |
Linear |
2×B×D×2D |
B×(D+2D)×4+(D×2D+2D)×4 |
| Chunk |
Memory |
0 |
B×2D×4 |
4.2 LayerNorm + Modulation
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| LayerNorm |
Normalization |
B×Limg×D×5 |
B×Limg×D×4×3 |
| Scale + Shift |
Elementwise |
B×Limg×D×2 |
B×Limg×D×4×2 |
说明: 仅处理图像序列部分 (L_img)
4.3 输出投影
操作: Linear(D, P² × C_out)
| 阶段 |
操作 |
计算负载 (FLOPs) |
内存访问 (Bytes) |
| Linear |
Linear |
2×B×Limg×D×(P2×Cout) |
B×Limg×(D+P2×Cout)×4+(D×P2×Cout+P2×Cout)×4 |
4.4 Final Layer 总计
| 阶段 |
计算负载 (FLOPs) |
| 总计 |
≈2×B×Limg×D×(2D+P2×Cout) |
五、完整前向传播总计
5.1 总 FLOPs
FLOPstotal=FLOPsprepare+Ndouble×FLOPsDoubleBlock+Nsingle×FLOPsSingleBlock+FLOPsfinal展开 (使用简化公式):
FLOPstotal≈2×B×D×[Limg×Cin+Ltxt×Dctx+(Limg+Ltxt)×256+Ndouble×(Limg+Ltxt)×(44D+H×(Limg+Ltxt))+Nsingle×(Limg+Ltxt)×(27D+H×(Limg+Ltxt))+Limg×(2D+P2×Cout)]5.2 主导项分析
注意力计算主导:
FLOPsattention≈4×B×H×(Ndouble+Nsingle)×L2×dMLP 计算:
FLOPsMLP≈16×B×(Ndouble+Nsingle)×L×D2典型值 (D=3072, H=24, N_double=19, N_single=38):
- 对于短序列 (L < 1024): MLP 占主导
- 对于长序列 (L > 4096): Attention 占主导
- 临界点: L≈H×d4D2≈2048
5.3 内存访问总计
参数内存:
Params=2×D×(6D+3D+Cin+Dctx+P2×Cout)+(Ndouble+Nsingle)×[2×D×(9D+8Dmlp)]激活内存 (峰值):
Activations≈B×[L×D×4(中间特征)+H×L×d×6(QKV)+L×Dmlp(MLP 中间)]六、关键公式总结
6.1 矩阵乘法 FLOPs
FLOPsmatmul(M,N,K)=2×M×N×K6.2 Attention FLOPs
FLOPsattn=4×B×H×L2×d6.3 MLP FLOPs
FLOPsMLP=2×B×L×D×(Din+Dout)6.4 内存访问 (通用)
Bytes=∑(Inputs+Outputs+Weights)×dtype_size
使用微信扫描二维码完成支付