Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions gemma/gm/nn/gemma4/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,64 @@ class _Inputs:
per_layer_inputs: Float['B L P'] | None = None


def compute_valid_audio_soft_token_counts(
audio_lengths: Int['B num_clips'], audio_seq_length: int
) -> Int['B num_clips']:
"""Compute valid audio soft token counts from raw sample counts."""
# frame_length = int(round(16000 * 20.0 / 1000.0)) = 320
# hop_length = int(round(16000 * 10.0 / 1000.0)) = 160
# frame_size_for_unfold = 321
frame_size_for_unfold = 321
hop_length = 160

# Compute mel frames using integer division
num_mel_frames = (audio_lengths - frame_size_for_unfold) // hop_length + 1
t = num_mel_frames
# two 2x downsampling layers
for _ in range(2):
t_padded = t + 2
t = (t_padded - 3) // 2 + 1

return jnp.minimum(t, audio_seq_length)


def mask_padded_audio_tokens(
tokens: Int['B L'],
inputs_mask: Bool['B L'],
valid_counts: Int['B num_clips'],
audio_seq_length: int,
) -> Bool['B L']:
"""Modify inputs_mask to set padded audio token positions to False."""
# audio_mask: [B, L]
audio_mask = tokens == _token_utils.AUDIO_SOFT_TOKEN_PLACEHOLDER

# cumsum: [B, L]
cumsum = jnp.cumsum(audio_mask, axis=-1)
global_audio_index = (cumsum - 1) * audio_mask

# clip_idx: [B, L]
clip_idx = global_audio_index // audio_seq_length
# local_idx: [B, L]
local_idx = global_audio_index % audio_seq_length

# safe_clip_idx: [B, L]
safe_clip_idx = jnp.maximum(clip_idx, 0)

# expanded_valid_counts: [B, L]
expanded_valid_counts = jnp.take_along_axis(
valid_counts, safe_clip_idx, axis=1
)

# is_valid_audio: [B, L]
is_valid_audio = local_idx < expanded_valid_counts

# padded_audio_mask: [B, L]
padded_audio_mask = audio_mask & ~is_valid_audio

# Modify inputs_mask
return inputs_mask & ~padded_audio_mask


