Ops Module Documentation¶
This document introduces the ops module of TeleFuser, providing efficient neural network operator implementations for video generation.
Architecture Principles¶
TeleFuser follows a strict layered architecture for operations:
┌─────────────────────────────────────────────────────────────┐
│ models/ │
│ (DiT, VAE, text encoders - ONLY import from ops/) │
└─────────────────────────┬───────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ ops/ │
│ (Compile-aware dispatch: native for compile, kernel for │
│ eager mode. Base classes: CustomOp, CustomOpFunction) │
└─────────────────────────┬───────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ kernel/triton/ │
│ (Pure Triton kernels, custom ops. NOT directly used by │
│ models. May have torch.library.custom_op registration.) │
└─────────────────────────────────────────────────────────────┘
Key Rules¶
-
models/ layer MUST only import from
telefuser.ops/:# ✅ Correct from telefuser.ops.normalization import RMSNorm, LayerNorm, modulate from telefuser.ops.rotary import apply_rotary_emb from telefuser.ops.attention import attention # ❌ Wrong - Never import from kernel layer in models from telefuser.kernel.triton import apply_rotary_embedding from telefuser.kernel.triton import fused_scale_shift -
ops/ layer handles compile-aware dispatch:
-
kernel/triton/ contains pure Triton code:
- No compile-state checks (handled by ops layer)
- May use
torch.library.custom_opfor torch.compile compatibility - Only called by ops/ layer, never directly by models/
Why This Architecture?¶
- torch.compile compatibility: ops layer dispatches to native PyTorch when compiling, allowing Inductor to fuse operations across layers
- Performance: ops layer uses optimized Triton kernels in eager mode
- Separation of concerns: kernel layer focuses on pure kernel implementation, ops layer handles dispatch logic
torch.compile Strategy by Operator Type¶
TeleFuser uses a mixed strategy for torch.compile compatibility, optimizing based on operator characteristics:
| Operator Type | Strategy | Reason |
|---|---|---|
| Attention (High compute density) | @torch.compiler.disable | FlashAttention/SageAttention outperform native PyTorch significantly; fusion gains limited |
| RoPE (Medium compute density) | @torch.compiler.disable | Triton kernel outperforms native; subsequent Attention blocks fusion anyway |
| RMSNorm/LayerNorm (Low compute density) | Native in compile mode | Overhead-bound; Inductor can fuse with adjacent ops for better gains |
| modulate (Point operations) | Native in compile mode | Minimal compute; Inductor auto-fusion optimal |
Execution flow example:
Since Attention uses @torch.compiler.disable, fusion beyond RoPE is blocked anyway. Therefore, RoPE uses Triton kernel to maximize single-operator performance.
Overview¶
The telefuser/ops module contains the following core components:
| Module | Description | File |
|---|---|---|
| Activations | GELU, SiLU, GEGLU, SwiGLU, etc. | activations.py |
| Feed-Forward Networks | Configurable FFN implementations | ffn.py |
| Normalization Layers | RMSNorm, LayerNorm, AdaLayerNorm | normalization.py |
| Quantized Linear Layers | FP8 quantized linear layers | quantized_linear.py |
| Attention | Dense and sparse attention implementations | attention/ |
Activations (activations.py)¶
Standard Activation Functions¶
from telefuser.ops.activations import get_activation
# Get standard activation functions
silu = get_activation("silu")
gelu = get_activation("gelu")
mish = get_activation("mish")
FP32SiLU¶
FP32 version of SiLU activation for numerical stability:
from telefuser.ops.activations import FP32SiLU
activation = FP32SiLU()
output = activation(inputs) # Internally converts to FP32 for computation
Gated Linear Units¶
GELU¶
Standard GELU activation function with tanh approximation support:
from telefuser.ops.activations import GELU
# Exact GELU
gelu = GELU(dim_in=1024, dim_out=4096, approximate="none")
# Tanh approximate GELU (faster)
gelu_approx = GELU(dim_in=1024, dim_out=4096, approximate="tanh")
GEGLU¶
Gated GELU unit, splits input and applies gating:
from telefuser.ops.activations import GEGLU
geglu = GEGLU(dim_in=1024, dim_out=4096)
# Output: hidden_states * gelu(gate)
SwiGLU¶
Gated SiLU unit, similar to GEGLU but uses SiLU activation:
from telefuser.ops.activations import SwiGLU
swiglu = SwiGLU(dim_in=1024, dim_out=4096)
# Output: hidden_states * silu(gate)
ApproximateGELU¶
Fast GELU approximation using sigmoid function:
from telefuser.ops.activations import ApproximateGELU
approx_gelu = ApproximateGELU(dim_in=1024, dim_out=4096)
# Formula: x * sigmoid(1.702 * x)
Activation Functions Reference¶
| Class | Formula | Reference |
|---|---|---|
GELU | GELU(x) | Gaussian Error Linear Units |
GEGLU | x * GELU(gate) | GLU Variants |
SwiGLU | x * SiLU(gate) | GLU Variants |
ApproximateGELU | x * sigmoid(1.702x) | GELU Approximation |
Feed-Forward Networks (ffn.py)¶
FeedForward¶
Configurable feed-forward network supporting multiple activation functions:
from telefuser.ops.ffn import FeedForward
# Standard FFN (4x expansion)
ffn = FeedForward(dim=1024, mult=4, activation_fn="geglu")
# Custom hidden dimension
ffn = FeedForward(dim=1024, inner_dim=4096, activation_fn="swiglu")
# With dropout
ffn = FeedForward(dim=1024, dropout=0.1, final_dropout=True)
Supported Activation Functions¶
| Activation Name | Description |
|---|---|
"gelu" | Standard GELU |
"gelu-approximate" | Tanh approximate GELU |
"geglu" | Gated GELU |
"geglu-approximate" | Approximate gated GELU |
"swiglu" | Gated SiLU |
"linear-silu" | Linear projection + SiLU |
Usage Example¶
import torch
from telefuser.ops.ffn import FeedForward
# Create FFN
ffn = FeedForward(
dim=1024, # Input/output dimension
mult=4, # Hidden layer expansion multiplier
dropout=0.0, # Dropout probability
activation_fn="geglu", # Activation function
bias=True, # Whether to use bias
)
# Forward pass
x = torch.randn(2, 256, 1024) # (batch, seq, dim)
output = ffn(x)
print(output.shape) # (2, 256, 1024)
Normalization Layers (normalization.py)¶
RMSNorm¶
Root Mean Square Layer Normalization, more efficient than LayerNorm:
from telefuser.ops.normalization import RMSNorm
# Create RMSNorm
norm = RMSNorm(dim=1024, eps=1e-5, elementwise_affine=True)
# Forward pass
output = norm(hidden_states)
Performance Optimization: - Uses tf_kernel.rmsnorm on CUDA for best performance - Falls back to Triton kernel - Uses PyTorch implementation for non-CUDA tensors
LayerNorm¶
Layer Normalization with Triton kernel optimization:
from telefuser.ops.normalization import LayerNorm
# Create LayerNorm
norm = LayerNorm(dim=1024, eps=1e-6, elementwise_affine=True, bias=True)
# Forward pass
output = norm(hidden_states)
Performance Optimization: - Uses Triton kernel on CUDA - Falls back to nn.functional.layer_norm for non-CUDA tensors
AdaLayerNormContinuous¶
Adaptive layer normalization with continuous conditioning:
from telefuser.ops.normalization import AdaLayerNormContinuous
# Create adaptive normalization
ada_norm = AdaLayerNormContinuous(
embedding_dim=1024, # Normalization dimension
conditioning_embedding_dim=256, # Conditioning embedding dimension
elementwise_affine=True,
norm_type="layer_norm", # or "rms_norm"
)
# Forward pass
x = torch.randn(2, 256, 1024)
cond = torch.randn(2, 256)
output = ada_norm(x, cond)
modulate Function¶
Modulation function for adaptive normalization:
from telefuser.ops.normalization import modulate
# Apply modulation: x * (1 + scale) + shift
output = modulate(x, shift, scale)
Performance Optimization: Uses Triton kernel's fused_scale_shift on CUDA.
Normalization Layers Reference¶
| Class | Description | Kernel Optimization |
|---|---|---|
RMSNorm | RMS normalization | tf_kernel > Triton > PyTorch |
LayerNorm | Layer normalization | Triton > PyTorch |
AdaLayerNormContinuous | Adaptive layer normalization | Uses LayerNorm or RMSNorm internally |
Quantized Linear Layers (quantized_linear.py)¶
LinearFP8¶
FP8 quantized linear layer for memory-efficient inference:
import torch.nn as nn
from telefuser.ops.quantized_linear import LinearFP8
# Create from existing Linear layer
original_linear = nn.Linear(1024, 4096)
fp8_linear = LinearFP8(original_linear, data_type=torch.float8_e4m3fn)
# Forward pass
x = torch.randn(2, 256, 1024, device="cuda")
output = fp8_linear(x)
Backend Support: - Prioritizes tf_kernel for best performance - Falls back to vLLM FP8 kernels
Model Quantization Tools¶
from telefuser.ops.quantized_linear import replace_linear_layers, convert_params_to_buffers
# Replace all Linear layers with FP8 versions
replace_linear_layers(model, quant_type=torch.float8_e4m3fn)
# Convert non-FP8 parameters to buffers (reduce memory overhead)
model = convert_params_to_buffers(model)
Complete Quantization Example¶
import torch
import torch.nn as nn
from telefuser.ops.quantized_linear import replace_linear_layers, convert_params_to_buffers
# Load model
model = load_my_model()
model = model.to("cuda")
# Replace Linear layers with FP8 versions
replace_linear_layers(model, quant_type=torch.float8_e4m3fn)
# Convert parameters to buffers
model = convert_params_to_buffers(model)
# Inference
with torch.no_grad():
output = model(input_tensor)
Attention Module (attention/)¶
The attention module provides unified dense and sparse attention implementations. For detailed documentation, please refer to the Attention Implementation Guide.
Quick Reference¶
from telefuser.ops.attention import attention, long_context_attention
from telefuser.core.config import AttentionConfig, AttnImplType
# Dense attention
config = AttentionConfig.dense_attention(AttnImplType.FLASH_ATTN_2)
output = attention(q, k, v, attention_config=config)
# Sparse attention (radial)
config = AttentionConfig.radial_attention(dense_timesteps=40)
output = attention(q, k, v, attention_config=config, sparse_state=sparse_state)
# Long context attention (distributed)
output = long_context_attention(q, k, v, device_mesh=device_mesh)
Module Structure¶
| File | Description |
|---|---|
attention_impl.py | Unified attention implementation supporting multiple backends |
radial_attention_core.py | Radial sparse attention core |
local_sparse_attn.py | Local window sparse attention |
sparse_attention.py | Sparse attention interface |
Supported Attention Backends¶
| Backend | Type | Dependencies |
|---|---|---|
TORCH_SDPA | Dense | PyTorch 2.0+ |
TORCH_CUDNN | Dense | cuDNN |
FLASH_ATTN_2/3/4 | Dense | flash-attn |
SAGE_ATTN_* | Dense | sageattention |
RADIAL_ATTN | Sparse | flashinfer / sageattention |
LOCAL_SPARSE_ATTN | Sparse | block_sparse_attn |
Using Ops in New Models¶
Example: Custom Transformer Block¶
import torch
import torch.nn as nn
from telefuser.ops.normalization import RMSNorm, modulate
from telefuser.ops.ffn import FeedForward
from telefuser.ops.attention import attention
from telefuser.core.config import AttentionConfig, AttnImplType
class MyTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
attention_config: AttentionConfig = None,
):
super().__init__()
self.norm1 = RMSNorm(dim)
self.norm2 = RMSNorm(dim)
# Attention
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
self.num_heads = num_heads
self.attention_config = attention_config or AttentionConfig.dense_attention(
AttnImplType.FLASH_ATTN_2
)
# FFN
self.ffn = FeedForward(
dim=dim,
mult=mlp_ratio,
activation_fn="geglu",
)
# Adaptive modulation parameters
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim),
)
def forward(self, x, cond):
# Adaptive modulation
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(cond).chunk(6, dim=1)
# Attention residual
x = x + gate_msa.unsqueeze(1) * self.attention(
modulate(self.norm1(x), shift_msa, scale_msa)
)
# FFN residual
x = x + gate_mlp.unsqueeze(1) * self.ffn(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def attention(self, x):
B, S, D = x.shape
qkv = self.qkv(x).reshape(B, S, 3, self.num_heads, D // self.num_heads)
q, k, v = qkv.unbind(2)
output = attention(
q, k, v,
attention_config=self.attention_config,
input_layout="BSND",
output_layout="BSND",
)
return self.proj(output.flatten(2))
Example: Using Quantized Layers¶
import torch
import torch.nn as nn
from telefuser.ops.quantized_linear import LinearFP8, replace_linear_layers
class MyQuantizedModel(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear1 = nn.Linear(dim, dim * 4)
self.linear2 = nn.Linear(dim * 4, dim)
def forward(self, x):
x = self.linear1(x)
x = torch.nn.functional.gelu(x)
x = self.linear2(x)
return x
# Create and quantize model
model = MyQuantizedModel(dim=1024).cuda()
replace_linear_layers(model, quant_type=torch.float8_e4m3fn)
Performance Optimization Tips¶
1. Choose the Right Attention Backend¶
| GPU | Recommended Backend |
|---|---|
| H100/B100+ | FLASH_ATTN_4 or SAGE_ATTN_2_8_8_SM90 |
| A100/RTX 4090 | FLASH_ATTN_2 or SAGE_ATTN_2_8_16 |
| Other CUDA GPUs | TORCH_SDPA or FLASH_ATTN_2 |
2. Use FP8 Quantization to Reduce Memory¶
# For large model inference
replace_linear_layers(model, quant_type=torch.float8_e4m3fn)
model = convert_params_to_buffers(model)
3. Use Sparse Attention for Long Videos¶
# Radial attention can reduce 50%+ memory
config = AttentionConfig.radial_attention(
dense_timesteps=40, # Use dense attention for early timesteps
decay_factor=1.0,
)
4. Use Long Context Attention for Distributed Training¶
# Ulysses sequence parallelism
from telefuser.distributed import create_device_mesh_from_config
from telefuser.core.config import ParallelConfig
config = ParallelConfig(device_ids=[0, 1, 2, 3], sp_ulysses_degree=4)
device_mesh = create_device_mesh_from_config(config)
# Enable in model
dit.enable_usp()
Related Documentation¶
- Attention Implementation Guide - Detailed attention module documentation
- Adding New Models - How to add new models
- Parallel Processing - Distributed training guide