Latent Cache (Cross-Request Approximate Cache for Diffusion)¶
Latent cache reuses the intermediate latent from a previous inference when an incoming prompt is similar enough to a prompt already served, so the first N denoising steps can be skipped. The implementation lives in telefuser/cache_mem/ and is plugged into the FastAPI service through telefuser/service/cache/.
Latent Cache vs. Feature Cache¶
The two solve different problems:
Feature cache (see feature_cache.md) | Latent cache (this doc) | |
|---|---|---|
| Granularity | Within a single inference, across timesteps | Across requests |
| Reuse key | Step index | Prompt embedding similarity |
| Acceleration | Skip approximable blocks | Skip the first N denoising steps |
| Module | telefuser/feature_cache/ | telefuser/cache_mem/ |
| Persistence | None (request lifetime only) | KV on disk / distributed store + vector DB + metadata |
Use feature cache to speed up one inference; use latent cache to speed up the next inference whose prompt is similar to a cached one. The two can be enabled at the same time without interfering.
Base Interface¶
Latent cache exposes two layers of interfaces: LatentCache is the pipeline / service-facing facade on top, and BaseCacheStrategy is the abstract base class underneath that decides hit logic and save behavior.
class BaseCacheStrategy(ABC):
@abstractmethod
async def lookup(self, prompt: str, task_type: str) -> CacheResult:
"""Query the cache and return whether a hit occurred and the cached latent state."""
pass
async def apply(self, result: CacheResult) -> CacheResult:
"""Post-process the lookup result (e.g. load the latent)."""
return result
@abstractmethod
async def save(
self,
prompt: str,
latent_states_dict: Dict[int, torch.Tensor],
num_frames: int,
task_type: str,
saved_steps: List[int],
embedding_video_frames: Optional[List[Any]] = None,
) -> None:
"""Write back the intermediate latent of this inference to the cache."""
pass
class LatentCache:
async def lookup(self, task_request: Any) -> CacheResult: ...
async def save(
self,
task_request: Any,
latent_states_dict: Dict[int, torch.Tensor],
num_frames: int,
final_step: int,
saved_steps: List[int],
embedding_video_frames: Optional[List[Any]] = None,
) -> None: ...
def shutdown(self) -> None: ...
def purge_by_prompt(self, prompt: str, collection: str = "whole") -> bool: ...
Key fields of CacheResult:
| Field | Type | Meaning |
|---|---|---|
hit | bool | Whether a hit occurred |
skip_step | int | Step to restart denoising from on hit (>0 means skipping the first N steps) |
cache_type | str | Hit cache type, e.g. approximate_cache |
similarity | float | Vector search / rerank score |
latent_state | Tensor | None | Cached latent tensor returned on hit |
cached_prompt | str | Original prompt of the hit entry |
Use in Model Forward¶
The pipeline forwards the latent_data injected by the service layer down to the denoise stage. The denoising loop uses skip_step to decide where to start, and snapshots intermediate latents at the steps listed in saved_steps:
# In the denoise stage (see telefuser/pipelines/wan_video/moe_dit_denoising.py)
cached_latent, effective_start_step, saved_steps = parse_latent_data(
latent_data,
expected_shape=tuple(latents.shape),
total_steps=total_steps,
)
if cached_latent is not None:
latents = cached_latent.to(device=latents.device, dtype=latents.dtype)
saved_steps_set = frozenset(saved_steps)
latent_states_dict: dict[int, torch.Tensor] = {}
for progress_id, timestep in enumerate(timesteps[effective_start_step:]):
absolute_step = effective_start_step + progress_id
# snapshot BEFORE scheduler.step: step k stores the latent that enters step k
if absolute_step in saved_steps_set:
latent_states_dict[absolute_step] = latents.detach().cpu()
noise_pred = self.predict_noise_with_cfg(...)
latents = self.scheduler.step(noise_pred, timesteps[absolute_step], latents)
# pipeline returns the payload alongside the latent so the service layer
# can write it back asynchronously
latent_payload = {
"latent_states_dict": latent_states_dict,
"saved_steps": saved_steps,
"final_step": total_steps - 1,
}
return latents, latent_payload
parse_latent_data (telefuser/pipelines/wan_video/latent_data_utils.py) performs shape and range validation: if the shape mismatches or skip_step is out of range, the cache is silently dropped and the pipeline falls back to full denoising, so the main path is never poisoned by a bad cache entry.
Factory Function¶
The production path does not construct LatentCache directly. Instead, CacheServiceFactory builds a CacheService from CLI arguments and the CACHE_CONFIG declared in the pipeline file:
from telefuser.service.cache import CacheServiceFactory
cache_service = CacheServiceFactory.create_cache_service(
ppl_file="examples/wan_video/wan22_14b_text_to_video_service.py",
enable_latent_cache=True,
cache_mode="read_write", # "read_write" / "read_only" / "write_only"
)
create_cache_service does the following internally:
- Loads
CACHE_CONFIG(a dict or aCacheConfiginstance) fromppl_fileas the base configuration. - Overrides the final config with the CLI's
enable_latent_cache/cache_mode. - Initializes the cache log sink up front.
- Loads
build_latent_datafromppl_file(must exist, otherwise it raises an error). - Instantiates
LatentCache(cache_dir, config)and wraps it insideCacheService.
Manual construction is also supported when needed:
from pathlib import Path
from telefuser.cache_mem.config import CacheConfig
from telefuser.cache_mem.latent_cache import LatentCache
config = CacheConfig(
enable_latent_cache=True,
latent_cache_dir="./latent_cache/wan22_t2v",
cache_strategy_type="video_approximate",
vector_dim=2048,
)
cache = LatentCache(Path(config.latent_cache_dir), config)
The strategy class is looked up in the registry via cache_strategy_type:
from telefuser.cache_mem.strategies import register_strategy, get_strategy_class
register_strategy("video_approximate", VideoBasedApproximateCache) # already registered by default
strategy_cls = get_strategy_class("video_approximate")
VideoBasedApproximateCache¶
The only production strategy implementation is VideoBasedApproximateCache, which combines:
- Prompt encoding:
Qwen3-VL-Embeddingencodes the prompt into a vector that is written to the vector store. - Video encoding: during save, several frames of the generated video are encoded into the same vector space, used as the similarity basis for future hits.
- Optional rerank: when
rerank_enabledis on,Qwen3-VL-Rerankerperforms cross-encoder reranking over the top-k candidates. - Shared backend: when text and video embedding configs end up loading the same model on the same device, the two automatically share a single
Qwen3VLEncoderinstance, saving roughly 5 GB of GPU memory and one cold load.
How VideoBasedApproximateCache Works¶
Write Path¶
When a request finishes, the pipeline hands its latent_payload (containing the per-step latents plus video frames used for prompt similarity) to CacheService.save_latent_payload, which enqueues it onto the cache-save-worker background thread. The thread invokes LatentCache.save:
- Writes each step's latent to the KV store under a key shaped like
f"{cache_id}_step{step}". - Encodes the video frames with
Qwen3-VL-Embeddingand upserts the vector into the vector store (default collection namevideo). - Registers
cache_id → {prompt, saved_steps, size_mb, …}in metadata, persistingprompt_index.jsonandcache_meta.json.
If any step fails, all the latents / vectors / metadata that were already written are rolled back cleanly to avoid an inconsistent state.
Hit Path¶
When a new request arrives, CacheService.build_latent_data:
- Waits on
vector_update_idleto make sure the vector upsert from the previous async save has been committed. - Calls
LatentCache.lookup: encodes the new prompt, queries the top-k approximate caches in the vector store, optionally reranks with Qwen3-VL-Reranker, and compares against the threshold to decide on a hit. - On a hit, loads the latent tensor for
skip_stepfrom the KV store and wraps it into aCacheResult.
The latent_data dict the pipeline receives includes cached_latent, skip_step, and saved_steps. The pipeline restarts the denoise loop at skip_step and snapshots this run's latents according to saved_steps — that is how the cache keeps growing.
Cache Parameters¶
The core parameters used by VideoBasedApproximateCache:
| Parameter | Type | Description |
|---|---|---|
key_steps | list | Step list at which the pipeline is asked to snapshot |
video_similarity_threshold | float | Lower bound for a vector-search hit |
rerank_enabled | bool | Whether to rerank the top-k with Qwen3-VL-Reranker |
rerank_top_k | int | Number of candidates fed into rerank |
rerank_score_threshold | float | Lower bound for a hit when rerank is enabled |
video_embedding_max_frames | int | Max frames sampled when encoding video |
video_vector_collection | str | FAISS collection name (default video) |
Which step to restart from after a hit is decided by
_determine_skip_step: in the current implementation, it skips to step 5 whensimilarityis above the rerank threshold and5 ∈ saved_steps, otherwise it is treated as a miss. Override this method in a subclass to customize the skip policy.
Using VideoBasedApproximateCache¶
Declaring CACHE_CONFIG in the pipeline file is enough to enable it (CacheServiceFactory picks it up automatically at service startup):
# examples/wan_video/wan22_14b_text_to_video_service.py
CACHE_CONFIG = dict(
enable_latent_cache=True,
latent_cache_dir=os.getenv("TELEFUSER_LATENT_CACHE_DIR", "./latent_cache/wan22_t2v"),
cache_mode="write_only",
kv_store_type="local_file",
vector_store_type="faiss",
# Qwen3-VL-Embedding-2B hidden_size=2048; must match the vector_store dim.
vector_dim=2048,
key_steps=[5, 10, 15, 20, 25],
video_embedding_enabled=True,
video_embedding_model_path=os.getenv("QWEN3VL_EMBEDDING_PATH", ""),
video_embedding_max_frames=16,
text_embedding_device_id=1,
video_embedding_device_id=1,
video_vector_collection="video",
rerank_enabled=True,
rerank_model_path=os.getenv("QWEN3VL_RERANKER_PATH", "/storage/model_zoo/Qwen3-VL-Reranker-2B"),
rerank_device_id=int(os.getenv("TELEFUSER_RERANK_DEVICE_ID", "0")),
rerank_top_k=5,
rerank_score_threshold=0.85,
)
The pipeline file also has to provide two hooks the service layer relies on to wire the cache into the main path:
build_latent_data(task_data: dict, cache_result=None) -> dict: convertscache_resultinto thelatent_datadict the pipeline expects (withhit / skip_step / cached_latent / saved_steps).run_with_file(pipeline, **task_data) -> dict: feedslatent_datainto the pipeline and returnslatent_payloadas part of the result so the service layer can write it back to the cache.
Example Scripts¶
| Pipeline | Script | Notes |
|---|---|---|
| Wan2.2 14B T2V (cache enabled) | examples/wan_video/wan22_14b_text_to_video_service.py | Full latent cache configuration example |
| Wan2.2 14B T2V (cache disabled) | examples/wan_video/wan22_14b_text_to_video_service_nocache.py | Same interface, cache off, control group |
Start the service:
telefuser serve examples/wan_video/wan22_14b_text_to_video_service.py \
--port 8000 \
--enable-latent-cache \
--cache-mode read_write
CacheConfig Field Reference¶
The full definition lives in telefuser/cache_mem/config.py; the tables below show the defaults.
Basic¶
| Field | Default | Description |
|---|---|---|
enable_latent_cache | False | Master switch; toggled by CLI --enable-latent-cache. |
cache_mode | READ_WRITE | One of READ_WRITE / READ_ONLY / WRITE_ONLY. |
latent_cache_dir | ./latent_cache | Root directory for storage, metadata, FAISS, and logs. |
max_cache_size_gb | 10 | Soft eviction cap (LRU by last_access_time). |
Logging¶
| Field | Default |
|---|---|
cache_log_enabled | True |
cache_log_dir | {latent_cache_dir}/logs |
cache_log_level | DEBUG |
cache_log_rotation | 100 MB |
cache_log_retention | 7 days |
KV / Vector Backend¶
| Field | Default | Description |
|---|---|---|
kv_store_type | local_file | Or fluxon (stub). |
vector_store_type | faiss | Or qdrant (stub). |
vector_dim | 2048 | Must match the embedder output dim. |
faiss_index_dir | {latent_cache_dir}/faiss | |
qdrant_url / qdrant_api_key | "" / None | Configure once Qdrant is wired up for real. |
cache_strategy_type | video_approximate | Key in the strategy registry. |
Strategy and Embedding¶
| Field | Default | Description |
|---|---|---|
key_steps | [0, 1, 2, 3, 4, 5] | Step list at which the pipeline is asked to snapshot. |
lookup_mode | video | |
video_embedding_enabled | True | |
video_embedding_model_path | Qwen/Qwen3-VL-Embedding-2B | |
video_embedding_max_frames | 16 | |
video_embedding_fps | 1.0 | |
text_embedding_model_path | "" | Empty means reuse the video embedder. |
video_similarity_threshold | 0.10 | Lower bound for a vector-search hit. |
rerank_enabled | False | When on, rerank the top-k with Qwen3-VL-Reranker. |
rerank_top_k | 5 | |
rerank_score_threshold | 0.90 | Lower bound for a hit when rerank is enabled. |
When the text and video embedding configurations end up loading the same model on the same device, VideoBasedApproximateCache lets them share a single Qwen3VLEncoder instance, saving roughly 5 GB of GPU memory and one cold load.
Async Save¶
| Field | Default | Description |
|---|---|---|
save_async_enabled | True | Offload save onto the worker thread. |
save_queue_size | 2 | 0 means unbounded. |
save_on_full | drop | drop / sync / downgrade (downgrade is TODO). |
save_queue_warn_threshold | 8 | Log a warning when queue depth exceeds this value. |
vector_wait_warn_s | 2.0 | Log a warning when lookup waits on the vector barrier longer than this. |
vector_wait_timeout_s | 120.0 | Give up the barrier after timeout and treat as miss. |
flush_on_shutdown | True | CacheService.shutdown drains the queue first. |
The Three Cache Modes¶
| Mode | Effect |
|---|---|
READ_WRITE | Lookup hits, and writes are also persisted. The default. |
READ_ONLY | Lookup hits, but the cache is not updated. Useful during canary rollouts. |
WRITE_ONLY | Lookup always misses, only accumulating cache. Common when warming up a cache against a benchmark. |
Architecture Overview¶
┌────────────────────────────────────────────────────────┐
│ LatentCache (facade) │
│ │
│ ├─ Strategy VideoBasedApproximateCache │
│ │ ├─ prompt_encoder Qwen3-VL-Embedding │
│ │ ├─ video_encoder Qwen3-VL-Embedding (shared) │
│ │ └─ reranker Qwen3-VL-Reranker (optional) │
│ │ │
│ ├─ KVStore LocalFileKVStore | FluxonKVStore* │
│ ├─ VectorStore FAISSVectorStore | QdrantStore* │
│ └─ MetadataManager LocalCacheMetadataManager │
└──────────▲─────────────────────────────────────────────┘
│ via CacheService (async writeback wrapper)
│
FastAPI request thread / pipeline
The Fluxon / Qdrant backends marked with * are still stubs (they raise NotImplementedError); the production path only goes through LocalFileKVStore + FAISSVectorStore.
CacheService owns the background async writeback thread plus a barrier called vector_update_idle, which prevents a lookup from reading a stale index before the previous save finishes its vector upsert.