Skip to content

[Bug]: Context Exhaustion and VRAM Spikes in KV Cache & SamplerLoop #675

@prince-shakyaa

Description

@prince-shakyaa

Description

A comprehensive review of the gemma/gm/text module reveals two significant memory management bottlenecks that severely degrade inference performance and context scalability. These architectural limitations prevent the model from sustaining long, multi-turn conversations and introduce unnecessary VRAM spikes during standard inference.

1. Lack of Rolling KV Cache in ChatSampler

Currently, the ChatSampler implementation utilizes a static KV cache size (cache_length = 4096). As multi-turn sessions progress, conversation context is linearly appended without an eviction strategy for older tokens.

  • Affected Component: gemma/gm/text/_chat_sampler.py (lines 125-126)
  • Impact: When the conversation context reaches the cache_length limit, the sampler encounters a hard boundary, leading to an Out-Of-Memory (OOM) error or halting generation. This precludes the model from handling indefinite chat sessions or extended document-processing tasks.

2. Inefficient Logits Extraction in SamplerLoop

The SamplingState in SamplerLoop computes and retains the full vocabulary logits tensor (predicted_logits: Float['B max_out_length V']) throughout the generation loop.

  • Affected Component: gemma/gm/text/_sampler_loop.py (lines 62-65)
  • Impact: Storing the full logits distribution for the entire output sequence across a large vocabulary size ($\sim 256k$ for Gemma) results in massive, unwarranted VRAM spikes. This severely limits the batch_size and max_out_length scaling capabilities.

Expected Behavior

  • ChatSampler: The KV Cache should implement a rolling buffer (e.g., Sliding Window Attention / StreamingLLM) that retains the system prompt and the most recent context while safely evicting the oldest conversational turns.
  • SamplerLoop: Logits should be filtered down to top-k probabilities immediately after computation, or discarded entirely once the next_token is selected, thereby minimizing the memory footprint.

Proposed Solution

  1. Implement a Ring Buffer for KV Cache: Introduce a rolling cache property to dynamically flush the oldest cache_info indices.
  2. Optimize SamplingState: Deprecate the retention of full predicted_logits in favor of an ephemeral memory structure that only surfaces top-k sampling metrics.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions