Skip to content

CPU Offloading

CPU offloading is a memory optimization technique that reduces GPU VRAM usage by temporarily moving model weights from GPU to CPU memory during inference. TeleFuser provides multiple offloading strategies to balance memory usage and inference speed.

Offloading Strategies

TeleFuser supports four offloading strategies, configured via WeightOffloadType:

Strategy Description Use Case
NO_CPU_OFFLOAD No offloading, all weights stay in GPU GPU memory is sufficient
MODEL_CPU_OFFLOAD Offload entire model to CPU between stages Moderate memory constraints
SEQUENTIAL_CPU_OFFLOAD Layer-by-layer offloading during forward pass Severe memory constraints
ASYNC_CPU_OFFLOAD Asynchronous layerwise offloading with prefetching Best balance of speed and memory

Async CPU Offload

ASYNC_CPU_OFFLOAD is the recommended strategy for most scenarios. It uses AsyncOffloadManager to:

  • Offload per-layer weights from GPU to pinned CPU memory
  • Prefetch upcoming layers asynchronously using a dedicated CUDA stream
  • Overlap data transfer with computation for minimal latency

How Async Offload Works

Time ──────────────────────────────────────────────►

Layer 0: [Load]──[Compute]────────────────────────────
Layer 1:      [Async Load]──[Compute]─────────────────
Layer 2:           [Async Load]──[Compute]────────────
Layer 3:                [Async Load]──[Compute]───────

Data transfer (load) overlaps with computation, hiding latency

Key Parameters

Parameter Type Default Description
offload_type WeightOffloadType NO_CPU_OFFLOAD Offloading strategy
pin_cpu_memory bool True Use pinned memory for faster H2D transfer
offload_ratio float 1.0 Ratio of layers to offload (1.0 = all layers)
prefetch_size int 1 Number of layers to prefetch ahead
lazy_gpu_cache bool False Delay GPU buffer allocation until first use

Lazy GPU Cache

The lazy_gpu_cache parameter (added in recent versions) controls whether GPU buffers are pre-allocated during initialization:

  • lazy_gpu_cache=False (default): GPU buffer pool is allocated during initialization
  • lazy_gpu_cache=True: GPU buffer pool is allocated on first use (saves VRAM during initialization)

Use lazy_gpu_cache=True when: - GPU memory is extremely limited during pipeline initialization - You want to defer VRAM allocation until inference starts

Use allocate_gpu_cache() and cleanup_gpu_cache() for manual control:

# Example: Manual GPU cache management
from telefuser.offload.async_offload import AsyncOffloadManager

# Initialize with lazy_gpu_cache=True
manager = AsyncOffloadManager(layers, lazy_gpu_cache=True)

# Manually allocate when ready
manager.allocate_gpu_cache()

# After inference, release cache to free VRAM
manager.cleanup_gpu_cache()

Usage in Pipelines

Basic Configuration

from telefuser.core.config import OffloadConfig, WeightOffloadType
from telefuser.pipelines.wan_video.wan21_video import Wan21VideoPipelineConfig

# Create pipeline configuration
pipe_config = Wan21VideoPipelineConfig()

# Enable async offload for DiT (most memory-intensive component)
pipe_config.dit_config.offload_config = OffloadConfig(
    offload_type=WeightOffloadType.ASYNC_CPU_OFFLOAD,
    pin_cpu_memory=True,
    prefetch_size=1,
)

# Optionally enable offload for other stages
pipe_config.vae_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD
pipe_config.text_encoding_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD

WanVideo Example

Complete example for Wan2.1 video generation with CPU offloading:

import torch
from telefuser.core.config import (
    AttentionConfig,
    AttnImplType,
    OffloadConfig,
    WeightOffloadType,
)
from telefuser.core.module_manager import ModuleManager
from telefuser.pipelines.wan_video.wan21_video import (
    Wan21VideoPipeline,
    Wan21VideoPipelineConfig,
)

def get_pipeline(model_root, parallelism=1):
    """Initialize Wan2.1 pipeline with CPU offloading."""

    # Load models to CPU first
    module_manager = ModuleManager(device="cpu")
    module_manager.load_models(
        [f"{model_root}/Wan2.1_VAE.pth"],
        torch_dtype=torch.bfloat16,
    )
    module_manager.load_models(
        [[f"{model_root}/diffusion_pytorch_model.safetensors"]],
        torch_dtype=torch.bfloat16,
    )
    module_manager.load_models(
        [f"{model_root}/models_t5_umt5-xxl-enc-bf16.pth"],
        torch_dtype=torch.bfloat16,
    )

    # Create pipeline
    pipe = Wan21VideoPipeline(device="cuda", torch_dtype=torch.bfloat16)
    pipe_config = Wan21VideoPipelineConfig()

    # Configure attention
    pipe_config.dit_config.attention_config = AttentionConfig.dense_attention(
        AttnImplType.SAGE_ATTN_2_8_8
    )

    # Configure offloading for different stages
    # DiT: Use async layerwise offload (best for large transformer)
    pipe_config.dit_config.offload_config = OffloadConfig(
        offload_type=WeightOffloadType.ASYNC_CPU_OFFLOAD,
        pin_cpu_memory=True,
        offload_ratio=1.0,
        prefetch_size=1,
    )

    # VAE: Use model-level offload (simpler, less frequent transfer)
    pipe_config.vae_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD

    # Text encoder: Use model-level offload
    pipe_config.text_encoding_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD

    # Optional: Enable distributed inference
    if parallelism > 1:
        pipe_config.dit_config.parallel_config.device_ids = list(range(parallelism))
        pipe_config.dit_config.parallel_config.sp_ulysses_degree = 2
        pipe_config.enable_denoising_parallel = True

    # Initialize pipeline
    pipe.init(module_manager, pipe_config)
    return pipe

