Skip to content

feat: add rolling cache turn pruning for ChatSampler (Issue #675)#703

Open
Paramveersingh-S wants to merge 1 commit into
google-deepmind:mainfrom
Paramveersingh-S:feat/chat-sampler-rolling-cache
Open

feat: add rolling cache turn pruning for ChatSampler (Issue #675)#703
Paramveersingh-S wants to merge 1 commit into
google-deepmind:mainfrom
Paramveersingh-S:feat/chat-sampler-rolling-cache

Conversation

@Paramveersingh-S

Copy link
Copy Markdown

Description

This PR addresses the Context Exhaustion issue outlined in Issue #675, specifically focusing on the ChatSampler crashing when long multi-turn conversations exceed the static 4096-token cache_length.

Since JAX arrays are statically compiled and dynamic jnp.roll sliding-window operations introduce significant compilation and latency overheads, this PR solves the issue at the orchestration layer by implementing Context Window Management (Turn Pruning) directly inside gemma/gm/text/_chat_sampler.py.

Key Changes

  • Automated Context Pruning: Added a _prune_context_to_fit mechanism to the ChatSampler.chat method. Before triggering the SamplerLoop, it calculates if used_cache + new_prompt_tokens + max_out_length > cache_length.
  • Eviction Strategy: If the context overflows, the sampler strategically pops the oldest User/Model conversation turn pair from self.turns while explicitly preserving the initial System prompt (if present).
  • Media History Tracking: Introduced history_images and history_audio properties to the ChatSampler. This ensures that when the context is pruned, the sampler can safely flush the static last_state KV Cache and execute a full re-prefill using the dynamically retained multimodal history without dropping user-provided media from active turns.
  • Unit Testing: Added _chat_sampler_test.py to statically verify the eviction constraints.

Fixes #675.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Context Exhaustion and VRAM Spikes in KV Cache & SamplerLoop

1 participant