diff --git a/nemo_rl/data/llm_message_utils.py b/nemo_rl/data/llm_message_utils.py index 29840b2b8d..a28c796fae 100644 --- a/nemo_rl/data/llm_message_utils.py +++ b/nemo_rl/data/llm_message_utils.py @@ -426,6 +426,26 @@ def get_first_index_that_differs(str1: str, str2: str) -> int: return min(len(str1), len(str2)) +def _get_longest_suffix_prefix_overlap(prev: str, cur: str) -> int: + """Find max overlap where suffix(prev) == prefix(cur).""" + max_overlap = min(len(prev), len(cur)) + for k in range(max_overlap, 0, -1): + if prev[-k:] == cur[:k]: + return k + return 0 + + +def _get_longest_token_suffix_prefix_overlap( + prev_ids: list[int], cur_ids: list[int] +) -> int: + """Find max overlap where suffix(prev_ids) == prefix(cur_ids).""" + max_overlap = min(len(prev_ids), len(cur_ids)) + for k in range(max_overlap, 0, -1): + if prev_ids[-k:] == cur_ids[:k]: + return k + return 0 + + def get_formatted_message_log( message_log: LLMMessageLogType, tokenizer: TokenizerType, @@ -452,6 +472,8 @@ def get_formatted_message_log( """ new_message_log: LLMMessageLogType = [] prev_formatted_message = "" + accumulated_text = "" + accumulated_token_ids: list[int] = [] message_log_strs: list[dict[str, str]] = cast( list[dict[str, str]], message_log ) # we just use the str:str parts here @@ -537,8 +559,12 @@ def _format_content_helper( formatted_message, ) - ## pull out the chunk corresponding to the current message - message_chunk = formatted_message[prev_message_len_no_eos:] + # Drop overlap with already-emitted text (for non-monotonic templates). + raw_message_chunk = formatted_message[prev_message_len_no_eos:] + overlap = _get_longest_suffix_prefix_overlap( + accumulated_text, raw_message_chunk + ) + message_chunk = raw_message_chunk[overlap:] # Debug: Print each message turn separately if debug: @@ -600,9 +626,22 @@ def _format_content_helper( new_message = message.copy() # extend this if statement to check for all(len(modality)) == 0 when adding other modalities if len(media_cur_message) == 0: - new_message["token_ids"] = tokenizer( - text=message_chunk, return_tensors="pt", add_special_tokens=False + # Tokenize in accumulated context, then keep only newly appended ids. + context_text = accumulated_text + message_chunk + context_ids = tokenizer( + text=context_text, return_tensors="pt", add_special_tokens=False )["input_ids"][0] + context_ids_list = cast(list[int], context_ids.tolist()) + token_overlap = _get_longest_token_suffix_prefix_overlap( + accumulated_token_ids, context_ids_list + ) + new_ids = context_ids_list[token_overlap:] + new_message["token_ids"] = torch.tensor( + new_ids, + dtype=context_ids.dtype, + device=context_ids.device, + ) + accumulated_token_ids = context_ids_list else: # extend the else statement to add other modalities (in this case, tokenizer will be a processor) media_kwargs = {} @@ -661,6 +700,7 @@ def _format_content_helper( new_message["content"].append(item) new_message_log.append(new_message) + accumulated_text += message_chunk prev_formatted_message = formatted_message return new_message_log diff --git a/tests/unit/data/test_llm_message_utils.py b/tests/unit/data/test_llm_message_utils.py index b39f117593..2f9cde5896 100644 --- a/tests/unit/data/test_llm_message_utils.py +++ b/tests/unit/data/test_llm_message_utils.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Callable +from typing import Any, Callable, cast import pytest import torch @@ -697,6 +697,106 @@ def test_get_first_index_that_differs(): assert get_first_index_that_differs("hello2", "hi1") == 1 +class _DummyReasoningTokenizer: + """Minimal tokenizer to reproduce #2821/#2844 behaviors in unit tests.""" + + bos_token = None + eos_token = "" + + def apply_chat_template( + self, + messages, + add_generation_prompt=False, + tokenize=False, + add_special_tokens=False, + **kwargs, + ): + assert not tokenize + out = "" + for idx, m in enumerate(messages): + role = m["role"] + content = m["content"] + if role == "user": + out += f"<|user|>{content}" + elif role == "assistant": + # Keep only on latest assistant turn (#2821 repro). + if idx == len(messages) - 1: + out += f"<|assistant|>{content}" + else: + out += f"<|assistant|>{content.replace('', '')}" + + if add_generation_prompt: + out += "<|assistant|>\n" + return out + + def _encode(self, text: str) -> list[int]: + # Toy sentencepiece behavior for #2844 repro: + # standalone "The" -> 900, in-context "\nThe" -> 901 + ids: list[int] = [] + i = 0 + while i < len(text): + if i == 0 and text.startswith("The", i): + ids.append(900) + i += 3 + continue + if text.startswith("\nThe", i): + ids.append(10) # '\n' + ids.append(901) + i += 4 + continue + ids.append(ord(text[i]) + 1000) + i += 1 + return ids + + def __call__(self, text, return_tensors="pt", add_special_tokens=False, **kwargs): + ids = self._encode(text) + return {"input_ids": torch.tensor([ids], dtype=torch.long)} + + +def test_get_formatted_message_log_no_duplication_for_non_monotonic_reasoning_template(): + tokenizer = _DummyReasoningTokenizer() + task_data_spec = TaskDataSpec(task_name="test") + msgs = [ + {"role": "user", "content": "q1"}, + {"role": "assistant", "content": "ANSWER1"}, + {"role": "user", "content": "q2"}, + {"role": "assistant", "content": "ANSWER2"}, + ] + + out = get_formatted_message_log( + msgs, + tokenizer, + task_data_spec, + add_bos_token=False, + add_eos_token=False, + add_generation_prompt=False, + ) + full = "".join(cast(str, m["content"]) for m in out) + assert full.count("ANSWER1") == 1 + assert full.count("ANSWER2") == 1 + + +def test_get_formatted_message_log_uses_context_tokenization_for_assistant_leading_token(): + tokenizer = _DummyReasoningTokenizer() + task_data_spec = TaskDataSpec(task_name="test") + msgs = [ + {"role": "user", "content": "What is 6 times 7?"}, + {"role": "assistant", "content": "The answer is 42."}, + ] + + out = get_formatted_message_log( + msgs, + tokenizer, + task_data_spec, + add_bos_token=False, + add_eos_token=False, + add_generation_prompt=True, + ) + assistant_ids = cast(torch.Tensor, out[-1]["token_ids"]) + # Must use in-context token (901), not standalone chunk token (900). + assert assistant_ids[0].item() == 901 + + def test_message_log_to_flat_messages_with_packed_images() -> None: from nemo_rl.data.multimodal_utils import PackedTensor