class Transformer(nn.Module):
"""Base transformer class.

Expand Down Expand Up @@ -447,6 +505,20 @@ def _encode_and_get_inputs(
)

inputs_mask = tokens != _PADDING_ID
if (
audio is not None
and audio_lengths is not None
and audio_soft_token_counts is not None
):
valid_counts = compute_valid_audio_soft_token_counts(
audio_lengths, audio_soft_token_counts[0]
)
inputs_mask = mask_padded_audio_tokens(
tokens,
inputs_mask,
valid_counts,
audio_soft_token_counts[0],
)
if positions is None:
positions = _pos_utils.build_positions_from_mask(inputs_mask)
if attention_mask is None:
Expand Down
95 changes: 95 additions & 0 deletions gemma/gm/nn/gemma4/_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from gemma.gm.nn.gemma4 import _gemma4 as gemma4_models
from gemma.gm.nn.gemma4 import _transformer as gt
from gemma.gm.vision import _token_utils
import jax
import jax.numpy as jnp
import pytest
Expand Down Expand Up @@ -57,3 +58,97 @@ def test_text_only():
out, params = _get_output(model, tokens=tokens)
assert 'vision_encoder' not in params
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)


def test_compute_valid_audio_soft_token_counts():
# Test with dummy audio lengths
# audio_lengths has shape [B, num_clips]
# For 16000 Hz:
# 2.0s = 32000 samples.
# mel frames = (32000 - 321) // 160 + 1 = 31679 // 160 + 1 = 197 + 1 = 198
# mel frames.
# Downsampling 1:
# t = (198 + 2 - 3) // 2 + 1 = 197 // 2 + 1 = 98 + 1 = 99.
# Downsampling 2:
# t = (99 + 2 - 3) // 2 + 1 = 98 // 2 + 1 = 49 + 1 = 50.
# So for 32000 samples, it should return 50 soft tokens.

# For 7.5s = 120000 samples.
# mel frames = (120000 - 321) // 160 + 1 = 119679 // 160 + 1 = 747 + 1 =
# 748 mel frames.
# Downsampling 1: (748 + 2 - 3) // 2 + 1 = 373 + 1 = 374.
# Downsampling 2: (374 + 2 - 3) // 2 + 1 = 186 + 1 = 187.
# So 187 soft tokens.

audio_lengths = jnp.array([[32000, 120000]], dtype=jnp.int32)
counts = gt.compute_valid_audio_soft_token_counts(audio_lengths, 750)
assert jnp.array_equal(counts, jnp.array([[50, 187]], dtype=jnp.int32))


def test_mask_padded_audio_tokens():
# Placeholder token ID is usually _token_utils.AUDIO_SOFT_TOKEN_PLACEHOLDER
placeholder = _token_utils.AUDIO_SOFT_TOKEN_PLACEHOLDER

# tokens has shape [B, L]
# B=1, L=13
# We have 2 audio clips, max 4 tokens each (so self.audio_seq_length=4)
# Clip 0 has valid count 2 (e.g. audio_lengths corresponds to 2 soft tokens)
# Clip 1 has valid count 3

tokens = jnp.array(
[[
1,
2,
placeholder,
placeholder,
placeholder,
placeholder,
3,
4,
placeholder,
placeholder,
placeholder,
placeholder,
5,
]],
dtype=jnp.int32,
)

# inputs_mask initially all True except PAD (none here)
inputs_mask = jnp.ones((1, 13), dtype=jnp.bool_)

# audio_lengths corresponding to [2, 3] soft tokens.
audio_lengths = jnp.array([[961, 1601]], dtype=jnp.int32)

valid_counts = gt.compute_valid_audio_soft_token_counts(
audio_lengths, audio_seq_length=4
)
new_inputs_mask = gt.mask_padded_audio_tokens(
tokens, inputs_mask, valid_counts, audio_seq_length=4
)

# Expected mask:
# Clip 0 (indices 2..5): first 2 True, last 2 False -> [T, T, F, F]
# Clip 1 (indices 8..11): first 3 True, last 1 False -> [T, T, T, F]
# Others: True
# Expected: [T, T, T, T, F, F, T, T, T, T, T, F, T]
expected_mask = jnp.array(
[[
True,
True,
True,
True,
False,
False,
True,
True,
True,
True,
True,
False,
True,
]],
dtype=jnp.bool_,
)

assert jnp.array_equal(new_inputs_mask, expected_mask)
35 changes: 25 additions & 10 deletions gemma/gm/text/_gemma4_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,37 @@ def sample(
if audio_lengths is None:
audio_lengths = [len(a) for a in audio]

for length in audio_lengths:
frame_length = int(round(self.audio_sample_rate * 20.0 / 1000.0))
hop_length = int(round(self.audio_sample_rate * 10.0 / 1000.0))
frame_size_for_unfold = frame_length + 1
num_mel_frames = (length - frame_size_for_unfold) // hop_length + 1
t = num_mel_frames
frame_length = int(round(self.audio_sample_rate * 20.0 / 1000.0))
hop_length = int(round(self.audio_sample_rate * 10.0 / 1000.0))
frame_size_for_unfold = frame_length + 1

if self.pad_length is not None:
audio_soft_token_counts = [self.audio_seq_length for _ in audio_lengths]
t = self.audio_seq_length
for _ in range(2):
t_padded = t + 2
t = (t_padded - 3) // 2 + 1
audio_soft_token_counts.append(min(t, self.audio_seq_length))
t = 2 * t
num_mel_frames = t
max_audio_len = (
(num_mel_frames - 1) * hop_length
+ frame_size_for_unfold
+ hop_length
- 1
)
else:
for length in audio_lengths:
num_mel_frames = (length - frame_size_for_unfold) // hop_length + 1
t = num_mel_frames
for _ in range(2):
t_padded = t + 2
t = (t_padded - 3) // 2 + 1
audio_soft_token_counts.append(min(t, self.audio_seq_length))
max_audio_len = max(len(a) for a in audio)

tokens = _token_utils.add_variable_extra_tokens_for_audio(
tokens,
soft_token_counts=audio_soft_token_counts,
)

max_audio_len = max(len(a) for a in audio)
padded_audio = np.zeros((len(audio), max_audio_len), dtype=np.float32)
for i, a in enumerate(audio):
padded_audio[i, : len(a)] = a
Expand Down Expand Up @@ -217,6 +231,7 @@ def sample(
audio=audio,
audio_lengths=audio_lengths,
audio_soft_token_counts=tuple(audio_soft_token_counts),
audio_seq_length=self.audio_seq_length,
)

if max_new_tokens and max_new_tokens > self.max_out_length:
Expand Down
50 changes: 45 additions & 5 deletions gemma/gm/text/_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from gemma.gm.data import _functional
from gemma.gm.nn import _config
from gemma.gm.nn import _transformer_like
from gemma.gm.nn.gemma4._transformer import compute_valid_audio_soft_token_counts
from gemma.gm.nn.gemma4._transformer import mask_padded_audio_tokens
from gemma.gm.text import _sampler_loop
from gemma.gm.text import _turn_utils
from gemma.gm.typing import _common
Expand Down Expand Up @@ -63,6 +65,7 @@ def prefill(
audio=None,
audio_lengths=None,
audio_soft_token_counts=None,
audio_seq_length: int | None = None,
) -> _sampler_loop.SamplingState:
"""Pre-fill the KV cache and initial model input.

