Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions nemo_rl/data/llm_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,26 @@ def get_first_index_that_differs(str1: str, str2: str) -> int:
return min(len(str1), len(str2))


def _get_longest_suffix_prefix_overlap(prev: str, cur: str) -> int:
"""Find max overlap where suffix(prev) == prefix(cur)."""
max_overlap = min(len(prev), len(cur))
for k in range(max_overlap, 0, -1):
if prev[-k:] == cur[:k]:
return k
return 0


def _get_longest_token_suffix_prefix_overlap(
prev_ids: list[int], cur_ids: list[int]
) -> int:
"""Find max overlap where suffix(prev_ids) == prefix(cur_ids)."""
max_overlap = min(len(prev_ids), len(cur_ids))
for k in range(max_overlap, 0, -1):
if prev_ids[-k:] == cur_ids[:k]:
return k
return 0


def get_formatted_message_log(
message_log: LLMMessageLogType,
tokenizer: TokenizerType,
Expand All @@ -452,6 +472,8 @@ def get_formatted_message_log(
"""
new_message_log: LLMMessageLogType = []
prev_formatted_message = ""
accumulated_text = ""
accumulated_token_ids: list[int] = []
message_log_strs: list[dict[str, str]] = cast(
list[dict[str, str]], message_log
) # we just use the str:str parts here
Expand Down Expand Up @@ -537,8 +559,12 @@ def _format_content_helper(
formatted_message,
)

## pull out the chunk corresponding to the current message
message_chunk = formatted_message[prev_message_len_no_eos:]
# Drop overlap with already-emitted text (for non-monotonic templates).
raw_message_chunk = formatted_message[prev_message_len_no_eos:]
overlap = _get_longest_suffix_prefix_overlap(
accumulated_text, raw_message_chunk
)
message_chunk = raw_message_chunk[overlap:]

# Debug: Print each message turn separately
if debug:
Expand Down Expand Up @@ -600,9 +626,22 @@ def _format_content_helper(
new_message = message.copy()
# extend this if statement to check for all(len(modality)) == 0 when adding other modalities
if len(media_cur_message) == 0:
new_message["token_ids"] = tokenizer(
text=message_chunk, return_tensors="pt", add_special_tokens=False
# Tokenize in accumulated context, then keep only newly appended ids.
context_text = accumulated_text + message_chunk
context_ids = tokenizer(
text=context_text, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]
context_ids_list = cast(list[int], context_ids.tolist())
token_overlap = _get_longest_token_suffix_prefix_overlap(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I'd like to double-check, new_ids is derived by subtracting the token overlap with accumulated_token_ids. This assumes tokenizing accumulated_text + message_chunk keeps the previous tokens as an exact prefix.
For a SentencePiece tokenizer without special role tokens (the #2844 setup), could a token merge across the turn boundary and cause the overlap to drift — pulling prior-turn tokens into the current turn's token_ids?

Could you verify on a real SentencePiece tokenizer that concat(per-turn token_ids) == tokenize(whole conversation) still holds? Will token drift happen at the boundry?

accumulated_token_ids, context_ids_list
)
new_ids = context_ids_list[token_overlap:]
new_message["token_ids"] = torch.tensor(
new_ids,
dtype=context_ids.dtype,
device=context_ids.device,
)
accumulated_token_ids = context_ids_list
else:
# extend the else statement to add other modalities (in this case, tokenizer will be a processor)
media_kwargs = {}
Expand Down Expand Up @@ -661,6 +700,7 @@ def _format_content_helper(
new_message["content"].append(item)

new_message_log.append(new_message)
accumulated_text += message_chunk
prev_formatted_message = formatted_message

return new_message_log
Expand Down
102 changes: 101 additions & 1 deletion tests/unit/data/test_llm_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import Any, Callable
from typing import Any, Callable, cast

import pytest
import torch
Expand Down Expand Up @@ -697,6 +697,106 @@ def test_get_first_index_that_differs():
assert get_first_index_that_differs("hello2", "hi1") == 1


class _DummyReasoningTokenizer:
"""Minimal tokenizer to reproduce #2821/#2844 behaviors in unit tests."""

bos_token = None
eos_token = "</s>"

def apply_chat_template(
self,
messages,
add_generation_prompt=False,
tokenize=False,
add_special_tokens=False,
**kwargs,
):
assert not tokenize
out = ""
for idx, m in enumerate(messages):
role = m["role"]
content = m["content"]
if role == "user":
out += f"<|user|>{content}"
elif role == "assistant":
# Keep <think> only on latest assistant turn (#2821 repro).
if idx == len(messages) - 1:
out += f"<|assistant|>{content}"
else:
out += f"<|assistant|>{content.replace('<think></think>', '')}"

if add_generation_prompt:
out += "<|assistant|>\n"
return out

def _encode(self, text: str) -> list[int]:
# Toy sentencepiece behavior for #2844 repro:
# standalone "The" -> 900, in-context "\nThe" -> 901
ids: list[int] = []
i = 0
while i < len(text):
if i == 0 and text.startswith("The", i):
ids.append(900)
i += 3
continue
if text.startswith("\nThe", i):
ids.append(10) # '\n'
ids.append(901)
i += 4
continue
ids.append(ord(text[i]) + 1000)
i += 1
return ids

def __call__(self, text, return_tensors="pt", add_special_tokens=False, **kwargs):
ids = self._encode(text)
return {"input_ids": torch.tensor([ids], dtype=torch.long)}


def test_get_formatted_message_log_no_duplication_for_non_monotonic_reasoning_template():
tokenizer = _DummyReasoningTokenizer()
task_data_spec = TaskDataSpec(task_name="test")
msgs = [
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "<think></think>ANSWER1"},
{"role": "user", "content": "q2"},
{"role": "assistant", "content": "<think></think>ANSWER2"},
]

out = get_formatted_message_log(
msgs,
tokenizer,
task_data_spec,
add_bos_token=False,
add_eos_token=False,
add_generation_prompt=False,
)
full = "".join(cast(str, m["content"]) for m in out)
assert full.count("ANSWER1") == 1
assert full.count("ANSWER2") == 1


def test_get_formatted_message_log_uses_context_tokenization_for_assistant_leading_token():
tokenizer = _DummyReasoningTokenizer()
task_data_spec = TaskDataSpec(task_name="test")
msgs = [
{"role": "user", "content": "What is 6 times 7?"},
{"role": "assistant", "content": "The answer is 42."},
]

out = get_formatted_message_log(
msgs,
tokenizer,
task_data_spec,
add_bos_token=False,
add_eos_token=False,
add_generation_prompt=True,
)
assistant_ids = cast(torch.Tensor, out[-1]["token_ids"])
# Must use in-context token (901), not standalone chunk token (900).
assert assistant_ids[0].item() == 901


def test_message_log_to_flat_messages_with_packed_images() -> None:
from nemo_rl.data.multimodal_utils import PackedTensor

Expand Down
Loading