# Usage
model_root = "/path/to/Wan2.1-T2V-1.3B"
pipe = get_pipeline(model_root, parallelism=1)

# Generate video
video = pipe(
    prompt="A cat playing piano",
    num_inference_steps=40,
    num_frames=81,
    height=480,
    width=832,
)

Large Model Example (14B+)

For large models like Wan2.1-14B, offloading is essential:

# Configuration for Wan2.1-14B (720P)
pipe_config = Wan21VideoPipelineConfig()

# Use async offload with larger prefetch for better overlap
pipe_config.dit_config.offload_config = OffloadConfig(
    offload_type=WeightOffloadType.ASYNC_CPU_OFFLOAD,
    pin_cpu_memory=True,
    prefetch_size=2,  # Prefetch 2 layers ahead
    offload_ratio=1.0,
)

# Enable offloading for all auxiliary models
pipe_config.clip_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD
pipe_config.vae_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD
pipe_config.text_encoding_config.offload_config.offload_type = WeightOffloadType.MODEL_CPU_OFFLOAD

Performance Considerations

Memory vs Speed Trade-off

Strategy VRAM Savings Speed Impact Recommendation
NO_CPU_OFFLOAD None Fastest 24GB+ VRAM
MODEL_CPU_OFFLOAD High (~50%) Moderate 16-24GB VRAM
ASYNC_CPU_OFFLOAD High (~60-70%) Low 8-16GB VRAM
SEQUENTIAL_CPU_OFFLOAD Maximum Slowest <8GB VRAM

Tuning Prefetch Size

The prefetch_size parameter affects the overlap between data transfer and computation:

  • prefetch_size=1: Default, balanced for most models
  • prefetch_size=2+: Better overlap for larger layers, but more VRAM usage
# For very large layers (e.g., 14B models)
pipe_config.dit_config.offload_config.prefetch_size = 2

Pinned Memory

Setting pin_cpu_memory=True (default) uses page-locked memory for faster H2D transfers:

  • Enabled: Faster transfer, slightly more CPU memory usage
  • Disabled: Slower transfer, less CPU memory usage

Troubleshooting

Out of Memory During Initialization

If GPU OOM occurs during pipeline initialization:

# Use lazy_gpu_cache to defer buffer allocation
pipe_config.dit_config.offload_config = OffloadConfig(
    offload_type=WeightOffloadType.ASYNC_CPU_OFFLOAD,
    lazy_gpu_cache=True,  # Delay GPU buffer allocation
)

Slow Inference

If offloading causes significant slowdown:

  1. Increase prefetch size for better overlap
  2. Reduce offload_ratio to keep more layers resident
  3. Check CPU-GPU interconnect (PCIe bandwidth matters)
# Keep 20% of layers resident in GPU
pipe_config.dit_config.offload_config.offload_ratio = 0.8

CPU Memory Issues

If CPU memory is insufficient:

# Disable pinned memory
pipe_config.dit_config.offload_config.pin_cpu_memory = False

API Reference

OffloadConfig

@dataclass
class OffloadConfig:
    offload_type: WeightOffloadType = WeightOffloadType.NO_CPU_OFFLOAD
    pin_cpu_memory: bool = True
    offload_ratio: float = 1.0
    prefetch_size: int = 1

AsyncOffloadManager

class AsyncOffloadManager:
    def __init__(
        self,
        layers: torch.nn.ModuleList,
        device: torch.device | None = None,
        *,
        enabled: bool = True,
        pin_cpu_memory: bool = True,
        offload_ratio: float = 1,
        prefetch_size: int = 1,
        lazy_gpu_cache: bool = False,
    ) -> None

    def allocate_gpu_cache(self) -> None:
        """Manually allocate GPU cache."""

    def cleanup_gpu_cache(self) -> None:
        """Release GPU cache."""

    def disable_offload(self) -> None:
        """Disable offloading and load all layers."""

    def enable_offload(self) -> None:
        """Re-enable offloading."""

Sequential CPU Offload

For scenarios requiring fine-grained VRAM management, TeleFuser provides enable_sequential_cpu_offload - a layer-by-layer offloading mechanism that wraps individual modules with smart state management.

Three-State System

Each wrapped module operates in one of three states:

State Value Data Location Description
Offload 0 offload_device (usually CPU) Default state, minimal VRAM usage
Onload 1 onload_device (usually GPU) Loaded but may use different dtype
Keep 2 computation_device (GPU) Pinned in GPU for repeated use

