From 6aade9927aebd036766caeee8f814c69fa0accef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Sat, 9 May 2026 11:52:31 +0800 Subject: [PATCH 001/104] wip --- cookbook/rl/grpo_condensed.py | 0 pyproject.toml | 1 + src/twinkle/data_format/sampling.py | 2 +- src/twinkle/dataset/base.py | 7 + src/twinkle/metric/__init__.py | 1 + src/twinkle/metric/grpo.py | 257 ++++++++++++++++++ .../sampler/vllm_sampler/vllm_engine.py | 5 +- .../sampler/vllm_sampler/vllm_sampler.py | 7 +- src/twinkle/template/__init__.py | 5 + src/twinkle/template/base.py | 126 ++++++++- src/twinkle/template/qwen.py | 81 ++++++ src/twinkle/template/qwen3_5_vl.py | 48 +++- src/twinkle_agentic/__init__.py | 0 src/twinkle_agentic/data_format/__init__.py | 0 src/twinkle_agentic/data_format/chunk.py | 11 + 15 files changed, 520 insertions(+), 31 deletions(-) create mode 100644 cookbook/rl/grpo_condensed.py create mode 100644 src/twinkle/metric/grpo.py create mode 100644 src/twinkle/template/qwen.py create mode 100644 src/twinkle_agentic/__init__.py create mode 100644 src/twinkle_agentic/data_format/__init__.py create mode 100644 src/twinkle_agentic/data_format/chunk.py diff --git a/cookbook/rl/grpo_condensed.py b/cookbook/rl/grpo_condensed.py new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index 85ede352..964a7548 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ docs = [ packages = [ { include = "twinkle", from = "src" }, { include = "twinkle_client", from = "src" }, + { include = "twinkle_agentic", from = "src" }, ] [build-system] diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index 687030f8..e5884351 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -172,7 +172,7 @@ class SampledSequence: """A single sampled sequence with tokens and logprobs.""" stop_reason: StopReason tokens: List[int] - logprobs: Optional[List[float]] = None + logprobs: Optional[List[List[Tuple[int, float]]]] = None decoded: str = None new_input_feature: InputFeature = None diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index 501be7dd..db75c47e 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -15,6 +15,13 @@ from twinkle.utils import construct_class, processing_lock +try: + import multiprocess + multiprocess.set_start_method('spawn', force=True) +except RuntimeError: + pass + + @dataclass class DatasetMeta: """ diff --git a/src/twinkle/metric/__init__.py b/src/twinkle/metric/__init__.py index 59d5bbeb..ccdcb228 100644 --- a/src/twinkle/metric/__init__.py +++ b/src/twinkle/metric/__init__.py @@ -3,5 +3,6 @@ from .base import Metric from .completion_and_reward import CompletionRewardMetric from .dpo import DPOMetric +from .grpo import GRPOMetric from .loss import LossMetric from .train_metric import TrainMetric diff --git a/src/twinkle/metric/grpo.py b/src/twinkle/metric/grpo.py new file mode 100644 index 00000000..2f63e26c --- /dev/null +++ b/src/twinkle/metric/grpo.py @@ -0,0 +1,257 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import math +from typing import Any, Dict, List, Optional, Union +from twinkle.data_format import InputFeature, ModelOutput +from .base import Metric + + +def _align_logps_to_mask( + ragged: Any, + mask: 'torch.Tensor', # noqa: F821 + dtype: 'torch.dtype', # noqa: F821 +) -> Optional['torch.Tensor']: # noqa: F821 + import torch + + device = mask.device + batch_size, seq_len = mask.shape + + if isinstance(ragged, torch.Tensor): + t = ragged.to(device=device, dtype=dtype) + if t.shape == (batch_size, seq_len): + return t + # Fall through to the list path (row-wise scatter). + ragged = [t[i] for i in range(min(batch_size, t.shape[0]))] + + if not isinstance(ragged, (list, tuple)): + return None + + result = torch.zeros((batch_size, seq_len), dtype=dtype, device=device) + for i, sample in enumerate(ragged): + if i >= batch_size: + break + pos = mask[i].nonzero(as_tuple=True)[0] + if len(pos) == 0: + continue + if isinstance(sample, (int, float)): + result[i, pos] = float(sample) + continue + vals = torch.as_tensor(sample, dtype=dtype, device=device).flatten() + n = min(len(pos), int(vals.numel())) + if n > 0: + result[i, pos[:n]] = vals[:n] + return result + + +class GRPOMetric(Metric): + + def __init__( + self, + device_mesh=None, + process_group=None, + ignore_index: int = -100, + temperature: float = 1.0, + **kwargs, + ): + super().__init__(device_mesh, process_group, **kwargs) + self.has_old = None + self.n_tokens = None + self.sum_approx_kl = None + self.sum_diff = None + self.sum_old = None + self.sum_new_sq = None + self.sum_new = None + self.ignore_index = ignore_index + self.temperature = float(temperature) + self.reset() + + def reset(self): + self.sum_new: float = 0.0 + self.sum_new_sq: float = 0.0 + self.sum_old: float = 0.0 + self.sum_diff: float = 0.0 + self.sum_approx_kl: float = 0.0 + self.n_tokens: int = 0 + self.has_old: bool = False + + @staticmethod + def _as_mb_list(logps_val) -> Optional[List]: + import torch + if logps_val is None: + return None + if isinstance(logps_val, list): + return logps_val or None + if torch.is_tensor(logps_val): + if logps_val.numel() == 0: + return None + return [logps_val] + return None + + def _accumulate_mb( + self, + labels: 'torch.Tensor', + logps: 'torch.Tensor', + old_slice: Any, + ) -> int: + """Reduce one microbatch into ``self.sum_*`` counters. + + Returns ``labels.shape[0]`` so the caller can advance the + ``old_logps`` slicing cursor even when the microbatch had zero + generated tokens (e.g. fully-masked prompt-only batch). + """ + import torch + + if labels.dim() == 1: + labels = labels.unsqueeze(0) + if not torch.is_tensor(logps) or logps.numel() == 0: + return labels.shape[0] + if labels.device != logps.device: + labels = labels.to(logps.device) + + # Safety-align seq_len (SP / packed edge cases may leave a + # small off-by-one between labels and logps within a mb). + if logps.shape[-1] != labels.shape[-1]: + m = min(logps.shape[-1], labels.shape[-1]) + logps = logps[..., :m] + labels = labels[..., :m] + # Safety-align num_seq (mb-local; normally matches exactly). + if logps.shape[0] != labels.shape[0]: + n = min(logps.shape[0], labels.shape[0]) + logps = logps[:n] + labels = labels[:n] + + mask = (labels != self.ignore_index) + n_tok = int(mask.sum().item()) + num_seq = labels.shape[0] + if n_tok == 0: + return num_seq + + # Recover T=1 log-probs if user told us the sampler temperature. + # At T=1 this is a no-op (temperature field defaults to 1.0). + # Rescaling keeps ``logp_diff`` / ``approx_kl`` unchanged because + # both new and old logps receive the same multiplier. + scale = self.temperature + logps_f = logps.float() + if scale != 1.0: + logps_f = logps_f * scale + mask_f = mask.float() + + self.n_tokens += n_tok + self.sum_new += float((logps_f * mask_f).sum().item()) + self.sum_new_sq += float(((logps_f ** 2) * mask_f).sum().item()) + + if old_slice is None: + return num_seq + + aligned = _align_logps_to_mask(old_slice, mask, logps_f.dtype) + if aligned is None: + return num_seq + old_f = aligned.float() + if scale != 1.0: + old_f = old_f * scale + + d = logps_f - old_f # new - old + self.sum_old += float((old_f * mask_f).sum().item()) + self.sum_diff += float((d * mask_f).sum().item()) + # Schulman K3 estimator of KL(old || new): + # samples x ~ old, r(x) = new(x) / old(x), + # k3 = r - 1 - log(r) = exp(new - old) - (new - old) - 1. + kl = torch.exp(d) - d - 1.0 + self.sum_approx_kl += float((kl * mask_f).sum().item()) + self.has_old = True + return num_seq + + def accumulate( + self, + inputs: Union[InputFeature, List[InputFeature]], + outputs: ModelOutput, + *, + old_logps: Any = None, + **kwargs, + ): + import torch + if outputs is None: + return + assert 'logps' in outputs + logps_val = outputs.get('logps') + logps_list = self._as_mb_list(logps_val) + inputs_list = inputs if isinstance(inputs, list) else [inputs] + + if (torch.is_tensor(logps_val) and len(inputs_list) > 1 + and all(isinstance(i, dict) and i.get('labels') is not None + for i in inputs_list)): + label_tensors = [torch.as_tensor(i['labels']) for i in inputs_list] + seq_lens = {t.shape[-1] for t in label_tensors} + if len(seq_lens) == 1: + merged = torch.cat(label_tensors, dim=0) + inputs_list = [{'labels': merged}] + + flat_old: Optional[List] = None + if old_logps is not None and isinstance(old_logps, (list, tuple)): + flat_old = list(old_logps) + + cursor = 0 + n_mb = min(len(inputs_list), len(logps_list)) + for mb_idx in range(n_mb): + mb_input = inputs_list[mb_idx] + if not isinstance(mb_input, dict): + continue + labels = mb_input.get('labels') + if labels is None: + continue + import torch + labels = torch.as_tensor(labels) + + logps_mb = logps_list[mb_idx] + + if flat_old is not None: + num_seq_est = (labels.shape[0] if labels.dim() >= 2 else 1) + old_slice = flat_old[cursor:cursor + num_seq_est] + elif old_logps is not None and hasattr(old_logps, 'shape'): + # Uncommon: aligned global tensor. Only honour when it + # exactly matches the single-mb shape; otherwise drop. + import torch as _torch # noqa: F811 + old_slice = old_logps if (_torch.is_tensor(old_logps) and old_logps.shape + == logps_mb.shape) else None + else: + old_slice = None + + advanced = self._accumulate_mb(labels, logps_mb, old_slice) + cursor += advanced + + def calculate(self) -> Dict[str, Any]: + local = [{ + 'sum_new': self.sum_new, + 'sum_new_sq': self.sum_new_sq, + 'sum_old': self.sum_old, + 'sum_diff': self.sum_diff, + 'sum_kl': self.sum_approx_kl, + 'n': self.n_tokens, + 'has_old': self.has_old, + }] + all_results = self.gather_results(local) + + n_total = sum(r['n'] for r in all_results) + if n_total == 0: + self.reset() + return {} + + sum_new = sum(r['sum_new'] for r in all_results) + sum_new_sq = sum(r['sum_new_sq'] for r in all_results) + mean_new = sum_new / n_total + var_new = max(0.0, sum_new_sq / n_total - mean_new * mean_new) + + results: Dict[str, Any] = { + 'train/policy_confidence': f'{math.exp(mean_new):.4f}', + 'train/mean_new_logp': f'{mean_new:.4f}', + 'train/logp_std': f'{math.sqrt(var_new):.4f}', + } + if any(r['has_old'] for r in all_results): + mean_old = sum(r['sum_old'] for r in all_results) / n_total + mean_diff = sum(r['sum_diff'] for r in all_results) / n_total + mean_kl = sum(r['sum_kl'] for r in all_results) / n_total + results['train/mean_old_logp'] = f'{mean_old:.4f}' + results['train/logp_diff_mean'] = f'{mean_diff:+.4f}' + results['train/approx_kl'] = f'{mean_kl:.6f}' + + self.reset() + return results diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index d037f1cd..4892b616 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -199,6 +199,7 @@ async def sample(self, *, multi_modal_data: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, + disable_lora: bool = False, **kwargs) -> SampleResponse: """ Sample completions from the model. @@ -244,7 +245,9 @@ async def sample(self, 'False — LoRA will be ignored for this request') lora_request = None - if lora_request is None and self._synced_lora_request is not None: + if disable_lora: + lora_request = None + elif lora_request is None and self._synced_lora_request is not None: # RL training path: use the LoRA synced via CheckpointEngine. # The request object is cached after the first ``list_loras`` # check to avoid per-request RPC overhead. diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index cca376f3..c6353e49 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -25,7 +25,7 @@ import os import threading from typing import Any, Dict, List, Optional, Type, Union - +from copy import copy from twinkle import DeviceMesh, get_logger, remote_class, remote_function, requires from twinkle.checkpoint_engine import CheckpointEngineMixin from twinkle.data_format import InputFeature, SampledSequence, SampleResponse, SamplingParams, Trajectory @@ -216,6 +216,7 @@ async def _sample_single( *, multi_modal_data: Optional[Dict[str, Any]] = None, logprobs_only: bool = False, + disable_lora: bool = False, ) -> SampleResponse: """Sample a single input asynchronously. @@ -237,6 +238,7 @@ async def _sample_single( lora_request=lora_request, multi_modal_data=multi_modal_data, mm_processor_kwargs=feat.get('mm_processor_kwargs'), + disable_lora=disable_lora, ) if 'input_ids' not in feat or multi_modal_data: @@ -288,6 +290,7 @@ def sample( adapter_path: Optional[str] = None, *, return_encoded: bool = False, + use_base_model: bool = False, ) -> List[SampleResponse]: """Sample responses for given inputs. @@ -325,6 +328,7 @@ def sample( is_trajectory = 'input_ids' not in inputs_list[0] logprobs_only = False if sampling_params.max_tokens == 0: + sampling_params = copy(sampling_params) sampling_params.max_tokens = 1 logprobs_only = True assert not is_trajectory, 'Logprobs only not supported for Trajectory inputs' @@ -360,6 +364,7 @@ async def _sample_all(): lora_request=lora_request, multi_modal_data=multi_modal_data, logprobs_only=logprobs_only, + disable_lora=use_base_model, ) for feat, multi_modal_data in zip(encoded_inputs, multi_modal_data_list) ] return await asyncio.gather(*tasks) diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 324ce7ac..9f10dcf8 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -1,3 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import Template from .qwen3_5_vl import Qwen3_5Template +from .tool_call_parser import ( + QWEN_TOOL_CALL_PARSER, + QwenToolCallParser, + ToolCallParser, +) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 5784ddae..7d8451da 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -1,5 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import inspect +import json + import numpy as np import os from collections.abc import Mapping @@ -10,6 +12,7 @@ from twinkle.hub import HubOperation from twinkle.utils import load_image, to_device from .utils import TokenizeByRound, transfer_to_standard_message +from .. import remote_class if TYPE_CHECKING: import torch @@ -21,6 +24,7 @@ AudioInput = Union[str, np.ndarray, 'torch.Tensor'] +@remote_class() class Template: # Placeholder tokens in user text @@ -36,6 +40,7 @@ def __init__(self, default_system: Optional[str] = None, enable_thinking: bool = True, **kwargs): + self.model_id = model_id model_id = HubOperation.download_model(model_id, ignore_model=True) if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')): from transformers import AutoProcessor @@ -63,6 +68,29 @@ def __init__(self, self._roll_labels, # roll labels ] + def parse_tool_call(self, decoded: str) -> List[Dict[str, Any]]: + """Parse tool calls from the assistant's decoded output. + + Dispatches by model family on ``self.model_id``; the actual + wire-format logic lives in :mod:`.tool_call_parser`. + """ + mid = (self.model_id or '').lower() + if 'qwen' in mid: + from .qwen import QwenTemplate + return QwenTemplate.parse(self, decoded) + # TODO: Other models (Llama3, OpenAI JSON, …) — add a parser in + # ``tool_call_parser.py`` and extend this dispatch. + return [] + + def clean_tool_call(self, decoded: str) -> str: + """Strip family-specific tool-call markup from assistant text.""" + mid = (self.model_id or '').lower() + if 'qwen' in mid: + from .qwen import QwenTemplate + return QwenTemplate.clean(self, decoded) + # TODO: Other models + return (decoded or '').rstrip() + @property def tokenizer(self): tokenizer = self.processor @@ -458,7 +486,16 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo k: v for k, v in b.items() if v is not None } for b in msg['content'] if isinstance(b, dict)] - tools = [dict(tool) for tool in trajectory.get('tools', [])] + + tool_calls = msg.get('tool_calls') + if isinstance(tool_calls, list) and tool_calls: + msg['tool_calls'] = [ + Template._normalize_tool_call_for_template(tool_call) for tool_call in tool_calls + ] + tools = [ + Template._normalize_tool_for_template(tool) + for tool in trajectory.get('tools', []) + ] # Use inspect to get apply_chat_template signature params sig = inspect.signature(self.processor.apply_chat_template) @@ -511,6 +548,65 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo **kwargs) return inputs + @staticmethod + def _parse_arguments(args: Any) -> Any: + if isinstance(args, str): + try: + parsed = json.loads(args) + return parsed + except (TypeError, ValueError): + return {} + return args + + @staticmethod + def _normalize_tool_call_for_template(tc: Any) -> Any: + if not isinstance(tc, dict): + return tc + # Already OpenAI-nested: ensure arguments is a mapping. + if isinstance(tc.get('function'), dict) and 'name' in tc['function']: + fn = dict(tc['function']) + if 'arguments' in fn: + fn['arguments'] = Template._parse_arguments(fn['arguments']) + out = dict(tc) + out['function'] = fn + out.setdefault('type', 'function') + return out + # Already flat OpenAI (``name`` at top-level): just normalize arguments. + if 'name' in tc and 'tool_name' not in tc: + out = dict(tc) + if 'arguments' in out: + out['arguments'] = Template._parse_arguments(out['arguments']) + return out + # Twinkle shape: lift ``tool_name`` to ``function.name``. + name = tc.get('tool_name') + if not name: + return tc + return { + 'type': 'function', + 'function': { + 'name': name, + 'arguments': Template._parse_arguments(tc.get('arguments', {})), + }, + } + + @staticmethod + def _normalize_tool_for_template(tool: Any) -> Any: + if not isinstance(tool, dict): + return tool + if isinstance(tool.get('function'), dict) and 'name' in tool['function']: + return tool + if 'name' in tool and 'tool_name' not in tool: + return tool + name = tool.get('tool_name') + if not name: + return tool + fn: Dict[str, Any] = {'name': name} + if 'description' in tool: + fn['description'] = tool['description'] + if 'parameters' in tool: + fn['parameters'] = Template._parse_arguments(tool['parameters']) + return {'type': 'function', 'function': fn} + def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs) -> InputFeature: """Encode a single trajectory's messages into InputFeature.""" labels = None @@ -661,24 +757,28 @@ def batch_encode( # Process List[Trajectory] trajectories = self._invoke_pre_pipeline(trajectories) - - # Use thread pool for parallel encoding - from concurrent.futures import ThreadPoolExecutor - from functools import partial - encode_fn = partial( - self._encode_messages, - add_generation_prompt=add_generation_prompt, - **kwargs, - ) - with ThreadPoolExecutor() as executor: - output = list(executor.map(encode_fn, trajectories)) - + output = [ + self._encode_messages(t, add_generation_prompt=add_generation_prompt, **kwargs) + for t in trajectories + ] output = self._invoke_post_pipeline(output) if _transfer: output = self.map_row_to_col(output) return output + def format_trajectory(self, trajectory: Trajectory, + add_default_system: bool = False) -> Trajectory: + current = [trajectory] + for pipeline in self.pre_pipeline: + if not add_default_system and pipeline == self._add_default_system: + continue + next_batch = [] + for traj in current: + next_batch.extend(pipeline(traj)) + current = next_batch + return current[0] + def check(self, trajectory: Trajectory) -> Optional[Trajectory]: encoded = None try: diff --git a/src/twinkle/template/qwen.py b/src/twinkle/template/qwen.py new file mode 100644 index 00000000..852a5399 --- /dev/null +++ b/src/twinkle/template/qwen.py @@ -0,0 +1,81 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import json +import re +from typing import Any, Dict, List + +from twinkle import remote_class +from twinkle.template import Template + + +@remote_class() +class QwenTemplate(Template): + + _BLOCK_RE = re.compile( + r'\s*([\s\S]*?)\s*(?:|\Z)') + _FUNCTION_RE = re.compile(r']+)>([\s\S]*?)') + _PARAMETER_RE = re.compile( + r']+)>\s*([\s\S]*?)\s*') + _STRIP_RE = re.compile(r'[\s\S]*?(?:|\Z)') + + def parse(self, decoded: str) -> List[Dict[str, Any]]: + calls: List[Dict[str, Any]] = [] + for block_m in self._BLOCK_RE.finditer(decoded or ''): + block = block_m.group(1) + func_m = self._FUNCTION_RE.search(block) + if func_m: + args: Dict[str, Any] = {} + for pm in self._PARAMETER_RE.finditer(func_m.group(2)): + key = pm.group(1).strip() + val = pm.group(2).strip() + try: + args[key] = json.loads(val) + except (json.JSONDecodeError, ValueError): + args[key] = val + calls.append({ + 'tool_name': func_m.group(1).strip(), + 'arguments': args, + }) + continue + # JSON fallback: ``{"name": ..., "arguments": ...}`` inside the block. + try: + data = json.loads(block) + except json.JSONDecodeError: + continue + name = data.get('name') or data.get('tool_name', '') + if not name: + continue + args = data.get('arguments', {}) + if isinstance(args, str): + try: + args = json.loads(args) if args.strip() else {} + except json.JSONDecodeError: + args = {} + calls.append({ + 'tool_name': name, + 'arguments': args if isinstance(args, dict) else {}, + }) + return calls + + def clean(self, decoded: str) -> str: + return self._STRIP_RE.sub('', decoded or '').rstrip() + + def parse_tool_call(self, decoded: str) -> List[Dict[str, Any]]: + """Parse tool calls from the assistant's decoded output. + + Dispatches by model family on ``self.model_id``; the actual + wire-format logic lives in :mod:`.tool_call_parser`. + """ + mid = (self.model_id or '').lower() + if 'qwen' in mid: + return self.parse(decoded) + # TODO: Other models (Llama3, OpenAI JSON, …) — add a parser in + # ``tool_call_parser.py`` and extend this dispatch. + return [] + + def clean_tool_call(self, decoded: str) -> str: + """Strip family-specific tool-call markup from assistant text.""" + mid = (self.model_id or '').lower() + if 'qwen' in mid: + return self.clean(decoded) + # TODO: Other models + return (decoded or '').rstrip() diff --git a/src/twinkle/template/qwen3_5_vl.py b/src/twinkle/template/qwen3_5_vl.py index 22799bab..71ee202b 100644 --- a/src/twinkle/template/qwen3_5_vl.py +++ b/src/twinkle/template/qwen3_5_vl.py @@ -3,17 +3,37 @@ import torch from copy import copy from PIL import Image -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Callable from twinkle import remote_class, requires from twinkle.data_format import InputFeature -from twinkle.template import Template from twinkle.template.base import ImageInput, VideoInput +from twinkle.template.qwen import QwenTemplate from twinkle.template.utils import get_inputs_embeds_hf +_ROPE_INDEX_CACHE: Dict[str, Callable] = {} + + +def _build_rope_index_func(config) -> Callable: + arch = config.architectures[0] + fn = _ROPE_INDEX_CACHE.get(arch) + if fn is not None: + return fn + import transformers + with torch.device('meta'): + model_cls = getattr(transformers, arch) + dummy_model = model_cls(config) + for _, sub_module in dummy_model.named_modules(): + if hasattr(sub_module, 'get_rope_index'): + _ROPE_INDEX_CACHE[arch] = sub_module.get_rope_index + return sub_module.get_rope_index + raise NotImplementedError( + f'Module {dummy_model.__class__.__name__} has no get_rope_index method!') + + @remote_class() -class Qwen3_5Template(Template): +class Qwen3_5Template(QwenTemplate): """ Processor for Qwen VL series. @@ -26,18 +46,16 @@ def __init__(self, *args, **kwargs): self._patch_size: Optional[int] = None self._merge_size: Optional[int] = None self._init_vision_config() - with torch.device('meta'): - import transformers - model_cls = self.config.architectures[0] - model_cls = getattr(transformers, model_cls) - self.dummy_model = model_cls(self.config) - self.rope_index_func = self.get_rope_index() - - def get_rope_index(self): - for _, sub_module in self.dummy_model.named_modules(): - if hasattr(sub_module, 'get_rope_index'): - return sub_module.get_rope_index - raise NotImplementedError(f'Module {self.dummy_model.__class__.__name__} has no get_rope_index method!') + + @property + def rope_index_func(self) -> Callable: + """Lazily resolve the rope-index function via a module-level cache. + + Kept off ``self`` so the template's ``__dict__`` stays free of + ``nn.Module`` state, which in turn keeps ``dill.dumps(template)`` + deterministic for HF datasets fingerprinting. + """ + return _build_rope_index_func(self.config) def _init_vision_config(self): """Initialize vision config from processor.""" diff --git a/src/twinkle_agentic/__init__.py b/src/twinkle_agentic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle_agentic/data_format/__init__.py b/src/twinkle_agentic/data_format/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle_agentic/data_format/chunk.py b/src/twinkle_agentic/data_format/chunk.py new file mode 100644 index 00000000..e51a7c8a --- /dev/null +++ b/src/twinkle_agentic/data_format/chunk.py @@ -0,0 +1,11 @@ +import sys +from dataclasses import dataclass +from itertools import groupby +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +if sys.version_info[:2] <= (3, 11): + # Pydantic requirements. + from typing_extensions import TypedDict +else: + from typing import TypedDict + From 99394a2985e5de0f8ca3446b76095a7dc30c8491 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Sat, 9 May 2026 13:12:15 +0800 Subject: [PATCH 002/104] wip --- src/twinkle/template/base.py | 2 +- src/twinkle_agentic/chunker/__init__.py | 0 src/twinkle_agentic/chunker/base.py | 10 + src/twinkle_agentic/condenser/__init__.py | 0 src/twinkle_agentic/data_format/chunk.py | 11 - src/twinkle_agentic/data_format/chunks.py | 105 ++++++++++ src/twinkle_agentic/reward/__init__.py | 0 src/twinkle_agentic/reward/f1.py | 191 ++++++++++++++++++ src/twinkle_agentic/rollout/__init__.py | 0 src/twinkle_agentic/tools/__init__.py | 0 src/twinkle_agentic/tools/base.py | 16 ++ .../tools/extract_condensed.py | 7 + src/twinkle_agentic/tools/tool_manager.py | 60 ++++++ 13 files changed, 390 insertions(+), 12 deletions(-) create mode 100644 src/twinkle_agentic/chunker/__init__.py create mode 100644 src/twinkle_agentic/chunker/base.py create mode 100644 src/twinkle_agentic/condenser/__init__.py delete mode 100644 src/twinkle_agentic/data_format/chunk.py create mode 100644 src/twinkle_agentic/data_format/chunks.py create mode 100644 src/twinkle_agentic/reward/__init__.py create mode 100644 src/twinkle_agentic/reward/f1.py create mode 100644 src/twinkle_agentic/rollout/__init__.py create mode 100644 src/twinkle_agentic/tools/__init__.py create mode 100644 src/twinkle_agentic/tools/base.py create mode 100644 src/twinkle_agentic/tools/extract_condensed.py create mode 100644 src/twinkle_agentic/tools/tool_manager.py diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 7d8451da..66368a32 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -12,7 +12,7 @@ from twinkle.hub import HubOperation from twinkle.utils import load_image, to_device from .utils import TokenizeByRound, transfer_to_standard_message -from .. import remote_class +from twinkle import remote_class if TYPE_CHECKING: import torch diff --git a/src/twinkle_agentic/chunker/__init__.py b/src/twinkle_agentic/chunker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle_agentic/chunker/base.py b/src/twinkle_agentic/chunker/base.py new file mode 100644 index 00000000..a506c75b --- /dev/null +++ b/src/twinkle_agentic/chunker/base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + +from twinkle.data_format import Trajectory + + +class Chunker(ABC): + + @abstractmethod + def __call__(self, trajectory: Trajectory) -> Chunks: + raise NotImplementedError \ No newline at end of file diff --git a/src/twinkle_agentic/condenser/__init__.py b/src/twinkle_agentic/condenser/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle_agentic/data_format/chunk.py b/src/twinkle_agentic/data_format/chunk.py deleted file mode 100644 index e51a7c8a..00000000 --- a/src/twinkle_agentic/data_format/chunk.py +++ /dev/null @@ -1,11 +0,0 @@ -import sys -from dataclasses import dataclass -from itertools import groupby -from typing import Any, Dict, List, Literal, Optional, Tuple, Union - -if sys.version_info[:2] <= (3, 11): - # Pydantic requirements. - from typing_extensions import TypedDict -else: - from typing import TypedDict - diff --git a/src/twinkle_agentic/data_format/chunks.py b/src/twinkle_agentic/data_format/chunks.py new file mode 100644 index 00000000..04e78d7f --- /dev/null +++ b/src/twinkle_agentic/data_format/chunks.py @@ -0,0 +1,105 @@ +import sys +from dataclasses import dataclass +from itertools import groupby +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +if sys.version_info[:2] <= (3, 11): + # Pydantic requirements. + from typing_extensions import TypedDict +else: + from typing import TypedDict + +_MULTIMODAL_TYPES = ('image', 'video', 'audio') +_MEDIA_BUCKETS = (('images', 'image'), ('videos', 'video'), ('audios', 'audio')) + + +class Chunk(TypedDict, total=False): + + type: Literal['text', 'image', 'video', 'audio'] + content: Union[str, Any] + raw: Union[str, Any] + role: str + + +@dataclass +class Chunks: + + chunks: List[Chunk] + + def to_trajectory( + self, + block_wrapper: Optional[Tuple[str, str]] = ('', ''), + ) -> Dict[str, Any]: + media: Dict[str, List[Any]] = {t: [] for t in _MULTIMODAL_TYPES} + bound: List[Chunk] = [] + wrap_counter = 0 + for c in self.chunks: + if c.get('type') in _MULTIMODAL_TYPES and not isinstance(c.get('raw'), dict): + media[c['type']].append(c.get('content')) + continue + if block_wrapper and c.get('type') == 'text': + raw = c.get('raw') + is_condensed = isinstance(raw, dict) and raw.get('condensed') + content = c.get('content') + if (is_condensed and isinstance(content, str) and content + and c.get('role') != 'tool'): + wrap_counter += 1 + prefix = block_wrapper[0].format(n=wrap_counter) + suffix = block_wrapper[1].format(n=wrap_counter) + c = {**c, 'content': f'{prefix}{content}{suffix}'} + bound.append(c) + + # Merge consecutive same-role chunks into one message via groupby. + messages = [ + self._group_to_message(role, list(grp)) + for role, grp in groupby(bound, key=lambda c: c.get('role') or 'user') + ] + + trajectory: Dict[str, Any] = {'messages': messages} + for plural, singular in _MEDIA_BUCKETS: + if media[singular]: + trajectory[plural] = media[singular] + return trajectory + + @staticmethod + def _group_to_message(role: str, group: List[Chunk]) -> Dict[str, Any]: + """Fold a same-role run of chunks into one :class:`Message`. + + Preserves the intra-group order so mixed text / image / video / audio + parts round-trip back into OpenAI-style structured ``content``. + """ + reasoning: List[str] = [] + parts: List[Dict[str, Any]] = [] + tool_calls: List[Dict[str, Any]] = [] + tool_call_id: Optional[str] = None + has_media = False + + for c in group: + t, raw, content = c.get('type'), c.get('raw'), c.get('content') + kind = raw.get('kind') if isinstance(raw, dict) else None + # Any chunk in the group may carry the shared ``tool_call_id``. + if isinstance(raw, dict) and raw.get('tool_call_id') and tool_call_id is None: + tool_call_id = raw['tool_call_id'] + + if t == 'text' and kind == 'reasoning_content' and content: + reasoning.append(content) + elif t == 'text' and kind == 'tool_call' and isinstance(raw.get('tool_call'), dict): + tool_calls.append(dict(raw['tool_call'])) + elif t == 'text' and content: + parts.append({'type': 'text', 'text': content}) + elif t in _MULTIMODAL_TYPES and isinstance(raw, dict): + has_media = True + # Drop condenser-only markers, keep the original part shape. + parts.append({k: v for k, v in raw.items() if k != 'condensed'} + or {'type': t, t: content}) + + msg: Dict[str, Any] = {'role': role} + if reasoning: + msg['reasoning_content'] = '\n\n'.join(reasoning) + if parts: + msg['content'] = parts if has_media else '\n\n'.join(p['text'] for p in parts) + if tool_calls: + msg['tool_calls'] = tool_calls + if tool_call_id is not None: + msg['tool_call_id'] = tool_call_id + return msg diff --git a/src/twinkle_agentic/reward/__init__.py b/src/twinkle_agentic/reward/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle_agentic/reward/f1.py b/src/twinkle_agentic/reward/f1.py new file mode 100644 index 00000000..ed30e82f --- /dev/null +++ b/src/twinkle_agentic/reward/f1.py @@ -0,0 +1,191 @@ +import re +import string +from typing import List, Dict, Any, Tuple +from collections import Counter + +from twinkle.reward import Reward + +_BOXED_MARKER = '\\boxed{' + + +def _extract_final_answer(completion: str) -> str: + if not completion: + return '' + out = '' + idx = 0 + while True: + i = completion.find(_BOXED_MARKER, idx) + if i == -1: + break + j = i + len(_BOXED_MARKER) + depth = 1 + while j < len(completion) and depth > 0: + c = completion[j] + if c == '{': + depth += 1 + elif c == '}': + depth -= 1 + j += 1 + if depth == 0: + out = completion[i + len(_BOXED_MARKER): j - 1].strip() + idx = j + else: + # Unbalanced trailing marker — stop, keep last good match. + break + return out + + +def _last_assistant_text(traj: Dict[str, Any]) -> str: + for msg in reversed(traj.get('messages', [])): + if msg.get('role') != 'assistant': + continue + content = msg.get('content') or '' + if isinstance(content, str): + return content + return '\n'.join( + p.get('text', '') for p in content + if isinstance(p, dict) and p.get('type') == 'text') + return '' + + +def _stem(tok: str) -> str: + from nltk.stem import PorterStemmer + return PorterStemmer().stem(tok) if len(tok) >= 4 and tok.isalpha() else tok + + +def _normalize_answer(s: str) -> str: + s = (s or '').lower() + s = ''.join(ch for ch in s if ch not in set(string.punctuation)) + s = re.sub(r'\b(a|an|the)\b', ' ', s) + return ' '.join(_stem(t) for t in s.split()) + + +def _f1_score(prediction: str, gold: str) -> Tuple[float, float]: + filler_tokens: frozenset = frozenset([ + 'long', 'tall', 'high', 'wide', 'deep', 'heavy', 'old', 'large', + 'small', 'big', 'short', 'away', 'ago', 'approximately', 'about', + 'around', 'over', 'under', 'below', 'above', 'total', 'roughly', + 'nearly', 'almost', 'exactly', + ]) + pred_tokens = _normalize_answer(prediction).split() + gold_tokens = _normalize_answer(gold).split() + if not pred_tokens or not gold_tokens: + em = float(pred_tokens == gold_tokens) + return em, em + em = float(pred_tokens == gold_tokens) + common = Counter(pred_tokens) & Counter(gold_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0.0, em + p = num_same / len(pred_tokens) + r = num_same / len(gold_tokens) + f1 = 2 * p * r / (p + r) + + pred_set, gold_set = set(pred_tokens), set(gold_tokens) + if gold_set < pred_set: + extras = pred_set - gold_set + if all(t.isdigit() or t in filler_tokens for t in extras): + return 1.0, em + if pred_set < gold_set: + missing = gold_set - pred_set + if all(t in filler_tokens for t in missing): + return 1.0, em + return f1, em + + +class HotpotQAF1Reward(Reward): + + def __init__(self, answer_pattern=None): + if isinstance(answer_pattern, str): + answer_pattern = re.compile(answer_pattern) + self._answer_pattern = answer_pattern + + def _extract(self, completion: str) -> str: + balanced = _extract_final_answer(completion) + if balanced: + return balanced + if self._answer_pattern is None: + return '' + matches = self._answer_pattern.findall(completion or '') + if not matches: + return '' + last = matches[-1] + if isinstance(last, tuple): + last = last[0] if last else '' + return (last or '').strip() + + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: + rewards = [] + for traj in trajectories: + gold = '' + for key, val in traj.get('user_data', []) or []: + if key == 'ground_truth': + gold = val or '' + break + pred = self._extract(_last_assistant_text(traj)) + f1, _ = _f1_score(pred, gold) + rewards.append(f1) + return rewards + + +class HotpotQACoTReward(Reward): + _STEP_LINE_RE = re.compile(r'(?im)^\s*step\s*(\d+)\s*[.:]') + _HAS_BOXED_RE = re.compile(r'\\boxed\{') + + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: + rewards: List[float] = [] + for t in trajectories: + msgs = t.get('messages', []) or [] + + # Newline-joined so ``^`` line anchors work even when + # multiple assistant turns exist. + assistant_text = '\n'.join( + m.get('content', '') or '' + for m in msgs + if m.get('role') == 'assistant' and isinstance(m.get('content'), str) + ) + + if not self._HAS_BOXED_RE.search(assistant_text): + rewards.append(0.0) + continue + + steps: set = set() + for match in self._STEP_LINE_RE.finditer(assistant_text): + try: + steps.add(int(match.group(1))) + except ValueError: + continue + + n = len(steps) + # 0 → 0.0, 1 → 0.25, 2 → 0.5, 3 → 0.75, 4+ → 1.0 + rewards.append(min(1.0, n * 0.25)) + + return rewards + + +class HotpotQAToolExploreReward(Reward): + + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: + rewards: List[float] = [] + for t in trajectories: + msgs = t.get('messages', []) or [] + n_msgs = len(msgs) + success = False + for i, m in enumerate(msgs): + if m.get('role') != 'assistant' or not m.get('tool_calls'): + continue + # Scan subsequent consecutive ``tool`` messages and keep + # the first non-ERROR one. + j = i + 1 + while j < n_msgs and msgs[j].get('role') == 'tool': + content = msgs[j].get('content') or '' + text = content if isinstance(content, str) else str(content) + if text.strip() and not text.lstrip().startswith('ERROR'): + success = True + break + j += 1 + if success: + break + rewards.append(1.0 if success else 0.0) + return rewards + diff --git a/src/twinkle_agentic/rollout/__init__.py b/src/twinkle_agentic/rollout/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle_agentic/tools/__init__.py b/src/twinkle_agentic/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle_agentic/tools/base.py b/src/twinkle_agentic/tools/base.py new file mode 100644 index 00000000..aa9a151e --- /dev/null +++ b/src/twinkle_agentic/tools/base.py @@ -0,0 +1,16 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from abc import ABC, abstractmethod +from typing import Any, Dict + +from twinkle.data_format.message import Tool as ToolInfo + + +class Tool(ABC): + + @abstractmethod + def __call__(self, tool_name: str, arguments: Dict[str, Any]) -> str: + raise NotImplementedError + + @abstractmethod + def tool_info(self) -> ToolInfo: + raise NotImplementedError diff --git a/src/twinkle_agentic/tools/extract_condensed.py b/src/twinkle_agentic/tools/extract_condensed.py new file mode 100644 index 00000000..f9505ea3 --- /dev/null +++ b/src/twinkle_agentic/tools/extract_condensed.py @@ -0,0 +1,7 @@ +from .base import Tool + + +class ExtractCondensed(Tool): + + # Extract the condensed block + pass \ No newline at end of file diff --git a/src/twinkle_agentic/tools/tool_manager.py b/src/twinkle_agentic/tools/tool_manager.py new file mode 100644 index 00000000..4996569c --- /dev/null +++ b/src/twinkle_agentic/tools/tool_manager.py @@ -0,0 +1,60 @@ +import json +from typing import List, Optional, Dict, Union, Any + +from fastmcp.utilities.inspect import ToolInfo + +from twinkle.data_format import ToolCall +from twinkle_agentic.tools.base import Tool + + +class ToolManager: + + def __init__(self, tools: Dict[str, Tool]): + self._tools = tools + + def register(self, tool: Tool): + info = tool.tool_info() + name = info.get('tool_name') if isinstance(info, dict) else None + if not name: + raise ValueError( + f'tool {type(tool).__name__} must expose a non-empty ' + f'tool_info()["tool_name"]') + self._tools[name] = tool + + def unregister(self, name: str) -> Optional[Tool]: + return self._tools.pop(name, None) + + def names(self) -> List[str]: + return list(self._tools) + + def tool_infos(self) -> List[ToolInfo]: + return [t.tool_info() for t in self._tools.values()] + + def __call__(self, tool_call: Union[ToolCall, Dict[str, Any]]) -> str: + if not isinstance(tool_call, dict): + return f'Error: tool_call must be an object, got {type(tool_call).__name__}.' + name = tool_call.get('tool_name') + if not name: + return 'Error: tool_call missing "tool_name".' + if (tool := self._tools.get(name)) is None: + available = ', '.join(sorted(self._tools)) or '(none)' + return f'Error: unknown tool {name!r}. Available: {available}.' + + raw_args = tool_call.get('arguments') + if raw_args is None: + args: Dict[str, Any] = {} + elif isinstance(raw_args, str): + try: + args = json.loads(raw_args) if raw_args.strip() else {} + except json.JSONDecodeError as e: + return f'Error: invalid JSON in arguments: {e}' + elif isinstance(raw_args, dict): + args = raw_args + else: + return (f'Error: "arguments" must be a JSON string or object, ' + f'got {type(raw_args).__name__}.') + + try: + return str(tool(name, args)) + except Exception as e: # noqa + return f'Error: tool {name!r} raised {type(e).__name__}: {e}' \ No newline at end of file From 27cd090aaae4a35256e18df68a548cf838b54cae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Sat, 9 May 2026 13:19:08 +0800 Subject: [PATCH 003/104] wip --- src/twinkle_agentic/chunker/base.py | 1 + src/twinkle_agentic/chunker/native.py | 22 +++++++++++++++++++++ src/twinkle_agentic/condenser/base.py | 10 ++++++++++ src/twinkle_agentic/condenser/keyword.py | 11 +++++++++++ src/twinkle_agentic/condenser/model.py | 11 +++++++++++ src/twinkle_agentic/data_format/__init__.py | 1 + src/twinkle_agentic/reward/__init__.py | 1 + src/twinkle_agentic/reward/f1.py | 6 +++--- src/twinkle_agentic/rollout/base.py | 10 ++++++++++ src/twinkle_agentic/rollout/multi_turn.py | 7 +++++++ 10 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 src/twinkle_agentic/chunker/native.py create mode 100644 src/twinkle_agentic/condenser/base.py create mode 100644 src/twinkle_agentic/condenser/keyword.py create mode 100644 src/twinkle_agentic/condenser/model.py create mode 100644 src/twinkle_agentic/rollout/base.py create mode 100644 src/twinkle_agentic/rollout/multi_turn.py diff --git a/src/twinkle_agentic/chunker/base.py b/src/twinkle_agentic/chunker/base.py index a506c75b..e446fc35 100644 --- a/src/twinkle_agentic/chunker/base.py +++ b/src/twinkle_agentic/chunker/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from twinkle.data_format import Trajectory +from twinkle_agentic.data_format import Chunks class Chunker(ABC): diff --git a/src/twinkle_agentic/chunker/native.py b/src/twinkle_agentic/chunker/native.py new file mode 100644 index 00000000..b9a44031 --- /dev/null +++ b/src/twinkle_agentic/chunker/native.py @@ -0,0 +1,22 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Rule-based trajectory chunker: splits Trajectory into Chunks.""" +from __future__ import annotations + +import json +import re +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional + +from twinkle.data_format import Trajectory +from twinkle.data_format.message import Message, ToolCall +from twinkle.template import Template +from twinkle_agentic.data_format import Chunks +from twinkle_agentic.data_format import Chunk + +from .base import Chunker + + + +class NativeChunker(Chunker): + + def __call__(self, trajectory: Trajectory) -> Chunks: + pass \ No newline at end of file diff --git a/src/twinkle_agentic/condenser/base.py b/src/twinkle_agentic/condenser/base.py new file mode 100644 index 00000000..f69fc518 --- /dev/null +++ b/src/twinkle_agentic/condenser/base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + +from twinkle_agentic.data_format import Chunks + + +class Condenser(ABC): + + @abstractmethod + def __call__(self, chunks: Chunks, **kwargs) -> Chunks: + raise NotImplementedError \ No newline at end of file diff --git a/src/twinkle_agentic/condenser/keyword.py b/src/twinkle_agentic/condenser/keyword.py new file mode 100644 index 00000000..c4b1e14c --- /dev/null +++ b/src/twinkle_agentic/condenser/keyword.py @@ -0,0 +1,11 @@ +from abc import abstractmethod + +from twinkle_agentic.condenser.base import Condenser +from twinkle_agentic.data_format import Chunks + + +class KeywordCondenser(Condenser): + + @abstractmethod + def __call__(self, chunks: Chunks, **kwargs) -> Chunks: + pass \ No newline at end of file diff --git a/src/twinkle_agentic/condenser/model.py b/src/twinkle_agentic/condenser/model.py new file mode 100644 index 00000000..b70371ed --- /dev/null +++ b/src/twinkle_agentic/condenser/model.py @@ -0,0 +1,11 @@ +from abc import abstractmethod + +from twinkle_agentic.condenser.base import Condenser +from twinkle_agentic.data_format import Chunks + + +class ModelCondenser(Condenser): + + @abstractmethod + def __call__(self, chunks: Chunks, **kwargs) -> Chunks: + pass \ No newline at end of file diff --git a/src/twinkle_agentic/data_format/__init__.py b/src/twinkle_agentic/data_format/__init__.py index e69de29b..9cf61751 100644 --- a/src/twinkle_agentic/data_format/__init__.py +++ b/src/twinkle_agentic/data_format/__init__.py @@ -0,0 +1 @@ +from .chunks import Chunks, Chunk diff --git a/src/twinkle_agentic/reward/__init__.py b/src/twinkle_agentic/reward/__init__.py index e69de29b..6d979d74 100644 --- a/src/twinkle_agentic/reward/__init__.py +++ b/src/twinkle_agentic/reward/__init__.py @@ -0,0 +1 @@ +from .f1 import F1Reward, CoTReward, ToolExploreReward diff --git a/src/twinkle_agentic/reward/f1.py b/src/twinkle_agentic/reward/f1.py index ed30e82f..a9faf081 100644 --- a/src/twinkle_agentic/reward/f1.py +++ b/src/twinkle_agentic/reward/f1.py @@ -93,7 +93,7 @@ def _f1_score(prediction: str, gold: str) -> Tuple[float, float]: return f1, em -class HotpotQAF1Reward(Reward): +class F1Reward(Reward): def __init__(self, answer_pattern=None): if isinstance(answer_pattern, str): @@ -128,7 +128,7 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: return rewards -class HotpotQACoTReward(Reward): +class CoTReward(Reward): _STEP_LINE_RE = re.compile(r'(?im)^\s*step\s*(\d+)\s*[.:]') _HAS_BOXED_RE = re.compile(r'\\boxed\{') @@ -163,7 +163,7 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: return rewards -class HotpotQAToolExploreReward(Reward): +class ToolExploreReward(Reward): def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: rewards: List[float] = [] diff --git a/src/twinkle_agentic/rollout/base.py b/src/twinkle_agentic/rollout/base.py new file mode 100644 index 00000000..5b078001 --- /dev/null +++ b/src/twinkle_agentic/rollout/base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + +from twinkle.data_format import Trajectory + + +class Rollout(ABC): + + @abstractmethod + def __call__(self, trajectory: Trajectory, **kwargs) -> Trajectory: + raise NotImplementedError() diff --git a/src/twinkle_agentic/rollout/multi_turn.py b/src/twinkle_agentic/rollout/multi_turn.py new file mode 100644 index 00000000..f88a6d9a --- /dev/null +++ b/src/twinkle_agentic/rollout/multi_turn.py @@ -0,0 +1,7 @@ +from twinkle.data_format import Trajectory +from .base import Rollout + +class MultiTurnRollout(Rollout): + + def __call__(self, trajectory: Trajectory, **kwargs) -> Trajectory: + \ No newline at end of file From 9e31c0717be33d370999b0bb46ec68932ac1db56 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 9 May 2026 20:31:11 +0800 Subject: [PATCH 004/104] fix --- src/twinkle/template/__init__.py | 6 +- src/twinkle_agentic/chunker/__init__.py | 4 + src/twinkle_agentic/chunker/native.py | 265 +++++- src/twinkle_agentic/condenser/__init__.py | 5 + src/twinkle_agentic/condenser/keyword.py | 514 ++++++++++- src/twinkle_agentic/condenser/model.py | 347 +++++++- src/twinkle_agentic/data_format/chunks.py | 1 + src/twinkle_agentic/rollout/base.py | 3 +- src/twinkle_agentic/rollout/multi_turn.py | 466 +++++++++- .../rollout/multi_turn_condense.py | 112 +++ .../tools/extract_condensed.py | 155 +++- src/twinkle_agentic/tools/tool_manager.py | 35 +- .../twinkle_agentic/test_extract_condensed.py | 433 +++++++++ .../twinkle_agentic/test_keyword_condenser.py | 488 +++++++++++ tests/twinkle_agentic/test_model_condenser.py | 559 ++++++++++++ .../test_multi_turn_rollout.py | 826 ++++++++++++++++++ tests/twinkle_agentic/test_native_chunker.py | 432 +++++++++ 17 files changed, 4619 insertions(+), 32 deletions(-) create mode 100644 src/twinkle_agentic/rollout/multi_turn_condense.py create mode 100644 tests/twinkle_agentic/test_extract_condensed.py create mode 100644 tests/twinkle_agentic/test_keyword_condenser.py create mode 100644 tests/twinkle_agentic/test_model_condenser.py create mode 100644 tests/twinkle_agentic/test_multi_turn_rollout.py create mode 100644 tests/twinkle_agentic/test_native_chunker.py diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 9f10dcf8..b3bfb448 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -1,8 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import Template +from .qwen import QwenTemplate from .qwen3_5_vl import Qwen3_5Template -from .tool_call_parser import ( - QWEN_TOOL_CALL_PARSER, - QwenToolCallParser, - ToolCallParser, -) diff --git a/src/twinkle_agentic/chunker/__init__.py b/src/twinkle_agentic/chunker/__init__.py index e69de29b..f826a645 100644 --- a/src/twinkle_agentic/chunker/__init__.py +++ b/src/twinkle_agentic/chunker/__init__.py @@ -0,0 +1,4 @@ +from .base import Chunker +from .native import NativeChunker + +__all__ = ['Chunker', 'NativeChunker'] diff --git a/src/twinkle_agentic/chunker/native.py b/src/twinkle_agentic/chunker/native.py index b9a44031..ad059987 100644 --- a/src/twinkle_agentic/chunker/native.py +++ b/src/twinkle_agentic/chunker/native.py @@ -1,22 +1,271 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Rule-based trajectory chunker: splits Trajectory into Chunks.""" +"""Rule-based trajectory chunker. + +Only the *first* ``user`` message is split into multiple text chunks +(capped at ``chunk_size`` characters, using a recursive separator list +that prefers paragraphs > lines > sentences > words). Every other +message is decomposed part-by-part *without* further splitting, so the +resulting :class:`Chunks` round-trips back to the original trajectory +via :meth:`Chunks.to_trajectory` (for non-split messages). + +The chunker never marks chunks as ``condensed`` — that is the +condenser's job. Consequently :meth:`Chunks.to_trajectory` will not +wrap any chunk with ``...`` when called directly on +a chunker output. +""" from __future__ import annotations -import json import re -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence from twinkle.data_format import Trajectory -from twinkle.data_format.message import Message, ToolCall -from twinkle.template import Template -from twinkle_agentic.data_format import Chunks -from twinkle_agentic.data_format import Chunk +from twinkle_agentic.data_format import Chunk, Chunks from .base import Chunker +# Recursive separator list, coarsest → finest. The empty string at the +# end forces a hard character cut when nothing finer fits. +_DEFAULT_SEPARATORS: tuple = ( + '\n\n', '\n', + '。', '.', '.', + '!', '!', + '?', '?', + ';', ';', + ',', ',', + ' ', + '', +) + +_MULTIMODAL_TYPES = ('image', 'video', 'audio') + +_SplitFn = Optional[Callable[[str], List[str]]] + class NativeChunker(Chunker): + """Character-level recursive chunker for trajectories. + + Args: + chunk_size: Soft upper bound (in characters) for every emitted + text chunk. Must be positive. + separators: Ordered separator list. The chunker tries each + separator in turn; any piece still larger than + ``chunk_size`` is re-split with the next one. A terminal + ``''`` (hard character cut) is appended automatically if + missing so the algorithm is guaranteed to terminate. + passage_boundary_re: Optional regex (compiled with + ``re.MULTILINE``) whose matches act as **hard, non-mergeable** + passage boundaries on the first user message. The regex + match is preserved at the start of the next piece (so + ``''.join(pieces) == text``). Pieces that are already + ``<= chunk_size`` are emitted as-is and are **never merged** + across boundaries; only pieces that still exceed + ``chunk_size`` fall back to the normal recursive split + merge. + This is how you keep e.g. HotpotQA passages atomic per + ````. + """ + def __init__( + self, + chunk_size: int = 1024, + separators: Optional[Sequence[str]] = None, + passage_boundary_re: Optional[str] = None, + ): + if chunk_size <= 0: + raise ValueError(f'chunk_size must be positive, got {chunk_size}') + self.chunk_size = chunk_size + seps = tuple(separators) if separators is not None else _DEFAULT_SEPARATORS + if '' not in seps: + seps += ('',) + self.separators = seps + self.passage_boundary_re: Optional[re.Pattern] = ( + re.compile(passage_boundary_re, re.MULTILINE) + if passage_boundary_re else None) + + # ------------------------------------------------------------------ + # public entry + # ------------------------------------------------------------------ def __call__(self, trajectory: Trajectory) -> Chunks: - pass \ No newline at end of file + chunks: List[Chunk] = [] + first_user_done = False + # ``round`` is 1-indexed at the first user message. Any messages + # emitted before that (e.g., leading ``system``) carry round 0. + round_idx = 0 + for msg in trajectory.get('messages') or []: + is_user = msg.get('role') == 'user' + if is_user: + round_idx += 1 + split = (self._split_text + if is_user and not first_user_done else None) + if is_user: + first_user_done = True + for chunk in self._parts(msg, split): + chunk['round'] = round_idx + chunks.append(chunk) + return Chunks(chunks=chunks) + + # ------------------------------------------------------------------ + # message → chunks decomposition + # ------------------------------------------------------------------ + def _parts(self, message: Dict[str, Any], split: _SplitFn) -> Iterator[Chunk]: + role = message.get('role') or 'user' + tcid = message.get('tool_call_id') + + rc = message.get('reasoning_content') + if rc: + yield _text_chunk(role, rc, kind='reasoning_content', tool_call_id=tcid) + + content = message.get('content') + if isinstance(content, str): + yield from self._emit_text(role, content, split, tcid) + elif isinstance(content, list): + for part in content: + if not isinstance(part, dict): + continue + ptype = part.get('type') + if ptype == 'text': + yield from self._emit_text( + role, part.get('text') or '', split, tcid) + elif ptype in _MULTIMODAL_TYPES: + # Keep raw part so Chunks.to_trajectory can rebuild + # the original OpenAI-style entry verbatim. + yield { # type: ignore[misc] + 'type': ptype, 'content': part.get(ptype), + 'raw': dict(part), 'role': role, + } + + for tc in message.get('tool_calls') or []: + yield _text_chunk(role, '', kind='tool_call', tool_call=tc, + tool_call_id=tcid) + + def _emit_text(self, role: str, text: str, split: _SplitFn, + tool_call_id: Optional[str]) -> Iterator[Chunk]: + if not text: + return + pieces = split(text) if split is not None else [text] + for piece in pieces: + if piece: + yield _text_chunk(role, piece, tool_call_id=tool_call_id) + + # ------------------------------------------------------------------ + # recursive text splitter + # ------------------------------------------------------------------ + def _split_text(self, text: str) -> List[str]: + if not text: + return [] + if self.passage_boundary_re is None: + if len(text) <= self.chunk_size: + return [text] + return self._merge(self._recursive_split(text, list(self.separators))) + # Force-split first; each forced piece is kept intact when it is + # already short enough, and is recursively re-split (but NOT + # merged with sibling passages) when it exceeds ``chunk_size``. + out: List[str] = [] + for piece in self._force_split(text): + if not piece: + continue + if len(piece) <= self.chunk_size: + out.append(piece) + else: + out.extend(self._merge( + self._recursive_split(piece, list(self.separators)))) + return out + + def _force_split(self, text: str) -> List[str]: + """Split ``text`` at every ``passage_boundary_re`` match; the + match itself sticks to the start of the **next** piece, so + ``''.join(_force_split(text)) == text``. + """ + assert self.passage_boundary_re is not None + matches = list(self.passage_boundary_re.finditer(text)) + if not matches: + return [text] + out: List[str] = [] + prev = 0 + for m in matches: + start = m.start() + if start > prev: + out.append(text[prev:start]) + prev = start + if prev < len(text): + out.append(text[prev:]) + return out + + def _recursive_split(self, text: str, separators: List[str]) -> List[str]: + if len(text) <= self.chunk_size: + return [text] if text else [] + # Terminal: no more separators, or next one is the hard-cut sentinel. + if not separators or separators[0] == '': + return _hard_cut(text, self.chunk_size) + + sep, *rest = separators + out: List[str] = [] + for piece in _split_keep(text, sep): + if not piece: + continue + if len(piece) <= self.chunk_size: + out.append(piece) + else: + out.extend(self._recursive_split(piece, rest)) + return out + + def _merge(self, pieces: List[str]) -> List[str]: + """Greedy concatenation: small fragments fuse up to ``chunk_size`` + without exceeding it. Relative order is preserved. + """ + merged: List[str] = [] + buf = '' + for p in pieces: + if not p: + continue + if buf and len(buf) + len(p) > self.chunk_size: + merged.append(buf) + buf = '' + buf += p + if buf: + merged.append(buf) + return merged + + +# ---------------------------------------------------------------------- +# helpers +# ---------------------------------------------------------------------- +def _split_keep(text: str, sep: str) -> List[str]: + """``str.split(sep)`` but the separator stays glued to the end of + each left-hand piece, so ``''.join(result) == text``. + """ + if not sep or sep not in text: + return [text] if text else [] + out: List[str] = [] + start, n = 0, len(sep) + while (i := text.find(sep, start)) != -1: + out.append(text[start:i + n]) + start = i + n + if start < len(text): + out.append(text[start:]) + return out + + +def _hard_cut(text: str, size: int) -> List[str]: + return [text[i:i + size] for i in range(0, len(text), size)] if text else [] + + +def _text_chunk( + role: str, + content: str, + *, + kind: Optional[str] = None, + tool_call: Any = None, + tool_call_id: Optional[str] = None, +) -> Chunk: + raw: Dict[str, Any] = {} + if kind is not None: + raw['kind'] = kind + if tool_call is not None: + raw['tool_call'] = tool_call + if tool_call_id is not None: + raw['tool_call_id'] = tool_call_id + chunk: Chunk = {'type': 'text', 'content': content, 'role': role} # type: ignore[assignment] + if raw: + chunk['raw'] = raw + return chunk diff --git a/src/twinkle_agentic/condenser/__init__.py b/src/twinkle_agentic/condenser/__init__.py index e69de29b..e7854500 100644 --- a/src/twinkle_agentic/condenser/__init__.py +++ b/src/twinkle_agentic/condenser/__init__.py @@ -0,0 +1,5 @@ +from .base import Condenser +from .keyword import KeywordCondenser +from .model import ModelCondenser + +__all__ = ['Condenser', 'KeywordCondenser', 'ModelCondenser'] diff --git a/src/twinkle_agentic/condenser/keyword.py b/src/twinkle_agentic/condenser/keyword.py index c4b1e14c..14d49631 100644 --- a/src/twinkle_agentic/condenser/keyword.py +++ b/src/twinkle_agentic/condenser/keyword.py @@ -1,11 +1,517 @@ -from abc import abstractmethod +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Extractive, spaCy-driven passage condenser. + +For each eligible chunk, produces a compact summary with three slots:: + + Open: + Rel: (subject | verb | object); (subject | verb | object | prep obj) + More: kw1, kw2, kw3 + +Strictly bounded by ``ceil(len(input) / compression_ratio)`` characters +for every chunk that passes ``min_chars``. Chunks shorter than +``min_chars`` are passed through unchanged (pre-filter). +""" +from __future__ import annotations + +import math +import re +import threading +from typing import Any, Dict, FrozenSet, List, Optional, Sequence, Tuple from twinkle_agentic.condenser.base import Condenser -from twinkle_agentic.data_format import Chunks +from twinkle_agentic.data_format import Chunk, Chunks + +# --------------------------------------------------------------------------- +# spaCy lazy loader (one model per process, thread-safe) +# --------------------------------------------------------------------------- +_SPACY_MODELS: Dict[str, Any] = {} +_SPACY_LOCK = threading.Lock() + + +def _load_spacy(name: str): + nlp = _SPACY_MODELS.get(name) + if nlp is not None: + return nlp + with _SPACY_LOCK: + nlp = _SPACY_MODELS.get(name) + if nlp is not None: + return nlp + try: + import spacy + except ImportError as e: + raise ImportError( + 'KeywordCondenser requires spaCy. Install with: ' + '`pip install spacy && python -m spacy download en_core_web_sm`' + ) from e + try: + nlp = spacy.load(name) + except OSError as e: + raise OSError( + f'spaCy model {name!r} not found. Download with: ' + f'`python -m spacy download {name}`' + ) from e + _SPACY_MODELS[name] = nlp + return nlp + + +# --------------------------------------------------------------------------- +# configuration-free constants +# --------------------------------------------------------------------------- +# Entity labels dropped from keyword candidates (low recall value). +_DROP_ENT_LABELS: FrozenSet[str] = frozenset( + {'CARDINAL', 'ORDINAL', 'PERCENT', 'QUANTITY'}) + +# Dependency labels that introduce sub-clauses / conjuncts we do NOT want +# to pull into a single noun-phrase span. +_DROP_NP_DEPS: FrozenSet[str] = frozenset( + {'relcl', 'acl', 'advcl', 'ccomp', 'xcomp', + 'conj', 'cc', 'appos', 'parataxis'}) + +# Tokens stripped from NP boundaries. +_LEADING_STRIP_POS: FrozenSet[str] = frozenset({'DET', 'PUNCT'}) + +# Tuple-slot separator. ``|`` avoids confusion when a slot itself +# contains a comma (e.g. ``"London, England"``). +_SLOT_SEP = ' | ' +_TRIPLE_SEP = '; ' + +_WORD_RE = re.compile(r'\w+', flags=re.UNICODE) + + +# --------------------------------------------------------------------------- +# NP / verb surface helpers +# --------------------------------------------------------------------------- +def _np_text(head) -> str: + """Return the noun-phrase text headed by ``head``. + + Keeps the contiguous span from the leftmost to the rightmost kept + token so internal punctuation (hyphens, apostrophes, slashes) is + preserved verbatim. Drops clausal / conjunct sub-trees and trims + leading determiners / possessive pronouns. + """ + # Collect subtree tokens, cutting off whole clausal children. + collected: List = [] + + def _walk(tok): + if tok is not head and tok.dep_ in _DROP_NP_DEPS: + return + collected.append(tok) + for child in tok.children: + _walk(child) + + _walk(head) + if not collected: + return head.text + collected.sort(key=lambda t: t.i) + + # Strip leading det/punct and possessive pronouns. + while collected and ( + collected[0].pos_ in _LEADING_STRIP_POS + or (collected[0].pos_ == 'PRON' and collected[0].dep_ == 'poss') + ): + collected.pop(0) + while collected and collected[-1].pos_ == 'PUNCT': + collected.pop() + if not collected: + return head.text + + start, end = collected[0].i, collected[-1].i + 1 + # If the kept tokens form a contiguous span, use the original text + # (preserves hyphens etc.). Otherwise fall back to text_with_ws. + if end - start == len(collected): + return head.doc[start:end].text.strip() + return ''.join(t.text_with_ws for t in collected).strip() + + +def _verb_surface(verb_tok) -> str: + """Verb text including auxiliaries (``was born``, ``has been released``).""" + aux = [c for c in verb_tok.children if c.dep_ in ('aux', 'auxpass')] + if not aux: + return verb_tok.text + tokens = sorted(aux + [verb_tok], key=lambda t: t.i) + return ' '.join(t.text for t in tokens) + + +def _first_child(token, deps: Sequence[str]): + if token is None: + return None + for c in token.children: + if c.dep_ in deps: + return c + return None + + +def _strip_leading_nc(noun_chunk) -> str: + toks = list(noun_chunk) + while toks and ( + toks[0].pos_ in _LEADING_STRIP_POS + or toks[0].pos_ == 'NUM' + or (toks[0].pos_ == 'PRON' and toks[0].tag_ in ('PRP$', 'WP$')) + ): + toks.pop(0) + while toks and toks[-1].pos_ == 'PUNCT': + toks.pop() + if not toks: + return '' + start, end = toks[0].i, toks[-1].i + 1 + if end - start == len(toks): + return noun_chunk.doc[start:end].text.strip() + return ''.join(t.text_with_ws for t in toks).strip() + + +def _word_tokens_lower(text: str) -> FrozenSet[str]: + return frozenset(m.group(0).lower() for m in _WORD_RE.finditer(text)) + + +def _word_boundary_truncate(text: str, limit: int) -> str: + """Truncate ``text`` to ``limit`` chars at the nearest space.""" + if len(text) <= limit: + return text + cut = text[:limit] + sp = cut.rfind(' ') + trimmed = cut[:sp] if sp >= limit // 2 else cut + return trimmed.rstrip() or cut + + +# --------------------------------------------------------------------------- +# extraction (pure functions on spaCy Doc) +# --------------------------------------------------------------------------- +def _extract_opening(doc, max_chars: int) -> str: + """First non-empty sentence, word-boundary-truncated to ``max_chars``.""" + if max_chars <= 0: + return '' + for sent in doc.sents: + text = sent.text.strip() + if text: + return _word_boundary_truncate(text, max_chars) + return '' + +def _extract_triples(doc, n: int) -> List[Tuple[str, ...]]: + """Subject-verb-object (+ optional prep-obj) triples. + - Skips pronoun subjects (unresolved coreference is noise). + - Preserves verb surface form (``was born`` rather than ``bear``). + - Deduplicates on lemmas. + """ + if n <= 0: + return [] + out: List[Tuple[str, ...]] = [] + seen: set = set() + for sent in doc.sents: + for verb in sent: + if verb.pos_ not in ('VERB', 'AUX'): + continue + subj = _first_child(verb, ('nsubj', 'nsubjpass', 'csubj')) + if subj is None or subj.pos_ == 'PRON': + continue + obj = _first_child(verb, ('dobj', 'attr', 'oprd')) + prep = _first_child(verb, ('prep',)) + prep_obj = _first_child(prep, ('pobj', 'pcomp')) if prep is not None else None + + subj_txt = _np_text(subj) + verb_txt = _verb_surface(verb) + + if obj is not None and prep_obj is not None: + triple = (subj_txt, verb_txt, _np_text(obj), + f'{prep.text} {_np_text(prep_obj)}') + key = (subj.lemma_.lower(), verb.lemma_.lower(), + obj.lemma_.lower(), + f'{prep.text.lower()} {prep_obj.lemma_.lower()}') + elif obj is not None: + triple = (subj_txt, verb_txt, _np_text(obj)) + key = (subj.lemma_.lower(), verb.lemma_.lower(), obj.lemma_.lower()) + elif prep_obj is not None: + triple = (subj_txt, f'{verb_txt} {prep.text}', _np_text(prep_obj)) + key = (subj.lemma_.lower(), + f'{verb.lemma_.lower()} {prep.text.lower()}', + prep_obj.lemma_.lower()) + else: + continue + if key in seen: + continue + seen.add(key) + out.append(triple) + if len(out) >= n: + return out + return out + + +def _extract_keywords(doc, k: int, excluded_tokens: FrozenSet[str]) -> List[str]: + """Rank keyword candidates by (entity-weighted) frequency. + + - Drops pure-numeric entities (CARDINAL / ORDINAL / PERCENT / QUANTITY). + - Skips any term whose words are all already in ``excluded_tokens`` + (so we don't repeat what the opening already says). + - Subsumption dedup: drops a shorter form if a longer form + containing it is already kept (``"Nolan"`` dropped when + ``"Christopher Nolan"`` is present). + """ + if k <= 0: + return [] + counts: Dict[str, float] = {} + order: Dict[str, int] = {} + idx = 0 + + def _add(term: str, weight: float) -> None: + nonlocal idx + t = term.strip() + if len(t) < 2: + return + words = [w.lower() for w in _WORD_RE.findall(t)] + if not words: + return + if all(w in excluded_tokens for w in words): + return + if t not in order: + order[t] = idx + idx += 1 + counts[t] = counts.get(t, 0.0) + weight + + for ent in doc.ents: + if ent.label_ in _DROP_ENT_LABELS: + continue + _add(ent.text, weight=10.0) + for nc in doc.noun_chunks: + _add(_strip_leading_nc(nc), weight=1.0) + for tok in doc: + if tok.pos_ == 'PROPN' and not tok.is_stop: + _add(tok.text, weight=2.0) + + ranked = sorted(counts.keys(), key=lambda t: (-counts[t], order[t])) + + kept: List[str] = [] + kept_word_sets: List[FrozenSet[str]] = [] + for term in ranked: + words = frozenset(_WORD_RE.findall(term.lower())) + # Subsumed by any already-kept term (identical or proper subset). + if any(words == ws or words < ws for ws in kept_word_sets): + continue + # Also drop earlier-kept strict subsets of the current term. + to_remove = [i for i, ws in enumerate(kept_word_sets) if ws < words] + for i in reversed(to_remove): + kept.pop(i) + kept_word_sets.pop(i) + kept.append(term) + kept_word_sets.append(words) + if len(kept) >= k: + break + return kept + + +# --------------------------------------------------------------------------- +# budget-aware formatting (pure strings) +# --------------------------------------------------------------------------- +def _format_triple(triple: Tuple[str, ...]) -> str: + return '(' + _SLOT_SEP.join(triple) + ')' + + +def _compose(opening: str, rel: str, kw: str) -> str: + parts: List[str] = [] + if opening: + parts.append(f'Open: {opening}') + if rel: + parts.append(f'Rel: {rel}') + if kw: + parts.append(f'More: {kw}') + return '\n'.join(parts) + + +def _fit_under_budget( + opening: str, + triples: List[Tuple[str, ...]], + keywords: List[str], + budget: int, + *, + fallback_text: str = '', +) -> str: + """Pack as many triples + keywords as possible under ``budget``. + + Strategy: + 1. If opening alone is already too long, word-boundary truncate it. + 2. Greedily append triples one-by-one, keeping a running string. + 3. Greedily append keywords one-by-one on top of whatever fits. + 4. Never exceed ``budget`` — final safety clamp applies. + """ + # ----- opening ----- + if opening and len(f'Open: {opening}') > budget: + max_open = max(0, budget - len('Open: ')) + opening = _word_boundary_truncate(opening, max_open) if max_open else '' + + if not opening and not triples and not keywords: + # Nothing extractable — fall back to raw text, strict-truncated. + base = fallback_text[:budget] if fallback_text else '' + return _word_boundary_truncate(base, budget) if base else base + + current = _compose(opening, '', '') + if len(current) > budget: + return current[:budget] + + # ----- triples ----- + kept_triples: List[Tuple[str, ...]] = [] + for t in triples: + trial_rel = _TRIPLE_SEP.join(_format_triple(x) for x in kept_triples + [t]) + trial = _compose(opening, trial_rel, '') + if len(trial) <= budget: + kept_triples.append(t) + else: + break + + rel_str = _TRIPLE_SEP.join(_format_triple(x) for x in kept_triples) + + # ----- keywords ----- + kept_kws: List[str] = [] + for k in keywords: + trial_kw = ', '.join(kept_kws + [k]) + trial = _compose(opening, rel_str, trial_kw) + if len(trial) <= budget: + kept_kws.append(k) + else: + break + + kw_str = ', '.join(kept_kws) + result = _compose(opening, rel_str, kw_str) + if not result: + # Budget too tight for any extracted slot — fall back to raw + # text truncated at a word boundary. + base = fallback_text[:budget] if fallback_text else '' + return _word_boundary_truncate(base, budget) if base else base + # Belt-and-braces: budget is strict. + return result if len(result) <= budget else result[:budget] + + +# --------------------------------------------------------------------------- +# KeywordCondenser +# --------------------------------------------------------------------------- class KeywordCondenser(Condenser): + """Extractive, spaCy-driven passage condenser. - @abstractmethod + Args: + num_relations: Max number of + ``(subject, verb, object[, prep-obj])`` tuples per chunk. + Set to ``0`` to disable the ``Rel:`` slot. + max_first_sentence_chars: Hard cap for the opening slot, applied + before the global compression budget. + num_keywords: Max keyword items per chunk. ``0`` disables ``More:``. + compression_ratio: Target compression factor. Must be ``> 1``. + ``len(output) <= ceil(len(input) / compression_ratio)`` is + strictly enforced for every chunk that passes ``min_chars``. + spacy_model: spaCy pipeline name (default ``en_core_web_sm``). + min_chars: Pre-filter. Chunks shorter than this are passed + through **unchanged**; the ratio contract does not apply to + them. Set to ``0`` to always compress. + skip_roles: Roles whose chunks are never compressed. + rounds: Optional set/list of conversation-turn numbers to + compress. ``None`` (default) = no round-based filtering; + when provided, chunks whose ``round`` is not in this set + are passed through unchanged. Chunks that lack a ``round`` + field are also skipped when this filter is active. + + Every produced chunk is marked with ``raw.condensed=True`` so + :meth:`Chunks.to_trajectory` wraps it in ``...``. + + Example: + >>> from twinkle_agentic.chunker import NativeChunker + >>> from twinkle_agentic.condenser.keyword import KeywordCondenser + >>> chunker = NativeChunker(chunk_size=1024) + >>> cond = KeywordCondenser( + ... num_relations=3, max_first_sentence_chars=160, + ... num_keywords=8, compression_ratio=4.0) + >>> traj = {'messages': [{'role': 'user', 'content': long_passage}]} + >>> chunks = cond(chunker(traj)) + >>> traj_compressed = chunks.to_trajectory() + """ + + def __init__( + self, + num_relations: int = 3, + max_first_sentence_chars: int = 160, + num_keywords: int = 8, + compression_ratio: float = 4.0, + spacy_model: str = 'en_core_web_sm', + min_chars: int = 200, + skip_roles: Sequence[str] = ('system', 'tool', 'assistant'), + rounds: Optional[Sequence[int]] = None, + ): + if num_relations < 0: + raise ValueError(f'num_relations must be >= 0, got {num_relations}') + if num_keywords < 0: + raise ValueError(f'num_keywords must be >= 0, got {num_keywords}') + if max_first_sentence_chars < 0: + raise ValueError( + f'max_first_sentence_chars must be >= 0, got {max_first_sentence_chars}') + if compression_ratio <= 1.0: + raise ValueError( + f'compression_ratio must be > 1, got {compression_ratio}') + if min_chars < 0: + raise ValueError(f'min_chars must be >= 0, got {min_chars}') + + self.num_relations = num_relations + self.max_first_sentence_chars = max_first_sentence_chars + self.num_keywords = num_keywords + self.compression_ratio = float(compression_ratio) + self.spacy_model = spacy_model + self.min_chars = min_chars + self.skip_roles = tuple(skip_roles) + self.rounds = set(rounds) if rounds is not None else None + + # ------------------------------------------------------------------ def __call__(self, chunks: Chunks, **kwargs) -> Chunks: - pass \ No newline at end of file + nlp = _load_spacy(self.spacy_model) + out: List[Chunk] = [] + for c in chunks.chunks: + if not self._should_condense(c): + out.append(c) + continue + compressed = self._condense(c['content'], nlp) + out.append(self._mark_condensed(c, compressed)) + return Chunks(chunks=out) + + # ------------------------------------------------------------------ + # selection policy + # ------------------------------------------------------------------ + def _should_condense(self, chunk: Chunk) -> bool: + if chunk.get('type') != 'text': + return False + if chunk.get('role') in self.skip_roles: + return False + if self.rounds is not None and chunk.get('round') not in self.rounds: + return False + content = chunk.get('content') + if not isinstance(content, str) or not content: + return False + if len(content) < self.min_chars: + return False + raw = chunk.get('raw') or {} + if isinstance(raw, dict): + # Chunker-emitted reasoning / tool-call text chunks carry a + # non-empty ``kind`` marker; leave them alone. + if raw.get('kind'): + return False + # Idempotency — don't re-condense already condensed chunks. + if raw.get('condensed'): + return False + return True + + @staticmethod + def _mark_condensed(chunk: Chunk, content: str) -> Chunk: + new: Dict[str, Any] = dict(chunk) + raw = dict(new.get('raw') or {}) + raw.setdefault('original', new.get('content', '')) + new['content'] = content + raw['condensed'] = True + new['raw'] = raw + return new # type: ignore[return-value] + + # ------------------------------------------------------------------ + # core extractive compression + # ------------------------------------------------------------------ + def _condense(self, text: str, nlp) -> str: + budget = max(1, math.ceil(len(text) / self.compression_ratio)) + doc = nlp(text) + opening = _extract_opening(doc, self.max_first_sentence_chars) + excluded = _word_tokens_lower(opening) + triples = _extract_triples(doc, self.num_relations) + keywords = _extract_keywords(doc, self.num_keywords, excluded) + return _fit_under_budget( + opening, triples, keywords, budget, fallback_text=text) diff --git a/src/twinkle_agentic/condenser/model.py b/src/twinkle_agentic/condenser/model.py index b70371ed..404cde2d 100644 --- a/src/twinkle_agentic/condenser/model.py +++ b/src/twinkle_agentic/condenser/model.py @@ -1,11 +1,350 @@ -from abc import abstractmethod +# Copyright (c) ModelScope Contributors. All rights reserved. +"""LLM-backed passage condenser. + +Delegates compression to a :class:`twinkle.sampler.base.Sampler`. For +each eligible chunk, builds a compression prompt, samples from the +LLM, parses the markdown response into ``## Summary / ## Key Facts / +## More`` sections, and strictly clamps the final output to +``ceil(len(input) / compression_ratio)`` characters via progressive +section-drop + word-boundary truncation. +""" +from __future__ import annotations + +import math +import re +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from twinkle_agentic.condenser.base import Condenser -from twinkle_agentic.data_format import Chunks +from twinkle_agentic.data_format import Chunk, Chunks + +if TYPE_CHECKING: # only used for type hints, keep runtime deps minimal + from twinkle.data_format import SamplingParams, Trajectory + from twinkle.sampler.base import Sampler + + +def _sampling_params_cls(): + """Lazy import to avoid coupling module import to twinkle.sampler.""" + from twinkle.data_format.sampling import SamplingParams + return SamplingParams +# Markdown headers emitted by the condenser. +_SUMMARY_HEADER = '## Summary' +_FACTS_HEADER = '## Key Facts' +_MORE_HEADER = '## More' +_DEFAULT_SYSTEM_PROMPT = ( + 'You are a precise text compression assistant. Summarize the user' + ' passage into the required markdown structure without inventing' + ' any information. Preserve named entities, dates, numbers, and' + ' factual relations.' +) + +_DEFAULT_USER_PROMPT_TEMPLATE = ( + 'Compress the passage below into markdown with EXACTLY three' + ' sections in this order:\n\n' + '## Summary\n\n\n' + '## Key Facts\n<3-5 bullet lines, each starting with "- ">\n\n' + '## More\n\n\n' + 'Hard rule: the total output MUST NOT exceed {budget} characters.' + ' Do not add extra sections, preambles, or closing remarks.\n\n' + 'Passage:\n{text}') + + +# --------------------------------------------------------------------------- +# ModelCondenser +# --------------------------------------------------------------------------- class ModelCondenser(Condenser): + """Condenser that delegates compression to an LLM via a :class:`Sampler`. + + Args: + sampler: A configured :class:`Sampler`. The sampler must already + have a ``template`` set so it can encode ``Trajectory`` + inputs. The sampler is reused across chunks (batched). + compression_ratio: Target factor, must be ``> 1``. For chunks + that pass ``min_chars``, + ``len(output) <= ceil(len(input) / compression_ratio)`` is + strictly enforced via post-sampling truncation (the model + cannot be trusted to obey a soft word count). + sampling_params: Override for per-call sampling. Defaults to + greedy (temperature 0) with ``max_tokens`` derived from the + budget. + system_prompt: Override the default system prompt. + user_prompt_template: Override the default user prompt. + Supported placeholders: ``{budget}`` and ``{text}``. + min_chars: Pre-filter. Chunks shorter than this are passed + through unchanged (the ratio contract does not apply to + them). + skip_roles: Roles whose chunks are never compressed. + rounds: Optional set/list of conversation-turn numbers to + compress. ``None`` (default) = no round-based filtering; + when provided, chunks whose ``round`` is not in this set + are passed through unchanged. Chunks that lack a ``round`` + field are also skipped when this filter is active. + batch_size: Max chunks per sampler call. Larger values amortize + LLM prefill / worker-dispatch overhead. + use_base_model: When ``True``, compression is done WITHOUT the + currently-synced LoRA adapter (i.e. the frozen base model). + This breaks the closed-loop "policy compresses its own + context" drift during RL training — strongly recommended + when ``sampler`` is also the training policy. The flag is + forwarded to :meth:`Sampler.sample` as ``use_base_model``; + samplers that do not support it will raise a + ``TypeError``. - @abstractmethod + The condenser marks every produced chunk with ``raw.condensed=True`` + so :meth:`Chunks.to_trajectory` wraps it in ``...``. + + Example: + >>> from twinkle.sampler import vLLMSampler + >>> sampler = vLLMSampler(model_id='Qwen/Qwen2.5-3B-Instruct', + ... engine_args={'dtype': 'bfloat16'}) + >>> sampler.set_template('qwen2_5') + >>> cond = ModelCondenser(sampler, compression_ratio=4.0) + >>> compressed = cond(chunks) + """ + + DEFAULT_SYSTEM_PROMPT: str = _DEFAULT_SYSTEM_PROMPT + DEFAULT_USER_PROMPT_TEMPLATE: str = _DEFAULT_USER_PROMPT_TEMPLATE + + def __init__( + self, + sampler: 'Sampler', + compression_ratio: float = 4.0, + *, + sampling_params: Optional['SamplingParams'] = None, + system_prompt: Optional[str] = None, + user_prompt_template: Optional[str] = None, + min_chars: int = 200, + skip_roles: Sequence[str] = ('system', 'tool', 'assistant'), + rounds: Optional[Sequence[int]] = None, + batch_size: int = 8, + use_base_model: bool = False, + ): + if sampler is None: + raise ValueError('sampler is required') + if compression_ratio <= 1.0: + raise ValueError( + f'compression_ratio must be > 1, got {compression_ratio}') + if min_chars < 0: + raise ValueError(f'min_chars must be >= 0, got {min_chars}') + if batch_size <= 0: + raise ValueError(f'batch_size must be >= 1, got {batch_size}') + + tpl = user_prompt_template or self.DEFAULT_USER_PROMPT_TEMPLATE + if '{budget}' not in tpl or '{text}' not in tpl: + raise ValueError( + 'user_prompt_template must contain both {budget} and {text}') + + self.sampler = sampler + self.compression_ratio = float(compression_ratio) + self.sampling_params = sampling_params + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.user_prompt_template = tpl + self.min_chars = min_chars + self.skip_roles = tuple(skip_roles) + self.rounds = set(rounds) if rounds is not None else None + self.batch_size = batch_size + self.use_base_model = bool(use_base_model) + + # ------------------------------------------------------------------ + # entry + # ------------------------------------------------------------------ def __call__(self, chunks: Chunks, **kwargs) -> Chunks: - pass \ No newline at end of file + out: List[Chunk] = list(chunks.chunks) + jobs: List[Tuple[int, Chunk, int]] = [] + for i, c in enumerate(chunks.chunks): + if not self._should_condense(c): + continue + text = c['content'] + budget = max(1, math.ceil(len(text) / self.compression_ratio)) + jobs.append((i, c, budget)) + + for start in range(0, len(jobs), self.batch_size): + batch = jobs[start:start + self.batch_size] + trajectories = [ + self._build_trajectory(c['content'], b) for _, c, b in batch + ] + sp = self._build_sampling_params(max(b for _, _, b in batch)) + sample_kwargs: Dict[str, Any] = {'sampling_params': sp} + if self.use_base_model: + sample_kwargs['use_base_model'] = True + responses = self.sampler.sample(trajectories, **sample_kwargs) + if len(responses) != len(batch): + raise RuntimeError( + f'sampler returned {len(responses)} responses for ' + f'{len(batch)} inputs') + for (i, c, budget), resp in zip(batch, responses): + raw_text = self._pick_decoded(resp) + compressed = self._postprocess(raw_text, budget, c['content']) + out[i] = self._mark_condensed(c, compressed) + + return Chunks(chunks=out) + + # ------------------------------------------------------------------ + # selection policy + # ------------------------------------------------------------------ + def _should_condense(self, chunk: Chunk) -> bool: + if chunk.get('type') != 'text': + return False + if chunk.get('role') in self.skip_roles: + return False + if self.rounds is not None and chunk.get('round') not in self.rounds: + return False + content = chunk.get('content') + if not isinstance(content, str) or not content: + return False + if len(content) < self.min_chars: + return False + raw = chunk.get('raw') or {} + if isinstance(raw, dict): + # Skip chunker-emitted reasoning / tool_call text chunks. + if raw.get('kind'): + return False + # Idempotency — don't re-condense already condensed chunks. + if raw.get('condensed'): + return False + return True + + @staticmethod + def _mark_condensed(chunk: Chunk, content: str) -> Chunk: + new: Dict[str, Any] = dict(chunk) + raw = dict(new.get('raw') or {}) + raw.setdefault('original', new.get('content', '')) + new['content'] = content + raw['condensed'] = True + new['raw'] = raw + return new # type: ignore[return-value] + + # ------------------------------------------------------------------ + # prompt construction + # ------------------------------------------------------------------ + def _build_trajectory(self, text: str, budget: int) -> 'Trajectory': + # Use str.replace to avoid .format() breaking on braces in text. + user = (self.user_prompt_template + .replace('{budget}', str(budget)) + .replace('{text}', text)) + return { # type: ignore[return-value] + 'messages': [ + {'role': 'system', 'content': self.system_prompt}, + {'role': 'user', 'content': user}, + ], + } + + def _build_sampling_params(self, budget: int) -> 'SamplingParams': + if self.sampling_params is not None: + return self.sampling_params + # Rough heuristic: ~1 token per 2-3 English chars + headroom. + max_new = max(64, int(budget * 0.8) + 64) + return _sampling_params_cls()(temperature=0.0, max_tokens=max_new) + + # ------------------------------------------------------------------ + # response parsing & strict-budget clamping + # ------------------------------------------------------------------ + @staticmethod + def _pick_decoded(response) -> str: + seqs = getattr(response, 'sequences', None) or [] + if not seqs: + return '' + decoded = getattr(seqs[0], 'decoded', None) + return decoded or '' + + def _postprocess(self, raw: str, budget: int, original: str) -> str: + text = _strip_code_fences(raw).strip() + sections = _parse_markdown_sections(text) + formatted = _format_sections(sections, fallback=text) + if formatted and len(formatted) <= budget: + return formatted + # Progressive drop on a *copy*: More → Key Facts → Summary. Keep + # the original ``sections`` intact for the body-only fallback. + remaining = dict(sections) + for drop in ('more', 'facts', 'summary'): + remaining.pop(drop, None) + reduced = _format_sections(remaining, fallback='') + if reduced and len(reduced) <= budget: + return reduced + # Even "## Summary\n" cannot fit — the header alone eats the + # budget. Clamp the most informative *body* (no header) so the user + # still gets meaningful content instead of dangling hash marks. + for key in ('summary', 'facts', 'more'): + body = sections.get(key) + if body: + clamped = _clamp_to_budget(body, budget) + if clamped: + return clamped + # No parsable sections at all — clamp the stripped raw text + # (or the original passage as a last resort). + return _clamp_to_budget(text or original, budget) + + +# --------------------------------------------------------------------------- +# helpers (pure functions) +# --------------------------------------------------------------------------- +_SECTION_RE = re.compile( + r'^[ \t]*#{1,6}[ \t]*(?P
summary|key[ \t]*facts?|more)[ \t]*$', + re.IGNORECASE | re.MULTILINE, +) +_SECTION_KEYS = { + 'summary': 'summary', + 'key fact': 'facts', + 'key facts': 'facts', + 'keyfact': 'facts', + 'keyfacts': 'facts', + 'more': 'more', +} +_HEADER_ORDER: Tuple[Tuple[str, str], ...] = ( + ('summary', _SUMMARY_HEADER), + ('facts', _FACTS_HEADER), + ('more', _MORE_HEADER), +) + + +def _parse_markdown_sections(text: str) -> Dict[str, str]: + """Extract ``{summary, facts, more}`` sections from ``text``. + + Last-writer wins on duplicate headers (e.g. the model repeats + ``## Summary`` twice — we keep the later body). + """ + if not text: + return {} + matches = list(_SECTION_RE.finditer(text)) + out: Dict[str, str] = {} + for i, m in enumerate(matches): + header = re.sub(r'\s+', ' ', m.group('header').strip().lower()) + key = _SECTION_KEYS.get(header) + if key is None: + continue + start = m.end() + end = matches[i + 1].start() if i + 1 < len(matches) else len(text) + body = text[start:end].strip() + if body: + out[key] = body + return out + + +def _format_sections(sections: Dict[str, str], *, fallback: str = '') -> str: + parts = [ + f'{header}\n{sections[key]}' for key, header in _HEADER_ORDER + if sections.get(key) + ] + if parts: + return '\n\n'.join(parts) + return fallback + + +def _strip_code_fences(text: str) -> str: + """Unwrap a leading/trailing triple-backtick fence if present.""" + stripped = text.strip() + m = re.match(r'^```[a-zA-Z]*\s*\n(.*?)\n```\s*$', stripped, re.DOTALL) + return m.group(1) if m else text + + +def _clamp_to_budget(text: str, budget: int) -> str: + """Word-boundary truncate ``text`` to at most ``budget`` chars.""" + if len(text) <= budget: + return text + if budget <= 0: + return '' + cut = text[:budget] + sp = cut.rfind(' ') + trimmed = cut[:sp] if sp >= budget // 2 else cut + return trimmed.rstrip() or cut diff --git a/src/twinkle_agentic/data_format/chunks.py b/src/twinkle_agentic/data_format/chunks.py index 04e78d7f..b596d65d 100644 --- a/src/twinkle_agentic/data_format/chunks.py +++ b/src/twinkle_agentic/data_format/chunks.py @@ -19,6 +19,7 @@ class Chunk(TypedDict, total=False): content: Union[str, Any] raw: Union[str, Any] role: str + round: int @dataclass diff --git a/src/twinkle_agentic/rollout/base.py b/src/twinkle_agentic/rollout/base.py index 5b078001..be74ff0e 100644 --- a/src/twinkle_agentic/rollout/base.py +++ b/src/twinkle_agentic/rollout/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import List from twinkle.data_format import Trajectory @@ -6,5 +7,5 @@ class Rollout(ABC): @abstractmethod - def __call__(self, trajectory: Trajectory, **kwargs) -> Trajectory: + def __call__(self, trajectories: List[Trajectory], **kwargs) -> List[Trajectory]: raise NotImplementedError() diff --git a/src/twinkle_agentic/rollout/multi_turn.py b/src/twinkle_agentic/rollout/multi_turn.py index f88a6d9a..5f90495a 100644 --- a/src/twinkle_agentic/rollout/multi_turn.py +++ b/src/twinkle_agentic/rollout/multi_turn.py @@ -1,7 +1,469 @@ +from typing import Any, Dict, List, Optional + +import json +import time + +import numpy as np + from twinkle.data_format import Trajectory +from twinkle.data_format.sampling import SampleResponse, SamplingParams +from twinkle.template.base import Template + +from twinkle_agentic.tools.tool_manager import ToolManager from .base import Rollout + +def _to_plain(obj: Any) -> Any: + """Recursively convert numpy arrays/scalars to plain Python lists/numbers. + + Mirrors ``vllm_sampler._convert_ndarray_to_list`` but lives locally so we + do not depend on a private symbol. + """ + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.bool_): + return bool(obj) + if isinstance(obj, dict): + return {k: _to_plain(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + conv = [_to_plain(x) for x in obj] + return type(obj)(conv) if isinstance(obj, tuple) else conv + return obj + + class MultiTurnRollout(Rollout): + """Agentic multi-turn rollout with tool use (batched). + + Contract (matches :class:`Rollout`): accepts a ``List[Trajectory]`` and + returns a ``List[Trajectory]`` of the same length, in the same order. + Every turn issues a SINGLE batched ``sampler.sample(active_pifs)`` call + so vLLM can run all live trajectories in parallel; finished trajectories + are parked and excluded from subsequent batches. + + Per-trajectory loop: + 1. Encode the initial trajectory into an ``InputFeature`` with a + generation prompt at the tail. + 2. Call ``sampler.sample(pifs)`` (batched). The sampler internally + invokes ``template.concat_input_feature`` to append the freshly + sampled assistant tokens; we pick up ``seq.new_input_feature`` as + the new running ``pif``. + 3. If ``stop_reason == 'length'`` or the decoded assistant output has + no tool calls, mark the trajectory as done. + 4. Otherwise, invoke the tools via ``ToolManager`` and append each + tool response as a ``{'role':'tool', 'content': ...}`` message. + Compute "bridge" tokens (tool turns + next ``<|im_start|>assistant`` + header) with ``labels = -100`` and extend the pif. + 5. Repeat until all trajectories are done or ``max_turns`` is hit. + + Per-call overrides via ``**kwargs``: + * ``sampling_params``: shared :class:`SamplingParams` for the batch. + * ``tool_manager``: either a single :class:`ToolManager` (applied to + every trajectory) or a list of ``ToolManager`` aligned 1:1 with + ``trajectories`` (used by :class:`MultiTurnCondenseRollout` to + attach a trajectory-bound ``ExtractCondensed``). + + The class intentionally has no knowledge of condensers/chunkers; they are + applied upstream (on the trajectory before rollout) or downstream + (on the returned messages). + """ + + def __init__( + self, + sampler, + template: Template, + tool_manager: ToolManager, + sampling_params: Optional[SamplingParams] = None, + max_turns: int = 6, + trace_path: Optional[str] = None, + ): + if template is None: + raise ValueError('MultiTurnRollout requires a local Template instance') + if tool_manager is None: + raise ValueError('MultiTurnRollout requires a ToolManager') + if max_turns < 1: + raise ValueError(f'max_turns must be >= 1, got {max_turns}') + self.sampler = sampler + self.template = template + self.tool_manager = tool_manager + self.sampling_params = sampling_params or SamplingParams() + self.max_turns = max_turns + # When set, every turn writes one JSONL record per active + # trajectory to ``trace_path``. The file is truncated at + # construction time (matching the behaviour of the legacy + # ``_make_dump_rollout_trace`` hook); subsequent writes append. + # Errors during trace writing are swallowed on purpose so + # observability can never break a training step. + self.trace_path = trace_path + if self.trace_path: + try: + # Truncate up front so repeated rollouts start from an + # empty file. Using a context manager here would be + # equivalent; explicit ``close()`` is clearer. + f = open(self.trace_path, 'w', encoding='utf-8') + f.close() + except OSError: + # If we can't even create the file, disable tracing + # silently rather than crashing the training job. + self.trace_path = None + + if self.sampling_params.num_samples != 1: + raise ValueError( + f'MultiTurnRollout currently supports num_samples=1 only, ' + f'got {self.sampling_params.num_samples}') + assert self.template.truncation_strategy != 'split', ( + "MultiTurnRollout does not support truncation_strategy='split'; " + 'use left/right/raise on the template.') + + def __call__(self, trajectories: List[Trajectory], **kwargs) -> List[Trajectory]: + if isinstance(trajectories, dict): + raise TypeError( + 'MultiTurnRollout.__call__ expects a List[Trajectory]; ' + 'wrap a single trajectory as [trajectory].') + trajectories = list(trajectories) + n = len(trajectories) + if n == 0: + return [] + + sampling_params = kwargs.get('sampling_params', self.sampling_params) + tool_managers = self._resolve_tool_managers( + kwargs.get('tool_manager', self.tool_manager), n) + + # 1. Encode each trajectory once; ``pifs[i]`` is the live per-turn + # state for trajectory ``i``. + pifs: List[Dict[str, Any]] = [] + for traj in trajectories: + pif = self.template.encode(traj, add_generation_prompt=True) + pif = _to_plain(pif) + pif.setdefault('messages', list(traj.get('messages', []))) + pifs.append(pif) + + all_logprobs: List[List[Any]] = [[] for _ in range(n)] + stop_reasons: List[Optional[str]] = [None] * n + turns: List[int] = [0] * n + truncated: List[bool] = [False] * n + done: List[bool] = [False] * n + + for _ in range(self.max_turns): + active = [i for i in range(n) if not done[i]] + if not active: + break + + # 2. One batched sample call for all currently-live trajectories. + batch_pifs = [pifs[i] for i in active] + resps = self.sampler.sample(batch_pifs, sampling_params=sampling_params) + resps = self._unwrap_response_list(resps, len(active)) + + pending_bridges: List[tuple] = [] # (global_idx, tool_messages) + trace_rows: List[Dict[str, Any]] = [] # buffered per-turn records + for local_idx, global_idx in enumerate(active): + turns[global_idx] += 1 + seq = resps[local_idx].sequences[0] + + if seq.new_input_feature is None or 'input_ids' not in seq.new_input_feature: + raise RuntimeError( + f'Sampler returned a SampledSequence without ' + f'new_input_feature.input_ids at batch index ' + f'{local_idx} (trajectory {global_idx}); ' + f'cannot continue multi-turn.') + + pifs[global_idx] = _to_plain(dict(seq.new_input_feature)) + if seq.logprobs is not None: + if len(seq.logprobs) != len(seq.tokens): + raise RuntimeError( + f'logprobs length ({len(seq.logprobs)}) does not ' + f'match sampled token count ({len(seq.tokens)}) ' + f'at turn {turns[global_idx]} ' + f'(trajectory {global_idx})') + all_logprobs[global_idx].extend(seq.logprobs) + stop_reasons[global_idx] = seq.stop_reason + + # 3. Termination conditions + if seq.stop_reason == 'length': + done[global_idx] = True + trace_rows.append(self._trace_row( + turn=turns[global_idx], + global_idx=global_idx, + n=n, + seq=seq, + tool_calls=None, + done=True, + truncated=False, + pif=pifs[global_idx])) + continue + + tool_calls = self.template.parse_tool_call(seq.decoded or '') + if not tool_calls: + done[global_idx] = True + trace_rows.append(self._trace_row( + turn=turns[global_idx], + global_idx=global_idx, + n=n, + seq=seq, + tool_calls=tool_calls, + done=True, + truncated=False, + pif=pifs[global_idx])) + continue + + if turns[global_idx] >= self.max_turns: + truncated[global_idx] = True + done[global_idx] = True + trace_rows.append(self._trace_row( + turn=turns[global_idx], + global_idx=global_idx, + n=n, + seq=seq, + tool_calls=tool_calls, + done=True, + truncated=True, + pif=pifs[global_idx])) + continue + + # 4. Dispatch tools per trajectory (uses this trajectory's + # tool_manager, which may be a trajectory-bound clone). + tool_messages = [{ + 'role': 'tool', + 'content': tool_managers[global_idx](tc), + } for tc in tool_calls] + pending_bridges.append((global_idx, tool_messages)) + trace_rows.append(self._trace_row( + turn=turns[global_idx], + global_idx=global_idx, + n=n, + seq=seq, + tool_calls=tool_calls, + done=False, + truncated=False, + pif=pifs[global_idx])) + + # Extend pif with bridge tokens for every trajectory that has + # outstanding tool turns. Done serially: bridge computation is + # a cheap decode-diff-encode on python strings / token lists. + for global_idx, tool_messages in pending_bridges: + pifs[global_idx] = self._extend_with_bridge( + pifs[global_idx], tool_messages) + + # Flush this turn's trace records (one JSONL line each). This + # happens AFTER bridge extension so a post-turn consumer sees + # the final pif length for the turn. + if self.trace_path and trace_rows: + self._write_trace(trace_rows) + + # 5. Merge pif fields into each trajectory dict at TOP LEVEL so + # downstream consumers (VLLMSampler with ``'input_ids' in inputs``) + # see an encoded InputFeature and skip re-encoding. + outs: List[Trajectory] = [] + for i, traj in enumerate(trajectories): + out = dict(traj) + out.update(pifs[i]) + out['messages'] = list(pifs[i].get('messages') or out.get('messages', [])) + out['logprobs'] = all_logprobs[i] if all_logprobs[i] else None + out['turns'] = turns[i] + out['stop_reason'] = stop_reasons[i] + out['truncated'] = truncated[i] + outs.append(out) + return outs + + # ------------------------------------------------------------------ private + + @staticmethod + def _resolve_tool_managers(arg, n: int) -> List[ToolManager]: + """Broadcast a single ``ToolManager`` or validate a per-trajectory list.""" + if isinstance(arg, list): + if len(arg) != n: + raise ValueError( + f'per-call tool_manager list length ({len(arg)}) does ' + f'not match number of trajectories ({n})') + return list(arg) + return [arg] * n + + @staticmethod + def _trace_row( + *, + turn: int, + global_idx: int, + n: int, + seq, + tool_calls, + done: bool, + truncated: bool, + pif: Dict[str, Any], + ) -> Dict[str, Any]: + """Build one per-trajectory trace record for the current turn. + + Deliberately flat + JSON-friendly. ``decoded`` is truncated-safe + (it's just a string). ``trainable_tokens`` is the count of labels + not equal to -100 so far, i.e. GRPO-loss-eligible positions. + """ + labels = pif.get('labels') or [] + trainable = sum(1 for l in labels if l != -100) + return { + 'ts': time.time(), + 'turn': int(turn), + 'batch_size': int(n), + 'trajectory_idx': int(global_idx), + 'stop_reason': getattr(seq, 'stop_reason', None), + 'decoded': getattr(seq, 'decoded', '') or '', + 'tool_call_count': 0 if not tool_calls else len(tool_calls), + 'done': bool(done), + 'truncated': bool(truncated), + 'input_ids_len': len(pif.get('input_ids') or []), + 'trainable_tokens': trainable, + } + + def _write_trace(self, rows: List[Dict[str, Any]]) -> None: + """Append trace rows as JSONL. Errors are swallowed by design. + + Observability must never break training -- any I/O or encoding + problem is silently ignored so a disk-full / permission issue + doesn't take down the optimisation loop. + """ + if not self.trace_path or not rows: + return + try: + lines = [ + json.dumps(r, ensure_ascii=False, default=str) + for r in rows] + with open(self.trace_path, 'a', encoding='utf-8') as f: + f.write('\n'.join(lines) + '\n') + except Exception: + pass + + @staticmethod + def _unwrap_response_list(resps, expected: int) -> List[SampleResponse]: + """Validate that the sampler returned ``expected`` ``SampleResponse``s, + one per input in the batch. + """ + if not isinstance(resps, list): + raise TypeError( + f'expected List[SampleResponse] from sampler.sample (batched ' + f'call), got {type(resps).__name__}') + if len(resps) != expected: + raise RuntimeError( + f'sampler returned {len(resps)} responses for a batch of ' + f'{expected} trajectories; expected one per input.') + for i, r in enumerate(resps): + if not isinstance(r, SampleResponse): + raise TypeError( + f'expected SampleResponse at batch index {i}, got ' + f'{type(r).__name__}') + if not r.sequences: + raise RuntimeError( + f'SampleResponse at batch index {i} has no sequences') + return resps + + def _extend_with_bridge( + self, + pif: Dict[str, Any], + tool_messages: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """Append tool messages and the next generation prompt as -100 bridge. + + Strategy: decode the CURRENT pif input_ids back to a string, render + the canonical chat-template string for ``messages + tool_messages`` + with ``add_generation_prompt=True``, diff at the STRING level, and + tokenize ONLY the delta. This avoids retokenising history (which would + drift through the ``decode(tokens, skip_special_tokens=True)`` round + trip that ``concat_input_feature`` does). + """ + tokenizer = self.template.tokenizer + + messages_before = list(pif.get('messages') or []) + messages_after = messages_before + list(tool_messages) + + current_text = tokenizer.decode(pif['input_ids'], skip_special_tokens=False) + s_after = tokenizer.apply_chat_template( + messages_after, tokenize=False, add_generation_prompt=True) + + bridge_text = self._compute_bridge_text(current_text, s_after) + if not bridge_text: + raise RuntimeError( + 'Bridge text computation returned empty string; ' + 'tool turn would add no tokens (template misconfiguration?).') + + bridge_ids = tokenizer.encode(bridge_text, add_special_tokens=False) + if not bridge_ids: + raise RuntimeError( + f'Bridge text tokenised to empty id list: {bridge_text!r}') + + new_pif = self._append_bridge_tokens(pif, bridge_ids) + new_pif['messages'] = messages_after + return new_pif + + @staticmethod + def _compute_bridge_text(current_text: str, s_after: str) -> str: + """Return the suffix of ``s_after`` beyond ``current_text``. + + Handles the case where ``current_text`` has trailing whitespace that + the canonical chat_template rendering already consumed (e.g. the + assistant ``<|im_end|>`` is emitted by vLLM without a trailing ``\\n`` + while the chat template always appends one between messages). + """ + if s_after.startswith(current_text): + return s_after[len(current_text):] + # Tolerate trailing whitespace mismatch at the boundary. + ct_stripped = current_text.rstrip() + if s_after.startswith(ct_stripped): + return s_after[len(ct_stripped):] + raise RuntimeError( + 'Cannot align decoded pif text with canonical chat_template output. ' + f'current_text tail: {current_text[-80:]!r}; ' + f's_after at same offset: ' + f'{s_after[max(0, len(current_text) - 80):len(current_text) + 80]!r}') + + def _append_bridge_tokens( + self, + pif: Dict[str, Any], + bridge_ids: List[int], + ) -> Dict[str, Any]: + """Append bridge tokens with labels = -100. + + Mirrors the unroll-append-reroll pattern of + :meth:`Template.concat_input_feature` so that ``labels`` semantics + stay consistent with the sampler-produced pif. + + Shallow copy is deliberately used: every mutation below is a + top-level key reassignment, never an in-place change to nested + tensors. Multimodal payloads (``images``, ``pixel_values``, + ``image_grid_thw`` ...) are shared by reference so we avoid + re-copying image buffers every turn. + """ + result = dict(pif) + + input_ids = list(result['input_ids']) + labels = list(result.get('labels') or []) + # labels arrive in output/shifted order (post _roll_labels). Unroll by + # one position (shift right by 1) to get back to input order. + if labels: + if len(labels) != len(input_ids): + raise RuntimeError( + f'labels length ({len(labels)}) != input_ids length ' + f'({len(input_ids)}); cannot safely append bridge tokens.') + labels = labels[-1:] + labels[:-1] + else: + labels = [-100] * len(input_ids) + + input_ids = input_ids + list(bridge_ids) + labels = labels + [-100] * len(bridge_ids) + + result['input_ids'] = input_ids + result['labels'] = labels + + if 'mm_token_type_ids' in result: + import torch + mm = result['mm_token_type_ids'] + if not isinstance(mm, torch.Tensor): + mm = torch.as_tensor(mm) + pad = torch.zeros((mm.shape[0], len(bridge_ids)), + dtype=mm.dtype, device=mm.device) + result['mm_token_type_ids'] = torch.cat([mm, pad], dim=1) - def __call__(self, trajectory: Trajectory, **kwargs) -> Trajectory: - \ No newline at end of file + # Replay the post pipeline: refresh attention_mask / position_ids / + # length and re-roll labels back into output/shifted order. + refreshed = self.template._invoke_post_pipeline([result])[0] + result.update(refreshed) + return _to_plain(result) diff --git a/src/twinkle_agentic/rollout/multi_turn_condense.py b/src/twinkle_agentic/rollout/multi_turn_condense.py new file mode 100644 index 00000000..155ff9e0 --- /dev/null +++ b/src/twinkle_agentic/rollout/multi_turn_condense.py @@ -0,0 +1,112 @@ +from typing import Any, Dict, List, Optional + +from twinkle.data_format import Trajectory +from twinkle.data_format.sampling import SamplingParams +from twinkle.template.base import Template + +from twinkle_agentic.chunker.base import Chunker +from twinkle_agentic.condenser.base import Condenser +from twinkle_agentic.tools.extract_condensed import ExtractCondensed, TOOL_NAME as EXTRACT_TOOL_NAME +from twinkle_agentic.tools.tool_manager import ToolManager +from .multi_turn import MultiTurnRollout + + +class MultiTurnCondenseRollout(MultiTurnRollout): + """Multi-turn rollout with trajectory compression + on-demand recovery. + + Pipeline per trajectory in the batch: + 1. ``chunker(trajectory)`` splits the incoming trajectory into chunks. + 2. ``condenser(chunks, **condenser_kwargs)`` rewrites selected text + chunks with compressed stand-ins, marking them ``raw.condensed=True`` + and stashing the original under ``raw.original``. + 3. ``chunks.to_trajectory()`` rebuilds a trajectory where every + condensed chunk is wrapped in ``...`` markers. + 4. A trajectory-scoped :class:`ExtractCondensed` tool is registered on + a per-trajectory clone of :attr:`tool_manager`, so the model can + recover the original text of any block by its number. + 5. The batch of compressed trajectories + a parallel list of + per-trajectory tool managers are handed to + :meth:`MultiTurnRollout.__call__`, which drives the sample/tool + loop (one batched ``sampler.sample`` per turn). + + The per-call tool manager is cloned via :meth:`ToolManager.copy`; the + shared ``self.tool_manager`` is never mutated, so concurrent rollouts on + the same instance are safe. + + Constructor accepts any :class:`Chunker` / :class:`Condenser` pair, so + plug-in chunkers (e.g. ``NativeChunker``) and condensers (e.g. + ``KeywordCondenser``, ``ModelCondenser``) compose freely. + """ + + def __init__( + self, + sampler, + template: Template, + tool_manager: ToolManager, + chunker: Chunker, + condenser: Condenser, + sampling_params: Optional[SamplingParams] = None, + max_turns: int = 6, + condenser_kwargs: Optional[Dict[str, Any]] = None, + trace_path: Optional[str] = None, + ): + super().__init__( + sampler=sampler, + template=template, + tool_manager=tool_manager, + sampling_params=sampling_params, + max_turns=max_turns, + trace_path=trace_path, + ) + if chunker is None: + raise ValueError( + 'MultiTurnCondenseRollout requires a Chunker instance') + if condenser is None: + raise ValueError( + 'MultiTurnCondenseRollout requires a Condenser instance') + if EXTRACT_TOOL_NAME in tool_manager.names(): + # We reserve the name because we register a trajectory-bound + # ExtractCondensed per trajectory; a pre-existing registration + # would be silently overwritten on the clone, which is confusing. + raise ValueError( + f'tool_manager already registers {EXTRACT_TOOL_NAME!r}; ' + f'MultiTurnCondenseRollout registers a trajectory-bound ' + f'ExtractCondensed per call and would shadow the existing ' + f'one. Remove it from the shared manager or rename it.') + self.chunker = chunker + self.condenser = condenser + self.condenser_kwargs = dict(condenser_kwargs or {}) + + def __call__(self, trajectories: List[Trajectory], **kwargs) -> List[Trajectory]: + if isinstance(trajectories, dict): + raise TypeError( + 'MultiTurnCondenseRollout.__call__ expects a ' + 'List[Trajectory]; wrap a single trajectory as [trajectory].') + trajectories = list(trajectories) + if not trajectories: + return [] + + compressed_list: List[Trajectory] = [] + tool_managers: List[ToolManager] = [] + for traj in trajectories: + # 1-2. Chunk + condense this trajectory. + chunks = self.chunker(traj) + chunks = self.condenser(chunks, **self.condenser_kwargs) + compressed = chunks.to_trajectory() + for k, v in traj.items(): + compressed.setdefault(k, v) + compressed_list.append(compressed) + + # 4. Per-trajectory tool manager: clone + inject ExtractCondensed + # bound to THIS trajectory's chunks. Never mutate + # self.tool_manager. + call_tm = self.tool_manager.copy() + call_tm.register(ExtractCondensed(chunks)) + tool_managers.append(call_tm) + + # 5. Delegate to the parent batch loop. A caller-supplied + # ``tool_manager`` would be surprising here (we already built + # the list) -- drop it to avoid ambiguity. + kwargs.pop('tool_manager', None) + return super().__call__( + compressed_list, tool_manager=tool_managers, **kwargs) diff --git a/src/twinkle_agentic/tools/extract_condensed.py b/src/twinkle_agentic/tools/extract_condensed.py index f9505ea3..b9fa980f 100644 --- a/src/twinkle_agentic/tools/extract_condensed.py +++ b/src/twinkle_agentic/tools/extract_condensed.py @@ -1,7 +1,158 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import json +from typing import Any, Dict, List, Optional + +from twinkle.data_format.message import Tool as ToolInfo +from twinkle_agentic.data_format import Chunks + from .base import Tool +TOOL_NAME = 'extract_condensed' + + class ExtractCondensed(Tool): + """Return the original text behind a ```` compressed segment. + + Args: + chunks: The :class:`Chunks` object emitted by a condenser + (post-compression). Each condensed chunk should carry + ``raw.original`` holding the pre-compression text; if that + snapshot is missing the block is still enumerated (so + numbering stays aligned with ````) but the tool + returns an explicit error on lookup rather than silently + handing back the compressed stand-in. + + The block enumeration rule mirrors :meth:`Chunks.to_trajectory` + exactly: only text chunks with ``raw.condensed=True``, + ``role != 'tool'`` and non-empty content are indexed, in chunk + order, starting from ``1``. This guarantees the block numbers this + tool accepts match the ```` tags the model actually sees. + """ + + def __init__(self, chunks: Chunks): + self._blocks: Dict[int, Optional[str]] = {} + counter = 0 + for c in chunks.chunks: + if c.get('type') != 'text': + continue + content = c.get('content') + if not isinstance(content, str) or not content: + continue + if c.get('role') == 'tool': + continue + raw = c.get('raw') + if not (isinstance(raw, dict) and raw.get('condensed')): + continue + counter += 1 + original = raw.get('original') + self._blocks[counter] = ( + original if isinstance(original, str) and original else None) + + # ------------------------------------------------------------------ + # Tool interface + # ------------------------------------------------------------------ + def tool_info(self) -> ToolInfo: + return { + 'tool_name': TOOL_NAME, + 'description': ( + 'Recover the full, uncompressed text of one or more ' + 'previously condensed passages, identified by their ' + ' tags. Use this tool whenever you need to ' + 're-read the original detail of compressed blocks.'), + 'parameters': json.dumps({ + 'blocks': ('int OR list[int], the 1-indexed block number(s) ' + 'N appearing inside .... ' + 'Pass a single int to expand one block, or a ' + 'list of ints to expand several in one call ' + '(e.g. 3 or [1, 3, 5]).'), + }), + } + + def __call__(self, tool_name: str, arguments: Dict[str, Any]) -> str: + if not isinstance(arguments, dict): + return (f'Error: arguments must be an object, got ' + f'{type(arguments).__name__}.') + # Accept the new preferred name ``blocks`` first, fall back to the + # legacy singular ``block`` for backward compatibility with callers + # that were built against the int-only interface. + if 'blocks' in arguments: + raw = arguments['blocks'] + key = 'blocks' + elif 'block' in arguments: + raw = arguments['block'] + key = 'block' + else: + return 'Error: missing required argument "blocks".' + + # Normalise to a list of integers. Single int / str-int → 1-element + # list; list/tuple → validate every element. Preserve order, + # deduplicate while keeping first occurrence. + if isinstance(raw, (list, tuple)): + items = list(raw) + else: + items = [raw] + + seen: Dict[int, None] = {} + parsed: List[int] = [] + for i, item in enumerate(items): + # ``bool`` subclasses ``int`` (``int(True) == 1``) and ``float`` + # coerces silently (``int(1.9) == 1``); reject both up front. + if isinstance(item, bool) or isinstance(item, float): + return (f'Error: "{key}" item at position {i} must be an ' + f'integer, got {type(item).__name__} {item!r}.') + try: + n = int(item) + except (TypeError, ValueError): + return (f'Error: "{key}" item at position {i} must be an ' + f'integer, got {item!r}.') + if n in seen: + continue + seen[n] = None + parsed.append(n) + + if not parsed: + return f'Error: "{key}" must contain at least one block number.' + + # Single-block path preserves the legacy bare-text return shape so + # existing callers / prompts keep working unchanged. + if len(parsed) == 1 and not isinstance(raw, (list, tuple)): + return self._lookup_one(parsed[0]) + + # Multi-block path wraps each result in ... so + # the model can tell them apart in the returned tool message. + parts: List[str] = [] + for n in parsed: + value = self._lookup_one(n) + parts.append(f'\n{value}\n') + return '\n\n'.join(parts) + + def _lookup_one(self, n: int) -> str: + """Return the original text for block ``n`` or an ``Error: ...`` string.""" + if n not in self._blocks: + available = ', '.join(str(k) for k in sorted(self._blocks)) + return (f'Error: block {n} not found. ' + f'Available blocks: {available or "(none)"}.') + value = self._blocks[n] + if value is None: + return (f'Error: block {n} has no original-text snapshot. ' + f'The upstream condenser must populate raw.original ' + f'before registering ExtractCondensed.') + return value + + # ------------------------------------------------------------------ + # Introspection helpers (handy for debugging / tests) + # ------------------------------------------------------------------ + @property + def blocks(self) -> List[int]: + """Sorted list of block indices available to this tool.""" + return sorted(self._blocks) + + def __len__(self) -> int: + return len(self._blocks) - # Extract the condensed block - pass \ No newline at end of file + def __contains__(self, n: Any) -> bool: + try: + return int(n) in self._blocks + except (TypeError, ValueError): + return False diff --git a/src/twinkle_agentic/tools/tool_manager.py b/src/twinkle_agentic/tools/tool_manager.py index 4996569c..61bb115b 100644 --- a/src/twinkle_agentic/tools/tool_manager.py +++ b/src/twinkle_agentic/tools/tool_manager.py @@ -1,16 +1,36 @@ import json -from typing import List, Optional, Dict, Union, Any - -from fastmcp.utilities.inspect import ToolInfo - +from typing import List, Optional, Dict, Union, Any, Iterable from twinkle.data_format import ToolCall +from twinkle.data_format.message import Tool as ToolInfo from twinkle_agentic.tools.base import Tool class ToolManager: - def __init__(self, tools: Dict[str, Tool]): - self._tools = tools + def __init__( + self, + tools: Optional[Union[Dict[str, Tool], Iterable[Tool]]] = None, + ): + if tools is None: + self._tools: Dict[str, Tool] = {} + return + if isinstance(tools, dict): + self._tools = dict(tools) + return + if isinstance(tools, (list, tuple)): + self._tools = {} + for t in tools: + info = t.tool_info() if hasattr(t, 'tool_info') else None + name = info.get('tool_name') if isinstance(info, dict) else None + if not name: + raise ValueError( + f'tool {type(t).__name__} must expose a non-empty ' + f'tool_info()["tool_name"]') + self._tools[name] = t + return + raise TypeError( + f'ToolManager expects dict | Iterable[Tool] | None; ' + f'got {type(tools).__name__}') def register(self, tool: Tool): info = tool.tool_info() @@ -27,6 +47,9 @@ def unregister(self, name: str) -> Optional[Tool]: def names(self) -> List[str]: return list(self._tools) + def copy(self) -> 'ToolManager': + return ToolManager(dict(self._tools)) + def tool_infos(self) -> List[ToolInfo]: return [t.tool_info() for t in self._tools.values()] diff --git a/tests/twinkle_agentic/test_extract_condensed.py b/tests/twinkle_agentic/test_extract_condensed.py new file mode 100644 index 00000000..e8325134 --- /dev/null +++ b/tests/twinkle_agentic/test_extract_condensed.py @@ -0,0 +1,433 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Unit tests for :class:`twinkle_agentic.tools.extract_condensed.ExtractCondensed`. + +Covers: +- block-index enumeration matches :meth:`Chunks.to_trajectory` exactly +- retrieval returns pre-compression text when ``raw.original`` is present +- fallback to current ``content`` when ``raw.original`` missing +- bad / missing arguments produce actionable error strings (no exceptions) +- tool metadata is complete and JSON-serializable +- integration with :class:`ToolManager` +- end-to-end: KeywordCondenser → Chunks → ExtractCondensed round-trips +""" +from __future__ import annotations + +import json + +import pytest + +from twinkle_agentic.data_format import Chunks +from twinkle_agentic.tools.extract_condensed import ( + TOOL_NAME, ExtractCondensed) +from twinkle_agentic.tools.tool_manager import ToolManager + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- +def _condensed(content, *, original=None, role='user', round_idx=1): + raw = {'condensed': True} + if original is not None: + raw['original'] = original + ch = {'type': 'text', 'role': role, 'content': content, 'raw': raw, + 'round': round_idx} + return ch + + +def _plain(content, *, role='user'): + return {'type': 'text', 'role': role, 'content': content} + + +# --------------------------------------------------------------------------- +# block enumeration parity with Chunks.to_trajectory +# --------------------------------------------------------------------------- +def test_blocks_indexed_from_1_in_document_order(): + chunks = Chunks(chunks=[ + _condensed('cmp1', original='orig one'), + _condensed('cmp2', original='orig two'), + _condensed('cmp3', original='orig three'), + ]) + tool = ExtractCondensed(chunks) + assert tool.blocks == [1, 2, 3] + assert len(tool) == 3 + assert 1 in tool and 3 in tool and 4 not in tool + + +def test_non_condensed_text_chunks_are_not_indexed(): + chunks = Chunks(chunks=[ + _plain('system prelude', role='system'), # not condensed + _condensed('cmp1', original='orig one'), + _plain('user follow-up'), # not condensed + _condensed('cmp2', original='orig two'), + ]) + tool = ExtractCondensed(chunks) + assert tool.blocks == [1, 2] + assert tool(TOOL_NAME, {'block': 1}) == 'orig one' + assert tool(TOOL_NAME, {'block': 2}) == 'orig two' + + +def test_tool_role_condensed_chunks_are_skipped(): + # Mirrors Chunks.to_trajectory: role=='tool' is NEVER wrapped, even + # if marked condensed, so it must not consume a block index either. + chunks = Chunks(chunks=[ + _condensed('cmp_user', original='user orig', role='user'), + _condensed('cmp_tool', original='tool orig', role='tool'), + _condensed('cmp_asst', original='asst orig', role='assistant'), + ]) + tool = ExtractCondensed(chunks) + # Only the user + assistant blocks count. + assert tool.blocks == [1, 2] + assert tool(TOOL_NAME, {'block': 1}) == 'user orig' + assert tool(TOOL_NAME, {'block': 2}) == 'asst orig' + + +def test_empty_content_condensed_chunks_are_skipped(): + chunks = Chunks(chunks=[ + _condensed('', original=''), # empty, skipped + _condensed('cmp', original='orig'), + ]) + tool = ExtractCondensed(chunks) + assert tool.blocks == [1] + assert tool(TOOL_NAME, {'block': 1}) == 'orig' + + +def test_non_text_chunks_ignored(): + chunks = Chunks(chunks=[ + {'type': 'image', 'content': 'image bytes', + 'raw': {'type': 'image', 'image': 'x'}, 'role': 'user'}, + _condensed('cmp', original='orig text'), + ]) + tool = ExtractCondensed(chunks) + assert tool.blocks == [1] + assert tool(TOOL_NAME, {'block': 1}) == 'orig text' + + +# --------------------------------------------------------------------------- +# retrieval semantics +# --------------------------------------------------------------------------- +def test_returns_original_when_present(): + chunks = Chunks(chunks=[_condensed('CMP', original='THE ORIGINAL')]) + tool = ExtractCondensed(chunks) + assert tool(TOOL_NAME, {'block': 1}) == 'THE ORIGINAL' + + +def test_missing_original_returns_error_not_compressed_content(): + # Contract: ExtractCondensed returns the *original* text. When the + # upstream pipeline forgot to snapshot it, the tool MUST fail loud + # rather than silently handing back the compressed stand-in, which + # would deceive the LLM into thinking it had recovered the source. + chunks = Chunks(chunks=[_condensed('CMP', original=None)]) + tool = ExtractCondensed(chunks) + # The block is still enumerated so numbering stays aligned. + assert tool.blocks == [1] + out = tool(TOOL_NAME, {'block': 1}) + assert out.startswith('Error:') + assert 'no original-text snapshot' in out + # And crucially, the compressed stand-in is NOT leaked. + assert 'CMP' not in out + + +def test_original_empty_string_also_reports_missing_snapshot(): + chunks = Chunks(chunks=[_condensed('CMP', original='')]) + tool = ExtractCondensed(chunks) + out = tool(TOOL_NAME, {'block': 1}) + assert out.startswith('Error:') + assert 'no original-text snapshot' in out + + +# --------------------------------------------------------------------------- +# bad input handling (never raises) +# --------------------------------------------------------------------------- +def test_missing_block_argument_returns_error_string(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp', original='orig')])) + out = tool(TOOL_NAME, {}) + assert out.startswith('Error: missing required argument') + + +def test_non_integer_block_returns_error_string(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp', original='orig')])) + for bad in ('abc', [], {}, None): + out = tool(TOOL_NAME, {'block': bad}) + assert out.startswith('Error:'), (bad, out) + + +def test_bool_block_is_rejected_not_coerced_to_int(): + # ``bool`` is a subclass of ``int`` so ``int(True) == 1``. Without + # an explicit guard, ``{'block': True}`` would silently retrieve + # block 1 -- a nasty footgun if an LLM stringifies a truthy flag. + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp', original='orig1')])) + out_true = tool(TOOL_NAME, {'block': True}) + assert out_true.startswith('Error:') and 'bool' in out_true + out_false = tool(TOOL_NAME, {'block': False}) + assert out_false.startswith('Error:') and 'bool' in out_false + # Sanity: the real integer 1 still works. + assert tool(TOOL_NAME, {'block': 1}) == 'orig1' + + +def test_float_block_is_rejected_not_silently_truncated(): + # ``int(1.9) == 1`` would silently round a float down; reject it. + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp', original='orig1')])) + out = tool(TOOL_NAME, {'block': 1.9}) + assert out.startswith('Error:') and 'float' in out + # And floats that happen to be integer-valued are also rejected to + # keep the contract simple. + out2 = tool(TOOL_NAME, {'block': 1.0}) + assert out2.startswith('Error:') + + +def test_non_dict_arguments_returns_error_not_attribute_error(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp', original='orig')])) + # Bypass ToolManager and feed a non-dict directly; must not raise. + out = tool(TOOL_NAME, 'not a dict') # type: ignore[arg-type] + assert out.startswith('Error:') + + +def test_out_of_range_block_returns_error_with_available_list(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp1', original='orig1'), + _condensed('cmp2', original='orig2'), + ])) + out = tool(TOOL_NAME, {'block': 99}) + assert 'block 99 not found' in out + assert 'Available blocks: 1, 2' in out + + +def test_empty_tool_reports_no_blocks_available(): + tool = ExtractCondensed(Chunks(chunks=[ + _plain('nothing condensed')])) + out = tool(TOOL_NAME, {'block': 1}) + assert 'Available blocks: (none)' in out + + +def test_integer_strings_are_accepted(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp', original='orig')])) + assert tool(TOOL_NAME, {'block': '1'}) == 'orig' + + +# --------------------------------------------------------------------------- +# multi-block expansion (``blocks`` accepts int OR list[int]) +# --------------------------------------------------------------------------- +def test_blocks_int_equivalent_to_legacy_block_arg(): + # Passing ``{'blocks': N}`` (single int under the new name) must + # behave identically to the legacy ``{'block': N}`` path: bare text, + # no wrapper. + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp1', original='orig one')])) + assert tool(TOOL_NAME, {'blocks': 1}) == 'orig one' + assert tool(TOOL_NAME, {'blocks': 1}) == tool(TOOL_NAME, {'block': 1}) + + +def test_blocks_list_wraps_each_result_in_block_tags(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp1', original='orig one'), + _condensed('cmp2', original='orig two'), + _condensed('cmp3', original='orig three'), + ])) + out = tool(TOOL_NAME, {'blocks': [1, 3]}) + # Both blocks present, each wrapped, separated by a blank line. + assert '\norig one\n' in out + assert '\norig three\n' in out + assert '' not in out + # Order respects input order. + assert out.index('') < out.index('') + + +def test_blocks_list_preserves_order_over_sorting(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('c1', original='a'), + _condensed('c2', original='b'), + _condensed('c3', original='c'), + ])) + out = tool(TOOL_NAME, {'blocks': [3, 1, 2]}) + # Output order must follow the caller's order, not numeric order. + assert out.index('') < out.index('') < out.index('') + + +def test_blocks_list_deduplicates_preserving_first_occurrence(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('c1', original='a'), + _condensed('c2', original='b'), + ])) + out = tool(TOOL_NAME, {'blocks': [1, 2, 1, 2, 1]}) + # Each block appears exactly once. + assert out.count('') == 1 + assert out.count('') == 1 + # And the first occurrence pins the order. + assert out.index('') < out.index('') + + +def test_blocks_list_with_single_element_still_wraps(): + # Explicit list form is a commitment to multi-block semantics even + # if only one element is present -- wrap it so the caller (or + # downstream sanitizer) can treat list-form results uniformly. + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('c1', original='orig a')])) + out = tool(TOOL_NAME, {'blocks': [1]}) + assert out == '\norig a\n' + + +def test_blocks_list_string_integers_accepted(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('c1', original='a'), + _condensed('c2', original='b'), + ])) + out = tool(TOOL_NAME, {'blocks': ['1', '2']}) + assert '\na\n' in out + assert '\nb\n' in out + + +def test_blocks_list_rejects_bool_and_float_per_element(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('c1', original='a'), + _condensed('c2', original='b'), + ])) + out_bool = tool(TOOL_NAME, {'blocks': [1, True]}) + assert out_bool.startswith('Error:') and 'bool' in out_bool + out_float = tool(TOOL_NAME, {'blocks': [1, 2.5]}) + assert out_float.startswith('Error:') and 'float' in out_float + + +def test_blocks_list_missing_blocks_embed_error_inline(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('c1', original='orig one')])) + out = tool(TOOL_NAME, {'blocks': [1, 99]}) + # Valid block returns its content; missing one returns an error + # string inside its own wrapper so the caller can tell + # which one failed without the tool itself raising. + assert '\norig one\n' in out + assert '' in out + assert 'block 99 not found' in out + + +def test_blocks_empty_list_returns_error(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('c1', original='a')])) + out = tool(TOOL_NAME, {'blocks': []}) + assert out.startswith('Error:') + assert 'at least one block number' in out + + +def test_prefers_blocks_over_legacy_block_when_both_present(): + # Undefined which wins in theory; we declare ``blocks`` takes + # precedence so callers can migrate incrementally. + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('c1', original='NEW'), + _condensed('c2', original='LEGACY'), + ])) + out = tool(TOOL_NAME, {'blocks': 1, 'block': 2}) + assert out == 'NEW' + + +# --------------------------------------------------------------------------- +# tool_info metadata +# --------------------------------------------------------------------------- +def test_tool_info_shape_and_serializability(): + tool = ExtractCondensed(Chunks(chunks=[])) + info = tool.tool_info() + assert info['tool_name'] == TOOL_NAME == 'extract_condensed' + assert 'description' in info and info['description'] + # parameters must be a JSON string that loads back cleanly. + params = json.loads(info['parameters']) + # Preferred parameter name is ``blocks`` (supports int OR list[int]). + assert 'blocks' in params + assert 'int' in params['blocks'] and 'list' in params['blocks'] + + +# --------------------------------------------------------------------------- +# ToolManager integration +# --------------------------------------------------------------------------- +def test_register_with_tool_manager_and_dispatch(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp1', original='orig one'), + _condensed('cmp2', original='orig two'), + ])) + mgr = ToolManager({}) + mgr.register(tool) + assert TOOL_NAME in mgr.names() + + # dict-form arguments + out = mgr({'tool_name': TOOL_NAME, 'arguments': {'block': 2}}) + assert out == 'orig two' + + # JSON-string-form arguments (OpenAI-style) + out = mgr({'tool_name': TOOL_NAME, 'arguments': '{"block": 1}'}) + assert out == 'orig one' + + +def test_manager_reports_error_on_unknown_block_without_raising(): + tool = ExtractCondensed(Chunks(chunks=[ + _condensed('cmp1', original='orig one')])) + mgr = ToolManager({}) + mgr.register(tool) + out = mgr({'tool_name': TOOL_NAME, 'arguments': '{"block": 999}'}) + assert out.startswith('Error:') + + +# --------------------------------------------------------------------------- +# end-to-end: round-trip with KeywordCondenser (uses raw.original) +# --------------------------------------------------------------------------- +_SPACY_OK = True +try: + import spacy # noqa: F401 + spacy.load('en_core_web_sm') +except Exception: + _SPACY_OK = False + + +LONG_PASSAGE = ( + 'Christopher Nolan was born on 30 July 1970 in London. ' + 'He is a British-American film director, producer and screenwriter. ' + 'His film Inception (2010) is a science-fiction heist movie. ' + 'Inception grossed over 829 million dollars worldwide.' +) + + +@pytest.mark.skipif(not _SPACY_OK, reason='en_core_web_sm not available') +def test_end_to_end_with_keyword_condenser_returns_original(): + from twinkle_agentic.condenser.keyword import KeywordCondenser + + pre = Chunks(chunks=[ + {'type': 'text', 'role': 'user', 'content': LONG_PASSAGE}]) + post = KeywordCondenser(compression_ratio=4.0, min_chars=50)(pre) + + # The condenser should have left behind an ``original`` snapshot. + assert post.chunks[0]['raw']['condensed'] is True + assert post.chunks[0]['raw']['original'] == LONG_PASSAGE + assert len(post.chunks[0]['content']) < len(LONG_PASSAGE) + + tool = ExtractCondensed(post) + assert tool.blocks == [1] + assert tool(TOOL_NAME, {'block': 1}) == LONG_PASSAGE + + +@pytest.mark.skipif(not _SPACY_OK, reason='en_core_web_sm not available') +def test_end_to_end_block_indices_match_to_trajectory_wrapping(): + from twinkle_agentic.condenser.keyword import KeywordCondenser + + pre = Chunks(chunks=[ + {'type': 'text', 'role': 'user', + 'content': LONG_PASSAGE, 'round': 1}, + {'type': 'text', 'role': 'assistant', + 'content': LONG_PASSAGE + ' Assistant elaboration.', 'round': 1}, + ]) + # skip_roles default excludes assistant → only first chunk condensed. + post = KeywordCondenser(compression_ratio=4.0, min_chars=50)(pre) + tool = ExtractCondensed(post) + + # Exactly one wrapped block. + assert tool.blocks == [1] + # The trajectory wrapper agrees: block_1 exists, block_2 does not. + traj = post.to_trajectory() + rendered = ''.join( + m['content'] if isinstance(m.get('content'), str) else '' + for m in traj['messages']) + assert '' in rendered and '' in rendered + assert '' not in rendered + # And the tool returns the correct original. + assert tool(TOOL_NAME, {'block': 1}) == LONG_PASSAGE diff --git a/tests/twinkle_agentic/test_keyword_condenser.py b/tests/twinkle_agentic/test_keyword_condenser.py new file mode 100644 index 00000000..c4e5642e --- /dev/null +++ b/tests/twinkle_agentic/test_keyword_condenser.py @@ -0,0 +1,488 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Unit tests for :class:`twinkle_agentic.condenser.keyword.KeywordCondenser`. + +Covers: +- strict compression-ratio enforcement (``len(output) <= ceil(len(input)/ratio)``) +- opening / relations / keywords slot extraction +- budget-priority fallback (drop keywords → drop relations → truncate opening) +- role / min_chars / kind filtering +- ``raw.condensed=True`` marker + block wrapping via ``Chunks.to_trajectory`` +- pass-through of non-text / short / skipped chunks +- constructor validation +""" +from __future__ import annotations + +import math + +import pytest + +# Module-level skip if spaCy or the small English model are unavailable. +spacy = pytest.importorskip('spacy') +try: + spacy.load('en_core_web_sm') +except OSError: + pytest.skip('en_core_web_sm not available', allow_module_level=True) + +from twinkle_agentic.chunker.native import NativeChunker +from twinkle_agentic.condenser.keyword import KeywordCondenser +from twinkle_agentic.data_format import Chunks + + +# A realistic multi-sentence passage; long enough to exercise the three +# output slots and the compression budget. +LONG_PASSAGE = ( + 'Christopher Nolan was born on 30 July 1970 in London. ' + 'He is a British-American film director, producer and screenwriter. ' + 'His film Inception (2010) is a science-fiction heist movie starring ' + 'Leonardo DiCaprio. Inception grossed over 829 million dollars worldwide ' + 'and received eight Academy Award nominations, winning four. ' + 'Nolan also directed The Dark Knight trilogy and Interstellar in 2014.' +) + + +def _user_chunk(text, role='user'): + return {'role': role, 'type': 'text', 'content': text} + + +def _wrap(*chunks): + return Chunks(chunks=list(chunks)) + + +# --------------------------------------------------------------------------- +# constructor validation +# --------------------------------------------------------------------------- +@pytest.mark.parametrize('kw', [ + {'num_relations': -1}, + {'num_keywords': -1}, + {'max_first_sentence_chars': -1}, + {'compression_ratio': 1.0}, + {'compression_ratio': 0.5}, + {'min_chars': -1}, +]) +def test_invalid_config_raises(kw): + with pytest.raises(ValueError): + KeywordCondenser(**kw) + + +# --------------------------------------------------------------------------- +# compression-ratio contract (STRICT upper bound) +# --------------------------------------------------------------------------- +@pytest.mark.parametrize('ratio', [2.0, 3.0, 4.0, 6.0, 10.0]) +def test_compression_ratio_is_strictly_enforced(ratio): + cond = KeywordCondenser( + num_relations=3, max_first_sentence_chars=160, + num_keywords=8, compression_ratio=ratio, min_chars=50) + src = _user_chunk(LONG_PASSAGE) + out = cond(_wrap(src)).chunks + assert len(out) == 1 + compressed = out[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / ratio) + assert len(compressed) <= budget, ( + f'ratio={ratio}: got len={len(compressed)} > budget={budget}') + assert compressed, 'output must be non-empty' + + +def test_extreme_ratio_keeps_output_non_empty_and_bounded(): + cond = KeywordCondenser(compression_ratio=100.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks + compressed = out[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 100.0) + assert 0 < len(compressed) <= budget + + +# --------------------------------------------------------------------------- +# raw.condensed marker + block wrapping +# --------------------------------------------------------------------------- +def test_marks_condensed_and_wraps_in_block_tags(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50) + chunks = cond(_wrap(_user_chunk(LONG_PASSAGE))) + assert chunks.chunks[0]['raw']['condensed'] is True + traj = chunks.to_trajectory() + # Exactly one compressed passage → block_1 wrap. + user_content = traj['messages'][0]['content'] + assert '' in user_content and '' in user_content + + +def test_multiple_chunks_numbered_sequentially_starting_from_1(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50) + passages = [_user_chunk(LONG_PASSAGE) for _ in range(3)] + chunks = cond(_wrap(*passages)) + traj = chunks.to_trajectory() + content = traj['messages'][0]['content'] + for i in (1, 2, 3): + assert f'' in content and f'' in content + assert '' not in content + + +# --------------------------------------------------------------------------- +# slot extraction (opening / relations / keywords) +# --------------------------------------------------------------------------- +def test_opening_relations_keywords_present_when_budget_allows(): + # Generous budget → all three slots should appear. + # LONG_PASSAGE is ~390 chars; full markup is ~370 chars, so we + # need a ratio close to 1.0 to keep every slot. + cond = KeywordCondenser( + num_relations=3, max_first_sentence_chars=160, num_keywords=8, + compression_ratio=1.05, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + assert out.startswith('Open: ') + assert '\nRel: ' in out + assert '\nMore: ' in out + # At least one of the primary entities should survive in keywords. + assert 'Nolan' in out or 'Inception' in out + + +def test_opening_first_sentence_respects_max_chars(): + cond = KeywordCondenser( + num_relations=0, max_first_sentence_chars=20, num_keywords=0, + compression_ratio=1.1, min_chars=10) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + # Opening slot is trimmed to <= 20 chars + opening_line = out.split('\n', 1)[0] + assert opening_line.startswith('Open: ') + opening_text = opening_line[len('Open: '):] + assert len(opening_text) <= 20 + + +def test_relations_use_triple_or_quadruple_syntax(): + cond = KeywordCondenser( + num_relations=5, max_first_sentence_chars=10, + num_keywords=0, compression_ratio=1.1, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + # We expect at least one '(a | b | c)' or '(a | b | c | d)' pattern. + assert '(' in out and ')' in out + # Parentheses must balance. + assert out.count('(') == out.count(')') + # Pipe-delimited slots (avoids ',' collision with slot-internal commas). + assert ' | ' in out + + +def test_verb_surface_preserved_not_lemma(): + """Triples keep surface form with auxiliaries: 'was born' not 'bear'.""" + cond = KeywordCondenser( + num_relations=3, max_first_sentence_chars=10, + num_keywords=0, compression_ratio=1.1, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + # Auxiliary preserved. + assert 'was born' in out or 'was released' in out or 'is' in out + # Bare lemma of 'born' must NOT appear as the verb slot. + assert '| bear |' not in out and '| bear on |' not in out + + +def test_internal_hyphens_preserved_in_np(): + """NP text keeps 'science-fiction' / 'British-American' hyphens.""" + cond = KeywordCondenser( + num_relations=5, max_first_sentence_chars=10, + num_keywords=0, compression_ratio=1.1, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + assert 'science-fiction' in out or 'British-American' in out + + +def test_pronoun_subject_triples_skipped(): + """Unresolved pronoun subjects (He/She/It) are noise and dropped.""" + cond = KeywordCondenser( + num_relations=5, max_first_sentence_chars=10, + num_keywords=0, compression_ratio=1.1, min_chars=50) + # LONG_PASSAGE has 'He is a British-American film director...' + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + assert '(He |' not in out and '(he |' not in out + + +def test_cardinal_entities_filtered_from_keywords(): + cond = KeywordCondenser( + num_relations=0, num_keywords=10, + max_first_sentence_chars=0, compression_ratio=1.1, min_chars=50) + passage = ( + 'Alpha earned 100 medals. Beta scored 200 points. Gamma made 300 attempts. ' + 'Delta received 400 votes. Epsilon collected 500 tokens. Zeta passed 600 miles.' + ) + out = cond(_wrap(_user_chunk(passage))).chunks[0]['content'] + for num in ('100', '200', '300', '400', '500', '600'): + assert num not in out, f'pure CARDINAL {num!r} leaked into keywords' + + +def test_keyword_subsumption_prefers_longer_form(): + """'Nolan' is dropped when 'Christopher Nolan' is already kept.""" + cond = KeywordCondenser( + num_relations=0, max_first_sentence_chars=10, num_keywords=8, + compression_ratio=1.05, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + more_line = next((ln for ln in out.splitlines() if ln.startswith('More: ')), '') + kws = [k.strip() for k in more_line[len('More: '):].split(',') if k.strip()] + # No keyword may be a token-subset of another kept keyword. + import re + sets = [frozenset(re.findall(r'\w+', k.lower())) for k in kws] + for i, a in enumerate(sets): + for j, b in enumerate(sets): + if i != j: + assert not a < b, ( + f'{kws[i]!r} is subsumed by {kws[j]!r} but kept') + + +def test_keyword_exclusion_is_token_level_not_substring(): + """A keyword is only excluded if ALL its words appear in the opening. + + Substring-based exclusion would wrongly drop 'Starfleet' because + 'star' appears inside other tokens; token-level exclusion keeps it. + """ + cond = KeywordCondenser( + num_relations=0, max_first_sentence_chars=60, num_keywords=5, + compression_ratio=1.1, min_chars=50) + passage = ( + 'The Starfleet Academy trains officers for deep-space missions. ' + 'Captain Kirk graduated there in 2251. Starfleet operates many vessels.' + ) + out = cond(_wrap(_user_chunk(passage))).chunks[0]['content'] + # 'Starfleet' shouldn't be dropped just because 'star' is a substring + # of something in the opening. + assert 'Starfleet' in out or 'Kirk' in out + + +def test_opening_truncation_at_word_boundary(): + """When opening exceeds max_chars, cut at the last whole word.""" + cond = KeywordCondenser( + num_relations=0, max_first_sentence_chars=25, num_keywords=0, + compression_ratio=1.1, min_chars=10) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + opening = out.split('\n', 1)[0][len('Open: '):] + assert len(opening) <= 25 + # Must not end mid-word: last char is a word char AND original passage + # contains the exact trimmed string as a prefix of the first sentence. + first_sent = LONG_PASSAGE.split('.', 1)[0] + assert first_sent.startswith(opening) + # The char after the trimmed prefix in the source should be a space + # (i.e. we really did stop on a word boundary). + if len(opening) < len(first_sent): + assert first_sent[len(opening)] == ' ' + + +def test_budget_is_filled_greedily_with_triples_and_keywords(): + """At a moderate ratio, output should include MORE than just opening. + + Regression test for the old priority-drop logic that collapsed to + opening-only whenever the full composition exceeded budget. + """ + cond = KeywordCondenser( + num_relations=3, max_first_sentence_chars=80, + num_keywords=8, compression_ratio=2.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 2.0) + assert len(out) <= budget + # At ratio=2.0 we MUST retain at least one relation AND at least one keyword. + assert '\nRel: ' in out + assert '\nMore: ' in out + + +def test_budget_too_small_falls_back_to_raw_truncation(): + """Even at absurd ratios, output is non-empty and bounded.""" + cond = KeywordCondenser( + num_relations=3, num_keywords=5, max_first_sentence_chars=160, + compression_ratio=200.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 200.0) + assert 0 < len(out) <= budget + + +def test_num_relations_zero_suppresses_slot(): + cond = KeywordCondenser( + num_relations=0, num_keywords=5, compression_ratio=1.2, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + assert '\nRel: ' not in out + + +def test_num_keywords_zero_suppresses_slot(): + cond = KeywordCondenser( + num_relations=3, num_keywords=0, compression_ratio=1.2, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + assert '\nMore: ' not in out + + +# --------------------------------------------------------------------------- +# budget priority: drop keywords → drop relations → truncate opening +# --------------------------------------------------------------------------- +def test_tight_budget_drops_keywords_first(): + # Pick a ratio that is just tight enough to force one slot to go. + # Full output len ≈ 200+; opening+relations alone ≈ 120. + cond = KeywordCondenser( + num_relations=2, max_first_sentence_chars=80, + num_keywords=8, compression_ratio=3.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 3.0) + assert len(out) <= budget + assert out.startswith('Open: ') + + +def test_very_tight_budget_falls_back_to_opening_only(): + # Ratio large enough that only the opening slot can fit. + # Keep max_first_sentence_chars small so it does fit. + cond = KeywordCondenser( + num_relations=5, max_first_sentence_chars=40, + num_keywords=8, compression_ratio=8.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 8.0) + assert len(out) <= budget + # Either opening-only or further truncated — both fine. + assert out.startswith('Open') or len(out) <= budget + + +# --------------------------------------------------------------------------- +# selection policy +# --------------------------------------------------------------------------- +def test_skip_roles_default_preserves_system_tool_assistant(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50) + src = _wrap( + _user_chunk(LONG_PASSAGE, role='system'), + _user_chunk(LONG_PASSAGE, role='assistant'), + _user_chunk(LONG_PASSAGE, role='tool'), + _user_chunk(LONG_PASSAGE, role='user'), + ) + out = cond(src).chunks + # First three pass through untouched. + for i in range(3): + assert out[i]['content'] == LONG_PASSAGE + assert (out[i].get('raw') or {}).get('condensed') is not True + # Fourth gets condensed. + assert out[3]['raw']['condensed'] is True + assert len(out[3]['content']) < len(LONG_PASSAGE) + + +def test_custom_skip_roles(): + cond = KeywordCondenser( + compression_ratio=4.0, min_chars=50, skip_roles=()) + src = _wrap(_user_chunk(LONG_PASSAGE, role='assistant')) + out = cond(src).chunks + assert out[0]['raw']['condensed'] is True + + +def test_short_content_passes_through(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=500) + src = _user_chunk(LONG_PASSAGE) # shorter than 500 + out = cond(_wrap(src)).chunks + assert out[0]['content'] == LONG_PASSAGE + assert (out[0].get('raw') or {}).get('condensed') is not True + + +def test_non_text_chunk_passes_through(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=1) + src = {'type': 'image', 'content': 'http://x/y.png', + 'role': 'user', 'raw': {'type': 'image', 'image': 'http://x/y.png'}} + out = cond(_wrap(src)).chunks + assert out[0] == src + + +def test_reasoning_and_tool_call_kind_chunks_pass_through(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50) + reasoning = { + 'type': 'text', 'role': 'assistant', 'content': LONG_PASSAGE, + 'raw': {'kind': 'reasoning_content'}, + } + # Assistant role would already be skipped, but the kind-filter must + # hold even if role is user. + tool_call = { + 'type': 'text', 'role': 'user', 'content': LONG_PASSAGE, + 'raw': {'kind': 'tool_call', 'tool_call': {'tool_name': 'x', 'arguments': '{}'}}, + } + out = cond(_wrap(reasoning, tool_call)).chunks + assert (out[0].get('raw') or {}).get('condensed') is not True + assert (out[1].get('raw') or {}).get('condensed') is not True + + +def test_empty_content_is_untouched(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=0) + src = _user_chunk('') + out = cond(_wrap(src)).chunks + assert out[0] == src + + +# --------------------------------------------------------------------------- +# integration with NativeChunker + to_trajectory round-trip +# --------------------------------------------------------------------------- +def test_chunker_then_condenser_produces_block_numbered_output(): + chunker = NativeChunker(chunk_size=300) + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50) + + passages = '\n\n'.join( + f'[{i}] Title_{i}: ' + LONG_PASSAGE for i in range(1, 4)) + user_text = f'Question: who directed Inception?\n\nContext:\n\n{passages}' + traj = {'messages': [ + {'role': 'system', 'content': 'You are a helpful agent.'}, + {'role': 'user', 'content': user_text}, + ]} + chunks = cond(chunker(traj)) + back = chunks.to_trajectory() + + # System untouched; user got multiple condensed blocks. + assert back['messages'][0]['content'] == 'You are a helpful agent.' + user_content = back['messages'][1]['content'] + assert '' in user_content + # Each block must be strictly smaller than its source chunk. + assert len(user_content) < len(user_text) + + +def test_condenser_preserves_chunk_order_and_count(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50) + src_chunks = _wrap( + _user_chunk('short', role='user'), + _user_chunk(LONG_PASSAGE, role='user'), + _user_chunk(LONG_PASSAGE, role='system'), + ) + out = cond(src_chunks).chunks + assert len(out) == 3 + assert out[0]['content'] == 'short' # too short + assert out[1]['raw']['condensed'] is True # condensed + assert out[2]['content'] == LONG_PASSAGE # skipped role + + +# --------------------------------------------------------------------------- +# idempotency: running condenser twice is safe +# --------------------------------------------------------------------------- +def test_condenser_is_idempotent_on_already_condensed_output(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50) + once = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0] + # Second pass must be a no-op: content identical, raw marker kept. + twice = cond(_wrap(once)).chunks[0] + assert twice['raw']['condensed'] is True + assert twice['content'] == once['content'] + # And a third pass must also be stable. + thrice = cond(_wrap(twice)).chunks[0] + assert thrice['content'] == once['content'] + + +# --------------------------------------------------------------------------- +# round-based selection filter +# --------------------------------------------------------------------------- +def _round_chunk(text, round_idx, role='user'): + return {'role': role, 'type': 'text', 'content': text, 'round': round_idx} + + +def test_rounds_filter_only_compresses_first_user_turn(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50, + rounds=[1]) + out = cond(_wrap( + _round_chunk(LONG_PASSAGE, 1), + _round_chunk(LONG_PASSAGE + ' extra.', 2), + )).chunks + # Round 1 compressed. + assert out[0]['raw']['condensed'] is True + assert len(out[0]['content']) < len(LONG_PASSAGE) + # Round 2 passed through unchanged. + assert out[1]['content'].endswith(' extra.') + assert not (out[1].get('raw') or {}).get('condensed') + + +def test_rounds_filter_excludes_chunks_without_round_field(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50, + rounds=[1]) + # Chunk missing ``round`` must be treated as non-matching. + plain = _user_chunk(LONG_PASSAGE) + out = cond(_wrap(plain)).chunks[0] + assert out['content'] == LONG_PASSAGE + assert not (out.get('raw') or {}).get('condensed') + + +def test_rounds_filter_default_none_preserves_legacy_behavior(): + cond = KeywordCondenser(compression_ratio=4.0, min_chars=50) + # No rounds set; chunks without ``round`` are still compressed. + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0] + assert out['raw']['condensed'] is True + assert len(out['content']) < len(LONG_PASSAGE) diff --git a/tests/twinkle_agentic/test_model_condenser.py b/tests/twinkle_agentic/test_model_condenser.py new file mode 100644 index 00000000..26f4970a --- /dev/null +++ b/tests/twinkle_agentic/test_model_condenser.py @@ -0,0 +1,559 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Unit + integration tests for :class:`twinkle_agentic.condenser.model.ModelCondenser`. + +Unit tests use a deterministic mock :class:`Sampler` so the suite runs +without GPUs / vLLM. The final block contains an opt-in integration +test that spins up a real ``Qwen/Qwen2.5-3B-Instruct`` sampler on a +single GPU; enable it with:: + + TWINKLE_TEST_REAL_SAMPLER=1 pytest tests/twinkle_agentic/test_model_condenser.py +""" +from __future__ import annotations + +import math +import os +from typing import Callable, List + +import pytest + +# Import directly from the submodule to avoid the (currently broken) +# ``twinkle.sampler.__init__`` import chain in this workspace. +from twinkle.data_format.sampling import ( + SampledSequence, + SampleResponse, + SamplingParams, +) + +from twinkle_agentic.condenser.model import ( + ModelCondenser, + _clamp_to_budget, + _parse_markdown_sections, + _strip_code_fences, +) +from twinkle_agentic.data_format import Chunks + + +# --------------------------------------------------------------------------- +# fixtures / helpers +# --------------------------------------------------------------------------- +LONG_PASSAGE = ( + 'Christopher Nolan was born on 30 July 1970 in London. ' + 'He is a British-American film director, producer and screenwriter. ' + 'His film Inception (2010) is a science-fiction heist movie starring ' + 'Leonardo DiCaprio. Inception grossed over 829 million dollars worldwide ' + 'and received eight Academy Award nominations, winning four. ' + 'Nolan also directed The Dark Knight trilogy and Interstellar in 2014.' +) + + +def _user_chunk(text, role='user'): + return {'role': role, 'type': 'text', 'content': text} + + +def _wrap(*chunks): + return Chunks(chunks=list(chunks)) + + +class _MockSampler: + """Deterministic duck-typed sampler. Calls ``responder(passage)`` per input. + + We do NOT subclass :class:`twinkle.sampler.base.Sampler` to avoid + dragging the workspace's currently-broken template init-chain into + the test module. ``ModelCondenser`` only touches + ``sampler.sample(...)``, so duck-typing is sufficient. + """ + + def __init__(self, responder: Callable[[str], str]): + self._responder = responder + self.template = object() # truthy placeholder, never inspected + self.engine = None + self.calls: List[dict] = [] + + def sample( + self, + inputs, + sampling_params=None, + adapter_name='', + *, + num_samples=1, + ) -> List[SampleResponse]: + inputs_list = inputs if isinstance(inputs, list) else [inputs] + out: List[SampleResponse] = [] + for traj in inputs_list: + user_msg = next(m for m in traj['messages'] if m['role'] == 'user') + prompt = user_msg['content'] + marker = 'Passage:\n' + idx = prompt.rfind(marker) + passage = prompt[idx + len(marker):] if idx >= 0 else prompt + decoded = self._responder(passage) + self.calls.append({ + 'passage': passage, + 'sampling_params': sampling_params, + }) + out.append(SampleResponse(sequences=[ + SampledSequence(stop_reason='stop', tokens=[], decoded=decoded) + ])) + return out + + +def _well_formed_markdown(passage: str) -> str: + """A standard three-section markdown response.""" + return ( + '## Summary\n' + 'Christopher Nolan is a British-American director born in London in 1970.\n\n' + '## Key Facts\n' + '- Nolan directed Inception (2010) starring Leonardo DiCaprio.\n' + '- Inception grossed over 829 million dollars worldwide.\n' + '- Nolan also directed The Dark Knight trilogy and Interstellar.\n\n' + '## More\n' + 'Nolan, Inception, Leonardo DiCaprio, Interstellar, London, 1970' + ) + + +# --------------------------------------------------------------------------- +# constructor validation +# --------------------------------------------------------------------------- +def test_requires_sampler(): + with pytest.raises(ValueError): + ModelCondenser(sampler=None) + + +@pytest.mark.parametrize('kw', [ + {'compression_ratio': 1.0}, + {'compression_ratio': 0.5}, + {'min_chars': -1}, + {'batch_size': 0}, + {'user_prompt_template': 'no placeholders'}, + {'user_prompt_template': 'only {budget} placeholder'}, + {'user_prompt_template': 'only {text} placeholder'}, +]) +def test_invalid_config_raises(kw): + with pytest.raises(ValueError): + ModelCondenser(_MockSampler(_well_formed_markdown), **kw) + + +# --------------------------------------------------------------------------- +# pure helper smoke tests +# --------------------------------------------------------------------------- +def test_parse_markdown_sections_basic(): + text = _well_formed_markdown('') + secs = _parse_markdown_sections(text) + assert set(secs.keys()) == {'summary', 'facts', 'more'} + assert 'Christopher Nolan' in secs['summary'] + assert 'Leonardo DiCaprio' in secs['facts'] + assert 'Interstellar' in secs['more'] + + +def test_parse_markdown_sections_handles_header_variants(): + text = ( + '# summary\nfoo\n\n### KEY FACT\n- bar\n\n## more\nkw1, kw2' + ) + secs = _parse_markdown_sections(text) + assert secs == {'summary': 'foo', 'facts': '- bar', 'more': 'kw1, kw2'} + + +def test_parse_markdown_sections_empty_input(): + assert _parse_markdown_sections('') == {} + + +def test_strip_code_fences(): + wrapped = '```markdown\n## Summary\nhi\n```' + assert _strip_code_fences(wrapped) == '## Summary\nhi' + # No fence → returned as-is. + plain = '## Summary\nhi' + assert _strip_code_fences(plain) == plain + + +def test_clamp_to_budget_word_boundary(): + assert _clamp_to_budget('hello world foo', 12) == 'hello world' + # Budget larger than text → untouched. + assert _clamp_to_budget('short', 100) == 'short' + # Budget 0 → empty. + assert _clamp_to_budget('anything', 0) == '' + + +# --------------------------------------------------------------------------- +# strict compression-ratio enforcement +# --------------------------------------------------------------------------- +@pytest.mark.parametrize('ratio', [2.0, 3.0, 4.0, 6.0, 10.0]) +def test_compression_ratio_is_strictly_enforced(ratio): + cond = ModelCondenser( + _MockSampler(_well_formed_markdown), + compression_ratio=ratio, + min_chars=50, + ) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / ratio) + assert len(out) <= budget, ( + f'ratio={ratio}: got len={len(out)} > budget={budget}') + assert out, 'output must be non-empty' + + +def test_misbehaving_model_output_is_still_clamped(): + """Even when the LLM exceeds the budget, output must fit.""" + overflow = lambda _p: _well_formed_markdown('') * 5 # noqa: E731 + cond = ModelCondenser( + _MockSampler(overflow), compression_ratio=3.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 3.0) + assert len(out) <= budget + + +def test_extreme_ratio_still_bounded_and_non_empty(): + cond = ModelCondenser( + _MockSampler(_well_formed_markdown), + compression_ratio=200.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 200.0) + assert 0 < len(out) <= budget + # Regression: at a budget too small to hold even "## Summary\n", the + # condenser must fall back to a non-empty *body* substring instead of + # returning dangling hash marks like "##" or "## ". + assert out.strip('#').strip(), ( + f'extreme-ratio output degenerated to markdown markers: {out!r}') + + +# --------------------------------------------------------------------------- +# structural output quality +# --------------------------------------------------------------------------- +def test_well_formed_output_keeps_three_sections_at_generous_budget(): + cond = ModelCondenser( + _MockSampler(_well_formed_markdown), + compression_ratio=1.1, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + assert '## Summary' in out + assert '## Key Facts' in out + assert '## More' in out + # Primary entities survive in some form. + assert 'Nolan' in out or 'Inception' in out + + +def test_tight_budget_drops_more_first(): + # Craft a response where dropping 'More' yields <=130 chars but keeping + # all three is over budget. + def responder(_p): + return ( + '## Summary\nA short sentence.\n\n' + '## Key Facts\n- Fact one here.\n- Fact two here.\n\n' + '## More\n' + ('x, ' * 60) # ~180 chars + ) + cond = ModelCondenser( + _MockSampler(responder), compression_ratio=3.5, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 3.5) + assert len(out) <= budget + assert '## Summary' in out + assert '## More' not in out + + +def test_very_tight_budget_keeps_only_summary(): + def responder(_p): + return ( + '## Summary\nA short sentence.\n\n' + '## Key Facts\n- Fact one.\n- Fact two.\n- Fact three.\n\n' + '## More\n' + ('kw, ' * 80) + ) + cond = ModelCondenser( + _MockSampler(responder), compression_ratio=10.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 10.0) + assert len(out) <= budget + # Summary should survive, the other two slots must not. + assert '## Summary' in out + assert '## Key Facts' not in out + assert '## More' not in out + + +def test_garbled_model_output_fallback_is_clamped(): + """When the model response has NO recognizable sections, fall back + to clamped raw text (never empty).""" + garbled = lambda _p: 'this is some unstructured blob ' * 10 # noqa: E731 + cond = ModelCondenser( + _MockSampler(garbled), compression_ratio=4.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 4.0) + assert 0 < len(out) <= budget + assert 'unstructured' in out + + +def test_code_fenced_output_is_unwrapped(): + wrapped = lambda _p: '```markdown\n' + _well_formed_markdown('') + '\n```' # noqa: E731 + cond = ModelCondenser( + _MockSampler(wrapped), compression_ratio=1.5, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + # After unwrapping, header is at the start (no leading ```). + assert not out.startswith('```') + assert out.startswith('## Summary') + + +# --------------------------------------------------------------------------- +# raw.condensed marker + block wrapping +# --------------------------------------------------------------------------- +def test_marks_condensed_and_wraps_in_block_tags(): + cond = ModelCondenser( + _MockSampler(_well_formed_markdown), + compression_ratio=4.0, min_chars=50) + chunks = cond(_wrap(_user_chunk(LONG_PASSAGE))) + assert chunks.chunks[0]['raw']['condensed'] is True + traj = chunks.to_trajectory() + user_content = traj['messages'][0]['content'] + assert '' in user_content and '' in user_content + + +def test_multiple_chunks_numbered_sequentially(): + cond = ModelCondenser( + _MockSampler(_well_formed_markdown), + compression_ratio=4.0, min_chars=50, batch_size=2) + passages = [_user_chunk(LONG_PASSAGE) for _ in range(3)] + chunks = cond(_wrap(*passages)) + traj = chunks.to_trajectory() + content = traj['messages'][0]['content'] + for i in (1, 2, 3): + assert f'' in content and f'' in content + assert '' not in content + + +# --------------------------------------------------------------------------- +# selection policy +# --------------------------------------------------------------------------- +def test_skip_roles_default_preserves_system_tool_assistant(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=50) + src = _wrap( + _user_chunk(LONG_PASSAGE, role='system'), + _user_chunk(LONG_PASSAGE, role='assistant'), + _user_chunk(LONG_PASSAGE, role='tool'), + _user_chunk(LONG_PASSAGE, role='user'), + ) + out = cond(src).chunks + for i in range(3): + assert out[i]['content'] == LONG_PASSAGE + assert (out[i].get('raw') or {}).get('condensed') is not True + assert out[3]['raw']['condensed'] is True + # Sampler saw only the user chunk. + assert len(sampler.calls) == 1 + + +def test_custom_skip_roles_empty_tuple(): + cond = ModelCondenser( + _MockSampler(_well_formed_markdown), + compression_ratio=4.0, min_chars=50, skip_roles=()) + src = _wrap(_user_chunk(LONG_PASSAGE, role='assistant')) + out = cond(src).chunks + assert out[0]['raw']['condensed'] is True + + +def test_short_content_passes_through(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=500) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks + assert out[0]['content'] == LONG_PASSAGE + assert (out[0].get('raw') or {}).get('condensed') is not True + assert sampler.calls == [] + + +def test_non_text_chunk_passes_through(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=1) + img = {'type': 'image', 'content': 'http://x/y.png', 'role': 'user', + 'raw': {'type': 'image', 'image': 'http://x/y.png'}} + out = cond(_wrap(img)).chunks + assert out[0] == img + assert sampler.calls == [] + + +def test_reasoning_kind_chunk_passes_through(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=50) + reasoning = { + 'type': 'text', 'role': 'user', 'content': LONG_PASSAGE, + 'raw': {'kind': 'reasoning_content'}, + } + out = cond(_wrap(reasoning)).chunks + assert (out[0].get('raw') or {}).get('condensed') is not True + assert sampler.calls == [] + + +def test_already_condensed_chunk_is_not_reprocessed(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=50) + once = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0] + assert once['raw']['condensed'] is True + sampler.calls.clear() + twice = cond(_wrap(once)).chunks[0] + # No second sampler call — idempotent. + assert sampler.calls == [] + assert twice == once + + +# --------------------------------------------------------------------------- +# batching & ordering +# --------------------------------------------------------------------------- +def test_batching_respects_batch_size(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=50, + batch_size=2) + src = _wrap(*[_user_chunk(LONG_PASSAGE) for _ in range(5)]) + out = cond(src).chunks + assert len(out) == 5 + for c in out: + assert c['raw']['condensed'] is True + assert len(sampler.calls) == 5 # 5 chunks total + + +def test_order_preserved_with_mixed_chunks(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=50, + batch_size=2) + src = _wrap( + _user_chunk('short', role='user'), # too short + _user_chunk(LONG_PASSAGE, role='user'), # condensed + _user_chunk(LONG_PASSAGE, role='system'), # skipped role + _user_chunk(LONG_PASSAGE, role='user'), # condensed + ) + out = cond(src).chunks + assert len(out) == 4 + assert out[0]['content'] == 'short' + assert out[1]['raw']['condensed'] is True + assert out[2]['content'] == LONG_PASSAGE + assert (out[2].get('raw') or {}).get('condensed') is not True + assert out[3]['raw']['condensed'] is True + + +# --------------------------------------------------------------------------- +# prompt robustness +# --------------------------------------------------------------------------- +def test_braces_in_text_do_not_break_prompt_formatting(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=50) + text = ('The JSON config was {"model": "Qwen", "temperature": 0.7}. ' + * 5) + out = cond(_wrap(_user_chunk(text))).chunks[0] + assert out['raw']['condensed'] is True + # Prompt contained the raw text verbatim. + assert sampler.calls[0]['passage'].strip().startswith( + 'The JSON config was {"model":') + + +def test_prompt_mentions_budget_in_user_message(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=3.0, min_chars=50) + cond(_wrap(_user_chunk(LONG_PASSAGE))) + expected_budget = math.ceil(len(LONG_PASSAGE) / 3.0) + # The mock recorded the prompt passage; we check the sampling_params + # carries a reasonable max_tokens (derived from budget). + assert sampler.calls[0]['sampling_params'].max_tokens >= expected_budget // 2 + + +def test_custom_sampling_params_is_forwarded(): + sampler = _MockSampler(_well_formed_markdown) + custom = SamplingParams(temperature=0.3, max_tokens=256) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=50, + sampling_params=custom) + cond(_wrap(_user_chunk(LONG_PASSAGE))) + assert sampler.calls[0]['sampling_params'] is custom + + +# --------------------------------------------------------------------------- +# semantic preservation (mock-level sanity) +# --------------------------------------------------------------------------- +def test_semantic_preservation_against_budget(): + """Under a moderate ratio, important entities appear in the output.""" + cond = ModelCondenser( + _MockSampler(_well_formed_markdown), + compression_ratio=2.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 2.0) + assert len(out) <= budget + # At ratio=2.0 we should still carry key entities. + hits = sum(1 for ent in ( + 'Nolan', 'Inception', 'Leonardo DiCaprio', 'London' + ) if ent in out) + assert hits >= 2 + + +# --------------------------------------------------------------------------- +# integration test (opt-in; requires single GPU + vLLM + Qwen model) +# --------------------------------------------------------------------------- +INTEGRATION_ENABLED = bool(os.environ.get('TWINKLE_TEST_REAL_SAMPLER')) +INTEGRATION_MODEL = os.environ.get( + 'TWINKLE_TEST_MODEL', 'Qwen/Qwen2.5-3B-Instruct') + + +@pytest.mark.skipif( + not INTEGRATION_ENABLED, + reason='Set TWINKLE_TEST_REAL_SAMPLER=1 to run the real-model integration test', +) +def test_integration_real_qwen_sampler_end_to_end(): + """End-to-end test with a real Qwen sampler on a single GPU.""" + vllm = pytest.importorskip('vllm') # noqa: F841 + from twinkle.sampler.vllm_sampler.vllm_sampler import vLLMSampler + + sampler = vLLMSampler( + model_id=INTEGRATION_MODEL, + engine_args={ + 'dtype': 'bfloat16', + 'gpu_memory_utilization': 0.7, + 'max_model_len': 4096, + 'enforce_eager': True, + }, + ) + try: + sampler.set_template('qwen2_5') + except Exception: + # Fall back to 'auto' template detection if the named one + # isn't registered in this build. + sampler.set_template('default') + + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0]['content'] + budget = math.ceil(len(LONG_PASSAGE) / 4.0) + + # Strict compression ratio holds end-to-end. + assert 0 < len(out) <= budget, f'len(out)={len(out)} budget={budget}' + # At least one key entity should survive. + assert any( + ent in out for ent in ('Nolan', 'Inception', 'London', 'Leonardo')) + + +# --------------------------------------------------------------------------- +# round-based selection filter +# --------------------------------------------------------------------------- +def _round_chunk(text, round_idx, role='user'): + return {'role': role, 'type': 'text', 'content': text, 'round': round_idx} + + +def test_rounds_filter_only_compresses_first_user_turn(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, + min_chars=50, rounds=[1]) + out = cond(_wrap( + _round_chunk(LONG_PASSAGE, 1), + _round_chunk(LONG_PASSAGE + ' extra.', 2), + )).chunks + # Only one sampler call happened — for round 1. + assert len(sampler.calls) == 1 + # Round 1 compressed. + assert out[0]['raw']['condensed'] is True + # Round 2 untouched. + assert out[1]['content'].endswith(' extra.') + assert not (out[1].get('raw') or {}).get('condensed') + + +def test_rounds_filter_excludes_chunks_without_round_field(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, + min_chars=50, rounds=[1]) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0] + # No call because the chunk had no ``round`` field. + assert sampler.calls == [] + assert out['content'] == LONG_PASSAGE + assert not (out.get('raw') or {}).get('condensed') + + +def test_rounds_filter_default_none_preserves_legacy_behavior(): + sampler = _MockSampler(_well_formed_markdown) + cond = ModelCondenser(sampler, compression_ratio=4.0, min_chars=50) + out = cond(_wrap(_user_chunk(LONG_PASSAGE))).chunks[0] + assert out['raw']['condensed'] is True + assert len(sampler.calls) == 1 diff --git a/tests/twinkle_agentic/test_multi_turn_rollout.py b/tests/twinkle_agentic/test_multi_turn_rollout.py new file mode 100644 index 00000000..04879aa7 --- /dev/null +++ b/tests/twinkle_agentic/test_multi_turn_rollout.py @@ -0,0 +1,826 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Unit tests for :class:`twinkle_agentic.rollout.multi_turn.MultiTurnRollout`. + +Focus: + - Control flow: no-tool / with-tool / length-stop / max-turns truncation + - Label alignment: trainable positions count == total sampled tokens + - Logprobs alignment: flat list length == trainable count + - Output structure: pif fields merged at TOP LEVEL (input_ids present ⇒ + VLLMSampler will skip re-encoding on a second pass) + - Input validation: constructor rejects bad config + - Defensive asserts: labels/input_ids length mismatch and logprobs + length mismatch both raise RuntimeError + - Shallow-copy safety: extra trajectory fields (e.g. ``images``) flow + through without deep copy + +The tests are self-contained — they use a char-level fake tokenizer, a +fake Template that replays the real ``concat_input_feature`` and post +pipeline semantics, and a fake Sampler that queues scripted responses. +""" +from __future__ import annotations + +import copy +import json +import re +from typing import Any, Dict, List, Optional + +import pytest + +from twinkle.data_format.sampling import ( + SampleResponse, SampledSequence, SamplingParams, +) +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout +from twinkle_agentic.tools.base import Tool +from twinkle_agentic.tools.tool_manager import ToolManager + + +# ============================================================================= +# Fakes +# ============================================================================= +class FakeTokenizer: + """Char-level tokenizer with atomic special tokens. + + Guarantees ``decode(encode(s)) == s`` for any mix of raw chars and + registered specials. This is what makes the decode-diff-encode alignment + strategy in MultiTurnRollout.__extend_with_bridge work in the test. + """ + SPECIALS = ('<|im_start|>', '<|im_end|>') + + def __init__(self) -> None: + self._s2i: Dict[str, int] = {} + self._i2s: Dict[int, str] = {} + for s in self.SPECIALS: + self._add(s) + + def _add(self, tok: str) -> int: + if tok not in self._s2i: + i = len(self._s2i) + self._s2i[tok] = i + self._i2s[i] = tok + return self._s2i[tok] + + def encode(self, text: str, add_special_tokens: bool = False) -> List[int]: + ids: List[int] = [] + i = 0 + while i < len(text): + matched = False + for sp in self.SPECIALS: + if text.startswith(sp, i): + ids.append(self._add(sp)) + i += len(sp) + matched = True + break + if not matched: + ids.append(self._add(text[i])) + i += 1 + return ids + + def decode(self, ids: List[int], skip_special_tokens: bool = False) -> str: + specials = set(self.SPECIALS) + toks = [self._i2s[int(i)] for i in ids] + if skip_special_tokens: + toks = [t for t in toks if t not in specials] + return ''.join(toks) + + def apply_chat_template( + self, + messages: List[Dict[str, Any]], + tokenize: bool = False, + add_generation_prompt: bool = False, + **_, + ): + s = '' + for m in messages: + role = m['role'] + content = m['content'] + s += f'<|im_start|>{role}\n{content}<|im_end|>\n' + if add_generation_prompt: + s += '<|im_start|>assistant\n' + if tokenize: + return self.encode(s) + return s + + +class FakeTemplate: + """Minimal Template that mirrors the parts MultiTurnRollout touches.""" + model_id = 'qwen-fake' + truncation_strategy = 'right' + + def __init__(self, tokenizer: FakeTokenizer) -> None: + self.tokenizer = tokenizer + + # --- the public API used by MultiTurnRollout ---------------------------- + def encode(self, trajectory: Dict[str, Any], add_generation_prompt: bool = False) -> Dict[str, Any]: + messages = trajectory.get('messages', []) + s = self.tokenizer.apply_chat_template( + messages, tokenize=False, + add_generation_prompt=add_generation_prompt) + input_ids = self.tokenizer.encode(s, add_special_tokens=False) + pif: Dict[str, Any] = dict(trajectory) # preserve top-level fields + pif['input_ids'] = input_ids + pif['labels'] = [-100] * len(input_ids) # inference mode + return self._invoke_post_pipeline([pif])[0] + + def _invoke_post_pipeline(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + out = [] + for pif in inputs: + pif = dict(pif) + input_ids = list(pif['input_ids']) + labels = list(pif.get('labels') or []) + if labels: + if len(labels) != len(input_ids): + raise RuntimeError( + f'FakeTemplate post_pipeline: labels({len(labels)}) ' + f'!= input_ids({len(input_ids)})') + # np.roll(labels, -1): shift LEFT by 1 (output/shifted order) + labels = labels[1:] + labels[:1] + pif['input_ids'] = input_ids + pif['labels'] = labels + pif['attention_mask'] = [1] * len(input_ids) + pif['position_ids'] = list(range(len(input_ids))) + pif['length'] = len(input_ids) + out.append(pif) + return out + + def parse_tool_call(self, decoded: str) -> List[Dict[str, Any]]: + matches = re.findall(r'\s*([\s\S]*?)\s*', decoded or '') + results: List[Dict[str, Any]] = [] + for m in matches: + try: + d = json.loads(m) + except json.JSONDecodeError: + continue + name = d.get('name') or d.get('tool_name') + if not name: + continue + results.append({ + 'tool_name': name, + 'arguments': d.get('arguments', {}), + }) + return results + + # --- Used by the fake sampler to mirror real concat_input_feature ------- + def concat_input_feature(self, pif: Dict[str, Any], new_tokens: List[int]) -> Dict[str, Any]: + result = copy.deepcopy(pif) + prompt_ids = list(result['input_ids']) + labels = list(result.get('labels') or []) + if labels: + # Unroll (shift RIGHT by 1): reverse the post_pipeline roll + labels = labels[-1:] + labels[:-1] + else: + labels = [-100] * len(prompt_ids) + input_ids = prompt_ids + list(new_tokens) + labels = labels + list(new_tokens) # assistant tokens trainable + result['input_ids'] = input_ids + result['labels'] = labels + result = self._invoke_post_pipeline([result])[0] + # Append assistant message with the decoded response (no special toks) + response_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) + messages = list(result.get('messages') or []) + messages.append({'role': 'assistant', 'content': response_text}) + result['messages'] = messages + return result + + +class FakeSampler: + """Queue-driven sampler that mirrors VLLMSampler output shape.""" + + def __init__(self, template: FakeTemplate) -> None: + self.template = template + self._queue: List[Dict[str, Any]] = [] + self.sample_calls = 0 + + def queue( + self, + response_text: str, + stop_reason: str = 'stop', + logprobs: Optional[List[Any]] = None, + append_im_end: bool = True, + ) -> None: + """``response_text`` is the model output (may contain …). + ``<|im_end|>`` is appended to the encoded tokens when ``append_im_end``. + ``seq.decoded`` is the raw response WITHOUT the trailing <|im_end|> + (matches vLLM's common behaviour).""" + raw = response_text + ('<|im_end|>' if append_im_end else '') + tokens = self.template.tokenizer.encode(raw, add_special_tokens=False) + self._queue.append({ + 'tokens': tokens, + 'decoded': response_text, + 'stop_reason': stop_reason, + 'logprobs': logprobs, + }) + + def sample(self, pifs, sampling_params=None): + # Batched contract: accept a list of pifs, return one + # SampleResponse per input, in order. A single-pif dict is also + # accepted for backwards compatibility with older call sites. + if isinstance(pifs, dict): + pifs = [pifs] + assert isinstance(pifs, list), ( + f'FakeSampler.sample expects a list, got {type(pifs).__name__}') + responses: List[SampleResponse] = [] + for pif in pifs: + assert self._queue, 'FakeSampler queue exhausted — scripted turns' + r = self._queue.pop(0) + self.sample_calls += 1 + new_pif = self.template.concat_input_feature(pif, r['tokens']) + seq = SampledSequence( + stop_reason=r['stop_reason'], + tokens=r['tokens'], + logprobs=r['logprobs'], + decoded=r['decoded'], + new_input_feature=new_pif, + ) + responses.append(SampleResponse(sequences=[seq])) + return responses + + +class EchoTool(Tool): + """Echoes its arguments as a JSON string.""" + + def __init__(self, name: str = 'search'): + self._name = name + + def __call__(self, tool_name: str, arguments: Dict[str, Any]) -> str: + return f'echo[{tool_name}]:{json.dumps(arguments, sort_keys=True)}' + + def tool_info(self): + return { + 'tool_name': self._name, + 'description': 'echo test tool', + 'parameters': '{}', + } + + +# ============================================================================= +# Fixtures +# ============================================================================= +@pytest.fixture +def tokenizer(): + return FakeTokenizer() + + +@pytest.fixture +def template(tokenizer): + return FakeTemplate(tokenizer) + + +@pytest.fixture +def sampler(template): + return FakeSampler(template) + + +@pytest.fixture +def tool_manager(): + mgr = ToolManager({}) + mgr.register(EchoTool('search')) + return mgr + + +@pytest.fixture +def make_rollout(sampler, template, tool_manager): + def _make(max_turns: int = 4, sampling_params: Optional[SamplingParams] = None): + return MultiTurnRollout( + sampler=sampler, + template=template, + tool_manager=tool_manager, + sampling_params=sampling_params or SamplingParams(), + max_turns=max_turns, + ) + return _make + + +# ============================================================================= +# Helpers +# ============================================================================= +def _count_trainable(labels: List[int]) -> int: + return sum(1 for l in labels if l != -100) + + +def _user_traj(text: str = 'hi') -> Dict[str, Any]: + return {'messages': [{'role': 'user', 'content': text}]} + + +def _tool_call_text(name: str, arguments: Dict[str, Any]) -> str: + return '' + json.dumps( + {'name': name, 'arguments': arguments}) + '' + + +# ============================================================================= +# Tests: control flow +# ============================================================================= +def test_single_turn_natural_stop(make_rollout, sampler): + """Model answers directly, no tool call → 1 turn, stop_reason='stop'.""" + sampler.queue('Hello there.', stop_reason='stop') + rollout = make_rollout(max_turns=4) + out = rollout([_user_traj()])[0] + + assert out['turns'] == 1 + assert out['stop_reason'] == 'stop' + assert out['truncated'] is False + assert sampler.sample_calls == 1 + + # Output must carry pif fields at TOP LEVEL so downstream sampler/model + # sees `input_ids` and skips re-encoding. + assert 'input_ids' in out + assert 'labels' in out + assert 'attention_mask' in out + assert 'position_ids' in out + assert len(out['input_ids']) == len(out['labels']) + assert len(out['input_ids']) == len(out['attention_mask']) + + +def test_single_turn_length_stop(make_rollout, sampler): + """stop_reason='length' exits immediately without tool-call parsing.""" + sampler.queue(_tool_call_text('search', {'q': 'x'}), stop_reason='length') + rollout = make_rollout(max_turns=4) + out = rollout([_user_traj()])[0] + + # Even though the decoded text contains a , length stop must + # short-circuit BEFORE we parse / dispatch tools. + assert out['turns'] == 1 + assert out['stop_reason'] == 'length' + assert out['truncated'] is False + assert sampler.sample_calls == 1 + # No tool message should have been appended. + roles = [m['role'] for m in out['messages']] + assert 'tool' not in roles + + +def test_two_turns_one_tool_call(make_rollout, sampler): + """Turn 1 emits tool_call, turn 2 stops normally.""" + sampler.queue(_tool_call_text('search', {'q': 'weather'}), stop_reason='stop') + sampler.queue('The weather is sunny.', stop_reason='stop') + rollout = make_rollout(max_turns=4) + out = rollout([_user_traj('What is the weather?')])[0] + + assert out['turns'] == 2 + assert out['stop_reason'] == 'stop' + assert out['truncated'] is False + assert sampler.sample_calls == 2 + + roles = [m['role'] for m in out['messages']] + assert roles == ['user', 'assistant', 'tool', 'assistant'] + + # Tool response content must be what EchoTool returned (exact contract). + tool_msg = out['messages'][2] + assert tool_msg['content'] == 'echo[search]:{"q": "weather"}' + + +def test_multiple_tool_calls_one_turn(make_rollout, sampler): + """Model emits TWO tool calls in one assistant turn → two tool messages.""" + decoded = (_tool_call_text('search', {'q': 'a'}) + + _tool_call_text('search', {'q': 'b'})) + sampler.queue(decoded, stop_reason='stop') + sampler.queue('Done.', stop_reason='stop') + rollout = make_rollout(max_turns=4) + out = rollout([_user_traj()])[0] + + assert out['turns'] == 2 + roles = [m['role'] for m in out['messages']] + assert roles == ['user', 'assistant', 'tool', 'tool', 'assistant'] + + +def test_max_turns_truncation(make_rollout, sampler): + """Model keeps emitting tool_calls past max_turns → truncated=True.""" + # 3 consecutive turns, all emitting tool_calls. + for i in range(5): + sampler.queue(_tool_call_text('search', {'q': f'q{i}'}), stop_reason='stop') + rollout = make_rollout(max_turns=3) + out = rollout([_user_traj()])[0] + + assert out['turns'] == 3 + assert out['truncated'] is True + assert sampler.sample_calls == 3 + # messages: user + (assistant + tool) × 3 = 7 + roles = [m['role'] for m in out['messages']] + assert roles.count('assistant') == 3 + # The last turn was cut off BEFORE the tool message was appended (bridge + # wouldn't help with no next generation) → 2 tool messages, not 3. + assert roles.count('tool') == 2 + + +def test_max_turns_natural_stop_at_ceiling(make_rollout, sampler): + """Natural stop exactly on turn = max_turns → truncated=False.""" + sampler.queue(_tool_call_text('search', {'q': 'x'}), stop_reason='stop') + sampler.queue('Final answer.', stop_reason='stop') + rollout = make_rollout(max_turns=2) + out = rollout([_user_traj()])[0] + + assert out['turns'] == 2 + assert out['stop_reason'] == 'stop' + assert out['truncated'] is False + + +# ============================================================================= +# Tests: label & logprobs alignment +# ============================================================================= +def test_trainable_count_matches_total_sampled_tokens(make_rollout, sampler, tokenizer): + """The output's non-(-100) label count must equal ∑ len(seq.tokens) + over all turns. This is the load-bearing invariant for GRPO's loss mask.""" + text1 = _tool_call_text('search', {'q': 'x'}) + text2 = 'ok' + sampler.queue(text1, stop_reason='stop') + sampler.queue(text2, stop_reason='stop') + rollout = make_rollout(max_turns=4) + out = rollout([_user_traj()])[0] + + # Total sampled tokens across turns (each turn appends <|im_end|>): + n1 = len(tokenizer.encode(text1 + '<|im_end|>')) + n2 = len(tokenizer.encode(text2 + '<|im_end|>')) + expected_trainable = n1 + n2 + + assert _count_trainable(out['labels']) == expected_trainable + + +def test_logprobs_concatenated_across_turns(make_rollout, sampler, tokenizer): + """all_logprobs = concat(per-turn logprobs) with length == #trainable.""" + text1 = _tool_call_text('search', {'q': 'x'}) + text2 = 'ok' + # Build sentinel logprobs for each sampled token so we can verify order. + toks1 = tokenizer.encode(text1 + '<|im_end|>') + toks2 = tokenizer.encode(text2 + '<|im_end|>') + lp1 = [[(tid, -0.1 * idx)] for idx, tid in enumerate(toks1)] + lp2 = [[(tid, -0.2 * idx)] for idx, tid in enumerate(toks2)] + + sampler.queue(text1, stop_reason='stop', logprobs=lp1) + sampler.queue(text2, stop_reason='stop', logprobs=lp2) + rollout = make_rollout(max_turns=4) + out = rollout([_user_traj()])[0] + + assert out['logprobs'] is not None + assert out['logprobs'] == lp1 + lp2 + assert len(out['logprobs']) == _count_trainable(out['labels']) + + +def test_logprobs_none_when_sampler_omits(make_rollout, sampler): + """If no turn carried logprobs, output['logprobs'] is None (not []). + Prevents GRPO from thinking logprobs are available but empty.""" + sampler.queue('bye', stop_reason='stop') + rollout = make_rollout(max_turns=2) + out = rollout([_user_traj()])[0] + assert out['logprobs'] is None + + +def test_logprobs_length_mismatch_raises(make_rollout, sampler, tokenizer): + """If sampler returns logprobs whose length ≠ token count, we raise.""" + text = 'hello' + toks = tokenizer.encode(text + '<|im_end|>') + bad_lp = [[(toks[0], -0.1)]] # length 1, tokens length > 1 + sampler.queue(text, stop_reason='stop', logprobs=bad_lp) + rollout = make_rollout(max_turns=2) + + with pytest.raises(RuntimeError, match='logprobs length'): + rollout([_user_traj()]) + + +# ============================================================================= +# Tests: output structure +# ============================================================================= +def test_pif_fields_merged_at_top_level(make_rollout, sampler): + """`input_ids` at top level ⇒ VLLMSampler will skip re-encoding.""" + sampler.queue('bye', stop_reason='stop') + rollout = make_rollout(max_turns=2) + out = rollout([_user_traj()])[0] + + # These are the fields a downstream sampler / model.forward consumes. + for k in ('input_ids', 'labels', 'attention_mask', 'position_ids', 'length'): + assert k in out, f'{k} missing from top-level output' + # And NOT nested under user_data. + assert 'input_feature' not in (out.get('user_data') or {}) + + +def test_extra_trajectory_fields_pass_through(make_rollout, sampler): + """Non-encoding fields like ``images`` / ``tools`` flow through. + + We only check that the fields are preserved by VALUE (not identity), + because the real ``concat_input_feature`` does ``copy.deepcopy(pif)`` + internally — that is the sampler's concern, not this rollout's. + """ + traj = _user_traj() + traj['images'] = ['/path/to/img.png'] + traj['tools'] = [{'tool_name': 'search', 'description': '', 'parameters': '{}'}] + + sampler.queue('ok', stop_reason='stop') + rollout = make_rollout(max_turns=2) + out = rollout([traj])[0] + + assert out['images'] == ['/path/to/img.png'] + assert out['tools'] == traj['tools'] + + +# ============================================================================= +# Tests: constructor validation +# ============================================================================= +def test_rejects_none_template(sampler, tool_manager): + with pytest.raises(ValueError, match='Template'): + MultiTurnRollout(sampler=sampler, template=None, + tool_manager=tool_manager) + + +def test_rejects_none_tool_manager(sampler, template): + with pytest.raises(ValueError, match='ToolManager'): + MultiTurnRollout(sampler=sampler, template=template, + tool_manager=None) + + +def test_rejects_bad_max_turns(sampler, template, tool_manager): + with pytest.raises(ValueError, match='max_turns'): + MultiTurnRollout(sampler=sampler, template=template, + tool_manager=tool_manager, max_turns=0) + + +def test_rejects_num_samples_gt_1(sampler, template, tool_manager): + with pytest.raises(ValueError, match='num_samples'): + MultiTurnRollout( + sampler=sampler, template=template, tool_manager=tool_manager, + sampling_params=SamplingParams(num_samples=2)) + + +# ============================================================================= +# Tests: defensive guards +# ============================================================================= +def test_missing_new_input_feature_raises(template, tool_manager): + class BrokenSampler: + def sample(self, pifs, sampling_params=None): + if isinstance(pifs, dict): + pifs = [pifs] + seq = SampledSequence( + stop_reason='stop', tokens=[], logprobs=None, + decoded='', new_input_feature=None) + return [SampleResponse(sequences=[seq]) for _ in pifs] + + rollout = MultiTurnRollout( + sampler=BrokenSampler(), template=template, + tool_manager=tool_manager) + with pytest.raises(RuntimeError, match='new_input_feature'): + rollout([_user_traj()]) + + +def test_empty_sampler_response_raises(template, tool_manager): + class EmptySampler: + def sample(self, pifs, sampling_params=None): + return [] + + rollout = MultiTurnRollout( + sampler=EmptySampler(), template=template, + tool_manager=tool_manager) + # Batched contract: 0 responses for a batch of 1 → mismatch error. + with pytest.raises(RuntimeError, match='0 responses'): + rollout([_user_traj()]) + + +def test_sample_response_no_sequences_raises(template, tool_manager): + class NoSeqSampler: + def sample(self, pifs, sampling_params=None): + if isinstance(pifs, dict): + pifs = [pifs] + return [SampleResponse(sequences=[]) for _ in pifs] + + rollout = MultiTurnRollout( + sampler=NoSeqSampler(), template=template, + tool_manager=tool_manager) + with pytest.raises(RuntimeError, match='no sequences'): + rollout([_user_traj()]) + + +# ============================================================================= +# Tests: batched / parallel rollout +# ============================================================================= +def test_empty_batch_returns_empty_list(make_rollout): + rollout = make_rollout(max_turns=2) + assert rollout([]) == [] + + +def test_batch_single_turn_two_trajectories(make_rollout, sampler): + """Two trajectories finish on turn 1 → one batched sample call.""" + sampler.queue('answer-A', stop_reason='stop') + sampler.queue('answer-B', stop_reason='stop') + rollout = make_rollout(max_turns=3) + outs = rollout([_user_traj('Q-A'), _user_traj('Q-B')]) + + assert len(outs) == 2 + # Exactly ONE batched sample call, not two. + assert sampler.sample_calls == 2 # one per item, still one turn + # But FakeSampler counts per-input; the critical batching invariant is + # that MultiTurnRollout only calls sampler.sample ONCE per turn. We + # enforce this via the queue ordering + single turn. + for out in outs: + assert out['turns'] == 1 + assert out['stop_reason'] == 'stop' + assert out['truncated'] is False + + +def test_batch_different_termination_turns(make_rollout, sampler): + """Trajectory A finishes on turn 1; trajectory B needs a tool turn. + + Turn 1 batch: [A: 'done-A' stop, B: tool_call stop] → A parked. + Turn 2 batch: [B: 'done-B' stop] → only B live. + """ + sampler.queue('done-A', stop_reason='stop') # A turn 1 + sampler.queue(_tool_call_text('search', {'q': 'b'}), # B turn 1 + stop_reason='stop') + sampler.queue('done-B', stop_reason='stop') # B turn 2 + rollout = make_rollout(max_turns=4) + outs = rollout([_user_traj('Q-A'), _user_traj('Q-B')]) + + assert len(outs) == 2 + # A: 1 turn, no tool. B: 2 turns, one tool. + assert outs[0]['turns'] == 1 + assert outs[1]['turns'] == 2 + roles_a = [m['role'] for m in outs[0]['messages']] + roles_b = [m['role'] for m in outs[1]['messages']] + assert 'tool' not in roles_a + assert roles_b == ['user', 'assistant', 'tool', 'assistant'] + + +def test_batch_per_trajectory_tool_manager(make_rollout, sampler, template): + """A list of ``tool_manager`` aligned with trajectories is honoured: + each trajectory dispatches through its OWN manager.""" + tm_a = ToolManager({}) + tm_a.register(EchoTool('search')) + + class TagTool(Tool): + def __init__(self, tag): + self._tag = tag + def __call__(self, tool_name, arguments): + return f'tagged[{self._tag}]:{json.dumps(arguments, sort_keys=True)}' + def tool_info(self): + return {'tool_name': 'search', 'description': '', 'parameters': '{}'} + + tm_b = ToolManager({}) + tm_b.register(TagTool('B')) + + sampler.queue(_tool_call_text('search', {'q': 'x'}), stop_reason='stop') + sampler.queue(_tool_call_text('search', {'q': 'y'}), stop_reason='stop') + sampler.queue('done-A', stop_reason='stop') + sampler.queue('done-B', stop_reason='stop') + + rollout = MultiTurnRollout( + sampler=sampler, template=template, + tool_manager=tm_a, # default (unused when per-call list supplied) + max_turns=4) + outs = rollout([_user_traj('A'), _user_traj('B')], + tool_manager=[tm_a, tm_b]) + + assert outs[0]['messages'][2]['content'] == 'echo[search]:{"q": "x"}' + assert outs[1]['messages'][2]['content'] == 'tagged[B]:{"q": "y"}' + + +def test_batch_tool_manager_list_length_mismatch(make_rollout, tool_manager): + rollout = make_rollout(max_turns=2) + with pytest.raises(ValueError, match='tool_manager list length'): + rollout([_user_traj('A'), _user_traj('B')], + tool_manager=[tool_manager]) # length 1 vs 2 trajectories + + +def test_single_trajectory_dict_rejected(make_rollout): + """A single ``Trajectory`` (dict) is NOT accepted — caller must wrap.""" + rollout = make_rollout(max_turns=2) + with pytest.raises(TypeError, match='List\\[Trajectory\\]'): + rollout(_user_traj()) + + +# ============================================================================= +# Tests: trace_path (JSONL per-turn observability) +# ============================================================================= +def test_trace_path_writes_one_record_per_turn_natural_stop( + tmp_path, sampler, template, tool_manager): + """Single-turn natural stop: trace file has exactly one JSON line.""" + trace = tmp_path / 'trace.jsonl' + rollout = MultiTurnRollout( + sampler=sampler, template=template, + tool_manager=tool_manager, + max_turns=4, trace_path=str(trace)) + sampler.queue('final answer', stop_reason='stop') + + outs = rollout([_user_traj('hello')]) + assert len(outs) == 1 + + lines = [l for l in trace.read_text().splitlines() if l] + assert len(lines) == 1 + rec = json.loads(lines[0]) + assert rec['turn'] == 1 + assert rec['batch_size'] == 1 + assert rec['trajectory_idx'] == 0 + assert rec['stop_reason'] == 'stop' + assert rec['decoded'] == 'final answer' + assert rec['tool_call_count'] == 0 + assert rec['done'] is True + assert rec['truncated'] is False + assert rec['trainable_tokens'] > 0 + + +def test_trace_path_captures_tool_turn_and_completion( + tmp_path, sampler, template, tool_manager): + """Two-turn rollout: one tool turn (done=False) then completion.""" + trace = tmp_path / 'trace.jsonl' + rollout = MultiTurnRollout( + sampler=sampler, template=template, + tool_manager=tool_manager, + max_turns=4, trace_path=str(trace)) + sampler.queue(_tool_call_text('search', {'q': 'x'})) + sampler.queue('done', stop_reason='stop') + + rollout([_user_traj('hello')]) + + lines = [l for l in trace.read_text().splitlines() if l] + assert len(lines) == 2 + turn1 = json.loads(lines[0]) + turn2 = json.loads(lines[1]) + + assert turn1['turn'] == 1 + assert turn1['tool_call_count'] == 1 + assert turn1['done'] is False + assert turn1['truncated'] is False + + assert turn2['turn'] == 2 + assert turn2['tool_call_count'] == 0 + assert turn2['done'] is True + # input_ids length must monotonically increase across turns. + assert turn2['input_ids_len'] > turn1['input_ids_len'] + + +def test_trace_path_truncates_file_on_construction( + tmp_path, sampler, template, tool_manager): + """Constructor opens the file in 'w' mode — stale data is wiped.""" + trace = tmp_path / 'trace.jsonl' + trace.write_text('STALE CONTENT SHOULD BE GONE\n') + assert trace.read_text() == 'STALE CONTENT SHOULD BE GONE\n' + + sampler.queue('ok', stop_reason='stop') + rollout = MultiTurnRollout( + sampler=sampler, template=template, + tool_manager=tool_manager, + max_turns=2, trace_path=str(trace)) + # After construction the file is empty (we truncate eagerly). + assert trace.read_text() == '' + + rollout([_user_traj('hi')]) + content = trace.read_text() + assert 'STALE' not in content + assert content.strip() # at least one record written + + +def test_trace_path_batch_emits_one_record_per_active_trajectory( + tmp_path, sampler, template, tool_manager): + """Batched rollout: each turn emits N active records (not N_total).""" + trace = tmp_path / 'trace.jsonl' + rollout = MultiTurnRollout( + sampler=sampler, template=template, + tool_manager=tool_manager, + max_turns=4, trace_path=str(trace)) + # Traj 0: stops turn 1. Traj 1: tool-calls turn 1, stops turn 2. + # Responses are consumed in batch order per turn. + sampler.queue('done0', stop_reason='stop') # t1-A + sampler.queue(_tool_call_text('search', {'q': 'y'})) # t1-B + sampler.queue('done1', stop_reason='stop') # t2-B (B only) + + rollout([_user_traj('A'), _user_traj('B')]) + + lines = [json.loads(l) for l in trace.read_text().splitlines() if l] + assert len(lines) == 3 + # Turn 1 has both trajectories. + turn1 = [r for r in lines if r['turn'] == 1] + turn2 = [r for r in lines if r['turn'] == 2] + assert sorted(r['trajectory_idx'] for r in turn1) == [0, 1] + # Turn 2 has only trajectory 1 (trajectory 0 already done). + assert [r['trajectory_idx'] for r in turn2] == [1] + # batch_size is the ORIGINAL batch count (2), not active count. + assert all(r['batch_size'] == 2 for r in lines) + + +def test_trace_path_none_disables_tracing( + tmp_path, sampler, template, tool_manager): + """Default ``trace_path=None`` never touches the filesystem.""" + trace = tmp_path / 'never.jsonl' + assert not trace.exists() + + rollout = MultiTurnRollout( + sampler=sampler, template=template, + tool_manager=tool_manager, max_turns=2) + sampler.queue('ok', stop_reason='stop') + rollout([_user_traj('hi')]) + + assert rollout.trace_path is None + assert not trace.exists() + + +def test_trace_path_truncation_marked_on_max_turns( + tmp_path, sampler, template, tool_manager): + """The final record of a max-turns truncation has truncated=True.""" + trace = tmp_path / 'trunc.jsonl' + rollout = MultiTurnRollout( + sampler=sampler, template=template, + tool_manager=tool_manager, + max_turns=2, trace_path=str(trace)) + # Two tool-call turns -> the second hits max_turns cap. + sampler.queue(_tool_call_text('search', {'q': 'a'})) + sampler.queue(_tool_call_text('search', {'q': 'b'})) + + rollout([_user_traj('hi')]) + + lines = [json.loads(l) for l in trace.read_text().splitlines() if l] + assert len(lines) == 2 + assert lines[0]['truncated'] is False and lines[0]['done'] is False + assert lines[1]['truncated'] is True and lines[1]['done'] is True diff --git a/tests/twinkle_agentic/test_native_chunker.py b/tests/twinkle_agentic/test_native_chunker.py new file mode 100644 index 00000000..dc1cacc8 --- /dev/null +++ b/tests/twinkle_agentic/test_native_chunker.py @@ -0,0 +1,432 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Unit tests for :class:`twinkle_agentic.chunker.native.NativeChunker`. + +Focus: chunk-size boundaries, separator priority, first-user-only scope, +lossless ``''.join`` of split outputs, and edge cases (empty, multimodal, +tool-calls, invalid config). +""" +from __future__ import annotations + +import pytest + +from twinkle_agentic.chunker.native import ( + NativeChunker, _hard_cut, _split_keep, +) +from twinkle_agentic.data_format import Chunks + + +def _u(content, role='user'): + return {'role': role, 'content': content} + + +def _join(chunks, type_='text'): + return ''.join(c['content'] for c in chunks if c.get('type') == type_) + + +# --------------------------------------------------------------------------- +# chunk_size boundaries +# --------------------------------------------------------------------------- +def test_under_chunk_size_returns_single_chunk(): + ch = NativeChunker(chunk_size=100) + out = ch({'messages': [_u('hello world')]}).chunks + assert len(out) == 1 + assert out[0]['content'] == 'hello world' + assert out[0]['role'] == 'user' + assert out[0]['type'] == 'text' + + +def test_exact_chunk_size_not_split(): + ch = NativeChunker(chunk_size=10) + out = ch({'messages': [_u('a' * 10)]}).chunks + assert [c['content'] for c in out] == ['a' * 10] + + +def test_one_over_chunk_size_is_split(): + ch = NativeChunker(chunk_size=10) + out = ch({'messages': [_u('a' * 11)]}).chunks + # No separator matches → hard cut; merge won't fuse (10+1 > 10) + assert len(out) == 2 + assert all(len(c['content']) <= 10 for c in out) + assert _join(out) == 'a' * 11 + + +def test_all_chunks_respect_size_limit_on_realistic_input(): + ch = NativeChunker(chunk_size=20) + text = ('hello world. ' * 50).strip() + out = ch({'messages': [_u(text)]}).chunks + assert all(len(c['content']) <= 20 for c in out) + assert _join(out) == text + + +def test_large_text_split_is_lossless_and_bounded(): + ch = NativeChunker(chunk_size=64) + text = 'The quick brown fox jumps over the lazy dog. ' * 100 + out = ch({'messages': [_u(text)]}).chunks + assert _join(out) == text + assert all(len(c['content']) <= 64 for c in out) + + +# --------------------------------------------------------------------------- +# separator priority (coarsest available wins) +# --------------------------------------------------------------------------- +def test_paragraph_split_preferred_over_sentence(): + ch = NativeChunker(chunk_size=40) + text = 'P1 sentence one. P1 sentence two.\n\nP2 sentence one. P2 sentence two.' + out = ch({'messages': [_u(text)]}).chunks + assert _join(out) == text + assert all(len(c['content']) <= 40 for c in out) + # Because paragraph boundary (18 + 2) and (35) both fit in 40, we + # expect at most 2 chunks (one per paragraph, possibly merged). + assert len(out) <= 2 + + +def test_newline_split_used_when_no_paragraph(): + ch = NativeChunker(chunk_size=10) + text = 'line1\nline2\nline3\nline4' + out = ch({'messages': [_u(text)]}).chunks + assert _join(out) == text + assert all(len(c['content']) <= 10 for c in out) + + +def test_sentence_split_used_when_no_newline(): + ch = NativeChunker(chunk_size=10) + text = 'foo bar b. qux qa bc. abc d.' + out = ch({'messages': [_u(text)]}).chunks + assert _join(out) == text + assert all(len(c['content']) <= 10 for c in out) + + +def test_chinese_sentence_separator(): + ch = NativeChunker(chunk_size=8) + text = '你好世界。这是测试。再见朋友。' + out = ch({'messages': [_u(text)]}).chunks + assert _join(out) == text + assert all(len(c['content']) <= 8 for c in out) + + +def test_custom_separator_list_only(): + ch = NativeChunker(chunk_size=10, separators=['|']) + text = 'aaa|bbb|ccccccccc|dd' + out = ch({'messages': [_u(text)]}).chunks + assert _join(out) == text + assert all(len(c['content']) <= 10 for c in out) + + +def test_empty_string_sentinel_appended_automatically(): + # User omits '' → chunker must still make progress on unsplittable text + ch = NativeChunker(chunk_size=3, separators=['|']) + text = 'abcdefghij' # no '|' at all + out = ch({'messages': [_u(text)]}).chunks + assert _join(out) == text + assert all(len(c['content']) <= 3 for c in out) + + +# --------------------------------------------------------------------------- +# first-user-only constraint +# --------------------------------------------------------------------------- +def test_only_first_user_message_is_split(): + ch = NativeChunker(chunk_size=10) + long = 'a' * 100 + traj = {'messages': [ + {'role': 'system', 'content': long}, + {'role': 'user', 'content': long}, # ← split + {'role': 'assistant', 'content': long}, + {'role': 'user', 'content': long}, # ← pass-through + {'role': 'tool', 'content': long, 'tool_call_id': 'c1'}, + ]} + out = ch(traj).chunks + + # Count chunks per message by position. + system_chunks = [c for c in out if c['role'] == 'system'] + assistant_chunks = [c for c in out if c['role'] == 'assistant'] + tool_chunks = [c for c in out if c['role'] == 'tool'] + user_chunks = [c for c in out if c['role'] == 'user'] + + assert len(system_chunks) == 1 + assert len(assistant_chunks) == 1 + assert len(tool_chunks) == 1 + # First user is split into many + second user pass-through (1 chunk). + assert len(user_chunks) > 2 + # And the second user chunk sits at the end of the user_chunks group + # only after the first-user splits. + assert user_chunks[-1]['content'] == long + + +def test_system_and_assistant_content_not_split(): + ch = NativeChunker(chunk_size=5) + long = 'abcdefghijklmn' + traj = {'messages': [ + {'role': 'system', 'content': long}, + {'role': 'assistant', 'content': long}, + ]} + out = ch(traj).chunks + assert len(out) == 2 + assert out[0]['content'] == long + assert out[1]['content'] == long + + +def test_trajectory_without_user_message_produces_no_split(): + ch = NativeChunker(chunk_size=5) + long = 'abcdefghij' + traj = {'messages': [ + {'role': 'system', 'content': long}, + {'role': 'assistant', 'content': long}, + ]} + out = ch(traj).chunks + assert all(len(c['content']) == len(long) for c in out) + + +# --------------------------------------------------------------------------- +# decomposition of special message parts +# --------------------------------------------------------------------------- +def test_reasoning_content_becomes_own_chunk(): + ch = NativeChunker(chunk_size=100) + traj = {'messages': [ + _u('hi'), + {'role': 'assistant', + 'reasoning_content': 'think step', + 'content': 'answer'}, + ]} + out = ch(traj).chunks + # user(hi) + assistant.reasoning + assistant.content + assert len(out) == 3 + assert out[1]['raw']['kind'] == 'reasoning_content' + assert out[1]['content'] == 'think step' + assert out[2]['content'] == 'answer' + assert 'raw' not in out[2] or 'kind' not in out[2].get('raw', {}) + + +def test_tool_calls_become_empty_text_chunks_with_kind(): + ch = NativeChunker(chunk_size=100) + traj = {'messages': [ + _u('hi'), + {'role': 'assistant', 'content': 'calling', + 'tool_calls': [ + {'tool_name': 'foo', 'arguments': '{}'}, + {'tool_name': 'bar', 'arguments': '{"x":1}'}, + ]}, + ]} + out = ch(traj).chunks + tc_chunks = [c for c in out if c.get('raw', {}).get('kind') == 'tool_call'] + assert len(tc_chunks) == 2 + assert tc_chunks[0]['raw']['tool_call']['tool_name'] == 'foo' + assert tc_chunks[1]['raw']['tool_call']['tool_name'] == 'bar' + # Empty content on tool_call chunks. + assert all(c['content'] == '' for c in tc_chunks) + + +def test_tool_message_preserves_tool_call_id(): + ch = NativeChunker(chunk_size=100) + traj = {'messages': [ + _u('hi'), + {'role': 'tool', 'content': 'result', 'tool_call_id': 'call-42'}, + ]} + out = ch(traj).chunks + tool_chunk = out[-1] + assert tool_chunk['role'] == 'tool' + assert tool_chunk['raw']['tool_call_id'] == 'call-42' + + +def test_multimodal_content_preserved_on_first_user(): + ch = NativeChunker(chunk_size=5) + traj = {'messages': [{ + 'role': 'user', + 'content': [ + {'type': 'text', 'text': 'describe this image'}, + {'type': 'image', 'image': 'http://x/y.png'}, + ], + }]} + out = ch(traj).chunks + text_chunks = [c for c in out if c['type'] == 'text'] + image_chunks = [c for c in out if c['type'] == 'image'] + assert len(image_chunks) == 1 + assert image_chunks[0]['content'] == 'http://x/y.png' + assert image_chunks[0]['raw'] == {'type': 'image', 'image': 'http://x/y.png'} + # Text part was split; concatenation is lossless. + assert _join(text_chunks) == 'describe this image' + assert all(len(c['content']) <= 5 for c in text_chunks) + + +# --------------------------------------------------------------------------- +# edge cases +# --------------------------------------------------------------------------- +def test_empty_trajectory(): + ch = NativeChunker(chunk_size=10) + assert ch({'messages': []}).chunks == [] + assert ch({}).chunks == [] + + +def test_empty_content_string_produces_no_chunks(): + ch = NativeChunker(chunk_size=10) + assert ch({'messages': [_u('')]}).chunks == [] + + +@pytest.mark.parametrize('bad', [0, -1, -999]) +def test_invalid_chunk_size_raises(bad): + with pytest.raises(ValueError): + NativeChunker(chunk_size=bad) + + +def test_chunk_size_one_hard_cuts_all_chars(): + ch = NativeChunker(chunk_size=1) + text = 'abc' + out = ch({'messages': [_u(text)]}).chunks + assert [c['content'] for c in out] == ['a', 'b', 'c'] + + +def test_whitespace_only_text_is_preserved_losslessly(): + ch = NativeChunker(chunk_size=3) + text = ' \n\n \n' + out = ch({'messages': [_u(text)]}).chunks + assert _join(out) == text + assert all(len(c['content']) <= 3 for c in out) + + +# --------------------------------------------------------------------------- +# HotpotQA-shaped realistic payload +# --------------------------------------------------------------------------- +def test_hotpotqa_like_passage_layout(): + ch = NativeChunker(chunk_size=80) + passages = '\n\n'.join( + f'[{i}] Title_{i}: ' + 'This is sentence. ' * 6 + for i in range(1, 6) + ) + user_text = f'Question: who wrote it?\n\nContext:\n\n{passages}' + out = ch({'messages': [ + {'role': 'system', 'content': 'sys'}, + _u(user_text), + ]}).chunks + # System message is not split. + assert out[0]['role'] == 'system' and out[0]['content'] == 'sys' + # User text reconstructs losslessly. + user_chunks = [c for c in out if c['role'] == 'user'] + assert _join(user_chunks) == user_text + assert all(len(c['content']) <= 80 for c in user_chunks) + + +# --------------------------------------------------------------------------- +# to_trajectory integration (non-split messages round-trip cleanly) +# --------------------------------------------------------------------------- +def test_non_split_messages_roundtrip_through_to_trajectory(): + ch = NativeChunker(chunk_size=1024) + traj = {'messages': [ + {'role': 'system', 'content': 'sys'}, + {'role': 'user', 'content': 'short question'}, + {'role': 'assistant', 'content': 'answer', + 'tool_calls': [{'tool_name': 'foo', 'arguments': '{}'}]}, + {'role': 'tool', 'content': 'result', 'tool_call_id': 'c1'}, + ]} + chunks = ch(traj) + back = chunks.to_trajectory(block_wrapper=None) + msgs = back['messages'] + assert msgs[0] == {'role': 'system', 'content': 'sys'} + assert msgs[1]['role'] == 'user' + assert msgs[1]['content'] == 'short question' + assert msgs[2]['role'] == 'assistant' + assert msgs[2]['content'] == 'answer' + assert msgs[2]['tool_calls'] == [{'tool_name': 'foo', 'arguments': '{}'}] + assert msgs[3]['role'] == 'tool' + assert msgs[3]['content'] == 'result' + assert msgs[3]['tool_call_id'] == 'c1' + + +# --------------------------------------------------------------------------- +# helper-level tests (white-box, catches regressions in primitives) +# --------------------------------------------------------------------------- +def test_split_keep_is_lossless(): + cases = [ + ('', '|'), + ('abc', '|'), + ('a|b|c', '|'), + ('|abc|', '|'), + ('|||', '|'), + ('aa..bb.', '.'), + ('hello', ''), # empty separator → single piece + ] + for text, sep in cases: + parts = _split_keep(text, sep) + assert ''.join(parts) == text, (text, sep, parts) + + +def test_hard_cut_bounds_and_lossless(): + for text, size in [('', 3), ('a', 3), ('abcde', 3), ('abcdef', 3)]: + parts = _hard_cut(text, size) + assert ''.join(parts) == text + assert all(len(p) <= size for p in parts) + + +def test_split_keep_keeps_separator_suffix(): + assert _split_keep('aa.bb.cc', '.') == ['aa.', 'bb.', 'cc'] + assert _split_keep('aa\n\nbb\n\ncc', '\n\n') == ['aa\n\n', 'bb\n\n', 'cc'] + + +# --------------------------------------------------------------------------- +# separator ordering / priority contract +# --------------------------------------------------------------------------- +def test_prefers_paragraph_boundary_over_period_when_both_fit(): + # Two paragraphs. Each fits in 40. The whole thing (47) does not. + ch = NativeChunker(chunk_size=40) + text = 'para one sentence. more.\n\npara two sentence.' + assert len(text) > 40 + out = ch({'messages': [_u(text)]}).chunks + # Chunker should split at '\n\n', not inside a paragraph. + assert out[0]['content'].endswith('\n\n') + assert _join(out) == text + + +# --------------------------------------------------------------------------- +# round numbering +# --------------------------------------------------------------------------- +def test_round_starts_at_zero_for_pre_user_system(): + ch = NativeChunker(chunk_size=1024) + out = ch({'messages': [ + {'role': 'system', 'content': 'you are helpful'}, + _u('hello'), + ]}).chunks + assert [c['round'] for c in out] == [0, 1] + + +def test_round_increments_on_each_user_message(): + ch = NativeChunker(chunk_size=1024) + out = ch({'messages': [ + _u('first user'), + {'role': 'assistant', 'content': 'first reply'}, + _u('second user'), + {'role': 'assistant', 'content': 'second reply'}, + _u('third user'), + ]}).chunks + rounds = [c['round'] for c in out] + # assistant msgs inherit the round of the preceding user turn. + assert rounds == [1, 1, 2, 2, 3] + + +def test_round_covers_tool_responses_between_users(): + ch = NativeChunker(chunk_size=1024) + out = ch({'messages': [ + _u('query'), + {'role': 'assistant', 'content': 'calling tool'}, + {'role': 'tool', 'content': 'tool result', 'tool_call_id': 'x'}, + {'role': 'assistant', 'content': 'final'}, + ]}).chunks + assert {c['round'] for c in out} == {1} + + +def test_round_preserved_when_first_user_is_split(): + ch = NativeChunker(chunk_size=20) + long_user = 'hello world. ' * 10 # gets split + out = ch({'messages': [ + {'role': 'system', 'content': 'sys'}, + _u(long_user), + {'role': 'assistant', 'content': 'ack'}, + _u('again'), + ]}).chunks + # All pieces of the split first user share round=1, system is round=0, + # assistant inherits round=1, second user is round=2. + by_role = {} + for c in out: + by_role.setdefault(c.get('role'), []).append(c['round']) + assert set(by_role.get('system', [])) == {0} + assert set(by_role.get('assistant', [])) == {1} + # Multiple user chunks from the split share round=1. + assert by_role['user'].count(1) >= 2 + assert by_role['user'][-1] == 2 From 33b8b32574a8c99d2edd4148cef30e9a9d289d70 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 9 May 2026 21:14:10 +0800 Subject: [PATCH 005/104] fix --- cookbook/rl/grpo_condensed.py | 445 +++++++++++++++++++++++++ src/twinkle_agentic/condenser/model.py | 102 +----- 2 files changed, 462 insertions(+), 85 deletions(-) diff --git a/cookbook/rl/grpo_condensed.py b/cookbook/rl/grpo_condensed.py index e69de29b..53a061d0 100644 --- a/cookbook/rl/grpo_condensed.py +++ b/cookbook/rl/grpo_condensed.py @@ -0,0 +1,445 @@ +import json +import os +import re +from typing import Any, Dict, List, Optional + +import swanlab +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.data_format import Message, SamplingParams, Trajectory +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.metric import CompletionRewardMetric +from twinkle.model import TransformersModel +from twinkle.preprocessor.base import Preprocessor +from twinkle.processor import InputProcessor +from twinkle.sampler import vLLMSampler +from twinkle.template import Qwen3_5Template +from twinkle_agentic.chunker.native import NativeChunker +from twinkle_agentic.condenser import ModelCondenser +from twinkle_agentic.reward import F1Reward, CoTReward, ToolExploreReward +from twinkle_agentic.rollout.multi_turn_condense import MultiTurnCondenseRollout +from twinkle_agentic.tools.tool_manager import ToolManager + +logger = get_logger() + +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) +NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', 10)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 0)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +ADAPTER_NAME = 'default' +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) +LORA_RANK = int(os.environ.get('LORA_RANK', 16)) + +MAX_TURNS = int(os.environ.get('MAX_TURNS', 6)) +CHUNK_SIZE = int(os.environ.get('CHUNK_SIZE', 1024)) + +HOTPOTQA_NUM_PROC = int(os.environ.get('HOTPOTQA_NUM_PROC', 16)) +HOTPOTQA_MAX_LENGTH = int(os.environ.get('HOTPOTQA_MAX_LENGTH', 64000)) + +# Reward weights +F1_REWARD_WEIGHT = float(os.environ.get('F1_REWARD_WEIGHT', 1.0)) +COT_REWARD_WEIGHT = float(os.environ.get('COT_REWARD_WEIGHT', 0.5)) +TOOL_BONUS_WEIGHT = float(os.environ.get('TOOL_BONUS_WEIGHT', 0.1)) + +WRONG_IDS_FILE = os.environ.get('WRONG_IDS_FILE', '') + +_ROLLOUT_TRACE_PATH = os.environ.get('ROLLOUT_TRACE_PATH', 'rollout_trace.jsonl') + +SYSTEM_PROMPT = """You are a careful multi-hop QA assistant. + +## Compressed Context +The context you receive is **compressed**. Each paragraph is wrapped in \ +... and displayed as a Markdown summary with three sections: +- **Summary**: one-sentence overview of the block +- **Key Facts**: bulleted salient facts +- **More**: keywords hinting at details hidden in the full text + +Because the context is compressed, critical details may not be immediately \ +visible. You are strongly encouraged to call the `extract_condensed` tool \ +to expand blocks that likely contain the answer. + +## Workflow + +### Phase 1 — Scan and Decide +Step 1: Read each block's Summary and Key facts to get an overview. +Step 2: Check the More keywords to judge whether hidden details are needed. +Step 3: Decide which blocks to expand, then call `extract_condensed`. + +### Phase 2 — Reason and Answer +After the tool returns the full text, continue stepping through the evidence: +Step N: From block X, I learn that [fact A]. +Step N+1: From block Y, I need to call `extract_condensed` to get more information, because this block is related to... +Step N+2: Combining these, the answer is ... +\\boxed{answer} + +You may call `extract_condensed` several times to expand more blocks if the information is not enough, only answer the question if you are sure about the facts. +The `blocks` parameter accepts either a single integer (e.g. `3`) or a list of integers (e.g. `[1, 3]`) to expand several blocks in one call. + +## Tool Call Format + + + +[1, 3] + + + + +## Output Format +End your final response with \\boxed{answer}, e.g. \\boxed{Delhi}. +Keep the boxed text short: a name, entity, date, or "yes"/"no". +Answers not inside \\boxed{} will not be scored.""" + + +_F1_REWARD: Optional[F1Reward] = F1Reward() +_COT_REWARD: Optional[CoTReward] = CoTReward() +_TOOL_EXPLORE_REWARD: Optional[ToolExploreReward] = ToolExploreReward() + + +def compute_rewards(trajectories: List[Dict[str, Any]]): + f1 = _F1_REWARD(trajectories) + cot = _COT_REWARD(trajectories) + tool_explore = _TOOL_EXPLORE_REWARD(trajectories) + total = [ + F1_REWARD_WEIGHT * a + COT_REWARD_WEIGHT * c + TOOL_BONUS_WEIGHT * te + for a, c, te in zip(f1, cot, tool_explore) + ] + return total, f1, cot, tool_explore + + +class HotpotQAProcessor(Preprocessor): + def __init__(self, system: str = SYSTEM_PROMPT, levels=None): + self.system = system + self.levels = levels + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = [r for r in rows if r is not None] + rows = self.map_row_to_col(rows) + return rows + + @staticmethod + def _format_context(context: Dict[str, Any]) -> str: + titles = context.get('title', []) or [] + sentences = context.get('sentences', []) or [] + lines = [] + for i, (title, sents) in enumerate(zip(titles, sentences), start=1): + if isinstance(sents, list): + body = ' '.join(s.strip() for s in sents if s and s.strip()) + else: + body = str(sents).strip() + lines.append(f'[{i}] {title}: {body}') + return '\n\n'.join(lines) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: + if self.levels is not None and (row.get('level') or '').strip().lower() not in self.levels: + return None + question = row['question'] + answer = row.get('answer', '') or '' + context_block = self._format_context(row.get('context', {}) or {}) + user_msg = f'Question: {question}\n\nContext:\n\n{context_block}' + messages = [ + Message(role='system', content=self.system), + Message(role='user', content=user_msg), + ] + return Trajectory(messages=messages, user_data=[('ground_truth', answer.strip())]) + + +def create_hotpotqa_dataset() -> Dataset: + dataset = Dataset() + dataset.add_dataset(DatasetMeta( + 'hf://hotpotqa/hotpot_qa', subset_name='fullwiki', split='train')) + + _wrong_ids_path = WRONG_IDS_FILE.strip() + if _wrong_ids_path: + with open(_wrong_ids_path, 'r', encoding='utf-8') as fh: + _ids = frozenset(ln.strip() for ln in fh if ln.strip()) + if _ids: + _key = next(iter(dataset.datasets.keys())) + _before = len(dataset.datasets[_key]) + dataset.datasets[_key] = dataset.datasets[_key].filter( + lambda row: row.get('id') in _ids) + dataset.dataset = dataset.datasets[_key] + logger.info(f'[WRONG_IDS_FILE] {_wrong_ids_path}: {_before} -> {len(dataset.dataset)} rows') + + dataset.set_template( + 'Qwen3_5Template', model_id=MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH, + truncation_strategy='delete', enable_thinking=False) + _HOTPOTQA_COLS = ['id', 'question', 'answer', 'type', 'level', + 'supporting_facts', 'context'] + dataset.map(HotpotQAProcessor(system=SYSTEM_PROMPT, levels=['hard']), remove_columns=_HOTPOTQA_COLS) + return dataset + + +# Matches a LaTeX ``\boxed{...}`` final-answer marker — used to flag +# rollouts that never committed an answer. Brace-balanced is overkill for +# a logging heuristic; a non-greedy ``[^}]*`` is good enough. +_BOXED_RE = re.compile(r'\\boxed\{[^}]*\}') + + +def _last_assistant_text(trajectory: Dict[str, Any]) -> Optional[str]: + """Return the text of the last ``assistant`` message, or ``None``.""" + for m in reversed(trajectory.get('messages', [])): + if m.get('role') == 'assistant': + return m.get('content') + return None + + +def _compute_rollout_diagnostics( + trajectories: List[Dict[str, Any]], + n_turns_per_rollout: List[int], + per_rollout_completion_length: List[int], +) -> Dict[str, float]: + """Aggregate rollout diagnostics for swanlab logging. + + All inputs are already flat: + * ``trajectories[i]`` is the merged trajectory dict returned by + :class:`MultiTurnCondenseRollout` (contains ``messages``, + ``input_ids``, ``labels``, ``turns`` at top level). + * ``n_turns_per_rollout[i] == trajectories[i]['turns']``. + * ``per_rollout_completion_length[i]`` == number of trainable + tokens in the trajectory (labels != -100). + """ + out: Dict[str, float] = {} + if n_turns_per_rollout: + out['avg_turns'] = sum(n_turns_per_rollout) / len(n_turns_per_rollout) + + # ``non_trainable_tokens`` is the longest non-trainable prefix across + # the batch: ``len(input_ids) - sum(1 for l in labels if l != -100)``. + # Tracks how much the condensed context + system prompt is eating the + # context budget (it does NOT equal the first-turn prompt length + # because multi-turn runs also contribute non-trainable tokens from + # the ``tool`` observations between assistant turns). + _max_non_trainable = 0 + for t, comp_len in zip(trajectories, per_rollout_completion_length): + ids = t.get('input_ids') or [] + non_trainable = max(0, len(ids) - int(comp_len or 0)) + if non_trainable > _max_non_trainable: + _max_non_trainable = non_trainable + out['non_trainable_tokens'] = _max_non_trainable + + if trajectories: + tool_counts = [ + sum(len(m.get('tool_calls') or []) + for m in t.get('messages', []) if m.get('role') == 'assistant') + for t in trajectories] + out['avg_tool_calls'] = sum(tool_counts) / len(tool_counts) + out['tool_use_rate'] = sum(1 for c in tool_counts if c > 0) / len(tool_counts) + n_no_boxed = sum( + 0 if _BOXED_RE.search(_last_assistant_text(t) or '') else 1 + for t in trajectories) + out['no_boxed_rate'] = n_no_boxed / len(trajectories) + return out + + +def main(): + swanlab.init(project='twinkle') + + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, + groups=device_groups, lazy_collect=False) + + logger.info('Building HotpotQA dataset') + _prebuilt_dataset = create_hotpotqa_dataset() + logger.info('Dataset ready: %d rows', len(_prebuilt_dataset)) + + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + batches_per_epoch = max(1, len(_prebuilt_dataset) // GLOBAL_BATCH_SIZE) + EXPECTED_AVG_TURNS = int(os.environ.get('EXPECTED_AVG_TURNS', 3)) + optim_steps_per_batch = max(1, (GLOBAL_BATCH_SIZE * NUM_GENERATIONS * EXPECTED_AVG_TURNS + + MINI_BATCH_SIZE - 1) // MINI_BATCH_SIZE) + steps_per_epoch = batches_per_epoch * optim_steps_per_batch + derived_total_steps = NUM_EPOCHS * steps_per_epoch + total_steps = min(MAX_STEPS, derived_total_steps) if MAX_STEPS > 0 else derived_total_steps + logger.info('Training horizon: %d steps (%d epochs × %d batches × %d steps/batch)', + total_steps, NUM_EPOCHS, batches_per_epoch, optim_steps_per_batch) + + lora_config = LoraConfig( + target_modules='all-linear', r=LORA_RANK, + lora_alpha=LORA_RANK * 2, lora_dropout=0.05) + + if USE_MEGATRON: + from twinkle.model.megatron import MegatronModel + model = MegatronModel( + model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', + mixed_precision='bf16', variable_seq_lengths=True) + else: + model = TransformersModel( + model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model') + + model.add_adapter_to_model(ADAPTER_NAME, lora_config, + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + if USE_MEGATRON: + model.set_optimizer('default', lr=LEARNING_RATE) + model.set_lr_scheduler('default', lr_decay_steps=total_steps, max_lr=LEARNING_RATE) + else: + model.set_optimizer('AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler('CosineAnnealingLR', T_max=total_steps, eta_min=0) + + model.set_loss('GRPOLoss', epsilon=0.2) + model.set_processor(InputProcessor, padding_free=True) + model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH) + + model.add_metric('GRPOMetric', is_training=True) + + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'gpu_memory_utilization': 0.8, 'max_model_len': 32768, + 'max_lora_rank': 32, 'enable_lora': True, + 'enable_tower_connector_lora': True, + }, + device_mesh=sampler_mesh, remote_group='sampler') + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH) + rollout_template = Qwen3_5Template( + MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH, enable_thinking=False) + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) + # ``passage_boundary_re`` keeps each HotpotQA passage (``[N] Title: ...``) + # atomic inside a single chunk — short passages are emitted as-is + # and are NEVER merged across boundaries, so every ```` + # after condensation corresponds to exactly one passage. + chunker = NativeChunker( + chunk_size=CHUNK_SIZE, + passage_boundary_re=r'^\[\d+\]\s+') + condenser = ModelCondenser( + sampler=sampler, + compression_ratio=4.0, + sampling_params=SamplingParams( + max_tokens=256, num_samples=1, temperature=0.4, top_p=0.9), + # HotpotQA passages are often short; 50 keeps almost all passages + # eligible for compression while still skipping single-sentence + # blurbs that compress poorly. + min_chars=50, + # Compress with the frozen base model so the training LoRA + # cannot drift the summarization policy mid-training (closed-loop + # drift). + use_base_model=True, + ) + + dataloader = DataLoader( + dataset=lambda: _prebuilt_dataset, + batch_size=GLOBAL_BATCH_SIZE, min_batch_size=GLOBAL_BATCH_SIZE) + + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + sampling_params = SamplingParams( + max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, + temperature=1.0, top_p=0.95, stop=['']) + + rollout = MultiTurnCondenseRollout( + sampler=sampler, + template=rollout_template, + tool_manager=ToolManager(), + chunker=chunker, + condenser=condenser, + sampling_params=sampling_params, + max_turns=MAX_TURNS, + trace_path=_ROLLOUT_TRACE_PATH or None, + ) + + optim_step = 0 + logger.info('Starting HotpotQA GRPO training (LLM condenser variant)') + + def _epoch_cycle(dl, n_epochs): + for ep in range(1, n_epochs + 1): + logger.info(f'=== Epoch {ep}/{n_epochs} (step={optim_step}/{total_steps}) ===') + for batch in dl: + yield batch + + for batch in _epoch_cycle(dataloader, NUM_EPOCHS): + if optim_step >= total_steps: + break + + metrics.reset() + expand_prompts = [p for prompt in batch for p in [prompt] * NUM_GENERATIONS] + + ckpt_manager.sync_weights(merge_and_sync=False) + sampler.reset_prefix_cache() + + # Batched multi-turn rollout with chunk+condense pre-processing. + # Each returned trajectory is a flat dict containing ``messages``, + # ``input_ids``, ``labels``, ``attention_mask``, ``position_ids``, + # ``turns``, ``logprobs``, ``stop_reason``, ``truncated``. + all_trajectories: List[Dict[str, Any]] = rollout(expand_prompts) + n_turns_per_rollout = [int(t.get('turns') or 0) for t in all_trajectories] + per_rollout_completion_length = [ + sum(1 for l in (t.get('labels') or []) if l != -100) + for t in all_trajectories] + + total_rewards, f1_rewards, cot_rewards, tool_explore_rewards = \ + compute_rewards(all_trajectories) + + metrics.accumulate( + completion_lengths=per_rollout_completion_length, + rewards={'total': total_rewards, 'f1': f1_rewards, + 'cot': cot_rewards, 'tool_explore': tool_explore_rewards}) + + rollout_advantages = advantage_fn( + total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() + + all_input_data: List[Any] = [] + all_old_logps: List[List[float]] = [] + advantages: List[float] = [] + for t, adv in zip(all_trajectories, rollout_advantages): + all_input_data.append(t) + all_old_logps.append([lp[0][1] for lp in (t.get('logprobs') or [])]) + advantages.append(adv) + + total_completions = len(all_input_data) + aligned_completions = (total_completions // MODEL_GPUS) * MODEL_GPUS + if aligned_completions < total_completions: + logger.info( + '[dp-align] dropping %d tail sample(s): total=%d -> aligned=%d (dp=%d)', + total_completions - aligned_completions, + total_completions, aligned_completions, MODEL_GPUS) + for mb_start in range(0, aligned_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, aligned_completions) + model.forward_backward( + inputs=all_input_data[mb_start:mb_end], + old_logps=all_old_logps[mb_start:mb_end], + advantages=advantages[mb_start:mb_end], + micro_batch_size=MICRO_BATCH_SIZE) + model.clip_grad_and_step() + optim_step += 1 + if optim_step >= total_steps: + break + if optim_step % SAVE_STEPS == 0: + model.save(f'hotpotqa-grpo-tools-llmcondense-checkpoint-{optim_step}') + + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric(is_training=True)) + log_dict.update(_compute_rollout_diagnostics( + all_trajectories, n_turns_per_rollout, per_rollout_completion_length)) + swanlab.log(log_dict) + metrics.reset() + logger.info(f'[Step {optim_step}/{total_steps}] {log_dict}') + + logger.info(f'Training completed. optim_steps={optim_step}') + model.save('hotpotqa-grpo-tools-llmcondense-final') + + +if __name__ == '__main__': + main() diff --git a/src/twinkle_agentic/condenser/model.py b/src/twinkle_agentic/condenser/model.py index 404cde2d..f4cae3a5 100644 --- a/src/twinkle_agentic/condenser/model.py +++ b/src/twinkle_agentic/condenser/model.py @@ -12,7 +12,7 @@ import math import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from twinkle_agentic.condenser.base import Condenser from twinkle_agentic.data_format import Chunk, Chunks @@ -27,11 +27,6 @@ def _sampling_params_cls(): from twinkle.data_format.sampling import SamplingParams return SamplingParams -# Markdown headers emitted by the condenser. -_SUMMARY_HEADER = '## Summary' -_FACTS_HEADER = '## Key Facts' -_MORE_HEADER = '## More' - _DEFAULT_SYSTEM_PROMPT = ( 'You are a precise text compression assistant. Summarize the user' ' passage into the required markdown structure without inventing' @@ -164,15 +159,19 @@ def __call__(self, chunks: Chunks, **kwargs) -> Chunks: trajectories = [ self._build_trajectory(c['content'], b) for _, c, b in batch ] + actual_len = len(trajectories) + # Pad to batch_size so distributed samplers (DP slice) never + # receive fewer inputs than expected. + if actual_len < self.batch_size and actual_len > 0: + pad_traj = trajectories[-1] + trajectories.extend( + [pad_traj] * (self.batch_size - actual_len)) sp = self._build_sampling_params(max(b for _, _, b in batch)) sample_kwargs: Dict[str, Any] = {'sampling_params': sp} if self.use_base_model: sample_kwargs['use_base_model'] = True responses = self.sampler.sample(trajectories, **sample_kwargs) - if len(responses) != len(batch): - raise RuntimeError( - f'sampler returned {len(responses)} responses for ' - f'{len(batch)} inputs') + responses = responses[:actual_len] for (i, c, budget), resp in zip(batch, responses): raw_text = self._pick_decoded(resp) compressed = self._postprocess(raw_text, budget, c['content']) @@ -249,88 +248,21 @@ def _pick_decoded(response) -> str: return decoded or '' def _postprocess(self, raw: str, budget: int, original: str) -> str: + """Strip code fences and clamp to budget via word-boundary truncation. + + The model is prompted to produce structured markdown (## Summary, + ## Key Facts, ## More). We trust the output as-is and only enforce + the character budget — no section parsing or re-formatting. + """ text = _strip_code_fences(raw).strip() - sections = _parse_markdown_sections(text) - formatted = _format_sections(sections, fallback=text) - if formatted and len(formatted) <= budget: - return formatted - # Progressive drop on a *copy*: More → Key Facts → Summary. Keep - # the original ``sections`` intact for the body-only fallback. - remaining = dict(sections) - for drop in ('more', 'facts', 'summary'): - remaining.pop(drop, None) - reduced = _format_sections(remaining, fallback='') - if reduced and len(reduced) <= budget: - return reduced - # Even "## Summary\n" cannot fit — the header alone eats the - # budget. Clamp the most informative *body* (no header) so the user - # still gets meaningful content instead of dangling hash marks. - for key in ('summary', 'facts', 'more'): - body = sections.get(key) - if body: - clamped = _clamp_to_budget(body, budget) - if clamped: - return clamped - # No parsable sections at all — clamp the stripped raw text - # (or the original passage as a last resort). + if text and len(text) <= budget: + return text return _clamp_to_budget(text or original, budget) # --------------------------------------------------------------------------- # helpers (pure functions) # --------------------------------------------------------------------------- -_SECTION_RE = re.compile( - r'^[ \t]*#{1,6}[ \t]*(?P
summary|key[ \t]*facts?|more)[ \t]*$', - re.IGNORECASE | re.MULTILINE, -) -_SECTION_KEYS = { - 'summary': 'summary', - 'key fact': 'facts', - 'key facts': 'facts', - 'keyfact': 'facts', - 'keyfacts': 'facts', - 'more': 'more', -} -_HEADER_ORDER: Tuple[Tuple[str, str], ...] = ( - ('summary', _SUMMARY_HEADER), - ('facts', _FACTS_HEADER), - ('more', _MORE_HEADER), -) - - -def _parse_markdown_sections(text: str) -> Dict[str, str]: - """Extract ``{summary, facts, more}`` sections from ``text``. - - Last-writer wins on duplicate headers (e.g. the model repeats - ``## Summary`` twice — we keep the later body). - """ - if not text: - return {} - matches = list(_SECTION_RE.finditer(text)) - out: Dict[str, str] = {} - for i, m in enumerate(matches): - header = re.sub(r'\s+', ' ', m.group('header').strip().lower()) - key = _SECTION_KEYS.get(header) - if key is None: - continue - start = m.end() - end = matches[i + 1].start() if i + 1 < len(matches) else len(text) - body = text[start:end].strip() - if body: - out[key] = body - return out - - -def _format_sections(sections: Dict[str, str], *, fallback: str = '') -> str: - parts = [ - f'{header}\n{sections[key]}' for key, header in _HEADER_ORDER - if sections.get(key) - ] - if parts: - return '\n\n'.join(parts) - return fallback - - def _strip_code_fences(text: str) -> str: """Unwrap a leading/trailing triple-backtick fence if present.""" stripped = text.strip() From bbed39d04a390f72a701590cadf8507681b298fa Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 10 May 2026 17:52:28 +0800 Subject: [PATCH 006/104] fix --- cookbook/rl/grpo_condensed.py | 57 +- src/twinkle_agentic/condenser/model.py | 586 +++++++++++++----- .../rollout/multi_turn_condense.py | 2 + tests/twinkle_agentic/test_model_condenser.py | 59 +- 4 files changed, 500 insertions(+), 204 deletions(-) diff --git a/cookbook/rl/grpo_condensed.py b/cookbook/rl/grpo_condensed.py index 53a061d0..0467365c 100644 --- a/cookbook/rl/grpo_condensed.py +++ b/cookbook/rl/grpo_condensed.py @@ -64,33 +64,46 @@ SYSTEM_PROMPT = """You are a careful multi-hop QA assistant. -## Compressed Context -The context you receive is **compressed**. Each paragraph is wrapped in \ -... and displayed as a Markdown summary with three sections: -- **Summary**: one-sentence overview of the block -- **Key Facts**: bulleted salient facts -- **More**: keywords hinting at details hidden in the full text - -Because the context is compressed, critical details may not be immediately \ -visible. You are strongly encouraged to call the `extract_condensed` tool \ -to expand blocks that likely contain the answer. +## Context Format (Mixed) +The context you receive is a **mix of two forms**: + +1. **Compressed blocks** — long passages wrapped in `...`, \ + displayed as a Markdown digest in **telegraphic style** (no \ + articles / "is" / "are"; colons and commas mean "is" / "has") \ + with up to three sections: + - **Summary**: one short phrase (≤ 15 words), NOT a full sentence + - **Key Facts**: up to 4 short bullets (each ≤ 10 words) + - **More**: 5–8 comma-separated keywords hinting at details hidden in the full text + Reading example: `India: 7th largest by area. Borders: Pakistan, \ + China.` means "India is the 7th largest country by area and \ + shares borders with Pakistan and China." +2. **Raw passages** — short passages shown inline as plain text (e.g. \ + `[K] Title: ...`) **without** any `` wrapping. These are already \ + the full text; nothing is hidden. + +Only the ``-wrapped blocks are compressed and can be expanded. \ +Do **not** try to extract raw passages — they have no block id and are \ +already complete. ## Workflow ### Phase 1 — Scan and Decide -Step 1: Read each block's Summary and Key facts to get an overview. -Step 2: Check the More keywords to judge whether hidden details are needed. -Step 3: Decide which blocks to expand, then call `extract_condensed`. +Step 1: Read each compressed block's Summary and Key Facts, and read raw \ +passages directly, to get an overview. +Step 2: For compressed blocks, check the More keywords to judge whether \ +hidden details are needed. +Step 3: Decide which compressed blocks to expand, then call \ +`extract_condensed` with their block ids. Raw passages need no extraction. ### Phase 2 — Reason and Answer After the tool returns the full text, continue stepping through the evidence: -Step N: From block X, I learn that [fact A]. +Step N: From block X (or raw passage [K]), I learn that [fact A]. Step N+1: From block Y, I need to call `extract_condensed` to get more information, because this block is related to... Step N+2: Combining these, the answer is ... \\boxed{answer} You may call `extract_condensed` several times to expand more blocks if the information is not enough, only answer the question if you are sure about the facts. -The `blocks` parameter accepts either a single integer (e.g. `3`) or a list of integers (e.g. `[1, 3]`) to expand several blocks in one call. +The `blocks` parameter accepts either a single integer (e.g. `3`) or a list of integers (e.g. `[1, 3]`) to expand several blocks in one call. Only pass ids that actually appear as `` in the context. ## Tool Call Format @@ -326,16 +339,11 @@ def main(): passage_boundary_re=r'^\[\d+\]\s+') condenser = ModelCondenser( sampler=sampler, - compression_ratio=4.0, + compression_ratio=2.0, sampling_params=SamplingParams( - max_tokens=256, num_samples=1, temperature=0.4, top_p=0.9), - # HotpotQA passages are often short; 50 keeps almost all passages - # eligible for compression while still skipping single-sentence - # blurbs that compress poorly. - min_chars=50, - # Compress with the frozen base model so the training LoRA - # cannot drift the summarization policy mid-training (closed-loop - # drift). + max_tokens=1024, num_samples=1, temperature=0.4, top_p=0.9), + min_chars=200, + template=rollout_template, use_base_model=True, ) @@ -348,7 +356,6 @@ def main(): sampling_params = SamplingParams( max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95, stop=['']) - rollout = MultiTurnCondenseRollout( sampler=sampler, template=rollout_template, diff --git a/src/twinkle_agentic/condenser/model.py b/src/twinkle_agentic/condenser/model.py index f4cae3a5..0870823f 100644 --- a/src/twinkle_agentic/condenser/model.py +++ b/src/twinkle_agentic/condenser/model.py @@ -1,95 +1,220 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """LLM-backed passage condenser. -Delegates compression to a :class:`twinkle.sampler.base.Sampler`. For -each eligible chunk, builds a compression prompt, samples from the -LLM, parses the markdown response into ``## Summary / ## Key Facts / -## More`` sections, and strictly clamps the final output to -``ceil(len(input) / compression_ratio)`` characters via progressive -section-drop + word-boundary truncation. +Pipeline +-------- +``Chunks`` → filter eligible chunks → batched ``Sampler.sample(...)`` → +strip code fences → boundary-aware character-budget clamp → ``Chunks`` +with ``raw.condensed=True`` (so :meth:`Chunks.to_trajectory` later +wraps them in ````). + +The compression prompt asks for up to three markdown sections +(``## Summary / ## Key Facts / ## More``) written in **telegraphic +style** (no articles / copulas / filler) with per-section length +hints. Telegraphic output is ~2–3× denser than natural-prose summaries +and is critical under tight compression ratios. The output is **not** +parsed — sections pass through verbatim. The character budget is a +safety net only; the prompt encourages the model to self-shorten and +drop ``## More`` first, so truncation rarely needs to fire. """ from __future__ import annotations import math import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple) from twinkle_agentic.condenser.base import Condenser from twinkle_agentic.data_format import Chunk, Chunks if TYPE_CHECKING: # only used for type hints, keep runtime deps minimal - from twinkle.data_format import SamplingParams, Trajectory - from twinkle.sampler.base import Sampler + from twinkle.data_format import SamplingParams, Trajectory # noqa: F401 + from twinkle.sampler.base import Sampler # noqa: F401 + + +_SECTION_SCHEMA = ( + 'Purpose: produce a compact retrieval index. The reader skims it to' + ' decide whether — and on what topic — to fetch the full text.' + ' Every token must carry unique, non-recoverable information.\n\n' + 'Output EXACTLY this skeleton — never rename, merge, or add sections;' + ' stop immediately after the Topics line:\n\n' + '## Summary\n' + '<≤{summary_words} words. Subject + full naming hierarchy' + ' (family→genus→species; person→role→era; org→function→head).' + ' Identity and classification ONLY.\n' + ' PROHIBITED in Summary: any number, rank ("7th largest",' + ' "most populous", "oldest"), size, area, range, or border fact.' + ' Every such item must move to Key Facts, no exceptions.>\n\n' + '## Key Facts\n' + '<0–{max_bullets} bullets, ≤{bullet_words} words each,' + ' non-redundant with Summary. Priority:\n' + ' (1) Verbatim numbers copied from the passage' + ' ("3287263 km² area", "7516.6 km coastline").\n' + ' (2) "N