diff --git a/gemma/diffusion/_sampler.py b/gemma/diffusion/_sampler.py index 71e4eb01..b2e9228c 100644 --- a/gemma/diffusion/_sampler.py +++ b/gemma/diffusion/_sampler.py @@ -354,6 +354,7 @@ def sample_next_canvas( cache: _config.Cache | None, params: _common.Params, rng: PRNGKey, + full_attention_mask: Bool['*B CacheLength'] | None = None, ) -> Tokens: """Samples a complete denoised canvas from an initial noisy canvas. @@ -391,6 +392,7 @@ def sample_next_canvas( canvas_length=canvas_length, cache_length=cache_length, num_valid_tokens=samples_in_cache, + full_attention_mask=full_attention_mask, ) block_local_mask = _make_block_local_attention_mask( @@ -510,6 +512,7 @@ def _sample_step( cache=cache, params=params, rng=sample_rng, + full_attention_mask=state.full_attention_mask, ) canvas, batch_has_stop_token = _truncate_canvas_at_stop_tokens( @@ -654,6 +657,7 @@ def _make_global_attention_mask( canvas_length: int, cache_length: int | None, num_valid_tokens: Int['*B'] | None, + full_attention_mask: Bool['*B CacheLength'] | None = None, ) -> Bool['*B CanvasLength CacheLength']: """Create attention mask for the diffusion sampler. @@ -669,6 +673,7 @@ def _make_global_attention_mask( cache_length: The length of the cache. If None, no cache is used. num_valid_tokens: The number of valid tokens in the cache. Required if cache_length is not None. + full_attention_mask: The full attention mask for prompt and cache. Returns: The attention mask. @@ -684,6 +689,8 @@ def _make_global_attention_mask( total_valid = jnp.minimum(num_valid_tokens + canvas_length, cache_length) mask = jnp.arange(cache_length)[None, :] < total_valid[:, None] + if full_attention_mask is not None: + mask = mask & full_attention_mask return jnp.broadcast_to( mask[:, None, :], (batch_size, canvas_length, cache_length) diff --git a/gemma/diffusion/_sampler_test.py b/gemma/diffusion/_sampler_test.py index b35db7e2..eff98b78 100644 --- a/gemma/diffusion/_sampler_test.py +++ b/gemma/diffusion/_sampler_test.py @@ -520,6 +520,60 @@ def test_make_global_attention_mask_batched_edge_cases(self): with self.subTest(name=name): np.testing.assert_array_equal(mask[i], expected) + def test_make_global_attention_mask_with_full_attention_mask(self): + """Tests global attention mask combined with a prompt full_attention_mask.""" + batch_size = 2 + canvas_length = 3 + cache_length = 8 + + # 4 valid cache tokens (so total_valid = min(4 + 3, 8) = 7) + num_valid_tokens = jnp.array([4, 4], dtype=jnp.int32) + + # Batch 0: has pad at index 2 (so mask is [1, 1, 0, 1, 1, 1, 1, 1]) + # Batch 1: has pad at index 1 and 3 (so mask is [1, 0, 1, 0, 1, 1, 1, 1]) + full_attention_mask = jnp.array( + [ + [True, True, False, True, True, True, True, True], + [True, False, True, False, True, True, True, True], + ], + dtype=jnp.bool_, + ) + + mask = _sampler._make_global_attention_mask( + batch_size=batch_size, + canvas_length=canvas_length, + cache_length=cache_length, + num_valid_tokens=num_valid_tokens, + full_attention_mask=full_attention_mask, + ) + + # Base mask (without full_attention_mask) for total_valid = 7 would be: + # [1, 1, 1, 1, 1, 1, 1, 0] for both batches. + # With full_attention_mask, it should be: + # Batch 0: [1, 1, 1, 1, 1, 1, 1, 0] & [1, 1, 0, 1, 1, 1, 1, 1] = [1, 1, 0, 1, 1, 1, 1, 0] + # Batch 1: [1, 1, 1, 1, 1, 1, 1, 0] & [1, 0, 1, 0, 1, 1, 1, 1] = [1, 0, 1, 0, 1, 1, 1, 0] + + expected_batch_0 = jnp.array( + [ + [1, 1, 0, 1, 1, 1, 1, 0], + [1, 1, 0, 1, 1, 1, 1, 0], + [1, 1, 0, 1, 1, 1, 1, 0], + ], + dtype=jnp.bool_, + ) + + expected_batch_1 = jnp.array( + [ + [1, 0, 1, 0, 1, 1, 1, 0], + [1, 0, 1, 0, 1, 1, 1, 0], + [1, 0, 1, 0, 1, 1, 1, 0], + ], + dtype=jnp.bool_, + ) + + np.testing.assert_array_equal(mask[0], expected_batch_0) + np.testing.assert_array_equal(mask[1], expected_batch_1) + def test_make_causal_attention_mask_no_cache(self): """Tests that the causal mask is lower-triangular when no cache is used.