Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions gemma/diffusion/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def sample_next_canvas(
cache: _config.Cache | None,
params: _common.Params,
rng: PRNGKey,
full_attention_mask: Bool['*B CacheLength'] | None = None,
) -> Tokens:
"""Samples a complete denoised canvas from an initial noisy canvas.

Expand Down Expand Up @@ -391,6 +392,7 @@ def sample_next_canvas(
canvas_length=canvas_length,
cache_length=cache_length,
num_valid_tokens=samples_in_cache,
full_attention_mask=full_attention_mask,
)

block_local_mask = _make_block_local_attention_mask(
Expand Down Expand Up @@ -510,6 +512,7 @@ def _sample_step(
cache=cache,
params=params,
rng=sample_rng,
full_attention_mask=state.full_attention_mask,
)

canvas, batch_has_stop_token = _truncate_canvas_at_stop_tokens(
Expand Down Expand Up @@ -654,6 +657,7 @@ def _make_global_attention_mask(
canvas_length: int,
cache_length: int | None,
num_valid_tokens: Int['*B'] | None,
full_attention_mask: Bool['*B CacheLength'] | None = None,
) -> Bool['*B CanvasLength CacheLength']:
"""Create attention mask for the diffusion sampler.

Expand All @@ -669,6 +673,7 @@ def _make_global_attention_mask(
cache_length: The length of the cache. If None, no cache is used.
num_valid_tokens: The number of valid tokens in the cache. Required if
cache_length is not None.
full_attention_mask: The full attention mask for prompt and cache.

Returns:
The attention mask.
Expand All @@ -684,6 +689,8 @@ def _make_global_attention_mask(

total_valid = jnp.minimum(num_valid_tokens + canvas_length, cache_length)
mask = jnp.arange(cache_length)[None, :] < total_valid[:, None]
if full_attention_mask is not None:
mask = mask & full_attention_mask

return jnp.broadcast_to(
mask[:, None, :], (batch_size, canvas_length, cache_length)
Expand Down
54 changes: 54 additions & 0 deletions gemma/diffusion/_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,60 @@ def test_make_global_attention_mask_batched_edge_cases(self):
with self.subTest(name=name):
np.testing.assert_array_equal(mask[i], expected)

def test_make_global_attention_mask_with_full_attention_mask(self):
"""Tests global attention mask combined with a prompt full_attention_mask."""
batch_size = 2
canvas_length = 3
cache_length = 8

# 4 valid cache tokens (so total_valid = min(4 + 3, 8) = 7)
num_valid_tokens = jnp.array([4, 4], dtype=jnp.int32)

# Batch 0: has pad at index 2 (so mask is [1, 1, 0, 1, 1, 1, 1, 1])
# Batch 1: has pad at index 1 and 3 (so mask is [1, 0, 1, 0, 1, 1, 1, 1])
full_attention_mask = jnp.array(
[
[True, True, False, True, True, True, True, True],
[True, False, True, False, True, True, True, True],
],
dtype=jnp.bool_,
)

mask = _sampler._make_global_attention_mask(
batch_size=batch_size,
canvas_length=canvas_length,
cache_length=cache_length,
num_valid_tokens=num_valid_tokens,
full_attention_mask=full_attention_mask,
)

# Base mask (without full_attention_mask) for total_valid = 7 would be:
# [1, 1, 1, 1, 1, 1, 1, 0] for both batches.
# With full_attention_mask, it should be:
# Batch 0: [1, 1, 1, 1, 1, 1, 1, 0] & [1, 1, 0, 1, 1, 1, 1, 1] = [1, 1, 0, 1, 1, 1, 1, 0]
# Batch 1: [1, 1, 1, 1, 1, 1, 1, 0] & [1, 0, 1, 0, 1, 1, 1, 1] = [1, 0, 1, 0, 1, 1, 1, 0]

expected_batch_0 = jnp.array(
[
[1, 1, 0, 1, 1, 1, 1, 0],
[1, 1, 0, 1, 1, 1, 1, 0],
[1, 1, 0, 1, 1, 1, 1, 0],
],
dtype=jnp.bool_,
)

expected_batch_1 = jnp.array(
[
[1, 0, 1, 0, 1, 1, 1, 0],
[1, 0, 1, 0, 1, 1, 1, 0],
[1, 0, 1, 0, 1, 1, 1, 0],
],
dtype=jnp.bool_,
)

np.testing.assert_array_equal(mask[0], expected_batch_0)
np.testing.assert_array_equal(mask[1], expected_batch_1)

def test_make_causal_attention_mask_no_cache(self):
"""Tests that the causal mask is lower-triangular when no cache is used.

Expand Down
Loading