diff --git a/gemma/gm/nn/gemma4/_transformer.py b/gemma/gm/nn/gemma4/_transformer.py index ad7927f2..6da07e93 100644 --- a/gemma/gm/nn/gemma4/_transformer.py +++ b/gemma/gm/nn/gemma4/_transformer.py @@ -107,6 +107,64 @@ class _Inputs: per_layer_inputs: Float['B L P'] | None = None +def compute_valid_audio_soft_token_counts( + audio_lengths: Int['B num_clips'], audio_seq_length: int +) -> Int['B num_clips']: + """Compute valid audio soft token counts from raw sample counts.""" + # frame_length = int(round(16000 * 20.0 / 1000.0)) = 320 + # hop_length = int(round(16000 * 10.0 / 1000.0)) = 160 + # frame_size_for_unfold = 321 + frame_size_for_unfold = 321 + hop_length = 160 + + # Compute mel frames using integer division + num_mel_frames = (audio_lengths - frame_size_for_unfold) // hop_length + 1 + t = num_mel_frames + # two 2x downsampling layers + for _ in range(2): + t_padded = t + 2 + t = (t_padded - 3) // 2 + 1 + + return jnp.minimum(t, audio_seq_length) + + +def mask_padded_audio_tokens( + tokens: Int['B L'], + inputs_mask: Bool['B L'], + valid_counts: Int['B num_clips'], + audio_seq_length: int, +) -> Bool['B L']: + """Modify inputs_mask to set padded audio token positions to False.""" + # audio_mask: [B, L] + audio_mask = tokens == _token_utils.AUDIO_SOFT_TOKEN_PLACEHOLDER + + # cumsum: [B, L] + cumsum = jnp.cumsum(audio_mask, axis=-1) + global_audio_index = (cumsum - 1) * audio_mask + + # clip_idx: [B, L] + clip_idx = global_audio_index // audio_seq_length + # local_idx: [B, L] + local_idx = global_audio_index % audio_seq_length + + # safe_clip_idx: [B, L] + safe_clip_idx = jnp.maximum(clip_idx, 0) + + # expanded_valid_counts: [B, L] + expanded_valid_counts = jnp.take_along_axis( + valid_counts, safe_clip_idx, axis=1 + ) + + # is_valid_audio: [B, L] + is_valid_audio = local_idx < expanded_valid_counts + + # padded_audio_mask: [B, L] + padded_audio_mask = audio_mask & ~is_valid_audio + + # Modify inputs_mask + return inputs_mask & ~padded_audio_mask + + class Transformer(nn.Module): """Base transformer class. @@ -447,6 +505,20 @@ def _encode_and_get_inputs( ) inputs_mask = tokens != _PADDING_ID + if ( + audio is not None + and audio_lengths is not None + and audio_soft_token_counts is not None + ): + valid_counts = compute_valid_audio_soft_token_counts( + audio_lengths, audio_soft_token_counts[0] + ) + inputs_mask = mask_padded_audio_tokens( + tokens, + inputs_mask, + valid_counts, + audio_soft_token_counts[0], + ) if positions is None: positions = _pos_utils.build_positions_from_mask(inputs_mask) if attention_mask is None: diff --git a/gemma/gm/nn/gemma4/_transformer_test.py b/gemma/gm/nn/gemma4/_transformer_test.py index 9a6675f6..e0bc9d9d 100644 --- a/gemma/gm/nn/gemma4/_transformer_test.py +++ b/gemma/gm/nn/gemma4/_transformer_test.py @@ -16,6 +16,7 @@ from gemma.gm.nn.gemma4 import _gemma4 as gemma4_models from gemma.gm.nn.gemma4 import _transformer as gt +from gemma.gm.vision import _token_utils import jax import jax.numpy as jnp import pytest @@ -57,3 +58,97 @@ def test_text_only(): out, params = _get_output(model, tokens=tokens) assert 'vision_encoder' not in params assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed) + + +def test_compute_valid_audio_soft_token_counts(): + # Test with dummy audio lengths + # audio_lengths has shape [B, num_clips] + # For 16000 Hz: + # 2.0s = 32000 samples. + # mel frames = (32000 - 321) // 160 + 1 = 31679 // 160 + 1 = 197 + 1 = 198 + # mel frames. + # Downsampling 1: + # t = (198 + 2 - 3) // 2 + 1 = 197 // 2 + 1 = 98 + 1 = 99. + # Downsampling 2: + # t = (99 + 2 - 3) // 2 + 1 = 98 // 2 + 1 = 49 + 1 = 50. + # So for 32000 samples, it should return 50 soft tokens. + + # For 7.5s = 120000 samples. + # mel frames = (120000 - 321) // 160 + 1 = 119679 // 160 + 1 = 747 + 1 = + # 748 mel frames. + # Downsampling 1: (748 + 2 - 3) // 2 + 1 = 373 + 1 = 374. + # Downsampling 2: (374 + 2 - 3) // 2 + 1 = 186 + 1 = 187. + # So 187 soft tokens. + + audio_lengths = jnp.array([[32000, 120000]], dtype=jnp.int32) + counts = gt.compute_valid_audio_soft_token_counts(audio_lengths, 750) + assert jnp.array_equal(counts, jnp.array([[50, 187]], dtype=jnp.int32)) + + +def test_mask_padded_audio_tokens(): + # Placeholder token ID is usually _token_utils.AUDIO_SOFT_TOKEN_PLACEHOLDER + placeholder = _token_utils.AUDIO_SOFT_TOKEN_PLACEHOLDER + + # tokens has shape [B, L] + # B=1, L=13 + # We have 2 audio clips, max 4 tokens each (so self.audio_seq_length=4) + # Clip 0 has valid count 2 (e.g. audio_lengths corresponds to 2 soft tokens) + # Clip 1 has valid count 3 + + tokens = jnp.array( + [[ + 1, + 2, + placeholder, + placeholder, + placeholder, + placeholder, + 3, + 4, + placeholder, + placeholder, + placeholder, + placeholder, + 5, + ]], + dtype=jnp.int32, + ) + + # inputs_mask initially all True except PAD (none here) + inputs_mask = jnp.ones((1, 13), dtype=jnp.bool_) + + # audio_lengths corresponding to [2, 3] soft tokens. + audio_lengths = jnp.array([[961, 1601]], dtype=jnp.int32) + + valid_counts = gt.compute_valid_audio_soft_token_counts( + audio_lengths, audio_seq_length=4 + ) + new_inputs_mask = gt.mask_padded_audio_tokens( + tokens, inputs_mask, valid_counts, audio_seq_length=4 + ) + + # Expected mask: + # Clip 0 (indices 2..5): first 2 True, last 2 False -> [T, T, F, F] + # Clip 1 (indices 8..11): first 3 True, last 1 False -> [T, T, T, F] + # Others: True + # Expected: [T, T, T, T, F, F, T, T, T, T, T, F, T] + expected_mask = jnp.array( + [[ + True, + True, + True, + True, + False, + False, + True, + True, + True, + True, + True, + False, + True, + ]], + dtype=jnp.bool_, + ) + + assert jnp.array_equal(new_inputs_mask, expected_mask) diff --git a/gemma/gm/text/_gemma4_sampler.py b/gemma/gm/text/_gemma4_sampler.py index 3fa815d6..66407fb6 100644 --- a/gemma/gm/text/_gemma4_sampler.py +++ b/gemma/gm/text/_gemma4_sampler.py @@ -160,23 +160,37 @@ def sample( if audio_lengths is None: audio_lengths = [len(a) for a in audio] - for length in audio_lengths: - frame_length = int(round(self.audio_sample_rate * 20.0 / 1000.0)) - hop_length = int(round(self.audio_sample_rate * 10.0 / 1000.0)) - frame_size_for_unfold = frame_length + 1 - num_mel_frames = (length - frame_size_for_unfold) // hop_length + 1 - t = num_mel_frames + frame_length = int(round(self.audio_sample_rate * 20.0 / 1000.0)) + hop_length = int(round(self.audio_sample_rate * 10.0 / 1000.0)) + frame_size_for_unfold = frame_length + 1 + + if self.pad_length is not None: + audio_soft_token_counts = [self.audio_seq_length for _ in audio_lengths] + t = self.audio_seq_length for _ in range(2): - t_padded = t + 2 - t = (t_padded - 3) // 2 + 1 - audio_soft_token_counts.append(min(t, self.audio_seq_length)) + t = 2 * t + num_mel_frames = t + max_audio_len = ( + (num_mel_frames - 1) * hop_length + + frame_size_for_unfold + + hop_length + - 1 + ) + else: + for length in audio_lengths: + num_mel_frames = (length - frame_size_for_unfold) // hop_length + 1 + t = num_mel_frames + for _ in range(2): + t_padded = t + 2 + t = (t_padded - 3) // 2 + 1 + audio_soft_token_counts.append(min(t, self.audio_seq_length)) + max_audio_len = max(len(a) for a in audio) tokens = _token_utils.add_variable_extra_tokens_for_audio( tokens, soft_token_counts=audio_soft_token_counts, ) - max_audio_len = max(len(a) for a in audio) padded_audio = np.zeros((len(audio), max_audio_len), dtype=np.float32) for i, a in enumerate(audio): padded_audio[i, : len(a)] = a @@ -217,6 +231,7 @@ def sample( audio=audio, audio_lengths=audio_lengths, audio_soft_token_counts=tuple(audio_soft_token_counts), + audio_seq_length=self.audio_seq_length, ) if max_new_tokens and max_new_tokens > self.max_out_length: diff --git a/gemma/gm/text/_prefill.py b/gemma/gm/text/_prefill.py index 46432e68..bdcf62d9 100644 --- a/gemma/gm/text/_prefill.py +++ b/gemma/gm/text/_prefill.py @@ -21,6 +21,8 @@ from gemma.gm.data import _functional from gemma.gm.nn import _config from gemma.gm.nn import _transformer_like +from gemma.gm.nn.gemma4._transformer import compute_valid_audio_soft_token_counts +from gemma.gm.nn.gemma4._transformer import mask_padded_audio_tokens from gemma.gm.text import _sampler_loop from gemma.gm.text import _turn_utils from gemma.gm.typing import _common @@ -63,6 +65,7 @@ def prefill( audio=None, audio_lengths=None, audio_soft_token_counts=None, + audio_seq_length: int | None = None, ) -> _sampler_loop.SamplingState: """Pre-fill the KV cache and initial model input. @@ -83,6 +86,7 @@ def prefill( audio: Audio input data or None. audio_lengths: Lengths of audio inputs or None. audio_soft_token_counts: Soft token counts for audio or None. + audio_seq_length: Max static audio sequence length. Returns: The initial state for the sampling loop. @@ -223,12 +227,12 @@ def prefill( new_used_cache_length = ( prev_turns.used_cache_length + input.length_with_mm - 1 ) - cache = cache.set_end_index(new_used_cache_length) + if audio_seq_length is None and audio_lengths is not None: + if hasattr(model, 'config') and hasattr(model.config, 'audio_seq_length'): + audio_seq_length = model.config.audio_seq_length + else: + audio_seq_length = 750 - # TODO(epot): The first token was predicted, so could use this, but would - # require to duplicate the logic of `_sample_step`, so leave this for later - # The `_sample_loop` will re-start from the last prompt token, so use `-1` - # as the first token is re-computed. return _make_init_state( input=input, max_out_length=max_out_length, @@ -236,6 +240,8 @@ def prefill( prev_turns=prev_turns, cache=cache, rng=rng, + audio_lengths=audio_lengths, + audio_seq_length=audio_seq_length, ) @@ -247,17 +253,36 @@ def _make_init_state( prev_turns: _turn_utils.PrevTurns, cache: _cache_helper.Cache, rng: PRNGKey, + audio_lengths: jax.Array | None = None, + audio_seq_length: int | None = None, ) -> _sampler_loop.SamplingState: """Initial state for the sampling loop.""" # The new last token position is shifted by the prompt length (after MM). last_token_pos = input.last_token_pos + prev_turns.last_token_pos + if audio_lengths is not None and audio_seq_length is not None: + valid_counts = compute_valid_audio_soft_token_counts( + audio_lengths, audio_seq_length + ) + padded_counts = jnp.sum(audio_seq_length - valid_counts, axis=-1) + last_token_pos = last_token_pos - padded_counts # Pre-compute the full attention mask for the last step. full_attention_mask = _make_full_attention_mask( input=input, prev_turns=prev_turns, cache_length=cache.total_cache_length, + audio_lengths=audio_lengths, + audio_seq_length=audio_seq_length, + ) + + jax.debug.print( + 'PREFILL DEBUG: last_token_pos={pos}, new_used_cache_length={cache_len},' + ' mask_sum={mask_sum}, tokens_with_mm_len={tokens_len}', + pos=last_token_pos, + cache_len=new_used_cache_length, + mask_sum=full_attention_mask.sum(axis=-1), + tokens_len=input.tokens_with_mm.shape[-1], ) return _sampler_loop.SamplingState( @@ -374,6 +399,8 @@ def _make_full_attention_mask( input: _types.Input, # pylint: disable=redefined-builtin prev_turns: _turn_utils.PrevTurns, cache_length: int, + audio_lengths: jax.Array | None = None, + audio_seq_length: int | None = None, ): """Pre-compute the full attention mask for the full `cache_length`. @@ -404,6 +431,8 @@ def _make_full_attention_mask( input: The input tokens. prev_turns: The previous turns. cache_length: The maximum length of the sequence. + audio_lengths: Padded audio lengths. + audio_seq_length: Max static audio sequence length. Returns: The full attention mask. @@ -411,6 +440,17 @@ def _make_full_attention_mask( # Mask out the padding tokens. full_attention_mask = input.tokens_with_mm != _PADDING_ID + if audio_lengths is not None and audio_seq_length is not None: + valid_counts = compute_valid_audio_soft_token_counts( + audio_lengths, audio_seq_length + ) + full_attention_mask = mask_padded_audio_tokens( + tokens=input.tokens_with_mm, + inputs_mask=full_attention_mask, + valid_counts=valid_counts, + audio_seq_length=audio_seq_length, + ) + # Compute the full attention mask across turns. if prev_turns: full_attention_mask = jnp.concatenate( diff --git a/gemma/gm/text/_prefill_test.py b/gemma/gm/text/_prefill_test.py index 5cc921c2..b416b210 100644 --- a/gemma/gm/text/_prefill_test.py +++ b/gemma/gm/text/_prefill_test.py @@ -18,7 +18,9 @@ from gemma.gm.text import _prefill from gemma.gm.text import _sampler_loop from gemma.gm.text import _turn_utils +from gemma.gm.utils import _cache_helper from gemma.gm.utils import _types +from gemma.gm.vision import _token_utils import jax import jax.numpy as jnp import numpy as np @@ -159,3 +161,98 @@ def test_full_attention_mask(): ], ) + +def test_full_attention_mask_with_audio(): + placeholder = _token_utils.AUDIO_SOFT_TOKEN_PLACEHOLDER + + tokens = jnp.array( + [[ + 1, + 2, + placeholder, + placeholder, + placeholder, + placeholder, + 3, + 4, + placeholder, + placeholder, + placeholder, + placeholder, + 5, + ]], + dtype=jnp.int32, + ) + + input_ = _types.Input( + text=tokens, + images=None, + config=_types.InputConfig( + support_images=True, + num_tokens_per_image=100, + special_tokens=gm.text.Gemma3Tokenizer.special_tokens, + ), + ) + + audio_lengths = jnp.array([[961, 1601]], dtype=jnp.int32) + + mask = _prefill._make_full_attention_mask( + input=input_, + prev_turns=_turn_utils.PrevTurns(last_state=None), + cache_length=20, + audio_lengths=audio_lengths, + audio_seq_length=4, + ) + + expected_mask = jnp.array( + [[ + True, + True, + True, + True, + False, + False, + True, + True, + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + ]], + dtype=jnp.bool_, + ) + + np.testing.assert_array_equal(mask, expected_mask) + + dummy_cache = _cache_helper.Cache({ + 'layer_0': { + 'k': jnp.zeros((1, 20, 8, 128)), + 'v': jnp.zeros((1, 20, 8, 128)), + 'positions': jnp.zeros((1, 20)), + 'end_index': jnp.zeros((1,)), + } + }) + + init_state = _prefill._make_init_state( + input=input_, + max_out_length=10, + new_used_cache_length=input_.length_with_mm - 1, + prev_turns=_turn_utils.PrevTurns(last_state=None), + cache=dummy_cache, + rng=jax.random.PRNGKey(0), + audio_lengths=audio_lengths, + audio_seq_length=4, + ) + + # The last token position should be 9 (12 - 3 padded audio tokens) + assert init_state.last_token_pos == 9 + # The init cache length should be 12 (13 - 1) + assert init_state.init_cache_length == 12