State Transition Flow

┌─────────────────────────────────────────────────────────────────┐
│                      Forward Pass Flow                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  if state == 2 (Keep):                                          │
│      → Use weights directly (fastest)                           │
│                                                                  │
│  elif onload config == computation config:                      │
│      → Use weights directly (no conversion needed)              │
│                                                                  │
│  elif vram_limit is set and GPU has free memory:                │
│      → Call keep() to promote to state 2                        │
│      → Use weights directly                                     │
│                                                                  │
│  else:                                                          │
│      → cast_to() temporary copy to GPU                          │
│      → Compute and release (state unchanged)                    │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Usage

from telefuser.offload import enable_sequential_cpu_offload, AutoWrappedLinear

# Define which modules to wrap
module_map = {
    torch.nn.Linear: AutoWrappedLinear,
}

# Configure dtype and device for each state
module_config = {
    "offload_dtype": torch.float32,
    "offload_device": "cpu",
    "onload_dtype": torch.bfloat16,
    "onload_device": "cuda",
    "computation_dtype": torch.bfloat16,
    "computation_device": "cuda",
}

# Enable sequential offloading
enable_sequential_cpu_offload(
    model,
    module_map=module_map,
    module_config=module_config,
    vram_limit=20.0,  # GB - promotes to Keep state when VRAM available
)

Parameters

Parameter Type Default Description
model nn.Module - Model to enable offloading on
module_map dict - Mapping from source module type to wrapper class
module_config dict - Configuration for dtype/device in each state
max_num_param int/None None Parameter threshold for using overflow config
overflow_module_config dict/None None Alternative config for layers exceeding threshold
vram_limit float/None None VRAM limit (GB) for automatic state promotion

Module Configuration

The module_config dictionary controls data placement:

module_config = {
    # Offload state (state=0) - minimal VRAM
    "offload_dtype": torch.float32,    # FP32 for CPU storage
    "offload_device": "cpu",            # Keep on CPU

    # Onload state (state=1) - ready for use
    "onload_dtype": torch.bfloat16,     # Lower precision for GPU
    "onload_device": "cuda",            # Load to GPU

    # Computation state (state=2) - actual computation
    "computation_dtype": torch.bfloat16,  # Must match onload for promotion
    "computation_device": "cuda",         # Must be GPU
}

Available Wrappers

Wrapper Source Module Description
AutoWrappedModule nn.Module Generic wrapper for any module
AutoWrappedLinear nn.Linear Optimized Linear layer with LoRA support
WanAutoCastLayerNorm nn.LayerNorm LayerNorm with automatic mixed precision

Layer-wise Configuration

Different configurations for different parameter sizes:

# Standard config for most layers
base_config = {
    "offload_dtype": torch.float32,
    "offload_device": "cpu",
    "onload_dtype": torch.bfloat16,
    "onload_device": "cuda",
    "computation_dtype": torch.bfloat16,
    "computation_device": "cuda",
}

# Config for large layers (always stay on CPU)
overflow_config = {
    "offload_dtype": torch.float32,
    "offload_device": "cpu",
    "onload_dtype": torch.float32,
    "onload_device": "cpu",  # Never load to GPU
    "computation_dtype": torch.bfloat16,
    "computation_device": "cuda",
}

enable_sequential_cpu_offload(
    model,
    module_map={nn.Linear: AutoWrappedLinear},
    module_config=base_config,
    max_num_param=1_000_000_000,  # 1B parameter threshold
    overflow_module_config=overflow_config,
    vram_limit=22.0,
)

Manual State Control

After enabling, you can manually control module states:

# Manual state transitions
for module in model.modules():
    if hasattr(module, 'offload'):
        module.offload()   # Force to state 0 (CPU)
        module.onload()    # Force to state 1 (onload device)
        module.keep()      # Force to state 2 (GPU)

# Check current state
if hasattr(module, 'state'):
    print(module.state)  # 0=offload, 1=onload, 2=keep

vram_limit Behavior

The vram_limit parameter controls automatic state promotion:

Setting Behavior
None (default) Conservative mode - never promotes to Keep state, always uses temporary cast
20.0 When VRAM usage < 20GB, promotes frequently used modules to Keep state

Recommendation: Always set vram_limit for production use to improve performance.

API Reference

def enable_sequential_cpu_offload(
    model: torch.nn.Module,
    module_map: dict,
    module_config: dict,
    max_num_param: int | None = None,
    overflow_module_config: dict | None = None,
    vram_limit: float | None = None,
) -> None

class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
    def __init__(
        self,
        module: torch.nn.Linear,
        offload_dtype,
        offload_device,
        onload_dtype,
        onload_device,
        computation_dtype,
        computation_device,
        vram_limit,
        name: str = "",
    )

    def offload(self) -> None:   # Switch to state 0
    def onload(self) -> None:    # Switch to state 1
    def keep(self) -> None:      # Switch to state 2

References

  • The async offloading implementation is adapted from SGLang's layerwise offload utility.
  • The sequential CPU offloading implementation is adapted from DiffSynth-Studio.