Expand All @@ -83,6 +86,7 @@ def prefill(
audio: Audio input data or None.
audio_lengths: Lengths of audio inputs or None.
audio_soft_token_counts: Soft token counts for audio or None.
audio_seq_length: Max static audio sequence length.

Returns:
The initial state for the sampling loop.
Expand Down Expand Up @@ -223,19 +227,21 @@ def prefill(
new_used_cache_length = (
prev_turns.used_cache_length + input.length_with_mm - 1
)
cache = cache.set_end_index(new_used_cache_length)
if audio_seq_length is None and audio_lengths is not None:
if hasattr(model, 'config') and hasattr(model.config, 'audio_seq_length'):
audio_seq_length = model.config.audio_seq_length
else:
audio_seq_length = 750

# TODO(epot): The first token was predicted, so could use this, but would
# require to duplicate the logic of `_sample_step`, so leave this for later
# The `_sample_loop` will re-start from the last prompt token, so use `-1`
# as the first token is re-computed.
return _make_init_state(
input=input,
max_out_length=max_out_length,
new_used_cache_length=new_used_cache_length,
prev_turns=prev_turns,
cache=cache,
rng=rng,
audio_lengths=audio_lengths,
audio_seq_length=audio_seq_length,
)


Expand All @@ -247,17 +253,36 @@ def _make_init_state(
prev_turns: _turn_utils.PrevTurns,
cache: _cache_helper.Cache,
rng: PRNGKey,
audio_lengths: jax.Array | None = None,
audio_seq_length: int | None = None,
) -> _sampler_loop.SamplingState:
"""Initial state for the sampling loop."""

# The new last token position is shifted by the prompt length (after MM).
last_token_pos = input.last_token_pos + prev_turns.last_token_pos
if audio_lengths is not None and audio_seq_length is not None:
valid_counts = compute_valid_audio_soft_token_counts(
audio_lengths, audio_seq_length
)
padded_counts = jnp.sum(audio_seq_length - valid_counts, axis=-1)
last_token_pos = last_token_pos - padded_counts

# Pre-compute the full attention mask for the last step.
full_attention_mask = _make_full_attention_mask(
input=input,
prev_turns=prev_turns,
cache_length=cache.total_cache_length,
audio_lengths=audio_lengths,
audio_seq_length=audio_seq_length,
)

jax.debug.print(
'PREFILL DEBUG: last_token_pos={pos}, new_used_cache_length={cache_len},'
' mask_sum={mask_sum}, tokens_with_mm_len={tokens_len}',
pos=last_token_pos,
cache_len=new_used_cache_length,
mask_sum=full_attention_mask.sum(axis=-1),
tokens_len=input.tokens_with_mm.shape[-1],
)

return _sampler_loop.SamplingState(
Expand Down Expand Up @@ -374,6 +399,8 @@ def _make_full_attention_mask(
input: _types.Input, # pylint: disable=redefined-builtin
prev_turns: _turn_utils.PrevTurns,
cache_length: int,
audio_lengths: jax.Array | None = None,
audio_seq_length: int | None = None,
):
"""Pre-compute the full attention mask for the full `cache_length`.

Expand Down Expand Up @@ -404,13 +431,26 @@ def _make_full_attention_mask(
input: The input tokens.
prev_turns: The previous turns.
cache_length: The maximum length of the sequence.
audio_lengths: Padded audio lengths.
audio_seq_length: Max static audio sequence length.

Returns:
The full attention mask.
"""
# Mask out the padding tokens.
full_attention_mask = input.tokens_with_mm != _PADDING_ID

if audio_lengths is not None and audio_seq_length is not None:
valid_counts = compute_valid_audio_soft_token_counts(
audio_lengths, audio_seq_length
)
full_attention_mask = mask_padded_audio_tokens(
tokens=input.tokens_with_mm,
inputs_mask=full_attention_mask,
valid_counts=valid_counts,
audio_seq_length=audio_seq_length,
)

# Compute the full attention mask across turns.
if prev_turns:
full_attention_mask = jnp.concatenate(
Expand Down
Loading
Loading