Ops 模块文档¶
本文档介绍 TeleFuser 的 ops 模块,提供高效的视频生成神经网络算子实现。
架构原则¶
TeleFuser 遵循严格的分层架构:
┌─────────────────────────────────────────────────────────────┐
│ models/ │
│ (DiT, VAE, text encoders - 只能从 ops/ 导入) │
└─────────────────────────┬───────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ ops/ │
│ (编译感知分发:compile 时用 native,eager 时用 kernel。 │
│ 基类:CustomOp, CustomOpFunction) │
└─────────────────────────┬───────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ kernel/triton/ │
│ (纯 Triton 内核,custom ops。不直接被 models 使用。 │
│ 可能有 torch.library.custom_op 注册) │
└─────────────────────────────────────────────────────────────┘
关键规则¶
-
models/ 层只能从
telefuser.ops/导入:# ✅ 正确 from telefuser.ops.normalization import RMSNorm, LayerNorm, modulate from telefuser.ops.rotary import apply_rotary_emb from telefuser.ops.attention import attention # ❌ 错误 - 模型中绝不要从 kernel 层导入 from telefuser.kernel.triton import apply_rotary_embedding from telefuser.kernel.triton import fused_scale_shift -
ops/ 层负责编译感知分发:
-
kernel/triton/ 包含纯 Triton 代码:
- 无编译状态检查(由 ops 层处理)
- 可使用
torch.library.custom_op以支持 torch.compile - 只被 ops 层调用,绝不直接被 models 调用
为什么采用此架构?¶
- torch.compile 兼容性:ops 层在编译时分发到 PyTorch 原生实现,让 Inductor 可以跨层融合操作
- 性能优化:ops 层在 eager 模式下使用优化的 Triton 内核
- 关注点分离:kernel 层专注纯内核实现,ops 层处理分发逻辑
不同算子类型的 torch.compile 策略¶
TeleFuser 采用**混合策略**处理 torch.compile 兼容性,根据算子特性优化:
| 算子类型 | 策略 | 原因 |
|---|---|---|
| Attention(高计算密度) | @torch.compiler.disable | FlashAttention/SageAttention 性能远优于原生 PyTorch;融合收益有限 |
| RoPE(中等计算密度) | @torch.compiler.disable | Triton 内核优于原生实现;后续 Attention 已阻断融合 |
| RMSNorm/LayerNorm(低计算密度) | 编译时使用原生实现 | Overhead-bound;Inductor 可与相邻算子融合获得更好收益 |
| modulate(点操作) | 编译时使用原生实现 | 计算量极小;Inductor 自动融合最优 |
执行流程示例:
由于 Attention 使用了 @torch.compiler.disable,RoPE 之后融合已被阻断。因此 RoPE 使用 Triton 内核最大化单算子性能。
概述¶
telefuser/ops 模块包含以下核心组件:
| 模块 | 描述 | 文件 |
|---|---|---|
| 激活函数 | GELU、SiLU、GEGLU、SwiGLU 等 | activations.py |
| 前馈网络 | 可配置的 FFN 实现 | ffn.py |
| 归一化层 | RMSNorm、LayerNorm、AdaLayerNorm | normalization.py |
| 量化线性层 | FP8 量化线性层 | quantized_linear.py |
| 注意力 | 密集和稀疏注意力实现 | attention/ |
激活函数 (activations.py)¶
标准激活函数¶
from telefuser.ops.activations import get_activation
# 获取标准激活函数
silu = get_activation("silu")
gelu = get_activation("gelu")
mish = get_activation("mish")
FP32SiLU¶
SiLU 激活函数的 FP32 版本,用于数值稳定性:
from telefuser.ops.activations import FP32SiLU
activation = FP32SiLU()
output = activation(inputs) # 内部转换为 FP32 计算
门控线性单元¶
GELU¶
标准 GELU 激活函数,支持 tanh 近似:
from telefuser.ops.activations import GELU
# 精确 GELU
gelu = GELU(dim_in=1024, dim_out=4096, approximate="none")
# tanh 近似 GELU(更快)
gelu_approx = GELU(dim_in=1024, dim_out=4096, approximate="tanh")
GEGLU¶
门控 GELU 单元,将输入分割后应用门控:
from telefuser.ops.activations import GEGLU
geglu = GEGLU(dim_in=1024, dim_out=4096)
# 输出: hidden_states * gelu(gate)
SwiGLU¶
门控 SiLU 单元,类似 GEGLU 但使用 SiLU 激活:
from telefuser.ops.activations import SwiGLU
swiglu = SwiGLU(dim_in=1024, dim_out=4096)
# 输出: hidden_states * silu(gate)
ApproximateGELU¶
快速 GELU 近似,使用 sigmoid 函数:
from telefuser.ops.activations import ApproximateGELU
approx_gelu = ApproximateGELU(dim_in=1024, dim_out=4096)
# 公式: x * sigmoid(1.702 * x)
激活函数对照表¶
| 类名 | 公式 | 参考文献 |
|---|---|---|
GELU | GELU(x) | Gaussian Error Linear Units |
GEGLU | x * GELU(gate) | GLU Variants |
SwiGLU | x * SiLU(gate) | GLU Variants |
ApproximateGELU | x * sigmoid(1.702x) | GELU Approximation |
前馈网络 (ffn.py)¶
FeedForward¶
可配置的前馈网络,支持多种激活函数:
from telefuser.ops.ffn import FeedForward
# 标准 FFN(4倍扩展)
ffn = FeedForward(dim=1024, mult=4, activation_fn="geglu")
# 自定义隐藏维度
ffn = FeedForward(dim=1024, inner_dim=4096, activation_fn="swiglu")
# 带 dropout
ffn = FeedForward(dim=1024, dropout=0.1, final_dropout=True)
支持的激活函数¶
| 激活函数名 | 描述 |
|---|---|
"gelu" | 标准 GELU |
"gelu-approximate" | tanh 近似 GELU |
"geglu" | 门控 GELU |
"geglu-approximate" | 近似门控 GELU |
"swiglu" | 门控 SiLU |
"linear-silu" | 线性投影 + SiLU |
使用示例¶
import torch
from telefuser.ops.ffn import FeedForward
# 创建 FFN
ffn = FeedForward(
dim=1024, # 输入/输出维度
mult=4, # 隐藏层扩展倍数
dropout=0.0, # dropout 概率
activation_fn="geglu", # 激活函数
bias=True, # 是否使用偏置
)
# 前向传播
x = torch.randn(2, 256, 1024) # (batch, seq, dim)
output = ffn(x)
print(output.shape) # (2, 256, 1024)
归一化层 (normalization.py)¶
RMSNorm¶
Root Mean Square Layer Normalization,比 LayerNorm 更高效:
from telefuser.ops.normalization import RMSNorm
# 创建 RMSNorm
norm = RMSNorm(dim=1024, eps=1e-5, elementwise_affine=True)
# 前向传播
output = norm(hidden_states)
性能优化: - CUDA 上优先使用 tf_kernel.rmsnorm(最佳性能) - 回退到 Triton 内核 - 非CUDA张量使用 PyTorch 实现
LayerNorm¶
带 Triton 内核优化的 Layer Normalization:
from telefuser.ops.normalization import LayerNorm
# 创建 LayerNorm
norm = LayerNorm(dim=1024, eps=1e-6, elementwise_affine=True, bias=True)
# 前向传播
output = norm(hidden_states)
性能优化: - CUDA 上使用 Triton 内核 - 非CUDA张量回退到 nn.functional.layer_norm
AdaLayerNormContinuous¶
自适应层归一化,支持连续条件输入:
from telefuser.ops.normalization import AdaLayerNormContinuous
# 创建自适应归一化
ada_norm = AdaLayerNormContinuous(
embedding_dim=1024, # 归一化维度
conditioning_embedding_dim=256, # 条件嵌入维度
elementwise_affine=True,
norm_type="layer_norm", # 或 "rms_norm"
)
# 前向传播
x = torch.randn(2, 256, 1024)
cond = torch.randn(2, 256)
output = ada_norm(x, cond)
modulate 函数¶
调制函数,用于自适应归一化:
from telefuser.ops.normalization import modulate
# 应用调制: x * (1 + scale) + shift
output = modulate(x, shift, scale)
性能优化:CUDA 上使用 Triton 内核的 fused_scale_shift。
归一化层对照表¶
| 类名 | 描述 | 核心优化 |
|---|---|---|
RMSNorm | RMS 归一化 | tf_kernel > Triton > PyTorch |
LayerNorm | 层归一化 | Triton > PyTorch |
AdaLayerNormContinuous | 自适应层归一化 | 内部使用 LayerNorm 或 RMSNorm |
量化线性层 (quantized_linear.py)¶
LinearFP8¶
FP8 量化的线性层,用于内存高效推理:
import torch.nn as nn
from telefuser.ops.quantized_linear import LinearFP8
# 从现有 Linear 层创建
original_linear = nn.Linear(1024, 4096)
fp8_linear = LinearFP8(original_linear, data_type=torch.float8_e4m3fn)
# 前向传播
x = torch.randn(2, 256, 1024, device="cuda")
output = fp8_linear(x)
后端支持: - 优先使用 tf_kernel(最佳性能) - 回退到 vLLM 的 FP8 内核
模型量化工具¶
from telefuser.ops.quantized_linear import replace_linear_layers, convert_params_to_buffers
# 替换所有 Linear 层为 FP8 版本
replace_linear_layers(model, quant_type=torch.float8_e4m3fn)
# 将非 FP8 参数转换为 buffer(减少内存开销)
model = convert_params_to_buffers(model)
完整量化示例¶
import torch
import torch.nn as nn
from telefuser.ops.quantized_linear import replace_linear_layers, convert_params_to_buffer
# 加载模型
model = load_my_model()
model = model.to("cuda")
# 替换 Linear 层为 FP8 版本
replace_linear_layers(model, quant_type=torch.float8_e4m3fn)
# 转换参数为 buffer
model = convert_params_to_buffers(model)
# 推理
with torch.no_grad():
output = model(input_tensor)
注意力模块 (attention/)¶
注意力模块提供统一的密集和稀疏注意力实现。详细文档请参考 Attention 实现指南。
快速参考¶
from telefuser.ops.attention import attention, long_context_attention
from telefuser.core.config import AttentionConfig, AttnImplType
# 密集注意力
config = AttentionConfig.dense_attention(AttnImplType.FLASH_ATTN_2)
output = attention(q, k, v, attention_config=config)
# 稀疏注意力(径向)
config = AttentionConfig.radial_attention(dense_timesteps=40)
output = attention(q, k, v, attention_config=config, sparse_state=sparse_state)
# 长上下文注意力(分布式)
output = long_context_attention(q, k, v, device_mesh=device_mesh)
模块结构¶
| 文件 | 描述 |
|---|---|
attention_impl.py | 统一注意力实现,支持多种后端 |
radial_attention_core.py | 径向稀疏注意力核心 |
local_sparse_attn.py | 局部窗口稀疏注意力 |
sparse_attention.py | 稀疏注意力接口 |
支持的注意力后端¶
| 后端 | 类型 | 依赖 |
|---|---|---|
TORCH_SDPA | 密集 | PyTorch 2.0+ |
TORCH_CUDNN | 密集 | cuDNN |
FLASH_ATTN_2/3/4 | 密集 | flash-attn |
SAGE_ATTN_* | 密集 | sageattention |
RADIAL_ATTN | 稀疏 | flashinfer / sageattention |
LOCAL_SPARSE_ATTN | 稀疏 | block_sparse_attn |
在新模型中使用 Ops¶
示例:自定义 Transformer Block¶
import torch
import torch.nn as nn
from telefuser.ops.normalization import RMSNorm, modulate
from telefuser.ops.ffn import FeedForward
from telefuser.ops.attention import attention
from telefuser.core.config import AttentionConfig, AttnImplType
class MyTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
attention_config: AttentionConfig = None,
):
super().__init__()
self.norm1 = RMSNorm(dim)
self.norm2 = RMSNorm(dim)
# 注意力
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
self.num_heads = num_heads
self.attention_config = attention_config or AttentionConfig.dense_attention(
AttnImplType.FLASH_ATTN_2
)
# FFN
self.ffn = FeedForward(
dim=dim,
mult=mlp_ratio,
activation_fn="geglu",
)
# 自适应调制参数
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim),
)
def forward(self, x, cond):
# 自适应调制
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(cond).chunk(6, dim=1)
# 注意力残差
x = x + gate_msa.unsqueeze(1) * self.attention(
modulate(self.norm1(x), shift_msa, scale_msa)
)
# FFN 残差
x = x + gate_mlp.unsqueeze(1) * self.ffn(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def attention(self, x):
B, S, D = x.shape
qkv = self.qkv(x).reshape(B, S, 3, self.num_heads, D // self.num_heads)
q, k, v = qkv.unbind(2)
output = attention(
q, k, v,
attention_config=self.attention_config,
input_layout="BSND",
output_layout="BSND",
)
return self.proj(output.flatten(2))
示例:使用量化层¶
import torch
import torch.nn as nn
from telefuser.ops.quantized_linear import LinearFP8, replace_linear_layers
class MyQuantizedModel(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear1 = nn.Linear(dim, dim * 4)
self.linear2 = nn.Linear(dim * 4, dim)
def forward(self, x):
x = self.linear1(x)
x = torch.nn.functional.gelu(x)
x = self.linear2(x)
return x
# 创建并量化模型
model = MyQuantizedModel(dim=1024).cuda()
replace_linear_layers(model, quant_type=torch.float8_e4m3fn)
性能优化建议¶
1. 选择合适的注意力后端¶
| GPU | 推荐后端 |
|---|---|
| H100/B100+ | FLASH_ATTN_4 或 SAGE_ATTN_2_8_8_SM90 |
| A100/RTX 4090 | FLASH_ATTN_2 或 SAGE_ATTN_2_8_16 |
| 其他 CUDA GPU | TORCH_SDPA 或 FLASH_ATTN_2 |
2. 使用 FP8 量化减少显存¶
# 对于大模型推理
replace_linear_layers(model, quant_type=torch.float8_e4m3fn)
model = convert_params_to_buffers(model)
3. 长视频使用稀疏注意力¶
# 径向注意力可减少 50%+ 显存
config = AttentionConfig.radial_attention(
dense_timesteps=40, # 早期时间步使用密集注意力
decay_factor=1.0,
)
4. 分布式训练使用长上下文注意力¶
# Ulysses 序列并行
from telefuser.distributed import create_device_mesh_from_config
from telefuser.core.config import ParallelConfig
config = ParallelConfig(device_ids=[0, 1, 2, 3], sp_ulysses_degree=4)
device_mesh = create_device_mesh_from_config(config)
# 在模型中启用
dit.enable_usp()
相关文档¶
- Attention 实现指南 - 注意力模块详细文档
- 添加新模型 - 如何添加新模型
- 并行处理 - 分布式训练指南