diff --git a/.gitignore b/.gitignore index 8cfd041f..dd83e74a 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ wheels/ /temp MANIFEST .locks/ +.temp/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/cookbook/exp/dataset.py b/cookbook/exp/dataset.py new file mode 100644 index 00000000..32c30de4 --- /dev/null +++ b/cookbook/exp/dataset.py @@ -0,0 +1,459 @@ +import hashlib +import json +import os +import re +from pathlib import Path +from typing import Any, Dict, List, Optional +from datasets import Features, Value +from modelscope import dataset_snapshot_download + +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.preprocessor import Preprocessor + +_TARGET_FEATURES = Features({ + 'id': Value('string'), + 'source': Value('string'), + 'messages': [{'role': Value('string'), 'content': Value('string')}], +}) + + +def _hash_id(prefix: str, content: str) -> str: + """Stable id from MD5 of content; collision-free for textual datasets.""" + return f'{prefix}__{hashlib.md5(content.encode("utf-8")).hexdigest()[:16]}' + + +def _register(dataset, processor_cls, meta: DatasetMeta, init_args: Optional[Dict[str, Any]] = None, + load_from_cache_file: bool = True) -> None: + """Add dataset and run preprocessor; auto-strip every input column to enforce + the universal ``{id, source, messages}`` output schema.""" + dataset.add_dataset(meta) + cols = list(dataset.datasets[meta.get_id()].column_names) + dataset.map( + processor_cls, + dataset_meta=meta, + init_args=init_args or {}, + remove_columns=cols, + load_from_cache_file=load_from_cache_file, + features=_TARGET_FEATURES, + ) + + +# ===== MuSiQue ===== +MUSIQUE_REPO = 'voidful/MuSiQue' + + +class MusiqueProcessor(Preprocessor): + """MuSiQue raw row → multiple ``{id, source, messages}`` rows, one per paragraph.""" + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + if row.get('answerable') is False: + continue + parent = str(row.get('id', '')) + for idx, p in enumerate(row.get('paragraphs') or []): + text = (p.get('paragraph_text') or '').strip() + if not text: + continue + out.append({ + 'id': f'musique__{parent}__{idx}', + 'source': 'musique', + 'messages': [{'role': 'assistant', 'content': text}], + }) + return self.map_row_to_col(out, keys=['id', 'source', 'messages']) + + +# Repo 仅含原始 JSONL 无 HF 元数据,必须先快照下载再以文件路径注册。 +_musique_jsonl = Path(dataset_snapshot_download(MUSIQUE_REPO)) / 'musique_ans_v1.0_train.jsonl' +if not _musique_jsonl.is_file(): + raise FileNotFoundError(f'MuSiQue raw file not found: {_musique_jsonl}') + + +# ===== swift/github-code ===== +GITHUB_CODE_REPO = 'ms://swift/github-code' + + +class GithubCodeProcessor(Preprocessor): + """github-code row → ``{id, source, messages}``;按代码长度均匀采样。 + + 把 ``[length_min, length_max)`` 切 ``n_buckets`` 桶,每桶配额 ``target/n_buckets``, + 桶满或超界即丢;近似得到 ``target`` 条且长度均匀分布的样本。 + 依赖 batched map 单进程下实例状态跨 batch 共享(``num_proc>1`` 会失效)。 + """ + + def __init__(self, target: int = 30000, length_min: int = 500, + length_max: int = 40000, n_buckets: int = 30): + self.length_min = length_min + self.length_max = length_max + self.n_buckets = n_buckets + self.bucket_quota = max(1, target // n_buckets) + self.bucket_count = [0] * n_buckets + + def _bucket(self, n: int) -> int: + if n < self.length_min or n >= self.length_max: + return -1 + idx = int((n - self.length_min) / (self.length_max - self.length_min) * self.n_buckets) + return min(idx, self.n_buckets - 1) + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + code = row.get('code') or '' + if not isinstance(code, str): + continue + b = self._bucket(len(code)) + if b < 0 or self.bucket_count[b] >= self.bucket_quota: + continue + self.bucket_count[b] += 1 + lang = row.get('language') or 'unknown' + out.append({ + 'id': _hash_id(f'github_code__{lang}', code), + 'source': 'github-code', + 'messages': [{'role': 'assistant', 'content': code}], + }) + return self.map_row_to_col(out, keys=['id', 'source', 'messages']) + + +# ===== modelscope/competition_math ===== +COMPETITION_MATH_REPO = 'ms://modelscope/competition_math' + + +class MathProcessor(Preprocessor): + """competition_math row → ``{id, source, messages}`` (user/assistant pair).""" + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + problem = (row.get('problem') or '').strip() + solution = (row.get('solution') or '').strip() + if not problem or not solution: + continue + out.append({ + 'id': _hash_id('math', f'{problem}\n{solution}'), + 'source': 'competition_math', + 'messages': [ + {'role': 'assistant', 'content': solution}, + ], + }) + return self.map_row_to_col(out, keys=['id', 'source', 'messages']) + + +# ===== nampdn-ai/tiny-textbooks ===== +TINY_TEXTBOOKS_REPO = 'ms://AI-ModelScope/tiny-textbooks' + + +class TinyTextbooksProcessor(Preprocessor): + """tiny-textbooks row → ``{id, source, messages}`` (user/assistant pair).""" + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + text = (row.get('text') or '').strip() + textbook = (row.get('textbook') or '').strip() + if not text or not textbook: + continue + out.append({ + 'id': _hash_id('tinytb', f'{text}\n{textbook}'), + 'source': 'tiny-textbooks', + 'messages': [ + {'role': 'assistant', 'content': textbook}, + ], + }) + return self.map_row_to_col(out, keys=['id', 'source', 'messages']) + + +# ===== Passage Explosion for Compression Distillation ===== +# Each message content >= threshold becomes a standalone row: messages=[{role:user, content:X}] + +_MIN_PASSAGE_LEN = 500 # CJK-equivalent units + + +def _effective_len(text: str) -> int: + """CJK chars count double; threshold 500 ≈ 500 Chinese chars ≈ 1000 Latin chars.""" + cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff' or '\u3000' <= c <= '\u303f') + return cjk * 2 + (len(text) - cjk) + + +def _extract_content(msg: dict) -> str: + """Extract text content from a message dict, handling multimodal list-content.""" + content = msg.get('content') + if isinstance(content, list): + content = '\n'.join( + p.get('text', '') if isinstance(p, dict) else str(p) for p in content) + if not isinstance(content, str): + return '' + return content.strip() + + +class PassageExplodeProcessor(Preprocessor): + """Explode multi-turn messages into individual long passages for compression distillation.""" + + def __init__(self, source: str): + self.source = source + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + messages = row.get('messages') + if isinstance(messages, str): + try: + messages = json.loads(messages) + except (ValueError, TypeError): + continue + if not isinstance(messages, list): + continue + for msg in messages: + if not isinstance(msg, dict): + continue + role = msg.get('role') or '' + if role == 'system': + continue + content = _extract_content(msg) + if not content or _effective_len(content) < _MIN_PASSAGE_LEN: + continue + out.append({ + 'id': _hash_id(self.source, content), + 'source': self.source, + 'messages': [{'role': 'assistant', 'content': content}], + }) + return self.map_row_to_col(out, keys=['id', 'source', 'messages']) + + +# ===== Reasoning / CoT datasets — explode query and assistant separately ===== +_THINK_RE = re.compile(r'(.*?)', re.DOTALL) + + +class CotExplodeProcessor(Preprocessor): + """Base for CoT datasets: explode query and full assistant content as separate passages.""" + + def _extract_rows(self, rows: List[Dict[str, Any]]) -> List[tuple]: + """Subclass returns list of (query, cot, response) tuples.""" + raise NotImplementedError + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows_list = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for query, cot, response, source in self._extract_rows(rows_list): + if cot: + response = _THINK_RE.sub('', response).strip() + assistant_content = f'{cot}{response}' if cot else response + for text in (query, assistant_content): + if not text or _effective_len(text) < _MIN_PASSAGE_LEN: + continue + out.append({ + 'id': _hash_id(source, text), + 'source': source, + 'messages': [{'role': 'assistant', 'content': text}], + }) + return self.map_row_to_col(out, keys=['id', 'source', 'messages']) + + +# -- Chinese-DeepSeek-R1-Distill-data-110k -- +CN_R1_DISTILL_REPO = 'ms://AI-ModelScope/Chinese-DeepSeek-R1-Distill-data-110k' + + +class ChineseR1DistillProcessor(CotExplodeProcessor): + """input → query, reasoning_content → cot, content → response.""" + + def _extract_rows(self, rows): + for row in rows: + query = (row.get('input') or '').strip() + cot = (row.get('reasoning_content') or '').strip() + response = (row.get('content') or '').strip() + if not query or not response: + continue + yield query, cot, response, 'Chinese-DeepSeek-R1-Distill-data-110k' + + +# -- Opus-4.6-Reasoning-3000x-filtered -- +OPUS_REASONING_REPO = 'ms://nohurry/Opus-4.6-Reasoning-3000x-filtered' + + +class OpusReasoningProcessor(CotExplodeProcessor): + """problem → query, thinking → cot, solution → response.""" + + def _extract_rows(self, rows): + for row in rows: + query = (row.get('problem') or '').strip() + cot = (row.get('thinking') or '').strip() + response = (row.get('solution') or '').strip() + if not query or not response: + continue + yield query, cot, response, 'Opus-4.6-Reasoning-3000x-filtered' + + +# -- claude-opus-4.6-10000x -- +CLAUDE_OPUS_REPO = 'ms://Roman1111111/claude-opus-4.6-10000x' + + +class ClaudeOpusProcessor(CotExplodeProcessor): + """messages (OpenAI format) → extract user/assistant, split or reasoning field.""" + + def _extract_rows(self, rows): + for row in rows: + messages = row.get('messages') + if not isinstance(messages, list): + continue + query = '' + assistant_text = '' + reasoning = '' + for msg in messages: + if not isinstance(msg, dict): + continue + role = msg.get('role') or '' + content = msg.get('content') or '' + if not isinstance(content, str): + continue + if role == 'user' and not query: + query = content.strip() + elif role == 'assistant' and not assistant_text: + assistant_text = content.strip() + reasoning = (msg.get('reasoning') or '').strip() + break + if not query or not assistant_text: + continue + cot = reasoning + if not cot: + m = _THINK_RE.search(assistant_text) + if m: + cot = m.group(1).strip() + assistant_text = assistant_text[m.end():].strip() + response = assistant_text if not reasoning else _THINK_RE.sub('', assistant_text).strip() + if not response: + continue + yield query, cot, response, 'claude-opus-4.6-10000x' + + +# -- angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k -- +ANGRYGIRAFFE_REPO = 'ms://hf/angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k' + + +class AngrygiraffeOpusReasoningProcessor(CotExplodeProcessor): + """messages (OpenAI format) → extract first user/assistant, split tag.""" + + def _extract_rows(self, rows): + for row in rows: + messages = row.get('messages') + if not isinstance(messages, list): + continue + query = '' + assistant_text = '' + for msg in messages: + if not isinstance(msg, dict): + continue + role = msg.get('role') or '' + content = msg.get('content') or '' + if not isinstance(content, str): + continue + if role == 'user' and not query: + query = content.strip() + elif role == 'assistant' and not assistant_text: + assistant_text = content.strip() + break + if not query or not assistant_text: + continue + m = _THINK_RE.search(assistant_text) + if m: + cot = m.group(1).strip() + response = assistant_text[m.end():].strip() + else: + cot = '' + response = assistant_text + if not response: + continue + yield query, cot, response, 'angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k' + + +_BASE_SIZES = { + 'tiny_textbooks': 10000, + 'musique': 1000, + 'github_code': 30000, + 'competition_math': 7500, + 'toucan': 10000, + 'swe_smith': 1000, + 'cn_r1_distill': 10000, + 'opus_reasoning': 3000, + 'claude_opus': 10000, + 'angrygiraffe': 20000, +} + + +def _scaled_sizes(total: Optional[int]) -> Dict[str, int]: + if total is None: + return dict(_BASE_SIZES) + scale = total / sum(_BASE_SIZES.values()) + return {k: max(1, int(round(v * scale))) for k, v in _BASE_SIZES.items()} + + +def get_dataset(total: Optional[int] = None, load_from_cache_file: bool = True) -> Dataset: + """Build the unified compression-distillation dataset. + + If ``total`` is given, every per-source row count in ``_BASE_SIZES`` is + scaled proportionally so the input-row sum approximates ``total``. + """ + sizes = _scaled_sizes(total) + dataset = Dataset() + + _register(dataset, TinyTextbooksProcessor, + DatasetMeta(dataset_id=TINY_TEXTBOOKS_REPO, split='train', + data_slice=range(sizes['tiny_textbooks'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, MusiqueProcessor, + DatasetMeta(str(_musique_jsonl), data_slice=range(sizes['musique'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, GithubCodeProcessor, + DatasetMeta(dataset_id=GITHUB_CODE_REPO, subset_name='all-apache-2.0', split='train'), + init_args={'target': sizes['github_code']}, + load_from_cache_file=load_from_cache_file) + + _register(dataset, MathProcessor, + DatasetMeta(dataset_id=COMPETITION_MATH_REPO, subset_name='default', split='train', + data_slice=range(sizes['competition_math'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, PassageExplodeProcessor, + DatasetMeta(dataset_id='ms://Agent-Ark/Toucan-1.5M', subset_name='Kimi-K2', split='train', + data_slice=range(sizes['toucan'])), + init_args={'source': 'toucan'}, + load_from_cache_file=load_from_cache_file) + + _register(dataset, PassageExplodeProcessor, + DatasetMeta(dataset_id='ms://SWE-bench/SWE-smith-trajectories', split='tool', + data_slice=range(sizes['swe_smith'])), + init_args={'source': 'swe-smith'}, + load_from_cache_file=load_from_cache_file) + + _register(dataset, ChineseR1DistillProcessor, + DatasetMeta(dataset_id=CN_R1_DISTILL_REPO, split='train', + data_slice=range(sizes['cn_r1_distill'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, OpusReasoningProcessor, + DatasetMeta(dataset_id=OPUS_REASONING_REPO, split='train', + data_slice=range(sizes['opus_reasoning'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, ClaudeOpusProcessor, + DatasetMeta(dataset_id=CLAUDE_OPUS_REPO, split='train', + data_slice=range(sizes['claude_opus'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, AngrygiraffeOpusReasoningProcessor, + DatasetMeta(dataset_id=ANGRYGIRAFFE_REPO, split='train', + data_slice=range(sizes['angrygiraffe'])), + load_from_cache_file=load_from_cache_file) + + dataset.mix_dataset(False) + return dataset + + +if __name__ == '__main__': + dataset = get_dataset(load_from_cache_file=True) + print(len(dataset)) diff --git a/cookbook/exp/dataset_think.py b/cookbook/exp/dataset_think.py new file mode 100644 index 00000000..7873a678 --- /dev/null +++ b/cookbook/exp/dataset_think.py @@ -0,0 +1,464 @@ +import hashlib +import re +from typing import Any, Dict, List, Optional + +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.preprocessor import Preprocessor + +_THINK_RE = re.compile(r'(.*?)', re.DOTALL) + + +def _hash_id(prefix: str, content: str) -> str: + return f'{prefix}__{hashlib.md5(content.encode("utf-8")).hexdigest()[:16]}' + + +def _register(dataset, processor_cls, meta: DatasetMeta, init_args: Optional[Dict[str, Any]] = None, + load_from_cache_file: bool = True) -> None: + """Add dataset and run preprocessor; auto-strip every input column to enforce + the universal ``{id, source, query, cot, response}`` output schema.""" + dataset.add_dataset(meta) + cols = list(dataset.datasets[meta.get_id()].column_names) + dataset.map( + processor_cls, + dataset_meta=meta, + init_args=init_args or {}, + remove_columns=cols, + load_from_cache_file=load_from_cache_file, + ) + + +# ===== Modotte/CodeX-2M-Thinking ===== +CODEX_THINKING_REPO = 'ms://Modotte/CodeX-2M-Thinking' + + +class CodeXThinkingProcessor(Preprocessor): + """CodeX-2M-Thinking row → ``{id, source, query, cot, response}``。 + + 输入 schema: ``input``(问题)、``output``(含 ``...`` + 答案)。 + 拆分 output 为 cot(think 标签内容)和 response(标签之后的正文)。 + 丢弃缺失 input/output 或无法解析 think 标签的行。 + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('input') or '').strip() + output = (row.get('output') or '').strip() + if not query or not output: + continue + m = _THINK_RE.search(output) + if not m: + continue + cot = m.group(1).strip() + response = output[m.end():].strip() + if not cot or not response: + continue + out.append({ + 'id': _hash_id('codex_think', f'{query}\n{response}'), + 'source': 'CodeX-2M-Thinking', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# ===== open-thoughts/OpenThoughts3-1.2M ===== +OPEN_THOUGHTS_REPO = 'ms://open-thoughts/OpenThoughts3-1.2M' + + +class OpenThoughtsProcessor(Preprocessor): + """OpenThoughts3 row → ``{id, source, query, cot, response}``。 + + 输入 schema: ``conversations`` (messages 格式 list[{from/value}])。 + 取第一个 human 作 query,第一个 gpt 的 value 按 ``...`` 拆 cot/response。 + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + convs = row.get('conversations') + if not isinstance(convs, list): + continue + query = '' + assistant_text = '' + for msg in convs: + if not isinstance(msg, dict): + continue + role = msg.get('from') or msg.get('role') or '' + value = msg.get('value') or msg.get('content') or '' + if role in ('human', 'user') and not query: + query = value.strip() + elif role in ('gpt', 'assistant') and not assistant_text: + assistant_text = value.strip() + break + if not query or not assistant_text: + continue + m = _THINK_RE.search(assistant_text) + if not m: + continue + cot = m.group(1).strip() + response = assistant_text[m.end():].strip() + if not cot or not response: + continue + out.append({ + 'id': _hash_id('openthoughts', f'{query}\n{response}'), + 'source': 'OpenThoughts3-1.2M', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# ===== GAIR/LIMO-v2 ===== +LIMO_REPO = 'ms://GAIR/LIMO-v2' + + +class LIMOProcessor(Preprocessor): + """LIMO-v2 row → ``{id, source, query, cot, response}``。 + + 输入 schema: ``question``、``solution``(含 ``...`` + 答案)。 + 拆分 solution 为 cot 和 response。 + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('question') or '').strip() + solution = (row.get('solution') or '').strip() + if not query or not solution: + continue + m = _THINK_RE.search(solution) + if m: + cot = m.group(1).strip() + response = solution[m.end():].strip() + else: + # 无 think 标签时,solution 整体作为 response,cot 留空 + cot = '' + response = solution + if not response: + continue + out.append({ + 'id': _hash_id('limo', f'{query}\n{response}'), + 'source': 'LIMO-v2', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# ===== AI-ModelScope/Chinese-DeepSeek-R1-Distill-data-110k ===== +CN_R1_DISTILL_REPO = 'ms://AI-ModelScope/Chinese-DeepSeek-R1-Distill-data-110k' + + +class ChineseR1DistillProcessor(Preprocessor): + """Chinese-DeepSeek-R1-Distill row → ``{id, source, query, cot, response}``。 + + 输入已有三列: ``input`` → query, ``reasoning_content`` → cot, ``content`` → response。 + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('input') or '').strip() + cot = (row.get('reasoning_content') or '').strip() + response = (row.get('content') or '').strip() + if not query or not response: + continue + if cot: + response = _THINK_RE.sub('', response).strip() + if not response: + continue + out.append({ + 'id': _hash_id('cn_r1_distill', f'{query}\n{response}'), + 'source': 'Chinese-DeepSeek-R1-Distill-data-110k', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# ===== nohurry/Opus-4.6-Reasoning-3000x-filtered ===== +OPUS_REASONING_REPO = 'ms://nohurry/Opus-4.6-Reasoning-3000x-filtered' + + +class OpusReasoningProcessor(Preprocessor): + """Opus-4.6-Reasoning-3000x-filtered row → ``{id, source, query, cot, response}``。 + + 输入已有三列: ``problem`` → query, ``thinking`` → cot, ``solution`` → response。 + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('problem') or '').strip() + cot = (row.get('thinking') or '').strip() + response = (row.get('solution') or '').strip() + if not query or not response: + continue + if cot: + response = _THINK_RE.sub('', response).strip() + if not response: + continue + out.append({ + 'id': _hash_id('opus_reasoning', f'{query}\n{response}'), + 'source': 'Opus-4.6-Reasoning-3000x-filtered', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# ===== Roman1111111/claude-opus-4.6-10000x ===== +CLAUDE_OPUS_REPO = 'ms://Roman1111111/claude-opus-4.6-10000x' + + +class ClaudeOpusProcessor(Preprocessor): + """claude-opus-4.6-10000x row → ``{id, source, query, cot, response}``。 + + 输入 schema: ``messages`` (OpenAI 格式 list[{role, content}])。 + 取首个 user 作 query,首个 assistant 按 ``...`` 拆 cot/response。 + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + messages = row.get('messages') + if not isinstance(messages, list): + continue + query = '' + assistant_text = '' + for msg in messages: + if not isinstance(msg, dict): + continue + role = msg.get('role') or '' + content = msg.get('content') or '' + if not isinstance(content, str): + continue + if role == 'user' and not query: + query = content.strip() + elif role == 'assistant' and not assistant_text: + assistant_text = content.strip() + break + if not query or not assistant_text: + continue + m = _THINK_RE.search(assistant_text) + if m: + cot = m.group(1).strip() + response = assistant_text[m.end():].strip() + else: + cot = '' + response = assistant_text + if not response: + continue + out.append({ + 'id': _hash_id('claude_opus', f'{query}\n{response}'), + 'source': 'claude-opus-4.6-10000x', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +ANGRYGIRAFFE_REPO = 'ms://hf/angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k' + + +class AngrygiraffeOpusReasoningProcessor(Preprocessor): + """angrygiraffe/claude-opus-4.6-4.7-reasoning-8.7k row → ``{id, source, query, cot, response}``。 + + 输入 schema: ``messages`` (OpenAI 格式 list[{role, content}])。 + 取首个 user 作 query,首个 assistant 按 ``...`` 拆 cot/response,仅用头一轮。 + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + messages = row.get('messages') + if not isinstance(messages, list): + continue + query = '' + assistant_text = '' + for msg in messages: + if not isinstance(msg, dict): + continue + role = msg.get('role') or '' + content = msg.get('content') or '' + if not isinstance(content, str): + continue + if role == 'user' and not query: + query = content.strip() + elif role == 'assistant' and not assistant_text: + assistant_text = content.strip() + break + if not query or not assistant_text: + continue + m = _THINK_RE.search(assistant_text) + if m: + cot = m.group(1).strip() + response = assistant_text[m.end():].strip() + else: + cot = '' + response = assistant_text + if not response: + continue + out.append({ + 'id': _hash_id('angrygiraffe_opus', f'{query}\n{response}'), + 'source': 'angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +_BASE_SIZES = { + 'codex_think': 100000, + 'open_thoughts': 400000, + 'cn_r1_distill': 100000, + 'opus_reasoning': 3000, + 'claude_opus': 10000, + 'angrygiraffe': 38000, +} + + +def _scaled_sizes(total: Optional[int]) -> Dict[str, int]: + if total is None: + return dict(_BASE_SIZES) + scale = total / sum(_BASE_SIZES.values()) + return {k: max(1, int(round(v * scale))) for k, v in _BASE_SIZES.items()} + + +def _build_dataset(total: Optional[int] = None, load_from_cache_file: bool = True) -> Dataset: + sizes = _scaled_sizes(total) + dataset = Dataset() + + _register(dataset, CodeXThinkingProcessor, + DatasetMeta(dataset_id=CODEX_THINKING_REPO, split='train', + data_slice=range(sizes['codex_think'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, OpenThoughtsProcessor, + DatasetMeta(dataset_id=OPEN_THOUGHTS_REPO, split='train', + data_slice=range(sizes['open_thoughts'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, LIMOProcessor, + DatasetMeta(dataset_id=LIMO_REPO, split='train'), + load_from_cache_file=load_from_cache_file) + + _register(dataset, ChineseR1DistillProcessor, + DatasetMeta(dataset_id=CN_R1_DISTILL_REPO, split='train', + data_slice=range(sizes['cn_r1_distill'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, OpusReasoningProcessor, + DatasetMeta(dataset_id=OPUS_REASONING_REPO, split='train', + data_slice=range(sizes['opus_reasoning'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, ClaudeOpusProcessor, + DatasetMeta(dataset_id=CLAUDE_OPUS_REPO, split='train', + data_slice=range(sizes['claude_opus'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, AngrygiraffeOpusReasoningProcessor, + DatasetMeta(dataset_id=ANGRYGIRAFFE_REPO, split='train', + data_slice=range(sizes['angrygiraffe'])), + load_from_cache_file=load_from_cache_file) + + dataset.mix_dataset(False) + return dataset + + +class ToMessagesProcessor(Preprocessor): + """Convert {query, cot, response} → {id, source, messages}.""" + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = row.get('query') or '' + cot = row.get('cot') or '' + response = row.get('response') or '' + if not cot: + continue + assistant_content = f'{cot}' + out.append({ + 'id': row.get('id', ''), + 'source': row.get('source', ''), + 'messages': [ + {'role': 'user', 'content': query}, + {'role': 'assistant', 'content': assistant_content, + 'reasoning_content': cot}, + ], + }) + return self.map_row_to_col(out, keys=['id', 'source', 'messages']) + + +def get_dataset(total: Optional[int] = None, dropped_log: Optional[str] = None, + load_from_cache_file: bool = True) -> Dataset: + """Build, convert to messages format, and quality-filter the CoT dataset. + + If ``total`` is given, every per-source row count in ``_BASE_SIZES`` is + scaled proportionally so the input-row sum approximates ``total``. + """ + from twinkle_agentic.preprocessor import ( + AlphanumericFilter, + CharRepeatFilter, + DeadLoopFilter, + FixUnicodeFilter, + FlaggedWordsFilter, + HardFilter, + IntentClassifier, + MessageSanityFilter, + QualityPreprocessor, + RefuseFilter, + RemoveRepeatSentencesFilter, + TokenNumFilter, + TokenSoupFilter, + WordRepeatFilter, + ) + + dataset = _build_dataset(total=total, load_from_cache_file=load_from_cache_file) + dataset.map(ToMessagesProcessor(), remove_columns=['query', 'cot', 'response'], + load_from_cache_file=load_from_cache_file) + qp = QualityPreprocessor( + pipeline=[ + HardFilter(), + RefuseFilter(), + DeadLoopFilter(), + TokenSoupFilter(), + MessageSanityFilter(min_turns=1, max_msg_chars=200000), + FixUnicodeFilter(), + RemoveRepeatSentencesFilter(), + WordRepeatFilter(), + CharRepeatFilter(), + AlphanumericFilter(), + FlaggedWordsFilter(), + TokenNumFilter(max_num=32768), + ], + dropped_log_path=dropped_log or '', + ) + dataset.map(qp, num_proc=32, load_from_cache_file=load_from_cache_file) + return dataset + + +if __name__ == '__main__': + import os + dropped_log = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dropped.jsonl') + if os.path.exists(dropped_log): + os.remove(dropped_log) + dataset = get_dataset(load_from_cache_file=False) + print(len(dataset)) diff --git a/cookbook/exp/eval_condensed.py b/cookbook/exp/eval_condensed.py new file mode 100644 index 00000000..8c5fa655 --- /dev/null +++ b/cookbook/exp/eval_condensed.py @@ -0,0 +1,383 @@ +"""Evaluation: native (full ctx) vs condensed (chunk → condense → extract_condensed tool). + +Reuses the training-time data shape and prompt so the comparison is apples-to-apples. + +Launch: + # native baseline (full HotpotQA context, no compression, no tool) + python cookbook/exp/eval_condensed.py --mode native \\ + --dataset /path/to/hotpot_dev_fullwiki.jsonl + + # condensed (chunk → condense via Qwen3.5-4B-Condenser → extract_condensed tool) + python cookbook/exp/eval_condensed.py --mode condensed \\ + --dataset /path/to/hotpot_dev_fullwiki.jsonl + +Outputs (under --out_dir / _/): + predictions.jsonl one row per sample with pred / gold / f1 / em / token-counts / tool-calls + summary.json aggregate metrics +""" +import argparse +import json +import os +import re +import time +import uuid +from collections import Counter +from typing import Any, Dict, List, Optional + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_logger +from twinkle.data_format import Message, SamplingParams, Trajectory +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.sampler import vLLMSampler +from twinkle.template import Qwen3_5Template +from twinkle_agentic.chunker.native import NativeChunker +from twinkle_agentic.condenser import ModelCondenser +from twinkle_agentic.reward import F1Reward +from twinkle_agentic.reward.f1 import _f1_score, _normalize_answer +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout +from twinkle_agentic.rollout.multi_turn_condense import MultiTurnCondenseRollout +from twinkle_agentic.tools.tool_manager import ToolManager +from twinkle.preprocessor.base import Preprocessor + +# Reuse training assets so eval and train share data shape + condensed prompt. +from cookbook.exp.grpo_condensed import ( + SYSTEM_PROMPT as CONDENSED_SYSTEM_PROMPT, + HotpotQAProcessor, + _BOXED_RE, + _last_assistant_text, +) + + +class MuSiQueProcessor(Preprocessor): + """MuSiQue-Ans → Trajectory adapter. + + MuSiQue native schema (per row): + id, question, paragraphs=[{idx, title, paragraph_text, is_supporting}], answer, + answer_aliases=[...], answerable, question_decomposition=[...] + + Maps to the same Trajectory(messages, user_data) shape that + :class:`HotpotQAProcessor` produces, so downstream rollout code is + schema-agnostic. ``ground_truth`` carries answer + answer_aliases. + """ + + def __init__(self, system: str): + self.system = system + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out = [self.preprocess(r) for r in rows] + out = [r for r in out if r is not None] + return self.map_row_to_col(out) + + @staticmethod + def _format_context(paragraphs: List[Dict[str, Any]]) -> str: + lines = [] + for p in paragraphs or []: + title = (p.get('title') or '').strip() + body = (p.get('paragraph_text') or '').strip() + if not body: + continue + lines.append(f'{title}: {body}' if title else body) + return '\n\n'.join(lines) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: + if row.get('answerable') is False: + return None + question = (row.get('question') or '').strip() + if not question: + return None + gold_main = (row.get('answer') or '').strip() + aliases = row.get('answer_aliases') or [] + gold = [g for g in dict.fromkeys([gold_main] + list(aliases)) if g] + if not gold: + return None + paragraphs = row.get('paragraphs') or [] + context_block = self._format_context(paragraphs) + user_msg = f'Question: {question}\n\nContext:\n\n{context_block}' + messages = [ + Message(role='system', content=self.system), + Message(role='user', content=user_msg), + ] + sf_titles = list(dict.fromkeys( + (p.get('title') or '').strip() + for p in paragraphs + if p.get('is_supporting') and (p.get('title') or '').strip())) + user_data = [('ground_truth', g) for g in gold] + [('sf_title', t) for t in sf_titles] + return Trajectory(messages=messages, user_data=user_data) + +logger = get_logger() + +NATIVE_SYSTEM_PROMPT = """You are a careful multi-hop QA assistant. + +The user message contains a Question and a Context. Read both, reason step by step, +then commit to a final answer. + +## Output Format +End your final response with \\boxed{answer}. +Keep the boxed text short: a name, entity, date, or "yes"/"no". +Answers not inside \\boxed{} will not be scored.""" + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument('--mode', choices=['native', 'condensed'], required=True) + p.add_argument('--dataset', required=True, + help='Eval set jsonl. HotpotQA or MuSiQue-Ans schema (see --dataset_format).') + p.add_argument('--dataset_format', choices=['hotpotqa', 'musique'], default='musique', + help='Schema of --dataset. MuSiQue-Ans (default) is harder multi-hop and OOD vs training.') + p.add_argument('--model_id', default='ms://Qwen/Qwen3.5-4B') + p.add_argument('--lora_path', default=None, + help='Optional LoRA adapter on top of model_id (e.g. trained QA LoRA).') + p.add_argument('--condenser_lora', default='ms://twinkle-kit/Qwen3.5-4B-Condenser') + p.add_argument('--limit', type=int, default=500) + p.add_argument('--num_gpus', type=int, default=4) + p.add_argument('--batch_size', type=int, default=8) + p.add_argument('--max_model_len', type=int, default=32768) + p.add_argument('--max_new_tokens', type=int, default=2048) + p.add_argument('--max_turns', type=int, default=4) + p.add_argument('--max_trajectory_tokens', type=int, default=8192) + p.add_argument('--chunk_size', type=int, default=1024) + p.add_argument('--temperature', type=float, default=0.0) + p.add_argument('--out_dir', default='eval_out') + p.add_argument('--seed', type=int, default=42) + return p.parse_args() + + +def build_dataset(path: str, dataset_format: str, model_id: str, + max_length: int, limit: int, system: str) -> Dataset: + """Load eval JSONL and produce Trajectory rows tagged with ground_truth user_data.""" + ds = Dataset() + ds.add_dataset(DatasetMeta(path)) + if limit > 0 and len(ds) > limit: + ds = ds.select(range(limit)) + ds.set_template( + 'Qwen3_5Template', model_id=model_id, max_length=max_length, + truncation_strategy='delete', enable_thinking=False) + if dataset_format == 'musique': + # MuSiQue-Ans cols (drop everything; we keep only the produced messages/user_data) + cols = ['id', 'question', 'paragraphs', 'answer', 'answer_aliases', + 'answerable', 'question_decomposition'] + ds.map(MuSiQueProcessor(system=system), remove_columns=cols) + else: + cols = ['id', 'question', 'question_fixed', 'answers', 'original_answer', + 'type', 'level', 'verdict', 'reasoning', 'supporting_facts', 'context'] + ds.map(HotpotQAProcessor(system=system), remove_columns=cols) + return ds + + +def extract_boxed(text: str) -> Optional[str]: + """Pull the inner text of the LAST `\\boxed{...}` marker, brace-balanced enough for short answers.""" + if not text: + return None + matches = _BOXED_RE.findall(text) + if not matches: + return None + last = matches[-1] + return last[len(r'\boxed{'):-1].strip() + + +def best_f1_em(pred: str, golds: List[str]) -> Dict[str, float]: + """Max-over-references SQuAD-style F1 / EM, reusing the training reward's normalizer.""" + if not golds: + return {'f1': 0.0, 'em': 0.0} + if not pred: + return {'f1': 0.0, 'em': 0.0} + best_f1, best_em = 0.0, 0.0 + for g in golds: + f1, em = _f1_score(pred, g) + if f1 > best_f1: + best_f1 = f1 + if em > best_em: + best_em = em + return {'f1': best_f1, 'em': best_em} + + +def _user_text(traj_or_msg) -> str: + """Concat all text parts of the first user message — used to count original context tokens.""" + msgs = traj_or_msg if isinstance(traj_or_msg, list) else (traj_or_msg.get('messages') or []) + for m in msgs: + role = m.get('role') if isinstance(m, dict) else getattr(m, 'role', None) + if role != 'user': + continue + content = m.get('content') if isinstance(m, dict) else getattr(m, 'content', None) + if isinstance(content, str): + return content + if isinstance(content, list): + return ''.join(p.get('text') or '' for p in content if isinstance(p, dict) and p.get('type') == 'text') + return '' + return '' + + +def _count_tool_calls(traj: Dict[str, Any]) -> int: + return sum(len(m.get('tool_calls') or []) + for m in (traj.get('messages') or []) if m.get('role') == 'assistant') + + +def main(): + args = parse_args() + run_id = time.strftime('%Y%m%d_%H%M%S') + '_' + uuid.uuid4().hex[:6] + out_dir = os.path.join(args.out_dir, f'{args.mode}_{run_id}') + os.makedirs(out_dir, exist_ok=True) + + device_groups = [DeviceGroup(name='sampler', ranks=list(range(args.num_gpus)), device_type='GPU')] + sampler_mesh = DeviceMesh.from_sizes(world_size=args.num_gpus, dp_size=args.num_gpus) + twinkle.initialize(mode='ray', nproc_per_node=args.num_gpus, + groups=device_groups, lazy_collect=False) + + system = CONDENSED_SYSTEM_PROMPT if args.mode == 'condensed' else NATIVE_SYSTEM_PROMPT + ds = build_dataset(args.dataset, args.dataset_format, args.model_id, + args.max_model_len, args.limit, system) + logger.info('Eval dataset: %d rows from %s (mode=%s, format=%s)', + len(ds), args.dataset, args.mode, args.dataset_format) + + sampler = vLLMSampler( + model_id=args.model_id, + engine_args={ + 'gpu_memory_utilization': 0.85, 'max_model_len': args.max_model_len, + 'max_lora_rank': 32, 'enable_lora': True, + 'enable_tower_connector_lora': True, 'max_loras': 5, + 'seed': args.seed, + }, + device_mesh=sampler_mesh, remote_group='sampler') + sampler.set_template('Qwen3_5Template', model_id=args.model_id, + enable_thinking=False, max_length=args.max_model_len) + template = Qwen3_5Template(args.model_id, max_length=args.max_model_len, enable_thinking=False) + + # stop=[''] only matters for condensed mode where the model issues tool calls + sampling_params = SamplingParams( + max_tokens=args.max_new_tokens, num_samples=1, + temperature=args.temperature, top_p=0.95, + stop=[''] if args.mode == 'condensed' else None, + ) + + if args.mode == 'condensed': + chunker = NativeChunker(chunk_size=args.chunk_size, passage_boundary_re=r'(?<=\n\n)') + # Chunk-level extraction of the question line; \A anchor avoids matching "Question:" inside passages. + _q_re = re.compile(r'\AQuestion:\s*(.+)') + + def _q_from_chunk(chunk): + c = chunk.get('content') + if chunk.get('type') != 'text' or not isinstance(c, str): + return None + m = _q_re.search(c) + return m.group(1).strip() if m else None + + condenser = ModelCondenser( + sampler=sampler, compression_ratio=2.0, + sampling_params=SamplingParams(max_tokens=1024, num_samples=1, + temperature=0.4, top_p=0.9), + min_chars=200, template=template, + lora_path=args.condenser_lora, skip_pattern=r'^Question:', + related_query=_q_from_chunk, + ) + rollout = MultiTurnCondenseRollout( + sampler=sampler, template=template, tool_manager=ToolManager(), + chunker=chunker, condenser=condenser, + sampling_params=sampling_params, + max_turns=args.max_turns, max_trajectory_tokens=args.max_trajectory_tokens, + ) + else: + # max_turns=1, no tools: reduces to single-turn QA over the full original context + rollout = MultiTurnRollout( + sampler=sampler, template=template, tool_manager=ToolManager(), + sampling_params=sampling_params, + max_turns=1, max_trajectory_tokens=args.max_trajectory_tokens, + ) + + dataloader = DataLoader(dataset=ds, batch_size=args.batch_size, + min_batch_size=1, shuffle=False) + + pred_path = os.path.join(out_dir, 'predictions.jsonl') + pf = open(pred_path, 'w', encoding='utf-8') + + agg = Counter() + sums = {'f1': 0.0, 'em': 0.0, + 'prompt_tok': 0, 'comp_tok': 0, 'orig_ctx_tok': 0, + 'turns': 0, 'tool_calls': 0} + t0 = time.time() + + for batch in dataloader: + trajs = rollout(batch) + + for src, traj in zip(batch, trajs): + text = _last_assistant_text(traj) or '' + pred = extract_boxed(text) or '' + golds = [v for k, v in (src.user_data or []) if k == 'ground_truth' and v] + + scores = best_f1_em(pred, golds) + ids = traj.get('input_ids') or [] + comp_tok = sum(1 for l in (traj.get('labels') or []) if l != -100) + prompt_tok = max(0, len(ids) - comp_tok) + tool_calls = _count_tool_calls(traj) + + # Original (uncondensed) context size — feed only the user msg, not the system prompt, + # so the compression ratio stays comparable across modes. + orig_user = _user_text(src.messages) + orig_ctx_tok = len(template.tokenizer.encode(orig_user)) if orig_user else 0 + + agg['n'] += 1 + agg['no_box'] += int(_BOXED_RE.search(text) is None) + agg['tool_use'] += int(tool_calls > 0) + sums['f1'] += scores['f1'] + sums['em'] += scores['em'] + sums['prompt_tok'] += prompt_tok + sums['comp_tok'] += comp_tok + sums['orig_ctx_tok'] += orig_ctx_tok + sums['turns'] += int(traj.get('turns') or 1) + sums['tool_calls'] += tool_calls + + pf.write(json.dumps({ + 'pred': pred, + 'gold': golds, + 'f1': scores['f1'], + 'em': scores['em'], + 'prompt_tok': prompt_tok, + 'comp_tok': comp_tok, + 'orig_ctx_tok': orig_ctx_tok, + 'tool_calls': tool_calls, + 'turns': int(traj.get('turns') or 1), + 'no_boxed': _BOXED_RE.search(text) is None, + 'response': text, + }, ensure_ascii=False) + '\n') + + logger.info('[eval] %d / %d processed', agg['n'], len(ds)) + + pf.close() + wall = time.time() - t0 + n = max(1, agg['n']) + summary = { + 'mode': args.mode, + 'dataset_format': args.dataset_format, + 'model_id': args.model_id, + 'lora_path': args.lora_path, + 'condenser_lora': args.condenser_lora if args.mode == 'condensed' else None, + 'dataset': args.dataset, + 'n_samples': agg['n'], + # quality + 'f1': sums['f1'] / n, + 'em': sums['em'] / n, + 'no_boxed_rate': agg['no_box'] / n, + # cost + 'avg_prompt_tokens': sums['prompt_tok'] / n, + 'avg_completion_tokens': sums['comp_tok'] / n, + 'avg_orig_context_tokens': sums['orig_ctx_tok'] / n, + 'compression_ratio': (sums['prompt_tok'] / sums['orig_ctx_tok'] + if sums['orig_ctx_tok'] else None), + # tool / multi-turn behavior + 'avg_turns': sums['turns'] / n, + 'avg_tool_calls': sums['tool_calls'] / n, + 'tool_use_rate': agg['tool_use'] / n, + # wall + 'wall_time_sec': wall, + 'samples_per_sec': agg['n'] / wall if wall > 0 else 0.0, + } + with open(os.path.join(out_dir, 'summary.json'), 'w', encoding='utf-8') as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + + logger.info('Done. Output: %s', out_dir) + logger.info('Summary: %s', json.dumps(summary, indent=2, ensure_ascii=False)) + + +if __name__ == '__main__': + main() diff --git a/cookbook/exp/eval_condensed_compressed.sh b/cookbook/exp/eval_condensed_compressed.sh new file mode 100755 index 00000000..5567a1a3 --- /dev/null +++ b/cookbook/exp/eval_condensed_compressed.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# Compressed run: chunk → condense via Qwen3.5-4B-Condenser LoRA → extract_condensed tool loop. +# Identical --dataset / --limit / --model_id as eval_condensed_native.sh for an A/B comparison. +set -euo pipefail + +DATASET="${DATASET:-/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl}" +MODEL_ID="${MODEL_ID:-ms://Qwen/Qwen3.5-4B}" +CONDENSER_LORA="${CONDENSER_LORA:-ms://twinkle-kit/Qwen3.5-4B-Condenser}" +LIMIT="${LIMIT:-500}" +NUM_GPUS="${NUM_GPUS:-4}" +OUT_DIR="${OUT_DIR:-eval_out}" + +CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} \ +python cookbook/exp/eval_condensed.py \ + --mode condensed \ + --dataset_format musique \ + --dataset "${DATASET}" \ + --model_id "${MODEL_ID}" \ + --condenser_lora "${CONDENSER_LORA}" \ + --limit "${LIMIT}" \ + --num_gpus "${NUM_GPUS}" \ + --batch_size 8 \ + --max_model_len 32768 \ + --max_new_tokens 2048 \ + --max_turns 4 \ + --max_trajectory_tokens 8192 \ + --chunk_size 1024 \ + --temperature 0.0 \ + --out_dir "${OUT_DIR}" diff --git a/cookbook/exp/eval_condensed_native.sh b/cookbook/exp/eval_condensed_native.sh new file mode 100755 index 00000000..0849e937 --- /dev/null +++ b/cookbook/exp/eval_condensed_native.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# Native baseline: full original context, single-turn QA, no compression, no tools. +# Compare against eval_condensed_compressed.sh on identical --dataset / --limit / --model_id. +set -euo pipefail + +DATASET="${DATASET:-/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl}" +MODEL_ID="${MODEL_ID:-ms://Qwen/Qwen3.5-4B}" +LIMIT="${LIMIT:-500}" +NUM_GPUS="${NUM_GPUS:-4}" +OUT_DIR="${OUT_DIR:-eval_out}" + +CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} \ +python cookbook/exp/eval_condensed.py \ + --mode native \ + --dataset_format musique \ + --dataset "${DATASET}" \ + --model_id "${MODEL_ID}" \ + --limit "${LIMIT}" \ + --num_gpus "${NUM_GPUS}" \ + --batch_size 8 \ + --max_model_len 32768 \ + --max_new_tokens 2048 \ + --max_trajectory_tokens 8192 \ + --temperature 0.0 \ + --out_dir "${OUT_DIR}" diff --git a/cookbook/exp/grpo_baseline.py b/cookbook/exp/grpo_baseline.py new file mode 100644 index 00000000..237f9b06 --- /dev/null +++ b/cookbook/exp/grpo_baseline.py @@ -0,0 +1,593 @@ +"""HotpotQA GRPO baseline — full context, no chunking, no compression, no tools. + +This is the **control group** for ``grpo_condensed.py``. Both scripts share: + * dataset (HotpotQA fullwiki, hard split) + * preprocessing (``HotpotQAProcessor`` with ``[K] Title: ...`` passages) + * GRPO infra (model / sampler / device mesh / hyperparams) + * rollout class (``MultiTurnRollout`` from ``multi_turn.py``) + +The only differences are intentional: + * no ``NativeChunker`` / ``ModelCondenser`` (full passages go in verbatim) + * no tools registered (``ToolManager()`` is empty) + * ``max_turns=1`` so the rollout is effectively single-turn + * simplified system prompt (no ```` / ``extract_condensed`` syntax) + * ``F1Reward + CoTReward`` only (no ``ToolExploreReward``) + * traces → ``rollout_trace_baseline.jsonl`` + * checkpoints prefixed ``hotpotqa-grpo-baseline-*`` + +Keeping the same ``MultiTurnRollout`` code path on both sides means any +training-loop-level discrepancy between the two runs is attributable to +the chunk+condense pipeline, not to differences in rollout plumbing. +""" + +import math +import os +import re +from typing import Any, Dict, List, Optional + +import swanlab +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.data_format import Message, SamplingParams, Trajectory +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.metric import CompletionRewardMetric +from twinkle.model import TransformersModel +from twinkle.preprocessor.base import Preprocessor +from twinkle.processor import InputProcessor +from twinkle.sampler import vLLMSampler +from twinkle.template import Qwen3_5Template +from twinkle_agentic.reward import F1Reward, CoTReward +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout +from twinkle_agentic.tools.tool_manager import ToolManager + +logger = get_logger() + +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) +NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', 1)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 0)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +ADAPTER_NAME = 'default' +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) +LORA_RANK = int(os.environ.get('LORA_RANK', 16)) + +# Single-turn baseline; tools are not registered, but we keep MultiTurnRollout +# to share the rollout code path with the condensed variant. ``max_turns=1`` +# guarantees the loop runs exactly one sampling pass per trajectory. +MAX_TURNS = int(os.environ.get('MAX_TURNS', 1)) + +HOTPOTQA_NUM_PROC = int(os.environ.get('HOTPOTQA_NUM_PROC', 16)) +HOTPOTQA_MAX_LENGTH = int(os.environ.get('HOTPOTQA_MAX_LENGTH', 64000)) + +F1_REWARD_WEIGHT = float(os.environ.get('F1_REWARD_WEIGHT', 1.0)) +COT_REWARD_WEIGHT = float(os.environ.get('COT_REWARD_WEIGHT', 0.2)) + +# KL penalty coefficient; 0 disables KL (and skips the ref forward pass entirely). +KL_BETA = float(os.environ.get('KL_BETA', 0.02)) + +# Entropy bonus coefficient; 0 disables entropy compute path. +ENTROPY_COEF = float(os.environ.get('ENTROPY_COEF', 0.0)) + +# CISPO token-level IS clamp thresholds (asymmetric: 0.2 / 0.28). +CISPO_EPS_LOW = float(os.environ.get('CISPO_EPS_LOW', 0.2)) +CISPO_EPS_HIGH = float(os.environ.get('CISPO_EPS_HIGH', 0.2)) + +# High-KL token capture: top-K per microbatch dumped into log_dict['_high_kl_records']. 0 = disabled. +HIGH_KL_TOPK = int(os.environ.get('HIGH_KL_TOPK', 0)) + +DATASET_PATH = os.environ.get( + 'DATASET_PATH', + os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + 'hotpotqa_fullwiki_reannotated_12k.jsonl')) +F1_BINARY_THRESHOLD = float(os.environ.get('F1_BINARY_THRESHOLD', 0.5)) + +_ROLLOUT_TRACE_DIR = os.environ.get( + 'ROLLOUT_TRACE_BASELINE_DIR', 'rollout_trace_baseline') + +SYSTEM_PROMPT = """You are a careful multi-hop QA assistant. + +You will receive a question and a set of supporting passages. Each passage \ +is shown inline as plain text in the form `[K] Title: ...`, where `K` is the \ +passage index. All passages are already complete — there is no extraction \ +or expansion step. + +## Workflow + +Step 1: Read every passage and identify which ones are relevant to the question. +Step 2: Reason step by step, citing the passage indices you used. + Step N: From passage [K], I learn that [fact A]. + Step N+1: From passage [M], I learn that [fact B]. + Step N+2: Combining these, the answer is ... +Step 3: Emit the final answer in `\\boxed{...}`. + +Only answer when you are confident in the supporting facts. + +## Output Format +End your final response with \\boxed{answer}, e.g. \\boxed{Delhi}. +Keep the boxed text short: a name, entity, date, or "yes"/"no". +Answers not inside \\boxed{} will not be scored.""" + + +_F1_REWARD: Optional[F1Reward] = F1Reward() +_COT_REWARD: Optional[CoTReward] = CoTReward() + + +def compute_rewards(trajectories: List[Dict[str, Any]]): + f1_raw = _F1_REWARD(trajectories) + f1 = [1.0 if v >= F1_BINARY_THRESHOLD else 0.0 for v in f1_raw] if F1_BINARY_THRESHOLD > 0 else f1_raw + cot = _COT_REWARD(trajectories) + total = [ + F1_REWARD_WEIGHT * a + COT_REWARD_WEIGHT * c + for a, c in zip(f1, cot) + ] + return total, f1, cot + + +class HotpotQAProcessor(Preprocessor): + """Preprocessor for the reannotated HotpotQA JSONL. Passages are emitted + as ``[K] Title: ...`` lines. Rows with ``verdict='drop'`` are excluded; + ``question_fixed`` is used in place of ``question`` when present.""" + + def __init__(self, system: str = SYSTEM_PROMPT): + self.system = system + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = [r for r in rows if r is not None] + rows = self.map_row_to_col(rows) + return rows + + @staticmethod + def _format_context(context: Dict[str, Any]) -> str: + titles = context.get('title', []) or [] + sentences = context.get('sentences', []) or [] + lines = [] + for i, (title, sents) in enumerate(zip(titles, sentences), start=1): + if isinstance(sents, list): + body = ' '.join(s.strip() for s in sents if s and s.strip()) + else: + body = str(sents).strip() + lines.append(f'[{i}] {title}: {body}') + return '\n\n'.join(lines) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: + if (row.get('verdict') or '').strip().lower() == 'drop': + return None + question = row.get('question_fixed') or row['question'] + answers = row.get('answers') + if isinstance(answers, list) and answers: + golds = [str(a).strip() for a in answers if str(a).strip()] + else: + golds = [s for s in [(row.get('answer', '') or '').strip()] if s] + context_block = self._format_context(row.get('context', {}) or {}) + user_msg = f'Question: {question}\n\nContext:\n\n{context_block}' + messages = [ + Message(role='system', content=self.system), + Message(role='user', content=user_msg), + ] + return Trajectory(messages=messages, user_data=[('ground_truth', g) for g in golds]) + + +def create_hotpotqa_dataset() -> Dataset: + dataset = Dataset() + dataset.add_dataset(DatasetMeta(DATASET_PATH)) + logger.info('[dataset] loaded %s: %d rows', DATASET_PATH, len(dataset)) + + dataset.set_template( + 'Qwen3_5Template', model_id=MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH, + truncation_strategy='delete', enable_thinking=False) + _HOTPOTQA_COLS = ['id', 'question', 'question_fixed', 'answers', + 'original_answer', 'type', 'level', 'verdict', + 'reasoning', 'supporting_facts', 'context'] + dataset.map(HotpotQAProcessor(system=SYSTEM_PROMPT), + remove_columns=_HOTPOTQA_COLS) + return dataset + + +# Matches a LaTeX ``\boxed{...}`` final-answer marker — used to flag +# rollouts that never committed an answer. Brace-balanced is overkill for +# a logging heuristic; a non-greedy ``[^}]*`` is good enough. +_BOXED_RE = re.compile(r'\\boxed\{[^}]*\}') + +# Pulls the leading number out of pre-formatted metric strings such as +# ``'0.03 iters/s'`` / ``'1.000000e-05'`` / ``'30 seconds'`` emitted by +# ``TrainMetric`` and ``GRPOMetric``. We use this in ``_coerce_for_swanlab`` +# so swanlab can build line charts instead of dropping those keys with a +# ``failed to create chart for key '...': invalid value type`` warning. +_LEADING_NUMBER_RE = re.compile(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?') + + +def _coerce_for_swanlab(log_dict: Dict[str, Any]) -> Dict[str, Any]: + """Cast string-valued metrics to float for swanlab line charts. + + ``TrainMetric.calculate()`` and ``GRPOMetric.calculate()`` return + pre-formatted strings (``'0.03 iters/s'``, ``'1.000000e-05'``, + ``'30 seconds'``, ``'0.8321'``). swanlab cannot build a line chart + from a string value and emits one warning per key per step. We extract + the leading number where possible; keys whose value can't be parsed + as a scalar are left as-is so they still show up in the text log. + """ + coerced: Dict[str, Any] = {} + for k, v in log_dict.items(): + if isinstance(v, bool) or isinstance(v, (int, float)): + coerced[k] = v + continue + if isinstance(v, str): + m = _LEADING_NUMBER_RE.search(v) + if m: + try: + coerced[k] = float(m.group()) + continue + except ValueError: + pass + coerced[k] = v + return coerced + + +def _last_assistant_text(trajectory: Dict[str, Any]) -> Optional[str]: + """Return the text of the last ``assistant`` message, or ``None``. + + ``content`` can be ``str`` | ``None`` | ``dict`` (single multimodal + part) | ``list[dict]`` (multiple parts). The downstream caller feeds + this into ``_BOXED_RE.search(...)``, so we collapse the visible text + into a single string and ignore non-text parts (images etc.). + """ + for m in reversed(trajectory.get('messages', [])): + if m.get('role') != 'assistant': + continue + c = m.get('content') + if c is None: + return None + if isinstance(c, str): + return c + if isinstance(c, dict): + return c.get('text') if c.get('type') == 'text' else None + if isinstance(c, list): + parts = [p.get('text') or '' for p in c + if isinstance(p, dict) and p.get('type') == 'text'] + return '\n'.join(parts) if parts else None + return str(c) + return None + + +def _compute_rollout_diagnostics( + trajectories: List[Dict[str, Any]], + n_turns_per_rollout: List[int], + per_rollout_completion_length: List[int], + f1_rewards: Optional[List[float]] = None, + old_logps: Optional[List[List[float]]] = None, +) -> Dict[str, float]: + """Aggregate rollout diagnostics for swanlab logging. + + Stripped-down version of the condensed variant's diagnostics — without + chunking we only care about (a) the longest non-trainable prefix + (system prompt + full passages), and (b) whether the rollout produced + a `\\boxed{}` final answer at all. ``avg_turns`` is logged for symmetry + even though it should be exactly 1.0 with ``MAX_TURNS=1``. + """ + out: Dict[str, float] = {} + if n_turns_per_rollout: + out['avg_turns'] = sum(n_turns_per_rollout) / len(n_turns_per_rollout) + + _max_non_trainable = 0 + for t, comp_len in zip(trajectories, per_rollout_completion_length): + ids = t.get('input_ids') or [] + non_trainable = max(0, len(ids) - int(comp_len or 0)) + if non_trainable > _max_non_trainable: + _max_non_trainable = non_trainable + out['non_trainable_tokens'] = _max_non_trainable + + if trajectories: + n_no_boxed = sum( + 0 if _BOXED_RE.search(_last_assistant_text(t) or '') else 1 + for t in trajectories) + out['no_boxed_rate'] = n_no_boxed / len(trajectories) + + def _content_chars(c: Any) -> int: + if not c: + return 0 + if isinstance(c, str): + return len(c) + if isinstance(c, dict): + if c.get('type') == 'text': + return len(c.get('text') or '') + return 0 + if isinstance(c, list): + total = 0 + for part in c: + if isinstance(part, dict) and part.get('type') == 'text': + total += len(part.get('text') or '') + elif isinstance(part, str): + total += len(part) + return total + # Unknown shape -- fall back to ``str()`` length rather than + # crashing, so a template quirk never breaks metric logging. + return len(str(c)) + + msg_chars_total, prompt_chars, asst_chars = [], [], [] + for t in trajectories: + total_i = prompt_i = asst_i = 0 + for m in (t.get('messages') or []): + role = m.get('role') + if role == 'system': + continue + n = _content_chars(m.get('content')) + total_i += n + if role in ('user', 'tool'): + prompt_i += n + elif role == 'assistant': + asst_i += n + msg_chars_total.append(total_i) + prompt_chars.append(prompt_i) + asst_chars.append(asst_i) + out['avg_chars_total_no_sys'] = sum(msg_chars_total) / len(msg_chars_total) + out['avg_chars_prompt_no_sys'] = sum(prompt_chars) / len(prompt_chars) + out['avg_chars_assistant'] = sum(asst_chars) / len(asst_chars) + + if f1_rewards is not None and old_logps is not None and f1_rewards: + per_traj_mean = [(sum(lp) / len(lp)) if lp else 0.0 for lp in old_logps] + pos_logp = [m for m, f1 in zip(per_traj_mean, f1_rewards) if f1 > 0] + zero_logp = [m for m, f1 in zip(per_traj_mean, f1_rewards) if f1 <= 0] + out['f1_correct_rate'] = len(pos_logp) / len(f1_rewards) + out['f1_zero_rate'] = len(zero_logp) / len(f1_rewards) + out['mean_old_logp_f1_pos'] = (sum(pos_logp) / len(pos_logp)) if pos_logp else 0.0 + out['mean_old_logp_f1_zero'] = (sum(zero_logp) / len(zero_logp)) if zero_logp else 0.0 + out['policy_confidence_f1_pos'] = math.exp(out['mean_old_logp_f1_pos']) + out['policy_confidence_f1_zero'] = math.exp(out['mean_old_logp_f1_zero']) + return out + + +def main(): + swanlab.init(project='twinkle') + + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, + groups=device_groups, lazy_collect=False) + + logger.info('Building HotpotQA dataset (baseline, full context)') + _prebuilt_dataset = create_hotpotqa_dataset() + logger.info('Dataset ready: %d rows', len(_prebuilt_dataset)) + + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + batches_per_epoch = max(1, len(_prebuilt_dataset) // GLOBAL_BATCH_SIZE) + # Single-turn baseline: every rollout produces exactly one assistant + # turn, so the per-batch optim-step count equals + # ceil(GLOBAL_BATCH_SIZE * NUM_GENERATIONS / MINI_BATCH_SIZE). + optim_steps_per_batch = max(1, (GLOBAL_BATCH_SIZE * NUM_GENERATIONS + + MINI_BATCH_SIZE - 1) // MINI_BATCH_SIZE) + steps_per_epoch = batches_per_epoch * optim_steps_per_batch + derived_total_steps = NUM_EPOCHS * steps_per_epoch + total_steps = min(MAX_STEPS, derived_total_steps) if MAX_STEPS > 0 else derived_total_steps + logger.info('Training horizon: %d steps (%d epochs × %d batches × %d steps/batch)', + total_steps, NUM_EPOCHS, batches_per_epoch, optim_steps_per_batch) + + lora_config = LoraConfig( + target_modules='all-linear', r=LORA_RANK, + lora_alpha=LORA_RANK * 2, lora_dropout=0.05) + + if USE_MEGATRON: + from twinkle.model.megatron import MegatronModel + model = MegatronModel( + model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', + mixed_precision='bf16', variable_seq_lengths=True) + else: + model = TransformersModel( + model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model') + + model.add_adapter_to_model(ADAPTER_NAME, lora_config, + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + if USE_MEGATRON: + model.set_optimizer('default', lr=LEARNING_RATE) + model.set_lr_scheduler('default', lr_decay_steps=total_steps, max_lr=LEARNING_RATE) + else: + model.set_optimizer('AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler('CosineAnnealingLR', T_max=total_steps, eta_min=0) + + model.set_loss('GRPOLoss', epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH, + beta=KL_BETA, entropy_coef=ENTROPY_COEF) + model.set_processor(InputProcessor, padding_free=True) + model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH) + + model.add_metric('GRPOMetric', is_training=True, + epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH, + top_k_kl=HIGH_KL_TOPK) + + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'gpu_memory_utilization': 0.8, 'max_model_len': 32768, + 'max_lora_rank': 32, 'enable_lora': True, + 'enable_tower_connector_lora': True, + }, + device_mesh=sampler_mesh, remote_group='sampler') + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH) + rollout_template = Qwen3_5Template( + MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH, enable_thinking=False) + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) + + dataloader = DataLoader( + dataset=lambda: _prebuilt_dataset, + batch_size=GLOBAL_BATCH_SIZE, min_batch_size=GLOBAL_BATCH_SIZE) + + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + sampling_params = SamplingParams( + max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, + temperature=1.0, top_p=0.95) + + def _trace_should_store(traj): + return True + + def _trace_is_success(traj): + return _F1_REWARD([traj])[0] > 0.0 + + rollout = MultiTurnRollout( + sampler=sampler, + template=rollout_template, + tool_manager=ToolManager(), + sampling_params=sampling_params, + max_turns=MAX_TURNS, + trace_dir=_ROLLOUT_TRACE_DIR or None, + trace_callback=_trace_should_store, + success_callback=_trace_is_success, + ) + + optim_step = 0 + logger.info('Starting HotpotQA GRPO baseline (no chunk / no condense / no tools)') + + def _epoch_cycle(dl, n_epochs): + for ep in range(1, n_epochs + 1): + logger.info(f'=== Epoch {ep}/{n_epochs} (step={optim_step}/{total_steps}) ===') + for batch in dl: + yield batch + + for batch in _epoch_cycle(dataloader, NUM_EPOCHS): + if optim_step >= total_steps: + break + + # Single source of truth for the step shown in swanlab / logger / rollout-trace filename. + batch_step = optim_step + + metrics.reset() + expand_prompts = [p for prompt in batch for p in [prompt] * NUM_GENERATIONS] + + ckpt_manager.sync_weights(merge_and_sync=False) + sampler.reset_prefix_cache() + + # Single batched rollout: each trajectory produces exactly one + # assistant turn (tools are unregistered, ``max_turns=1``). + all_trajectories: List[Dict[str, Any]] = rollout(expand_prompts) + n_turns_per_rollout = [int(t.get('turns') or 0) for t in all_trajectories] + per_rollout_completion_length = [ + sum(1 for l in (t.get('labels') or []) if l != -100) + for t in all_trajectories] + + total_rewards, f1_rewards, cot_rewards = compute_rewards(all_trajectories) + + rollout_advantages = advantage_fn( + total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() + + all_f1_labels: List[bool] = [f > 0 for f in f1_rewards] + n_pos = sum(1 for p in all_f1_labels if p) + n_neg = sum(1 for p in all_f1_labels if not p) + pos_with_neg_adv = sum(1 for p, a in zip(all_f1_labels, rollout_advantages) if p and a < 0) + neg_with_pos_adv = sum(1 for p, a in zip(all_f1_labels, rollout_advantages) if not p and a > 0) + + all_old_logps: List[List[float]] = [ + [lp[0][1] for lp in (t.get('logprobs') or [])] for t in all_trajectories] + + # Skip homogeneous groups where gradient signal is meaningless + f1_pos_rate = n_pos / len(f1_rewards) if f1_rewards else 0.5 + if f1_pos_rate > 0.9 or f1_pos_rate < 0.1: + logger.info('[skip-homogeneous] f1_pos_rate=%.3f, skipping training update', f1_pos_rate) + metrics.accumulate( + completion_lengths=per_rollout_completion_length, + rewards={'total': total_rewards, 'f1': f1_rewards, 'cot': cot_rewards}) + log_dict = metrics.calculate() + log_dict.update(_compute_rollout_diagnostics( + all_trajectories, n_turns_per_rollout, per_rollout_completion_length, + f1_rewards=f1_rewards, old_logps=all_old_logps)) + log_dict['skipped'] = True + log_dict['pos_neg_adv_rate'] = pos_with_neg_adv / n_pos if n_pos else 0.0 + log_dict['neg_pos_adv_rate'] = neg_with_pos_adv / n_neg if n_neg else 0.0 + log_dict['adv_max'] = max(rollout_advantages) if rollout_advantages else 0.0 + log_dict['adv_min'] = min(rollout_advantages) if rollout_advantages else 0.0 + swanlab.log(_coerce_for_swanlab(log_dict), step=batch_step) + metrics.reset() + logger.info(f'[Step {batch_step}/{total_steps}] [SKIPPED] {log_dict}') + optim_step += optim_steps_per_batch + continue + + metrics.accumulate( + completion_lengths=per_rollout_completion_length, + rewards={'total': total_rewards, 'f1': f1_rewards, 'cot': cot_rewards}) + + all_input_data: List[Any] = list(all_trajectories) + advantages: List[float] = list(rollout_advantages) + + total_completions = len(all_input_data) + aligned_completions = (total_completions // MODEL_GPUS) * MODEL_GPUS + if aligned_completions < total_completions: + logger.info( + '[dp-align] dropping %d tail sample(s): total=%d -> aligned=%d (dp=%d)', + total_completions - aligned_completions, + total_completions, aligned_completions, MODEL_GPUS) + for mb_start in range(0, aligned_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, aligned_completions) + mb_inputs = all_input_data[mb_start:mb_end] + # Reference log-probs for KL: same policy with LoRA disabled (= base model). + ref_logps = None + if KL_BETA > 0.0: + ref_outputs = model.forward_only(inputs=mb_inputs, disable_lora=True) + ref_logps = ref_outputs.get('logps') if isinstance(ref_outputs, dict) else getattr(ref_outputs, 'logps', None) + model.forward_backward( + inputs=mb_inputs, + old_logps=all_old_logps[mb_start:mb_end], + advantages=advantages[mb_start:mb_end], + ref_logps=ref_logps, + positive_mask=all_f1_labels[mb_start:mb_end], + micro_batch_size=MICRO_BATCH_SIZE) + model.clip_grad_and_step() + optim_step += 1 + if optim_step >= total_steps: + break + if optim_step % SAVE_STEPS == 0: + model.save(f'hotpotqa-grpo-baseline-checkpoint-{optim_step}') + + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric(is_training=True)) + log_dict.update(_compute_rollout_diagnostics( + all_trajectories, n_turns_per_rollout, per_rollout_completion_length, + f1_rewards=f1_rewards, old_logps=all_old_logps)) + log_dict['pos_neg_adv_rate'] = pos_with_neg_adv / n_pos if n_pos else 0.0 + log_dict['neg_pos_adv_rate'] = neg_with_pos_adv / n_neg if n_neg else 0.0 + log_dict['adv_max'] = max(rollout_advantages) if rollout_advantages else 0.0 + log_dict['adv_min'] = min(rollout_advantages) if rollout_advantages else 0.0 + # Pop high-KL token records before swanlab.log: list-of-dict won't render as a chart. + _hk = log_dict.pop('_high_kl_records', None) + if _hk: + _tok = rollout_template.tokenizer + for r in _hk: + gsi = r.get('gsi') + tid = all_trajectories[gsi].get('id') if gsi is not None and 0 <= gsi < len(all_trajectories) else None + try: + tok_text = _tok.decode([r['token_id']]) + except Exception: + tok_text = None + logger.info( + '[high-kl] step=%d gsi=%s tid=%s pos=%s tok=%r kl=%.4f r=%.4f lp_new=%.4f lp_old=%.4f', + batch_step, gsi, tid, r.get('pos'), tok_text, + r.get('kl'), r.get('ratio'), r.get('logp_new'), r.get('logp_old')) + swanlab.log(_coerce_for_swanlab(log_dict), step=batch_step) + metrics.reset() + logger.info(f'[Step {batch_step}/{total_steps}] {log_dict}') + + logger.info(f'Training completed. optim_steps={optim_step}') + model.save('hotpotqa-grpo-baseline-final') + + +if __name__ == '__main__': + main() diff --git a/cookbook/exp/grpo_condensed.py b/cookbook/exp/grpo_condensed.py new file mode 100644 index 00000000..83eb49ac --- /dev/null +++ b/cookbook/exp/grpo_condensed.py @@ -0,0 +1,955 @@ +import copy +import math +import os +import re +from typing import Any, Dict, List, Optional + +import torch +import swanlab +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.data_format import Message, SamplingParams, Trajectory +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.metric import CompletionRewardMetric +from twinkle.model import TransformersModel +from twinkle.preprocessor.base import Preprocessor +from twinkle.processor import InputProcessor +from twinkle.sampler import vLLMSampler +from twinkle.template import Qwen3_5Template +from twinkle_agentic.chunker.native import NativeChunker +from twinkle_agentic.condenser import ModelCondenser +from twinkle_agentic.reward import F1Reward, CoTReward, ToolExploreReward +from twinkle_agentic.rollout.multi_turn_condense import MultiTurnCondenseRollout +from twinkle_agentic.tools.tool_manager import ToolManager + +logger = get_logger() + +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0'))) + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) +NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', 1)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 0)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +ADAPTER_NAME = 'default' +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) +LORA_RANK = int(os.environ.get('LORA_RANK', 16)) + +MAX_TURNS = int(os.environ.get('MAX_TURNS', 4)) +MAX_TRAJECTORY_TOKENS = int(os.environ.get('MAX_TRAJECTORY_TOKENS', 8192)) +CHUNK_SIZE = int(os.environ.get('CHUNK_SIZE', 1024)) + +HOTPOTQA_NUM_PROC = int(os.environ.get('HOTPOTQA_NUM_PROC', 16)) +HOTPOTQA_MAX_LENGTH = int(os.environ.get('HOTPOTQA_MAX_LENGTH', 64000)) + +F1_REWARD_WEIGHT = float(os.environ.get('F1_REWARD_WEIGHT', 1.0)) +COT_REWARD_WEIGHT = float(os.environ.get('COT_REWARD_WEIGHT', 0)) +TOOL_BONUS_WEIGHT = float(os.environ.get('TOOL_BONUS_WEIGHT', 0.0)) +TOOL_BONUS_F1_THRESHOLD = float( + os.environ.get('TOOL_BONUS_F1_THRESHOLD', 0.5)) + +# KL penalty coefficient; 0 disables KL (and skips the ref forward pass entirely). +# CISPO is token-level and DOES support per-token KL — small positive value (e.g. 0.005) recommended as anchor. +KL_BETA = float(os.environ.get('KL_BETA', 0.01)) + +# Entropy bonus coefficient; 0 disables the entropy compute path entirely. +# Typical GRPO values: 0.001–0.01. Loss is: L = L_PPO + beta*KL - entropy_coef*H. +ENTROPY_COEF = float(os.environ.get('ENTROPY_COEF', 0.0)) + +# Per-token oracle bonus coefficient; 0 disables. Typical: 0.05–0.2. +# Loss becomes: L = L_PPO + beta*KL - entropy_coef*H - token_bonus_coef*(oracle_logps - rollout_logps) +ORACLE_BONUS_COEF = float(os.environ.get('ORACLE_BONUS_COEF', 0.0)) + +# CISPO token-level IS clamp thresholds (MiniMax CISPO defaults: 0.2 / 0.28 asymmetric). +CISPO_EPS_LOW = float(os.environ.get('CISPO_EPS_LOW', 0.2)) +CISPO_EPS_HIGH = float(os.environ.get('CISPO_EPS_HIGH', 0.2)) + +# High-KL token capture: top-K per microbatch dumped into log_dict['_high_kl_records']. 0 = disabled. +HIGH_KL_TOPK = int(os.environ.get('HIGH_KL_TOPK', 0)) + +INIT_LORA_PATH = os.environ.get('INIT_LORA_PATH', 'output/condensed_sft_ddp/last-checkpoint') +DATASET_PATH = os.environ.get( + 'DATASET_PATH', + os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + 'hotpotqa_fullwiki_reannotated_12k.jsonl')) +F1_BINARY_THRESHOLD = float(os.environ.get('F1_BINARY_THRESHOLD', 0.5)) + +_ROLLOUT_TRACE_DIR = os.environ.get('ROLLOUT_TRACE_DIR', 'rollout_trace') +ORACLE_HINT = bool(int(os.environ.get('ORACLE_HINT', '0'))) + + +# [EXP-ORACLE] staged hint injection — appended to the Question line so skip_pattern keeps it uncompressed. +def _oracle_hint_stage(step: int, total_steps: int) -> int: + """0 = explicit titles, 1 = vague count, 2 = no hint.""" + return 0 + # if total_steps <= 0: + # return 0 + # third = max(1, total_steps // 3) + # if step < third: + # return 0 + # if step < 2 * third: + # return 1 + # return 2 + + + +def _make_oracle_hint_callback(total_steps: int): + """Return a post_compress_callback that injects oracle hints with actual block IDs. + + Called by MultiTurnCondenseRollout after compression + metadata merge, so + ``compressed['user_data']`` carries sf_titles and ``chunks`` carries the + condensed/raw status of each passage. + + Stages (determined by global_step / total_steps): + 0 — explicit block IDs for supporting-fact passages + 1 — block count only (no IDs) + 2 — no hint + """ + _q_split = re.compile(r'(Question:\s*.+?)(\n\nContext:)', re.DOTALL) + + def _callback(compressed, chunks, **kwargs): + step = kwargs.get('global_step', 0) + stage = _oracle_hint_stage(step, total_steps) + if stage == 2: + return compressed + + user_data = compressed.get('user_data') or [] + sf_titles = [v for k, v in user_data if k == 'sf_title' and v] + if not sf_titles: + return compressed + sf_set = set(sf_titles) + + # Map sf_titles → block IDs by walking condensed chunks + block_id = 0 + sf_block_ids = [] + for c in chunks.chunks: + if c.get('type') != 'text': + continue + content = c.get('content') + if not isinstance(content, str) or not content: + continue + if c.get('role') == 'tool': + continue + raw = c.get('raw') + if not (isinstance(raw, dict) and raw.get('condensed')): + continue + block_id += 1 + original = raw.get('original', '') + if isinstance(original, str): + for title in sf_set: + if original.startswith(f'{title}: ') or original.startswith(f'{title}:'): + sf_block_ids.append(block_id) + break + + if stage == 0: + if sf_block_ids: + ids_str = ', '.join(str(b) for b in sf_block_ids) + hint = (f'\n[Oracle Hint] Block {ids_str} contain(s) the supporting facts. ' + 'Call `extract_condensed` to expand them if you need more detail information.') + else: + n = len(sf_set) + word = {1: 'One', 2: 'Two', 3: 'Three'}.get(n, str(n)) + hint = (f'\n[Oracle Hint] {word} short passage(s) contain the supporting facts; ' + 'they are uncompressed — read them directly.') + else: + hint = (f'\n[Oracle Hint] Some compressed block(s) contain the supporting facts; ' + 'call `extract_condensed` to expand them if you need more detail information.') + + for m in (compressed.get('messages') or []): + if m.get('role') != 'user': + continue + c = m.get('content') + if isinstance(c, str): + m['content'] = _q_split.sub( + lambda g: g.group(1) + hint + g.group(2), c, count=1) + elif isinstance(c, list): + for part in c: + if isinstance(part, dict) and part.get('type') == 'text': + part['text'] = _q_split.sub( + lambda g: g.group(1) + hint + g.group(2), + part.get('text') or '', count=1) + break + break + return compressed + + return _callback + +SYSTEM_PROMPT = """You are a careful multi-hop QA assistant. + +## Context Format (Mixed) +The context you receive is a **mix of two forms**: + +1. **Compressed blocks** — long passages wrapped in `...`, \ + displayed as a Markdown digest in **telegraphic style** (no \ + articles / "is" / "are"; colons and commas mean "is" / "has") \ + with two sections: + - **Summary**: overview plus facts strongly related to the question, stated explicitly. + - **More**: a collapsed INDEX of category keywords hinting at extra details hidden in the full text (call `extract_condensed` to see them). + Reading example: `India: 7th largest by area. Borders: Pakistan, \ + China.` means "India is the 7th largest country by area and \ + shares borders with Pakistan and China." +2. **Raw passages** — short passages shown inline as plain text (`Title: \ + body`) **without** any `` wrapping. These are already the full \ + text; nothing is hidden. + +Only the ``-wrapped blocks are compressed and can be expanded. \ +Block ids `N` are 1-based and assigned in the order compressed blocks \ +appear in the context, so they are always contiguous (``, \ +``, ``, ...). Raw passages have no block id and cannot \ +be extracted — they are already complete. + +## Workflow + +### Phase 1 — Scan and Decide +Step 1: Read each compressed block's Summary, and read raw \ +passages directly, to get an overview. +Step 2: For compressed blocks, check the More keywords to judge whether \ +hidden details are needed. +Step 3: Decide which compressed blocks to expand, then call \ +`extract_condensed` with their block ids. Raw passages need no extraction. + +### Phase 2 — Reason and Answer +After the tool returns the full text, continue stepping through the evidence: +Step N: From block X (or the raw passage titled "..."), I learn that [fact A]. +Step N+1: From block Y, I need to call `extract_condensed` to get more information, because this block is related to... +Step N+2: Combining these, the answer is ... +\\boxed{answer} + +You may call `extract_condensed` several times to expand more blocks if the information is not enough, only answer the question if you are sure about the facts. +The `blocks` parameter accepts **exactly one integer** per call (e.g. `3`); lists are rejected. Expand additional blocks by issuing separate `extract_condensed` calls, one per block. Only pass ids that actually appear as `` in the context, and do **not** request the same block twice — its text is already in the conversation after the first expansion. + +## Tool Call Format + + + +3 + + + + +## Output Format +End your final response with \\boxed{answer}, e.g. \\boxed{Delhi}. +Keep the boxed text short: a name, entity, date, or "yes"/"no". +Answers not inside \\boxed{} will not be scored.""" + + +_F1_REWARD: Optional[F1Reward] = F1Reward() +_COT_REWARD: Optional[CoTReward] = CoTReward() +_TOOL_EXPLORE_REWARD: Optional[ToolExploreReward] = ToolExploreReward( + f1_threshold=TOOL_BONUS_F1_THRESHOLD) + + +def compute_rewards(trajectories: List[Dict[str, Any]]): + f1_raw = _F1_REWARD(trajectories) + f1 = [1.0 if v >= F1_BINARY_THRESHOLD else 0.0 for v in f1_raw] if F1_BINARY_THRESHOLD > 0 else f1_raw + cot = _COT_REWARD(trajectories) + tool_explore = _TOOL_EXPLORE_REWARD(trajectories) + total = [ + F1_REWARD_WEIGHT * a + COT_REWARD_WEIGHT * c + TOOL_BONUS_WEIGHT * te + for a, c, te in zip(f1, cot, tool_explore) + ] + return total, f1, cot, tool_explore + + +class HotpotQAProcessor(Preprocessor): + def __init__(self, system: str = SYSTEM_PROMPT): + self.system = system + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = [r for r in rows if r is not None] + rows = self.map_row_to_col(rows) + return rows + + @staticmethod + def _format_context(context: Dict[str, Any]) -> str: + titles = context.get('title', []) or [] + sentences = context.get('sentences', []) or [] + lines = [] + for title, sents in zip(titles, sentences): + if isinstance(sents, list): + body = ' '.join(s.strip() for s in sents if s and s.strip()) + else: + body = str(sents).strip() + lines.append(f'{title}: {body}') + return '\n\n'.join(lines) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: + if (row.get('verdict') or '').strip().lower() == 'drop': + return None + question = row.get('question_fixed') or row['question'] + answers = row.get('answers') + if isinstance(answers, list) and answers: + gold = [str(a).strip() for a in answers if str(a).strip()] + else: + gold = [s for s in [(row.get('answer', '') or '').strip()] if s] + context_block = self._format_context(row.get('context', {}) or {}) + user_msg = f'Question: {question}\n\nContext:\n\n{context_block}' + messages = [ + Message(role='system', content=self.system), + Message(role='user', content=user_msg), + ] + # [EXP-ORACLE] carry supporting_facts titles via user_data; rollout injects post-compression block hint + sf = row.get('supporting_facts') or {} + sf_titles = sf.get('title') or [] + sf_unique = list(dict.fromkeys(t for t in sf_titles if t)) + user_data = [('ground_truth', g) for g in gold] + [('sf_title', t) for t in sf_unique] + return Trajectory(messages=messages, user_data=user_data) + + +def create_hotpotqa_dataset() -> Dataset: + dataset = Dataset() + dataset.add_dataset(DatasetMeta(DATASET_PATH)) + logger.info('[dataset] loaded %s: %d rows', DATASET_PATH, len(dataset)) + + dataset.set_template( + 'Qwen3_5Template', model_id=MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH, + truncation_strategy='delete', enable_thinking=False) + _HOTPOTQA_COLS = ['id', 'question', 'question_fixed', 'answers', + 'original_answer', 'type', 'level', 'verdict', + 'reasoning', 'supporting_facts', 'context'] + dataset.map(HotpotQAProcessor(system=SYSTEM_PROMPT), remove_columns=_HOTPOTQA_COLS) + return dataset + + +# Matches a LaTeX ``\boxed{...}`` final-answer marker — used to flag +# rollouts that never committed an answer. Brace-balanced is overkill for +# a logging heuristic; a non-greedy ``[^}]*`` is good enough. +_BOXED_RE = re.compile(r'\\boxed\{[^}]*\}') + +# Pulls the leading number out of pre-formatted metric strings such as +# ``'0.03 iters/s'`` / ``'1.000000e-05'`` / ``'30 seconds'`` emitted by +# ``TrainMetric`` and ``GRPOMetric``. We use this in ``_coerce_for_swanlab`` +# so swanlab can build line charts instead of dropping those keys with a +# ``failed to create chart for key '...': invalid value type`` warning. +_LEADING_NUMBER_RE = re.compile(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?') + + +def _coerce_for_swanlab(log_dict: Dict[str, Any]) -> Dict[str, Any]: + """Cast string-valued metrics to float for swanlab line charts. + + ``TrainMetric.calculate()`` and ``GRPOMetric.calculate()`` return + pre-formatted strings (``'0.03 iters/s'``, ``'1.000000e-05'``, + ``'30 seconds'``, ``'0.8321'``). swanlab cannot build a line chart + from a string value and emits one warning per key per step. We extract + the leading number where possible; keys whose value can't be parsed + as a scalar are left as-is so they still show up in the text log. + """ + coerced: Dict[str, Any] = {} + for k, v in log_dict.items(): + if isinstance(v, bool) or isinstance(v, (int, float)): + coerced[k] = v + continue + if isinstance(v, str): + m = _LEADING_NUMBER_RE.search(v) + if m: + try: + coerced[k] = float(m.group()) + continue + except ValueError: + pass + coerced[k] = v + return coerced + + +def _last_assistant_text(trajectory: Dict[str, Any]) -> Optional[str]: + """Return the text of the last ``assistant`` message, or ``None``. + + ``content`` can be ``str`` | ``None`` | ``dict`` (single multimodal + part) | ``list[dict]`` (multiple parts). The downstream caller feeds + this into ``_BOXED_RE.search(...)``, so we collapse the visible text + into a single string and ignore non-text parts (images etc.). + """ + for m in reversed(trajectory.get('messages', [])): + if m.get('role') != 'assistant': + continue + c = m.get('content') + if c is None: + return None + if isinstance(c, str): + return c + if isinstance(c, dict): + return c.get('text') if c.get('type') == 'text' else None + if isinstance(c, list): + parts = [p.get('text') or '' for p in c + if isinstance(p, dict) and p.get('type') == 'text'] + return '\n'.join(parts) if parts else None + return str(c) + return None + + +def _compute_rollout_diagnostics( + trajectories: List[Dict[str, Any]], + n_turns_per_rollout: List[int], + per_rollout_completion_length: List[int], + f1_rewards: Optional[List[float]] = None, + old_logps: Optional[List[List[float]]] = None, +) -> Dict[str, float]: + """Aggregate rollout diagnostics for swanlab logging. + + All inputs are already flat: + * ``trajectories[i]`` is the merged trajectory dict returned by + :class:`MultiTurnCondenseRollout` (contains ``messages``, + ``input_ids``, ``labels``, ``turns`` at top level). + * ``n_turns_per_rollout[i] == trajectories[i]['turns']``. + * ``per_rollout_completion_length[i]`` == number of trainable + tokens in the trajectory (labels != -100). + """ + out: Dict[str, float] = {} + if n_turns_per_rollout: + out['avg_turns'] = sum(n_turns_per_rollout) / len(n_turns_per_rollout) + + # ``non_trainable_tokens`` is the longest non-trainable prefix across + # the batch: ``len(input_ids) - sum(1 for l in labels if l != -100)``. + # Tracks how much the condensed context + system prompt is eating the + # context budget (it does NOT equal the first-turn prompt length + # because multi-turn runs also contribute non-trainable tokens from + # the ``tool`` observations between assistant turns). + _max_non_trainable = 0 + for t, comp_len in zip(trajectories, per_rollout_completion_length): + ids = t.get('input_ids') or [] + non_trainable = max(0, len(ids) - int(comp_len or 0)) + if non_trainable > _max_non_trainable: + _max_non_trainable = non_trainable + out['non_trainable_tokens'] = _max_non_trainable + + if trajectories: + tool_counts = [ + sum(len(m.get('tool_calls') or []) + for m in t.get('messages', []) if m.get('role') == 'assistant') + for t in trajectories] + out['avg_tool_calls'] = sum(tool_counts) / len(tool_counts) + out['tool_use_rate'] = sum(1 for c in tool_counts if c > 0) / len(tool_counts) + n_no_boxed = sum( + 0 if _BOXED_RE.search(_last_assistant_text(t) or '') else 1 + for t in trajectories) + out['no_boxed_rate'] = n_no_boxed / len(trajectories) + def _content_chars(c: Any) -> int: + if not c: + return 0 + if isinstance(c, str): + return len(c) + if isinstance(c, dict): + if c.get('type') == 'text': + return len(c.get('text') or '') + return 0 + if isinstance(c, list): + total = 0 + for part in c: + if isinstance(part, dict) and part.get('type') == 'text': + total += len(part.get('text') or '') + elif isinstance(part, str): + total += len(part) + return total + # Unknown shape -- fall back to ``str()`` length rather than + # crashing, so a template quirk never breaks metric logging. + return len(str(c)) + + msg_chars_total, prompt_chars, asst_chars = [], [], [] + for t in trajectories: + total_i = prompt_i = asst_i = 0 + for m in (t.get('messages') or []): + role = m.get('role') + if role == 'system': + continue + n = _content_chars(m.get('content')) + total_i += n + if role in ('user', 'tool'): + prompt_i += n + elif role == 'assistant': + asst_i += n + msg_chars_total.append(total_i) + prompt_chars.append(prompt_i) + asst_chars.append(asst_i) + out['avg_chars_total_no_sys'] = sum(msg_chars_total) / len(msg_chars_total) + out['avg_chars_prompt_no_sys'] = sum(prompt_chars) / len(prompt_chars) + out['avg_chars_assistant'] = sum(asst_chars) / len(asst_chars) + + if f1_rewards is not None and old_logps is not None and f1_rewards: + per_traj_mean = [ + (sum(lp) / len(lp)) if lp else 0.0 for lp in old_logps] + pos_logp = [m for m, f1 in zip(per_traj_mean, f1_rewards) if f1 > 0] + zero_logp = [m for m, f1 in zip(per_traj_mean, f1_rewards) if f1 <= 0] + out['f1_correct_rate'] = len(pos_logp) / len(f1_rewards) + out['f1_zero_rate'] = len(zero_logp) / len(f1_rewards) + out['mean_old_logp_f1_pos'] = (sum(pos_logp) / len(pos_logp)) if pos_logp else 0.0 + out['mean_old_logp_f1_zero'] = (sum(zero_logp) / len(zero_logp)) if zero_logp else 0.0 + out['policy_confidence_f1_pos'] = math.exp(out['mean_old_logp_f1_pos']) + out['policy_confidence_f1_zero'] = math.exp(out['mean_old_logp_f1_zero']) + return out + + +def _build_oracle_inputs( + mb_inputs: List[Dict[str, Any]], + f1_labels: List[bool], + template, +) -> Optional[List[Dict[str, Any]]]: + """Build oracle-context inputs at the TOKEN level for per-token bonus computation. + + The approach: + 1. Find ``first_trainable`` from labels (first position != -100). + Due to NTP shift, input_ids[first_trainable] is the last prefix token (e.g. \\n + after ``assistant``) and labels[first_trainable] is the first response token target. + 2. Construct oracle messages: [system, user_with_oracle_suffix]. + 3. Encode with template (add_generation_prompt=True) → oracle_prefix_ids ending with + the same assistant header token. + 4. Concatenate: oracle_prefix_ids + input_ids[first_trainable+1:] (response tokens). + 5. Labels: [-100]*(len(oracle_prefix)-1) + labels[first_trainable:] so the last prefix + position predicts the first response token. + + For F1=0 samples: copied unchanged (bonus zeroed by _compute_token_bonus). + """ + _q_line_re = re.compile(r'Question:\s*(.+?)(?:\n|$)', re.DOTALL) + oracle_inputs = [] + any_modified = False + + for inp, is_pos in zip(mb_inputs, f1_labels): + if not is_pos: + oracle_inputs.append(inp) + continue + + user_data = inp.get('user_data') or [] + sf_titles = [v for k, v in user_data if k == 'sf_title' and v] + gts = [v for k, v in user_data if k == 'ground_truth' and v] + if not sf_titles and not gts: + oracle_inputs.append(inp) + continue + + labels = inp.get('labels') or [] + input_ids = inp.get('input_ids') or [] + if not labels or not input_ids: + oracle_inputs.append(inp) + continue + + # 1. Find first trainable position + first_trainable = None + for i, l in enumerate(labels): + if l != -100: + first_trainable = i + break + + assert first_trainable is not None + + # 2. Extract question from first user message + question = None + msgs = inp.get('messages') or [] + for m in msgs: + if m.get('role') != 'user': + continue + c = m.get('content') + text = c if isinstance(c, str) else ( + next((p.get('text') for p in c if isinstance(p, dict) and p.get('type') == 'text'), '') + if isinstance(c, list) else '') + q_match = _q_line_re.match(text or '') + if q_match: + question = q_match.group(1).strip() + break + + if not question: + oracle_inputs.append(inp) + continue + + # 3. Build oracle user message (concise: question + oracle hints only) + hint_parts = [] + if sf_titles: + hint_parts.append('Supporting passages: ' + ', '.join(f'"{t}"' for t in sf_titles)) + if gts: + hint_parts.append('Answer: ' + '; '.join(gts)) + hint_parts.append('You must call `extract_condensed` to read the right original passage from the condensed block with thinking steps, and give the final correct answer') + oracle_suffix = '\n[Oracle Context] ' + '. '.join(hint_parts) + '.' + oracle_user_content = f'Question: {question}{oracle_suffix}' + + oracle_msgs = [ + Message(role='system', content=SYSTEM_PROMPT), + Message(role='user', content=oracle_user_content), + ] + + # 4. Encode oracle prefix (ends with <|im_start|>assistant\n) + oracle_feature = template.encode( + Trajectory(messages=oracle_msgs), add_generation_prompt=True) + oracle_prefix_ids = list(oracle_feature['input_ids']) + + # 5. Splice: oracle_prefix + response_tokens + response_tokens = list(input_ids[first_trainable + 1:]) + response_labels = list(labels[first_trainable:]) + + oracle_input_ids = oracle_prefix_ids + response_tokens + # Last position of oracle prefix predicts first response token + oracle_labels = [-100] * (len(oracle_prefix_ids) - 1) + response_labels + + assert len(oracle_input_ids) == len(oracle_labels) + seq_len = len(oracle_input_ids) + # Start from original keys to keep collator-compatible shape + oi = dict(inp) + oi['input_ids'] = oracle_input_ids + oi['labels'] = oracle_labels + oi['attention_mask'] = [1] * seq_len + oi['messages'] = None + oi['length'] = seq_len + # Replicate mrope position_ids shape from original input + orig_pos = inp.get('position_ids') + if isinstance(orig_pos, torch.Tensor) and orig_pos.dim() == 3: + n_dims = orig_pos.shape[0] + pos_range = torch.arange(seq_len).unsqueeze(0).unsqueeze(0) + oi['position_ids'] = pos_range.expand(n_dims, 1, seq_len) + else: + oi['position_ids'] = list(range(seq_len)) + if 'mm_token_type_ids' in inp: + oi['mm_token_type_ids'] = torch.zeros(1, seq_len) + oracle_inputs.append(oi) + any_modified = True + + return oracle_inputs if any_modified else None + + +def _compute_token_bonus( + oracle_logps: Any, + old_logps: List[List[float]], + f1_labels: List[bool], + oracle_inputs: List[Dict[str, Any]], +) -> List[List[float]]: + """Compute per-token bonus = oracle_logps - rollout_logps, zeroed for F1=0 samples. + + oracle_logps is full-sequence form [batch, padded_seq] from forward_only + collector. + We extract valid positions using oracle_inputs[i]['labels'] mask to get response-only + logps aligned 1:1 with old_logps. + """ + import torch + + if isinstance(oracle_logps, torch.Tensor): + oracle_logps = oracle_logps.float().cpu() + + bonus = [] + for i, (is_pos, old_lp) in enumerate(zip(f1_labels, old_logps)): + if not is_pos or not old_lp: + bonus.append([0.0] * len(old_lp) if old_lp else []) + continue + + n = len(old_lp) + oracle_labels = oracle_inputs[i].get('labels') or [] + + # Build mask from oracle labels to extract valid (trainable) positions + if isinstance(oracle_logps, torch.Tensor): + orc_row = oracle_logps[i] + mask = torch.tensor([l != -100 for l in oracle_labels], dtype=torch.bool) + seq_len = min(len(mask), orc_row.numel()) + orc_valid = orc_row[:seq_len][mask[:seq_len]].tolist() + else: + orc_row = oracle_logps[i] if i < len(oracle_logps) else [] + if isinstance(orc_row, torch.Tensor): + orc_row = orc_row.float().cpu().tolist() + elif not isinstance(orc_row, (list, tuple)): + orc_row = [] + orc_valid = [v for v, l in zip(orc_row, oracle_labels) if l != -100] + + assert len(orc_valid) == n + bonus.append([o - r for o, r in zip(orc_valid, old_lp)]) + return bonus + + +def main(): + swanlab.init(project='twinkle') + + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, + groups=device_groups, lazy_collect=False) + + logger.info('Building HotpotQA dataset') + _prebuilt_dataset = create_hotpotqa_dataset() + logger.info('Dataset ready: %d rows', len(_prebuilt_dataset)) + + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + batches_per_epoch = max(1, len(_prebuilt_dataset) // GLOBAL_BATCH_SIZE) + optim_steps_per_batch = max(1, (GLOBAL_BATCH_SIZE * NUM_GENERATIONS + + MINI_BATCH_SIZE - 1) // MINI_BATCH_SIZE) + steps_per_epoch = batches_per_epoch * optim_steps_per_batch + derived_total_steps = NUM_EPOCHS * steps_per_epoch + total_steps = min(MAX_STEPS, derived_total_steps) if MAX_STEPS > 0 else derived_total_steps + logger.info('Training horizon: %d steps (%d epochs × %d batches × %d steps/batch)', + total_steps, NUM_EPOCHS, batches_per_epoch, optim_steps_per_batch) + + lora_config = LoraConfig( + target_modules='all-linear', r=LORA_RANK, + lora_alpha=LORA_RANK * 2, lora_dropout=0.05) + + if USE_MEGATRON: + from twinkle.model.megatron import MegatronModel + model = MegatronModel( + model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', + mixed_precision='bf16', variable_seq_lengths=True) + else: + model = TransformersModel( + model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model') + + model.add_adapter_to_model(ADAPTER_NAME, lora_config, + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + if INIT_LORA_PATH: + model.load(INIT_LORA_PATH, adapter_name=ADAPTER_NAME) + logger.info('Loaded cold-start LoRA from %s', INIT_LORA_PATH) + if USE_MEGATRON: + model.set_optimizer('default', lr=LEARNING_RATE) + model.set_lr_scheduler('default', lr_decay_steps=total_steps, max_lr=LEARNING_RATE) + else: + model.set_optimizer('AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler('CosineAnnealingLR', T_max=total_steps, eta_min=0) + + model.set_loss('GRPOLoss', epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH, + beta=KL_BETA, entropy_coef=ENTROPY_COEF, token_bonus_coef=ORACLE_BONUS_COEF) + model.set_processor(InputProcessor, padding_free=True) + model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH) + + model.add_metric('GRPOMetric', is_training=True, + epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH, + top_k_kl=HIGH_KL_TOPK) + + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'gpu_memory_utilization': 0.8, 'max_model_len': 32768, + 'max_lora_rank': 32, 'enable_lora': True, + 'enable_tower_connector_lora': True, + 'max_loras': 5 + }, + device_mesh=sampler_mesh, remote_group='sampler') + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH) + rollout_template = Qwen3_5Template( + MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH, enable_thinking=False) + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) + chunker = NativeChunker( + chunk_size=CHUNK_SIZE, + passage_boundary_re=r'(?<=\n\n)', + ) + # ``\A`` anchor: prevents a ``Question:`` line inside a passage from being misread as the query. + _question_re = re.compile(r'\AQuestion:\s*(.+)') + + def _extract_question(chunk): + content = chunk.get('content') + if chunk.get('type') != 'text' or not isinstance(content, str): + return None + m = _question_re.search(content) + return m.group(1).strip() if m else None + + condenser = ModelCondenser( + sampler=sampler, + compression_ratio=2.0, + sampling_params=SamplingParams( + max_tokens=1024, num_samples=1, temperature=0.4, top_p=0.9), + min_chars=200, + template=rollout_template, + lora_path='ms://twinkle-kit/Qwen3.5-4B-Condenser', + skip_pattern=r'^Question:', + related_query=_extract_question, + ) + + dataloader = DataLoader( + dataset=lambda: _prebuilt_dataset, + batch_size=GLOBAL_BATCH_SIZE, min_batch_size=GLOBAL_BATCH_SIZE) + + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + sampling_params = SamplingParams( + max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, + temperature=1.0, top_p=0.95, + stop=['']) + + def _trace_should_store(traj): + return _F1_REWARD([traj])[0] == 0.0 + + def _trace_is_success(traj): + return _F1_REWARD([traj])[0] > 0.0 + + rollout = MultiTurnCondenseRollout( + sampler=sampler, + template=rollout_template, + tool_manager=ToolManager(), + chunker=chunker, + condenser=condenser, + sampling_params=sampling_params, + max_turns=MAX_TURNS, + max_trajectory_tokens=MAX_TRAJECTORY_TOKENS, + trace_dir=_ROLLOUT_TRACE_DIR or None, + trace_callback=_trace_should_store, + success_callback=_trace_is_success, + post_compress_callback=( + _make_oracle_hint_callback(total_steps) if ORACLE_HINT else None), + ) + + optim_step = 0 + logger.info('Starting HotpotQA GRPO training (LLM condenser variant)') + + def _epoch_cycle(dl, n_epochs): + for ep in range(1, n_epochs + 1): + logger.info(f'=== Epoch {ep}/{n_epochs} (step={optim_step}/{total_steps}) ===') + for batch in dl: + yield batch + + for batch in _epoch_cycle(dataloader, NUM_EPOCHS): + if optim_step >= total_steps: + break + + # Single source of truth for the step shown in swanlab / logger / rollout-trace filename. + # Equals the number of optimizer updates already completed when this rollout was sampled. + batch_step = optim_step + + metrics.reset() + expand_prompts = [p for prompt in batch for p in [prompt] * NUM_GENERATIONS] + + ckpt_manager.sync_weights(merge_and_sync=False) + sampler.reset_prefix_cache() + + # Batched multi-turn rollout with chunk+condense pre-processing. + # Each returned trajectory is a flat dict containing ``messages``, + # ``input_ids``, ``labels``, ``attention_mask``, ``position_ids``, + # ``turns``, ``logprobs``, ``stop_reason``, ``truncated``. + all_trajectories: List[Dict[str, Any]] = rollout(expand_prompts, global_step=batch_step) + n_turns_per_rollout = [int(t.get('turns') or 0) for t in all_trajectories] + per_rollout_completion_length = [ + sum(1 for l in (t.get('labels') or []) if l != -100) + for t in all_trajectories] + + total_rewards, f1_rewards, cot_rewards, tool_explore_rewards = \ + compute_rewards(all_trajectories) + + rollout_advantages = advantage_fn( + total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() + + all_f1_labels: List[bool] = [f > 0 for f in f1_rewards] + n_pos = sum(1 for p in all_f1_labels if p) + n_neg = sum(1 for p in all_f1_labels if not p) + pos_with_neg_adv = sum(1 for p, a in zip(all_f1_labels, rollout_advantages) if p and a < 0) + neg_with_pos_adv = sum(1 for p, a in zip(all_f1_labels, rollout_advantages) if not p and a > 0) + + # Skip homogeneous groups where gradient signal is meaningless + f1_pos_rate = n_pos / len(f1_rewards) if f1_rewards else 0.5 + if f1_pos_rate > 0.9 or f1_pos_rate < 0.1: + logger.info('[skip-homogeneous] f1_pos_rate=%.3f, skipping training update', f1_pos_rate) + metrics.accumulate( + completion_lengths=per_rollout_completion_length, + rewards={'total': total_rewards, 'f1': f1_rewards, + 'cot': cot_rewards, 'tool_explore': tool_explore_rewards}) + log_dict = metrics.calculate() + log_dict.update(_compute_rollout_diagnostics( + all_trajectories, n_turns_per_rollout, per_rollout_completion_length, + f1_rewards=f1_rewards, old_logps=[[lp[0][1] for lp in (t.get('logprobs') or [])] for t in all_trajectories])) + log_dict['skipped'] = True + log_dict['pos_neg_adv_rate'] = pos_with_neg_adv / n_pos if n_pos else 0.0 + log_dict['neg_pos_adv_rate'] = neg_with_pos_adv / n_neg if n_neg else 0.0 + log_dict['adv_max'] = max(rollout_advantages) if rollout_advantages else 0.0 + log_dict['adv_min'] = min(rollout_advantages) if rollout_advantages else 0.0 + swanlab.log(_coerce_for_swanlab(log_dict), step=batch_step) + metrics.reset() + logger.info(f'[Step {batch_step}/{total_steps}] [SKIPPED] {log_dict}') + optim_step += optim_steps_per_batch + continue + + metrics.accumulate( + completion_lengths=per_rollout_completion_length, + rewards={'total': total_rewards, 'f1': f1_rewards, + 'cot': cot_rewards, 'tool_explore': tool_explore_rewards}) + + all_input_data: List[Any] = [] + all_old_logps: List[List[float]] = [] + advantages: List[float] = [] + for t, adv in zip(all_trajectories, rollout_advantages): + all_input_data.append(t) + all_old_logps.append([lp[0][1] for lp in (t.get('logprobs') or [])]) + advantages.append(adv) + + total_completions = len(all_input_data) + aligned_completions = (total_completions // MODEL_GPUS) * MODEL_GPUS + if aligned_completions < total_completions: + logger.info( + '[dp-align] dropping %d tail sample(s): total=%d -> aligned=%d (dp=%d)', + total_completions - aligned_completions, + total_completions, aligned_completions, MODEL_GPUS) + for mb_start in range(0, aligned_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, aligned_completions) + mb_inputs = all_input_data[mb_start:mb_end] + # Reference log-probs for KL: same policy model with LoRA adapter disabled (= base model). + # Skipped when KL_BETA == 0 to save one extra forward per mini-batch. + ref_logps = None + if KL_BETA > 0.0: + ref_outputs = model.forward_only(inputs=mb_inputs, disable_lora=True) + ref_logps = ref_outputs.get('logps') if isinstance(ref_outputs, dict) else getattr(ref_outputs, 'logps', None) + # [EXP-ORACLE] per-token bonus: forward with oracle context, diff against rollout logps + mb_token_bonus = None + if ORACLE_BONUS_COEF > 0.0: + mb_oracle_inputs = _build_oracle_inputs( + mb_inputs, all_f1_labels[mb_start:mb_end], rollout_template) + if mb_oracle_inputs is not None: + oracle_outputs = model.forward_only(inputs=mb_oracle_inputs) + oracle_logps = oracle_outputs.get('logps') if isinstance(oracle_outputs, dict) else getattr(oracle_outputs, 'logps', None) + if oracle_logps is not None: + mb_token_bonus = _compute_token_bonus( + oracle_logps, all_old_logps[mb_start:mb_end], + all_f1_labels[mb_start:mb_end], mb_oracle_inputs) + model.forward_backward( + inputs=mb_inputs, + old_logps=all_old_logps[mb_start:mb_end], + advantages=advantages[mb_start:mb_end], + ref_logps=ref_logps, + token_bonus=mb_token_bonus, + positive_mask=all_f1_labels[mb_start:mb_end], + micro_batch_size=MICRO_BATCH_SIZE) + model.clip_grad_and_step() + optim_step += 1 + if optim_step >= total_steps: + break + if optim_step % SAVE_STEPS == 0: + model.save(f'hotpotqa-grpo-tools-llmcondense-checkpoint-{optim_step}') + + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric(is_training=True)) + log_dict.update(_compute_rollout_diagnostics( + all_trajectories, n_turns_per_rollout, per_rollout_completion_length, + f1_rewards=f1_rewards, old_logps=all_old_logps)) + log_dict['pos_neg_adv_rate'] = pos_with_neg_adv / n_pos if n_pos else 0.0 + log_dict['neg_pos_adv_rate'] = neg_with_pos_adv / n_neg if n_neg else 0.0 + log_dict['adv_max'] = max(rollout_advantages) if rollout_advantages else 0.0 + log_dict['adv_min'] = min(rollout_advantages) if rollout_advantages else 0.0 + # Pop high-KL token records before swanlab.log: list-of-dict won't render as a chart. + _hk = log_dict.pop('_high_kl_records', None) + if _hk: + _tok = rollout_template.tokenizer + for r in _hk: + gsi = r.get('gsi') + tid = all_trajectories[gsi].get('id') if gsi is not None and 0 <= gsi < len(all_trajectories) else None + try: + tok_text = _tok.decode([r['token_id']]) + except Exception: + tok_text = None + logger.info( + '[high-kl] step=%d gsi=%s tid=%s pos=%s tok=%r kl=%.4f r=%.4f lp_new=%.4f lp_old=%.4f', + batch_step, gsi, tid, r.get('pos'), tok_text, + r.get('kl'), r.get('ratio'), r.get('logp_new'), r.get('logp_old')) + swanlab.log(_coerce_for_swanlab(log_dict), step=batch_step) + metrics.reset() + logger.info(f'[Step {batch_step}/{total_steps}] {log_dict}') + + logger.info(f'Training completed. optim_steps={optim_step}') + model.save('hotpotqa-grpo-tools-llmcondense-final') + + +if __name__ == '__main__': + main() diff --git a/cookbook/exp/make_condensed_sft.py b/cookbook/exp/make_condensed_sft.py new file mode 100644 index 00000000..3b9855ac --- /dev/null +++ b/cookbook/exp/make_condensed_sft.py @@ -0,0 +1,945 @@ +"""Cold-start SFT dataset builder for the condensed multi-hop QA task. + +Pipeline per HotpotQA distractor row: + 1. Build the standard system + user-with-context trajectory using the + production ``SYSTEM_PROMPT`` and ``_format_context`` from + ``cookbook/rl/grpo_condensed.py`` so the offline data matches what + the policy sees at training/inference time. + 2. Run the production ``NativeChunker`` + ``ModelCondenser`` on the + row to produce ``...`` compressed text. + 3. **Validation pass** (super-LLM, ``enable_thinking=True``, no oracle, + no tools): judge whether the question / supporting_facts / GT are + well-formed against the raw passages; return strict JSON + ``{"verdict": "ok"|"fix"|"drop", ...}`` with fixed SF + GT when + applicable. ``drop`` skips the row. + 4. **Oracle rollout pass** via :class:`APIMultiTurnRollout` with a + trajectory-bound :class:`ExtractCondensed` tool. The oracle hint + (SF titles + GT) is injected into the system prompt **only for + the API call**; it is stripped before saving. The model emits + OpenAI-shape ``tool_calls`` for ``extract_condensed``, the rollout + dispatches them through :class:`ToolManager` and feeds back the + pre-compression passage text as a ``tool`` message, looping until + the model finalises with ``\\boxed{...}`` or hits ``MAX_TURNS``. + 5. Accept iff F1(boxed, used_gt) >= ``F1_ACCEPT_THRESHOLD``. On miss, + retry once with a higher temperature. + 6. Convert OpenAI-shape ``tool_calls`` into the textual + ``N`` + format consumed by the training chat template (mirrors + ``grpo_condensed.SYSTEM_PROMPT`` L232-239), restore the clean + system prompt, and emit one JSONL line. + +Run:: + + python cookbook/rl/make_condensed_sft.py \\ + --output hotpotqa_sft_coldstart.jsonl \\ + --model --api-key $KEY --base-url $URL \\ + --total 9000 --easy 1500 --medium 3000 --hard 4500 \\ + --concurrency 16 --seed 42 \\ + --condenser-model-id ms://Qwen/Qwen3.5-4B \\ + --condenser-lora ms://twinkle-kit/Qwen3.5-4B-Condenser +""" +from __future__ import annotations + +import argparse +import json +import os +import random +import re +import sys +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from twinkle.data_format.sampling import SamplingParams +from twinkle.sampler import vLLMSampler +from twinkle.template import Qwen3_5Template +from twinkle_agentic.chunker.native import NativeChunker +from twinkle_agentic.condenser import ModelCondenser +from twinkle_agentic.data_format import Chunks +from twinkle_agentic.protocol.openai import OpenAI +from twinkle_agentic.reward.f1 import _extract_final_answer, _f1_score +from twinkle_agentic.rollout import APIMultiTurnRollout +from twinkle_agentic.tools.extract_condensed import ExtractCondensed +from twinkle_agentic.tools.tool_manager import ToolManager + + +# -------------------------------------------------------------------------- +# Constants mirrored from grpo_condensed.py so the SFT data matches the +# runtime contract byte-for-byte. Re-import would pull the whole training +# module; copying these few strings keeps the builder standalone. +# -------------------------------------------------------------------------- +SYSTEM_PROMPT = """You are a careful multi-hop QA assistant. + +## Context Format (Mixed) +The context you receive is a **mix of two forms**: + +1. **Compressed blocks** — long passages wrapped in `...`, \ +displayed as a Markdown digest in **telegraphic style** (no \ +articles / "is" / "are"; colons and commas mean "is" / "has") \ +with two sections: + - **Summary**: overview plus facts strongly related to the question, stated explicitly. + - **More**: a collapsed INDEX of category keywords hinting at extra details hidden in the full text (call `extract_condensed` to see them). + Reading example: `India: 7th largest by area. Borders: Pakistan, \ +China.` means "India is the 7th largest country by area and \ +shares borders with Pakistan and China." +2. **Raw passages** — short passages shown inline as plain text (`Title: \ +body`) **without** any `` wrapping. These are already the full \ +text; nothing is hidden. + +Only the ``-wrapped blocks are compressed and can be expanded. \ +Block ids `N` are 1-based and assigned in the order compressed blocks \ +appear in the context, so they are always contiguous (``, \ +``, ``, ...). Raw passages have no block id and cannot \ +be extracted — they are already complete. + +## Workflow + +### Phase 1 — Scan and Decide +Step 1: Read each compressed block's Summary, and read raw \ +passages directly, to get an overview. +Step 2: For compressed blocks, check the More keywords to judge whether \ +hidden details are needed. +Step 3: Decide which compressed blocks to expand, then call \ +`extract_condensed` with their block ids. Raw passages need no extraction. + +### Phase 2 — Reason and Answer +After the tool returns the full text, continue stepping through the evidence: +Step N: From block X (or the raw passage titled "..."), I learn that [fact A]. +Step N+1: From block Y, I need to call `extract_condensed` to get more information, because this block is related to... +Step N+2: Combining these, the answer is ... +\\boxed{answer} + +You may call `extract_condensed` several times to expand more blocks if the information is not enough, only answer the question if you are sure about the facts. +The `blocks` parameter accepts **exactly one integer** per call (e.g. `3`); lists are rejected. Expand additional blocks by issuing separate `extract_condensed` calls, one per block. Only pass ids that actually appear as `` in the context, and do **not** request the same block twice — its text is already in the conversation after the first expansion. + +## Tool Call Format + + + +3 + + + + +## Output Format +End your final response with \\boxed{answer}, e.g. \\boxed{Delhi}. +Keep the boxed text short: a name, entity, date, or "yes"/"no". +Answers not inside \\boxed{} will not be scored.""" + + +# Oracle suffix appended ONLY for API generation; stripped before save. +_ORACLE_HINT_TEMPLATE = ( + '\n\n## Oracle hint (PRIVATE — do NOT quote verbatim)\n' + 'The following supporting-fact titles and ground-truth answer are ' + 'provided to make your final answer reliable. Use them as a signpost ' + 'while you reason from the context; your final `\\boxed{{...}}` MUST ' + 'paraphrase the ground truth using evidence from the blocks (after ' + 'expanding compressed blocks when needed), not just echo it.\n' + 'Supporting facts (titles): {sf}\n' + 'Ground truth: {gt}\n' + 'You MUST still call `extract_condensed` on EVERY compressed block ' + 'whose Summary or More keywords touch any supporting-fact title, even ' + 'if the Summary already seems to state the answer — the compressed ' + 'Summary occasionally loses pronoun referents or attribution and the ' + 'raw passage is the authoritative source.' +) + + +VALIDATION_SYSTEM = ( + 'You are a HotpotQA annotation auditor. Read the raw passages, the ' + 'question, the supplied supporting-fact titles and the supplied ' + 'ground-truth answer. Decide whether this row is usable for training ' + 'a multi-hop QA model.\n\n' + 'Pathologies to catch (drop or fix):\n' + ' - question template leakage: the question literally contains the ' + 'answer, references a passage id, or is malformed;\n' + ' - subject/answer mismatch: the GT does not actually answer the ' + 'question given the passages (e.g. the question asks about an event ' + 'X but GT is from a sibling event Y);\n' + ' - GT entity not present in any passage AND not directly inferable ' + 'by a 2-hop bridge from the passages;\n' + ' - supporting-fact titles obviously incomplete for a 2-hop question.\n' + '\n' + 'Return STRICT JSON ONLY (no markdown fence, no preamble) with this ' + 'exact shape:\n' + ' {"verdict": "ok"|"fix"|"drop", "reason": "", ' + '"fixed_supporting_facts": ["", ...], ' + '"fixed_ground_truth": "<short answer>"}\n' + 'Use verdict "ok" when the supplied SF + GT are correct (then ' + '"fixed_supporting_facts" and "fixed_ground_truth" MAY be empty). ' + 'Use verdict "fix" when the question is answerable but SF or GT are ' + 'wrong/incomplete -- fill the fixed fields with the corrected values, ' + 'titles drawn verbatim from the passage titles below. Use verdict ' + '"drop" when the question itself is invalid or unanswerable from the ' + 'given passages.' +) + + +VALIDATION_USER_TEMPLATE = ( + 'Question: {question}\n' + '\n' + 'Supplied supporting-fact titles: {sf}\n' + 'Supplied ground truth: {gt}\n' + '\n' + 'Passage titles (verbatim):\n{titles}\n' + '\n' + 'Passages (raw, uncompressed):\n\n{passages}' +) + + +# JSON Schema for the OpenAI API; the in-process ExtractCondensed tool's +# tool_info() emits a free-form description that the OpenAI SDK rejects. +EXTRACT_CONDENSED_TOOL: Dict[str, Any] = { + 'type': 'function', + 'function': { + 'name': 'extract_condensed', + 'description': ( + 'Recover the full, uncompressed text of ONE previously ' + 'condensed passage, identified by its <block_N> tag. Use ' + 'this tool whenever you need to re-read the original detail ' + 'of a compressed block. Each call expands exactly one block; ' + 'issue separate calls for additional blocks, and do not ' + 'request the same block twice.'), + 'parameters': { + 'type': 'object', + 'properties': { + 'blocks': { + 'type': 'integer', + 'description': ( + 'The 1-indexed block number N appearing inside ' + '<block_N>...</block_N>. Exactly one block per ' + 'call (e.g. 3); lists are rejected.'), + }, + }, + 'required': ['blocks'], + }, + }, +} + + +F1_ACCEPT_THRESHOLD: float = 0.5 +ROLLOUT_MAX_TURNS: int = 8 +ROLLOUT_MAX_TOKENS: int = 2048 +VALIDATION_MAX_TOKENS: int = 1024 +ROLLOUT_TEMPERATURE_LADDER: Tuple[float, ...] = (0.4, 0.7) + + +# -------------------------------------------------------------------------- +# Trajectory + chunk helpers (mirror HotpotQAProcessor + production prompt). +# -------------------------------------------------------------------------- +def _format_passage(title: str, sentences: Any) -> str: + if isinstance(sentences, list): + body = ' '.join(s.strip() for s in sentences if s and s.strip()) + else: + body = str(sentences).strip() + return f'{title}: {body}' + + +def _format_context(titles: List[str], sentences_list: List[Any]) -> str: + return '\n\n'.join( + _format_passage(t, s) for t, s in zip(titles, sentences_list)) + + +def _build_initial_trajectory(row: Dict[str, Any]) -> Dict[str, Any]: + """Build the pre-compression trajectory dict the chunker expects.""" + ctx = row.get('context') or {} + titles = list(ctx.get('title') or []) + sentences_list = list(ctx.get('sentences') or []) + user_msg = ( + f"Question: {row['question']}\n\n" + f'Context:\n\n{_format_context(titles, sentences_list)}') + return { + 'messages': [ + {'role': 'system', 'content': SYSTEM_PROMPT}, + {'role': 'user', 'content': user_msg}, + ], + } + + +def _extract_question_from_chunk(chunk): + content = chunk.get('content') + if chunk.get('type') != 'text' or not isinstance(content, str): + return None + m = re.search(r'\AQuestion:\s*(.+)', content) + return m.group(1).strip() if m else None + + +# -------------------------------------------------------------------------- +# Per-batch compression (re-use MultiTurnCondenseRollout's batching trick: +# merge all per-row chunks into ONE Chunks so the sampler sees a packed batch). +# -------------------------------------------------------------------------- +def compress_rows( + rows: List[Dict[str, Any]], + chunker: NativeChunker, + condenser: ModelCondenser, +) -> List[Tuple[Dict[str, Any], Chunks]]: + """Return ``[(compressed_trajectory_dict, per_row_Chunks), ...]``. + + ``compressed_trajectory_dict`` already has ``<block_N>...</block_N>`` + wrapping in its user message (see :meth:`Chunks.to_trajectory`). + ``per_row_Chunks`` carries ``raw.original`` snapshots so + :class:`ExtractCondensed` can return the pre-compression text. + """ + if not rows: + return [] + initial = [_build_initial_trajectory(r) for r in rows] + per_row_chunks = [chunker(t) for t in initial] + merged_list: List[Any] = [] + boundaries: List[int] = [] + for ck in per_row_chunks: + merged_list.extend(ck.chunks) + boundaries.append(len(merged_list)) + merged = condenser(Chunks(chunks=merged_list)) + out: List[Tuple[Dict[str, Any], Chunks]] = [] + start = 0 + for end in boundaries: + slc = Chunks(chunks=list(merged.chunks[start:end])) + out.append((slc.to_trajectory(), slc)) + start = end + return out + + +# -------------------------------------------------------------------------- +# Stage 1: validation pass. +# -------------------------------------------------------------------------- +_JSON_FENCE_RE = re.compile(r'```(?:json)?\s*\n(.*?)\n```', re.DOTALL) + + +def _extract_json_object(text: str) -> Optional[Dict[str, Any]]: + """Best-effort JSON parse: strip fence, then locate first ``{...}`` block.""" + if not text: + return None + candidate = text.strip() + m = _JSON_FENCE_RE.search(candidate) + if m: + candidate = m.group(1).strip() + depth = 0 + start = -1 + for i, ch in enumerate(candidate): + if ch == '{': + if depth == 0: + start = i + depth += 1 + elif ch == '}': + depth -= 1 + if depth == 0 and start != -1: + blob = candidate[start:i + 1] + try: + return json.loads(blob) + except json.JSONDecodeError: + start = -1 + continue + return None + + +def validate_row( + api: OpenAI, row: Dict[str, Any], original_gt: List[str], sf_titles: List[str], +) -> Optional[Dict[str, Any]]: + """Return parsed JSON verdict, or ``None`` on unrecoverable parse failure.""" + ctx = row.get('context') or {} + titles = list(ctx.get('title') or []) + sentences_list = list(ctx.get('sentences') or []) + passages = _format_context(titles, sentences_list) + user = VALIDATION_USER_TEMPLATE.format( + question=row['question'], + sf=json.dumps(sf_titles, ensure_ascii=False), + gt=json.dumps(original_gt, ensure_ascii=False), + titles='\n'.join(f'- {t}' for t in titles), + passages=passages, + ) + trajectory = { + 'messages': [ + {'role': 'system', 'content': VALIDATION_SYSTEM}, + {'role': 'user', 'content': user}, + ], + } + sp = SamplingParams( + temperature=0.0, max_tokens=VALIDATION_MAX_TOKENS, num_samples=1) + for attempt in range(2): + try: + reply = api( + trajectory, sp, extra_body={'enable_thinking': True}) + except Exception as exc: + sys.stderr.write(f'[validate] row={row.get("id")} attempt={attempt} api error: {exc}\n') + return None + content = reply.get('content') or '' + parsed = _extract_json_object(content) + if parsed and parsed.get('verdict') in ('ok', 'fix', 'drop'): + return parsed + return None + + +def resolve_validation( + verdict: Dict[str, Any], original_gt: List[str], sf_titles: List[str], +) -> Tuple[List[str], List[str]]: + """Pick the SF + GT list to use downstream based on verdict.""" + v = verdict.get('verdict') + if v == 'fix': + fixed_gt = verdict.get('fixed_ground_truth') or '' + fixed_sf = verdict.get('fixed_supporting_facts') or [] + gt_list: List[str] = [] + if isinstance(fixed_gt, list): + gt_list = [str(x).strip() for x in fixed_gt if str(x).strip()] + elif isinstance(fixed_gt, str) and fixed_gt.strip(): + gt_list = [fixed_gt.strip()] + if not gt_list: + gt_list = original_gt + sf_list = ( + [str(x).strip() for x in fixed_sf if str(x).strip()] + if isinstance(fixed_sf, list) else sf_titles) + if not sf_list: + sf_list = sf_titles + return gt_list, sf_list + return original_gt, sf_titles + + +# -------------------------------------------------------------------------- +# Stage 2 prep: build oracle trajectory + per-trajectory ToolManager. +# -------------------------------------------------------------------------- +def _oracle_system_prompt(sf_titles: List[str], gt_list: List[str]) -> str: + sf_render = ', '.join(repr(t) for t in sf_titles) if sf_titles else '(none)' + gt_render = ' | '.join(gt_list) if gt_list else '(unknown)' + return SYSTEM_PROMPT + _ORACLE_HINT_TEMPLATE.format( + sf=sf_render, gt=gt_render) + + +def _build_oracle_trajectory( + compressed_traj: Dict[str, Any], + sf_titles: List[str], + gt_list: List[str], +) -> Dict[str, Any]: + """Replace the system message with the oracle-suffixed variant and + attach the JSON-schema tools field consumed by the OpenAI API.""" + oracle_sp = _oracle_system_prompt(sf_titles, gt_list) + out_messages: List[Dict[str, Any]] = [] + sys_inserted = False + for m in compressed_traj.get('messages') or []: + if m.get('role') == 'system' and not sys_inserted: + out_messages.append({'role': 'system', 'content': oracle_sp}) + sys_inserted = True + else: + out_messages.append(dict(m)) + if not sys_inserted: + out_messages.insert(0, {'role': 'system', 'content': oracle_sp}) + return { + 'messages': out_messages, + 'tools': [EXTRACT_CONDENSED_TOOL], + } + + +def _make_tool_manager(chunks: Chunks) -> ToolManager: + """One ToolManager + ExtractCondensed per trajectory; the tool keeps + a ``_already_expanded`` set, so reusing across trials would lie to + the model on retry.""" + tm = ToolManager() + tm.register(ExtractCondensed(chunks)) + return tm + + +# -------------------------------------------------------------------------- +# Stage 3 + 4: F1 acceptance + conversion to training-runtime format. +# -------------------------------------------------------------------------- +def boxed_f1(boxed: str, gt_list: List[str]) -> float: + if not boxed or not gt_list: + return 0.0 + return max(_f1_score(boxed, g)[0] for g in gt_list) + + +def _last_assistant_text(messages: List[Dict[str, Any]]) -> str: + for m in reversed(messages): + if m.get('role') == 'assistant' and isinstance(m.get('content'), str): + return m['content'] + return '' + + +def _format_tool_call_text(blocks: int) -> str: + return ( + '<tool_call>\n' + '<function=extract_condensed>\n' + '<parameter=blocks>\n' + f'{blocks}\n' + '</parameter>\n' + '</function>\n' + '</tool_call>' + ) + + +def convert_to_runtime_messages( + api_messages: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """OpenAI tool_calls -> textual <tool_call> format consumed by the + training chat template. The first system message has its oracle + suffix stripped (we just replace it with the clean SYSTEM_PROMPT). + """ + out: List[Dict[str, Any]] = [] + sys_done = False + for m in api_messages: + role = m.get('role') + if role == 'system' and not sys_done: + out.append({'role': 'system', 'content': SYSTEM_PROMPT}) + sys_done = True + continue + if role == 'assistant': + content = m.get('content') or '' + tool_calls = m.get('tool_calls') or [] + if tool_calls: + pieces = [content.rstrip()] if content else [] + for tc in tool_calls: + fn = tc.get('function') or {} + args_raw = fn.get('arguments') + try: + args = ( + json.loads(args_raw) if isinstance(args_raw, str) + else (args_raw or {})) + except json.JSONDecodeError: + args = {} + blocks_val = args.get('blocks', args.get('block')) + try: + n = int(blocks_val) + except (TypeError, ValueError): + continue + pieces.append(_format_tool_call_text(n)) + text = '\n\n'.join(p for p in pieces if p) + out.append({'role': 'assistant', 'content': text}) + else: + out.append({'role': 'assistant', 'content': content}) + continue + if role == 'tool': + out.append({'role': 'tool', 'content': m.get('content') or ''}) + continue + out.append({k: v for k, v in m.items() if k in ('role', 'content')}) + return out + + +def trajectory_achieved_ratio(chunks: Chunks) -> float: + total_src = 0 + total_cmp = 0 + for c in chunks.chunks: + if c.get('type') != 'text': + continue + raw = c.get('raw') + if not (isinstance(raw, dict) and raw.get('condensed')): + continue + original = raw.get('original') + compressed = c.get('content') + if isinstance(original, str) and isinstance(compressed, str): + total_src += len(original) + total_cmp += len(compressed) + return round(total_cmp / total_src, 4) if total_src else 0.0 + + +def build_record( + row: Dict[str, Any], + runtime_messages: List[Dict[str, Any]], + chunks: Chunks, + verdict: Dict[str, Any], + original_gt: List[str], + used_gt: List[str], + used_sf: List[str], + boxed: str, + f1: float, + num_tool_calls: int, +) -> Dict[str, Any]: + ctx = row.get('context') or {} + titles = list(ctx.get('title') or []) + sentences_list = list(ctx.get('sentences') or []) + raw_passages = [ + { + 'title': t, + 'sentences': list(s) if isinstance(s, list) else [str(s)], + } + for t, s in zip(titles, sentences_list) + ] + sf_full = row.get('supporting_facts') or {} + return { + 'id': row['id'], + 'level': row.get('level'), + 'type': row.get('type'), + 'messages': runtime_messages, + 'tools': [EXTRACT_CONDENSED_TOOL], + 'meta': { + 'num_tool_calls': num_tool_calls, + 'achieved_ratio': trajectory_achieved_ratio(chunks), + 'validation_verdict': verdict.get('verdict'), + 'validation_reason': verdict.get('reason'), + 'original_question': row.get('question'), + 'original_answer': row.get('answer'), + 'original_gt': original_gt, + 'used_gt': used_gt, + 'used_supporting_facts': used_sf, + 'original_supporting_facts': { + 'title': list(sf_full.get('title') or []), + 'sent_id': list(sf_full.get('sent_id') or []), + }, + 'original_passages': raw_passages, + 'f1': round(f1, 4), + 'boxed': boxed, + }, + } + + +# -------------------------------------------------------------------------- +# Per-batch pipeline orchestration. +# -------------------------------------------------------------------------- +def _extract_original_gt_sf(row: Dict[str, Any]) -> Tuple[List[str], List[str]]: + answers = row.get('answers') + if isinstance(answers, list) and answers: + original_gt = [str(a).strip() for a in answers if str(a).strip()] + else: + original_gt = [(row.get('answer', '') or '').strip()] + original_gt = [g for g in original_gt if g] + sf = row.get('supporting_facts') or {} + sf_titles = list(dict.fromkeys(t for t in (sf.get('title') or []) if t)) + return original_gt, sf_titles + + +def _validate_in_parallel( + api: OpenAI, batch: List[Dict[str, Any]], pool: ThreadPoolExecutor, +) -> Tuple[List[Optional[Dict[str, Any]]], List[Tuple[List[str], List[str]]]]: + """Run ``validate_row`` for every row in parallel (one OpenAI call each).""" + futures = [] + payloads: List[Tuple[List[str], List[str]]] = [] + for row in batch: + original_gt, sf_titles = _extract_original_gt_sf(row) + payloads.append((original_gt, sf_titles)) + futures.append(pool.submit( + validate_row, api, row, original_gt, sf_titles)) + verdicts: List[Optional[Dict[str, Any]]] = [f.result() for f in futures] + return verdicts, payloads + + +def _num_tool_calls(messages: List[Dict[str, Any]]) -> int: + return sum( + len(m.get('tool_calls') or []) + for m in messages if m.get('role') == 'assistant') + + +def process_batch( + api: OpenAI, + rollout: APIMultiTurnRollout, + batch: List[Dict[str, Any]], + chunker: NativeChunker, + condenser: ModelCondenser, + validation_pool: ThreadPoolExecutor, +) -> List[Dict[str, Any]]: + """Validate -> compress -> rollout (T-ladder) -> accept. Returns the + list of accepted JSONL records for the batch.""" + if not batch: + return [] + # 1. Validation in parallel. + verdicts, payloads = _validate_in_parallel(api, batch, validation_pool) + + survivors_meta: List[Dict[str, Any]] = [] + for row, verdict, (original_gt, sf_titles) in zip(batch, verdicts, payloads): + if verdict is None or verdict.get('verdict') == 'drop': + continue + if not original_gt: + continue + used_gt, used_sf = resolve_validation(verdict, original_gt, sf_titles) + if not used_gt: + continue + survivors_meta.append({ + 'row': row, 'verdict': verdict, + 'original_gt': original_gt, + 'used_gt': used_gt, 'used_sf': used_sf, + }) + if not survivors_meta: + return [] + + # 2. Compress survivors (one packed batch through ModelCondenser). + survivor_rows = [m['row'] for m in survivors_meta] + try: + compressed = compress_rows(survivor_rows, chunker, condenser) + except Exception as exc: + sys.stderr.write(f'[compress] batch crashed: {exc}\n') + return [] + + # 3. Build oracle trajectories + per-trajectory ToolManagers. + trajs: List[Dict[str, Any]] = [] + chunks_list: List[Chunks] = [] + for meta, (compressed_traj, chunks) in zip(survivors_meta, compressed): + trajs.append(_build_oracle_trajectory( + compressed_traj, meta['used_sf'], meta['used_gt'])) + chunks_list.append(chunks) + + # 4. Temperature ladder. Each rung gets fresh ExtractCondensed tools so + # a retry does not see the previous attempt's already-expanded set. + accepted: List[Dict[str, Any]] = [] + pending_idx = list(range(len(trajs))) + for temperature in ROLLOUT_TEMPERATURE_LADDER: + if not pending_idx: + break + sp = SamplingParams( + temperature=temperature, max_tokens=ROLLOUT_MAX_TOKENS, num_samples=1) + run_trajs = [trajs[i] for i in pending_idx] + run_tms = [_make_tool_manager(chunks_list[i]) for i in pending_idx] + try: + outs = rollout( + run_trajs, tool_manager=run_tms, sampling_params=sp) + except Exception as exc: + sys.stderr.write(f'[rollout] batch crashed at T={temperature}: {exc}\n') + return accepted + next_pending: List[int] = [] + for local_pos, traj_idx in enumerate(pending_idx): + out_traj = outs[local_pos] + if out_traj.get('stop_reason') == 'api_error': + continue # hard-drop API failures, do not retry + messages = out_traj.get('messages') or [] + boxed = _extract_final_answer(_last_assistant_text(messages)) + meta = survivors_meta[traj_idx] + f1 = boxed_f1(boxed, meta['used_gt']) + if f1 >= F1_ACCEPT_THRESHOLD: + runtime_messages = convert_to_runtime_messages(messages) + accepted.append(build_record( + row=meta['row'], + runtime_messages=runtime_messages, + chunks=chunks_list[traj_idx], + verdict=meta['verdict'], + original_gt=meta['original_gt'], + used_gt=meta['used_gt'], + used_sf=meta['used_sf'], + boxed=boxed, f1=f1, + num_tool_calls=_num_tool_calls(messages))) + else: + next_pending.append(traj_idx) + pending_idx = next_pending + return accepted + + +# -------------------------------------------------------------------------- +# Stratified sampling + resume. +# -------------------------------------------------------------------------- +LEVELS: Tuple[str, str, str] = ('easy', 'medium', 'hard') + + +def stratified_sample( + ds, per_level: Dict[str, int], seed: int, +) -> List[Dict[str, Any]]: + rng = random.Random(seed) + buckets: Dict[str, List[int]] = {lv: [] for lv in LEVELS} + for i, lv in enumerate(ds['level']): + if lv in buckets: + buckets[lv].append(i) + picked: List[int] = [] + for lv in LEVELS: + need = per_level[lv] + pool = buckets[lv] + if len(pool) < need: + raise RuntimeError( + f'level={lv} has only {len(pool)} rows, need {need}') + picked.extend(rng.sample(pool, need)) + rng.shuffle(picked) + return [ds[int(i)] for i in picked] + + +def load_done_ids(path: str) -> set: + if not os.path.exists(path): + return set() + done = set() + with open(path, 'r', encoding='utf-8') as fh: + for line in fh: + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + rid = obj.get('id') + if rid: + done.add(rid) + return done + + +def apply_reannotation_overlay( + rows: List[Dict[str, Any]], path: str, +) -> List[Dict[str, Any]]: + """Drop verdict=drop ids; overlay ``question_fixed`` and multi-form ``answers``. + + The validation stage in ``process_batch`` still runs on every survivor + because the audit ran on a different HF subset (fullwiki) than this + builder's default (distractor) and passage contexts differ. + """ + overrides: Dict[str, Dict[str, Any]] = {} + drop_ids: set = set() + with open(path, 'r', encoding='utf-8') as fh: + for line in fh: + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + rid = obj.get('id') + if not rid: + continue + if obj.get('verdict') == 'drop': + drop_ids.add(rid) + else: + overrides[rid] = obj + out: List[Dict[str, Any]] = [] + overridden = 0 + for row in rows: + rid = row.get('id') + if rid in drop_ids: + continue + ov = overrides.get(rid) + if ov is not None: + row = dict(row) + qfix = (ov.get('question_fixed') or '').strip() + if qfix: + row['question'] = qfix + ans = [str(a).strip() for a in (ov.get('answers') or []) if str(a).strip()] + if ans: + row['answers'] = ans + overridden += 1 + out.append(row) + sys.stderr.write( + f'[REANNOTATED] {path}: {len(rows)} -> {len(out)} rows ' + f'(dropped={len(drop_ids)}, overridden={overridden})\n') + return out + + +# -------------------------------------------------------------------------- +# CLI + main loop. +# -------------------------------------------------------------------------- +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument('--output', required=True) + parser.add_argument('--model', required=True, + help='Super-LLM model name (OpenAI-protocol).') + parser.add_argument('--api-key', default=os.environ.get('OPENAI_API_KEY')) + parser.add_argument('--base-url', default=os.environ.get('OPENAI_BASE_URL')) + parser.add_argument('--total', type=int, default=12000) + parser.add_argument('--easy', type=int, default=2000) + parser.add_argument('--medium', type=int, default=4000) + parser.add_argument('--hard', type=int, default=6000) + parser.add_argument('--concurrency', type=int, default=16) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--reannotated', default=os.environ.get('REANNOTATED_FILE', ''), + help='Path to wrong_ids_reannotated.jsonl. Drops verdict=drop ids and overlays question_fixed + multi-form answers. Validation stage still runs because the audit was on a different HF subset.') + parser.add_argument('--hf-subset', default='distractor') + parser.add_argument('--hf-split', default='train') + parser.add_argument('--condenser-model-id', + default=os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')) + parser.add_argument('--condenser-lora', + default='ms://twinkle-kit/Qwen3.5-4B-Condenser') + parser.add_argument('--chunk-size', type=int, default=1024) + parser.add_argument('--hotpotqa-max-length', type=int, default=64000) + parser.add_argument('--compress-batch-size', type=int, default=32, + help='How many rows to feed to ModelCondenser at once.') + parser.add_argument('--gpu-memory-utilization', type=float, default=0.8) + return parser.parse_args() + + +def build_condenser(args: argparse.Namespace) -> Tuple[NativeChunker, ModelCondenser]: + sampler = vLLMSampler( + model_id=args.condenser_model_id, + engine_args={ + 'gpu_memory_utilization': args.gpu_memory_utilization, + 'max_model_len': max(8192, args.hotpotqa_max_length), + 'max_lora_rank': 32, + 'enable_lora': True, + 'max_loras': 2, + }, + ) + sampler.set_template( + 'Qwen3_5Template', model_id=args.condenser_model_id, + enable_thinking=False, max_length=args.hotpotqa_max_length) + rollout_template = Qwen3_5Template( + args.condenser_model_id, max_length=args.hotpotqa_max_length, + enable_thinking=False) + chunker = NativeChunker( + chunk_size=args.chunk_size, + passage_boundary_re=r'(?<=\n\n)', + ) + condenser = ModelCondenser( + sampler=sampler, + compression_ratio=2.0, + sampling_params=SamplingParams( + max_tokens=1024, num_samples=1, temperature=0.4, top_p=0.9), + min_chars=200, + template=rollout_template, + lora_path=args.condenser_lora or None, + skip_pattern=r'^Question:', + related_query=_extract_question_from_chunk, + ) + return chunker, condenser + + +def main() -> None: + args = parse_args() + if args.easy + args.medium + args.hard != args.total: + raise ValueError( + f'--easy + --medium + --hard ({args.easy + args.medium + args.hard}) ' + f'must equal --total ({args.total})') + per_level = {'easy': args.easy, 'medium': args.medium, 'hard': args.hard} + + sys.stderr.write( + f'Loading hotpotqa/hotpot_qa:{args.hf_subset}:{args.hf_split}...\n') + ds = load_dataset( + 'hotpotqa/hotpot_qa', args.hf_subset, split=args.hf_split) + + rows = stratified_sample(ds, per_level=per_level, seed=args.seed) + if args.reannotated.strip(): + rows = apply_reannotation_overlay(rows, args.reannotated.strip()) + done = load_done_ids(args.output) + sys.stderr.write(f'Resume: {len(done)} rows already emitted.\n') + pending = [r for r in rows if r['id'] not in done] + sys.stderr.write(f'Pending: {len(pending)} / {len(rows)}\n') + + chunker, condenser = build_condenser(args) + api = OpenAI( + model=args.model, api_key=args.api_key, base_url=args.base_url) + + # APIMultiTurnRollout itself owns the per-trajectory thread pool. The + # validation phase runs on a separate pool of equal size; both phases + # are network-bound so we never need more threads than ``concurrency``. + rollout = APIMultiTurnRollout( + api=api, + tool_manager=ToolManager(), # placeholder; per-call list overrides + sampling_params=SamplingParams( + temperature=ROLLOUT_TEMPERATURE_LADDER[0], + max_tokens=ROLLOUT_MAX_TOKENS, num_samples=1), + max_turns=ROLLOUT_MAX_TURNS, + concurrency=args.concurrency, + extra_body={'enable_thinking': False}, + ) + + write_lock = threading.Lock() + out_fh = open(args.output, 'a', encoding='utf-8') + accepted_total = 0 + seen_total = 0 + + with ThreadPoolExecutor(max_workers=args.concurrency) as validation_pool: + try: + for start in range(0, len(pending), args.compress_batch_size): + batch = pending[start:start + args.compress_batch_size] + seen_total += len(batch) + try: + records = process_batch( + api, rollout, batch, chunker, condenser, + validation_pool) + except Exception as exc: + sys.stderr.write( + f'[batch {start}-{start + len(batch)}] crashed: {exc}\n') + continue + with write_lock: + for record in records: + out_fh.write( + json.dumps(record, ensure_ascii=False) + '\n') + out_fh.flush() + accepted_total += len(records) + sys.stderr.write( + f'[progress] seen={seen_total}/{len(pending)} ' + f'accepted={accepted_total} ' + f'(+{len(records)} from this batch)\n') + finally: + out_fh.close() + + sys.stderr.write( + f'Done. accepted={accepted_total} total_pending={len(pending)}\n') + + +if __name__ == '__main__': + main() diff --git a/cookbook/exp/make_condenser_dataset.py b/cookbook/exp/make_condenser_dataset.py new file mode 100644 index 00000000..4e7ce7e9 --- /dev/null +++ b/cookbook/exp/make_condenser_dataset.py @@ -0,0 +1,634 @@ +import argparse +import hashlib +import json +import os +import random +import re +import sys +import threading +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from typing import Any, Dict, Iterator, List, Optional, Set + +from tqdm import tqdm + +from twinkle.data_format.sampling import SamplingParams +from twinkle_agentic.protocol.openai import OpenAI + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Prompts +# ═══════════════════════════════════════════════════════════════════════════════ + +QUERY_GEN_SYSTEM = """\ +You are a query designer. Given a source passage, enumerate distinct information \ +queries a reader might ask of it. Each query must steer toward a meaningfully \ +DIFFERENT compression of the same source — different facets, not rephrasings of \ +the same need. + +Category hints (not exhaustive — combine or invent as fits the source): +- Interface extraction (code): class / method signatures, parameter and return types +- Functional summary: what the passage accomplishes at a high level +- Error & pitfall analysis: bugs, anti-patterns, failure modes, edge cases +- Experience distillation: lessons learned, best practices, do's and don'ts +- Skill extraction (knowledge-as-skill): WHAT this passage lets you do, HOW to \ +apply it as reusable steps, WHEN to invoke it (trigger conditions / use cases) +- Abstract analysis: design patterns, architectural decisions, trade-offs +- Information summary: key facts, entities, numbers, relationships +- Dependency & context: prerequisites, imports, environment, related modules + +Rules: +1. SHAPE — each query is one short imperative or interrogative sentence (e.g. \ +"List all public method signatures with parameter and return types", "What race \ +conditions does this code contain?"). +2. DISTINCT — reject any pair whose answers would substantially overlap; \ +rephrasings of the same information need do NOT count as separate queries. +3. SKILL FOR KNOWLEDGE — when the source reads as tutorial / experience / \ +how-to / domain knowledge, ALWAYS include exactly one skill-style query asking \ +what the reader can accomplish with it and how to apply it (phrased in the \ +source language). +4. ANSWERABLE — skip queries the source cannot actually answer, and skip \ +trivial queries that would just reproduce the source verbatim. +5. SCALE — short / single-purpose → 1; medium → 2; rich / multi-topic → 3–4. \ +Do not pad. +6. LANGUAGE — query language MUST match the source language. +7. OUTPUT — a single JSON array of strings; no preamble, no code fences, \ +nothing else.\ +""" + +QUERY_GEN_USER = "Analyze the following text and return a JSON array of queries.\n\n{text}" + +COMPRESS_SYSTEM = """\ +You are a compression assistant. For the (query, source) pair, emit a Markdown \ +answer with TWO sections, designed to pair with the `extract_compressed` tool: \ +the reader absorbs `## Summary` directly, then calls `extract_compressed` \ +on any topic-key listed under `## More` to recover its \ +fuller content. + + `## Summary` — extreme-density text the reader reads directly. + `## More` — a topic index whose keys are valid arguments \ +to `extract_compressed` for recovering material not captured inline. + +Together the two sections must form a COMPLETE, NON-DISTORTING inventory of the \ +source for the query — nothing essential lost, nothing implied that the source \ +does not support. NO preamble, NO meta-commentary, NO code fences wrapping the \ +whole output. + +Output skeleton: + +## Summary +Topic: <what the source is about + scope, one line> +<dense body answering the query> + +## More +- <topic-key>: <one-line hint of what is revealed when expanded> +- ... + +Format selection for the inline body (pick the MOST COMPACT form per query, mix \ +when helpful): +- Interface / signature → code notation directly: `func(a:int)->str` +- Factual / entity → telegraphic prose; drop function words; ":" for "is", "," \ +for "has" +- Skill / how-to / usage → lead with `Use when: <trigger>`; numbered telegraphic \ +steps `1.do X 2.then Y`; close with `Output: <result>` when relevant +- Procedural → numbered short steps +- Analytical / design → hierarchical bullets with abbreviations + +`## Summary` rules: +1. TOPIC LINE — line 1 is ALWAYS `Topic: <subject — scope>`, even when the \ +query is narrow. Anchors both the reader and the tool. +2. DENSITY — every token in the body carries query-relevant signal; cut filler. +3. PRIMARY-COMPLETE — never silently drop a fact essential to answering the \ +query. Anything cut for length MUST appear as a key under \ +`## More`. +4. NON-MISLEADING — phrasing must not let the reader infer anything the source \ +does not support; partial truths that mislead are worse than honest omissions \ +flagged in the index. +5. SELF-CONTAINED — the reader can act on the answer without re-opening the source. +6. FAITHFUL — only content the source supports; no fabrication, no extrapolation. +7. LANGUAGE — match the source language. +8. NO outer code fences around the whole answer; no meta-commentary. + +`## More` rules (MANDATORY — this section is never omitted): +1. FORMAT — each bullet is `- <topic-key>: <one-line hint>`: + • topic-key — short, unambiguous, grounded in source vocabulary so the \ +`extract_compressed` tool can locate the aspect (e.g. `decorators`, \ +`error handling`, `pitfalls`). + • hint — tells WHAT the reader gains by expanding (concrete numbers, code \ +listings, secondary cases, edge details, related context, …); do NOT restate \ +the inline answer. +2. CRITERION — each bullet names an aspect that EXISTS in the source but is \ +NOT fully captured inline. Material that genuinely fits inline without \ +distortion MUST NOT be duplicated here. +3. FAITHFUL — hints must be grounded in the source; never speculate or invent. +4. ORDER — by relevance to the query, then by importance. +5. EMPTY CASE — if the source is so short / single-purpose that everything \ +fits inline, write a single line `- (none)`. + +Examples: + +Query: List all public method signatures with parameter and return types +Source: (a Python HTTP client class with retry decorator, structured logging, \ +and request helpers) +## Summary +Topic: Python HTTP client class — public surface of retried request helpers. +retry_request(url:str, max_retries:int=3, timeout:float=10.0) -> Response +fetch_json(endpoint:str, params:dict|None=None) -> dict +post_data(endpoint:str, payload:dict, headers:dict|None=None) -> Response + +## More +- decorators: @retry config — exponential backoff (base=2.0, max=60s) +- logging: structured per-request logs with request_id and latency_ms +- private helpers: _build_headers, _parse_error — not in public surface +─── +Query: What can this passage help you accomplish, and how to use it? +Source: (a tutorial on configuring Linux cgroups v2 caps for a systemd service) +## Summary +Topic: Linux cgroups v2 — per-service CPU / memory caps via systemd slice units. +Use when: needing per-service CPU/memory caps on systemd hosts. +1.create slice unit /etc/systemd/system/<name>.slice with CPUQuota=, MemoryMax= +2.attach service via Slice=<name>.slice in [Service] +3.systemctl daemon-reload + restart service +4.verify: systemctl status <svc> shows Tasks/CPU/Memory inside slice +Output: hard caps enforced by kernel cgroup v2. + +## More +- pitfalls: cgroup v1/v2 mode detection, MemorySwapMax behavior on OOM +- delegation: Delegate=yes for nested controllers in container managers +- examples: nginx and postgres slice templates with concrete numeric caps +- diagnostics: systemd-cgls / systemd-cgtop walkthrough +─── +Query: 总结这段代码的错误和改进经验 +Source: (一段有 race condition 和未关闭资源的 Go 代码) +## Summary +Topic: Go HTTP fetch 循环 — 并发写共享 map + 未关闭响应体导致的稳定性缺陷。 +1.race: 并发写 map 未锁 → sync.RWMutex 或 sync.Map +2.泄漏: resp.Body 未 Close → 请求后立即 defer resp.Body.Close() +3.吞错: err 未检查 → 每处 err!=nil 必处理或上抛 + +## More +- (none) + +Now begin.\ +""" + +COMPRESS_USER = "## Query\n{query}\n\n## Source\n{text}" + +# Short system prompt embedded in emitted SFT samples — the long COMPRESS_SYSTEM +# is for data generation only; training samples carry only the binding contract. +COMPRESS_SYSTEM_TRAIN = """\ +You are a compression assistant. For the (query, source) pair, emit a Markdown \ +answer with TWO sections, designed to pair with the `extract_compressed` tool: \ +the reader absorbs `## Summary` directly, then calls `extract_compressed` \ +on any topic-key listed under `## More` to recover its \ +fuller content. + +Output skeleton: + +## Summary +Topic: <subject — scope, one line> +<dense body answering the query> + +## More +- <topic-key>: <one-line hint of what is revealed when expanded> +- ... + +Rules: +1. Line 1 of `## Summary` is ALWAYS `Topic: ...`. +2. Body is maximally dense; every token carries query-relevant signal. +3. Never silently drop a fact — anything cut for length MUST appear as a key \ +under `## More` (do not duplicate inline material here). +4. No fabrication, no extrapolation, no misleading partial truths. +5. Match the source language. No outer code fences, no meta-commentary.\ +""" + +# Fixed queries — used directly (no Phase-1 LLM generation) for a proportion of items. +FIXED_QUERY_NEED = ( + 'What problem does this passage address, and what skill or method is needed? ' + 'Topic must name the specific pattern, never generic labels. ' + 'Compress into a retrieval-friendly need description.') +FIXED_QUERY_SKILL = ( + 'Extract the reusable skill: trigger conditions, key steps, and expected output. ' + 'Topic names the method/pattern; format as "Use when: ...", numbered steps, ' + '"Output: ...". Compress into a standardized procedure for retrieval.') +FIXED_QUERIES = [FIXED_QUERY_NEED, FIXED_QUERY_SKILL] +FIXED_QUERY_RATIO = 0.3 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Core logic +# ═══════════════════════════════════════════════════════════════════════════════ + +def _extract_json_array(text: str) -> Optional[List[str]]: + """Best-effort extraction of a JSON string array from LLM output.""" + text = text.strip() + # Try direct parse first + if text.startswith('['): + try: + arr = json.loads(text) + if isinstance(arr, list) and all(isinstance(x, str) for x in arr): + return arr + except json.JSONDecodeError: + pass + # Fallback: find first [...] block + m = re.search(r'\[.*\]', text, re.DOTALL) + if m: + try: + arr = json.loads(m.group()) + if isinstance(arr, list) and all(isinstance(x, str) for x in arr): + return arr + except json.JSONDecodeError: + pass + return None + + +def generate_queries(api: OpenAI, text: str) -> List[str]: + """Phase 1: ask the LLM what queries can be asked about ``text``.""" + trajectory = { + 'messages': [ + {'role': 'system', 'content': QUERY_GEN_SYSTEM}, + {'role': 'user', 'content': QUERY_GEN_USER.format(text=text)}, + ] + } + sp = SamplingParams(temperature=0.7, max_tokens=1024) + for attempt in range(2): + try: + reply = api(trajectory, sp, extra_body={'enable_thinking': True}) + except Exception as exc: + sys.stderr.write(f'[query_gen] error: {exc}\n') + return [] + content = reply.get('content') or '' + queries = _extract_json_array(content) + if queries: + return queries + if attempt == 0: + sys.stderr.write('[query_gen] retry: failed to parse JSON array\n') + return [] + + +def compress_for_query(api: OpenAI, text: str, query: str, + thinking_budget: int = 1024) -> Optional[str]: + """Phase 2: compress ``text`` w.r.t. ``query``. Returns compressed content or None.""" + trajectory = { + 'messages': [ + {'role': 'system', 'content': COMPRESS_SYSTEM}, + {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)}, + ] + } + sp = SamplingParams(temperature=0.3, max_tokens=16384) + for attempt in range(2): + try: + reply = api(trajectory, sp, extra_body={ + 'enable_thinking': False, + 'thinking_budget': thinking_budget, + }) + except Exception as exc: + sys.stderr.write(f'[compress] error: {exc}\n') + return None + content = (reply.get('content') or '').strip() + if not content: + if attempt == 0: + sys.stderr.write('[compress] retry: empty response\n') + continue + # Strip whole-answer code fence if present. + m = re.match(r'^```[a-zA-Z]*\n(.*?)\n```\s*$', content, re.DOTALL) + if m: + content = m.group(1).strip() + if not (re.search(r'(?im)^##\s*Summary\b', content) + and re.search(r'(?im)^##\s*More\b', content)): + if attempt == 0: + sys.stderr.write('[compress] retry: missing required sections\n') + continue + return content + return None + + +def _query_hash(query: str) -> str: + """Stable short hash of a query string — embedded in sample id for resume.""" + return hashlib.md5(query.strip().encode('utf-8')).hexdigest()[:8] + + +def process_item( + api: OpenAI, + item: Dict[str, Any], + done_sample_ids: Optional[Set[str]] = None, + thinking_budget: int = 1024, + fixed_query_ratio: float = FIXED_QUERY_RATIO, +) -> List[Dict[str, Any]]: + """Run both phases on one dataset item. Returns list of SFT samples. + + Input rows come from ``dataset.py`` (single assistant message) or + ``dataset_think.py`` (user query + assistant with reasoning_content). + For thinking-data rows, ``FIXED_QUERY_NEED`` is applied to the query + and ``FIXED_QUERY_SKILL`` to the CoT, skipping Phase-1 generation. + + ``done_sample_ids`` (full sample ids already on disk for this item) + lets resume skip queries that were already emitted, keyed by query + content hash so a phase-1 reorder still resolves correctly. + """ + done = done_sample_ids or set() + messages = item.get('messages') or [] + + # Detect thinking-data: user message + assistant with reasoning_content + user_query = '' + cot_text = '' + assistant_text = '' + for m in messages: + if not isinstance(m, dict): + continue + role = m.get('role', '') + if role == 'user' and not user_query: + user_query = (m.get('content') or '').strip() + elif role == 'assistant': + cot_text = (m.get('reasoning_content') or '').strip() + assistant_text = (m.get('content') or '').strip() + break + + item_id = item.get('id') + if not item_id: + return [] + source = item.get('source', 'unknown') + + # Thinking-data path: compress query and CoT separately with fixed queries + if user_query and cot_text: + pairs = [(user_query, FIXED_QUERY_NEED), (cot_text, FIXED_QUERY_SKILL)] + samples: List[Dict[str, Any]] = [] + for text, query in pairs: + if len(text) < 100: + continue + sample_id = f'{item_id}__{_query_hash(query)}' + if sample_id in done: + continue + compressed = compress_for_query(api, text, query, thinking_budget=thinking_budget) + if not compressed: + continue + sft_messages = [ + {'role': 'system', 'content': COMPRESS_SYSTEM_TRAIN}, + {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)}, + {'role': 'assistant', 'content': compressed}, + ] + samples.append({ + 'id': sample_id, + 'source': source, + 'query': query, + 'original_len': len(text), + 'compressed_len': len(compressed), + 'original_tokens': 0, + 'compressed_tokens': 0, + 'messages': sft_messages, + '__src': text, + '__cmp': compressed, + }) + return samples + + # Plain-data path: single assistant message + text = assistant_text + if not text or len(text) < 100: + return [] + + queries = generate_queries(api, text) + if not queries: + return [] + queries = queries[:2] + + # Mix in fixed queries for a proportion of items + if random.random() < fixed_query_ratio: + queries = list(FIXED_QUERIES) + + samples: List[Dict[str, Any]] = [] + for query in queries: + sample_id = f'{item_id}__{_query_hash(query)}' + if sample_id in done: + continue + compressed = compress_for_query(api, text, query, thinking_budget=thinking_budget) + if not compressed: + continue + sft_messages = [ + {'role': 'system', 'content': COMPRESS_SYSTEM_TRAIN}, + {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)}, + {'role': 'assistant', 'content': compressed}, + ] + samples.append({ + 'id': sample_id, + 'source': source, + 'query': query, + 'original_len': len(text), + 'compressed_len': len(compressed), + 'original_tokens': 0, + 'compressed_tokens': 0, + 'messages': sft_messages, + # Stashed for sparse tokenization on main thread; popped before write. + '__src': text, + '__cmp': compressed, + }) + return samples + + +# ═══════════════════════════════════════════════════════════════════════════════ +# I/O helpers +# ═══════════════════════════════════════════════════════════════════════════════ + +def iter_input(path: str) -> Iterator[Dict[str, Any]]: + """Stream JSONL dataset row-by-row (no full-file load).""" + with open(path, 'r', encoding='utf-8') as fh: + for line in fh: + line = line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError: + continue + + +def iter_dataset_py(total: Optional[int], load_from_cache_file: bool) -> Iterator[Dict[str, Any]]: + """Stream rows directly from ``dataset.py::get_dataset`` without any JSONL hop.""" + # Lazy import: dataset.py triggers HF / ModelScope downloads at module load. + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from dataset import get_dataset + hf = get_dataset(total=total, load_from_cache_file=load_from_cache_file) + sys.stderr.write(f'Loaded dataset.py::get_dataset: {len(hf)} rows\n') + for row in hf: + yield row + + +def iter_dataset_think_py(total: Optional[int], load_from_cache_file: bool) -> Iterator[Dict[str, Any]]: + """Stream rows from ``dataset_think.py::get_dataset`` (query + CoT data).""" + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from dataset_think import get_dataset + hf = get_dataset(total=total, load_from_cache_file=load_from_cache_file) + sys.stderr.write(f'Loaded dataset_think.py::get_dataset: {len(hf)} rows\n') + for row in hf: + yield row + + +def load_done_sample_ids(path: str) -> Set[str]: + """Collect already-written full sample ids (``base__hash``) for resume.""" + if not os.path.exists(path): + return set() + done: Set[str] = set() + with open(path, 'r', encoding='utf-8') as fh: + for line in fh: + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + sid = obj.get('id', '') + if sid: + done.add(sid) + return done + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════════════════════ + +def main() -> None: + parser = argparse.ArgumentParser( + description='Two-phase query-diverse condenser dataset builder.') + parser.add_argument('--input', default=None, + help='Optional JSONL override; default uses dataset.py::get_dataset') + parser.add_argument('--output', required=True, + help='Output JSONL file for SFT samples') + parser.add_argument('--total', type=int, default=0, + help='Total input rows for proportional scaling in dataset.py (0 = base sizes)') + parser.add_argument('--no-cache', action='store_true', + help='Disable load_from_cache_file when calling dataset.py::get_dataset') + parser.add_argument('--model', required=True, + help='API model name') + parser.add_argument('--api-key', default=os.environ.get('OPENAI_API_KEY')) + parser.add_argument('--base-url', default=os.environ.get('OPENAI_BASE_URL')) + parser.add_argument('--concurrency', type=int, default=32, + help='Number of parallel workers') + parser.add_argument('--limit', type=int, default=0, + help='Max items to process (0 = all)') + parser.add_argument('--thinking-budget', type=int, default=1024, + help='Max thinking tokens for phase-2 compress (shorter = faster, cheaper)') + parser.add_argument('--tokenizer', default='Qwen/Qwen3.5-4B', + help='HF/ModelScope tokenizer id for sparse token-ratio probe') + parser.add_argument('--tokenize-every', type=int, default=1000, + help='Tokenize one sample every N writes; others get tokens=0') + parser.add_argument('--fixed-query-ratio', type=float, default=FIXED_QUERY_RATIO, + help='Proportion of plain-data items using fixed queries instead of LLM-generated ones') + parser.add_argument('--source', choices=['think', 'plain', 'both'], default='think', + help='Data source: think=dataset_think.py (query+CoT), plain=dataset.py, both=chain both') + args = parser.parse_args() + + out_dir = os.path.dirname(args.output) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + done_sample_ids = load_done_sample_ids(args.output) + # Group done sample ids by base item id so each worker only sees its slice. + done_per_item: Dict[str, Set[str]] = {} + for sid in done_sample_ids: + if '__' in sid: + base = sid.rsplit('__', 1)[0] + done_per_item.setdefault(base, set()).add(sid) + sys.stderr.write( + f'Resume: {len(done_sample_ids)} samples on disk across ' + f'{len(done_per_item)} items.\n') + + api = OpenAI(model=args.model, api_key=args.api_key, base_url=args.base_url) + + from modelscope import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) + + def iter_pending() -> Iterator[Dict[str, Any]]: + if args.input: + source_iter = iter_input(args.input) + else: + import itertools + sources = [] + if args.source in ('plain', 'both'): + sources.append(iter_dataset_py( + total=args.total or None, + load_from_cache_file=not args.no_cache, + )) + if args.source in ('think', 'both'): + sources.append(iter_dataset_think_py( + total=args.total or None, + load_from_cache_file=not args.no_cache, + )) + source_iter = itertools.chain(*sources) + emitted = 0 + for it in source_iter: + iid = it.get('id') + if not iid: + sys.stderr.write('[skip] row missing "id" field\n') + continue + if args.limit > 0 and emitted >= args.limit: + return + yield it + emitted += 1 + + write_lock = threading.Lock() + out_fh = open(args.output, 'a', encoding='utf-8') + items_done = 0 + items_failed = 0 + samples_emitted = 0 + pbar = tqdm(desc='condense', unit='item', dynamic_ncols=True) + + items_iter = iter_pending() + in_flight: Dict[Any, str] = {} + # Sliding window: keep ~2x concurrency tasks queued so the pool never starves. + window = max(args.concurrency * 2, args.concurrency + 4) + + try: + with ThreadPoolExecutor(max_workers=args.concurrency) as ex: + exhausted = False + while True: + while not exhausted and len(in_flight) < window: + try: + it = next(items_iter) + except StopIteration: + exhausted = True + break + iid = it['id'] + fut = ex.submit( + process_item, api, it, done_per_item.get(iid), + args.thinking_budget, args.fixed_query_ratio, + ) + in_flight[fut] = iid + if not in_flight: + break + done, _ = wait(list(in_flight.keys()), return_when=FIRST_COMPLETED) + for fut in done: + iid = in_flight.pop(fut) + try: + samples = fut.result() + except Exception as exc: + sys.stderr.write(f'[item {iid}] crashed: {exc}\n') + items_failed += 1 + pbar.update(1) + continue + if not samples: + items_failed += 1 + pbar.update(1) + continue + with write_lock: + for s in samples: + src = s.pop('__src', '') + cmp = s.pop('__cmp', '') + samples_emitted += 1 + if (samples_emitted - 1) % args.tokenize_every == 0: + s['original_tokens'] = len(tokenizer(src).input_ids) + s['compressed_tokens'] = len(tokenizer(cmp).input_ids) + out_fh.write(json.dumps(s, ensure_ascii=False) + '\n') + out_fh.flush() + items_done += 1 + pbar.set_postfix( + done=items_done, failed=items_failed, + samples=samples_emitted, refresh=False, + ) + pbar.update(1) + finally: + out_fh.close() + pbar.close() + + sys.stderr.write( + f'Done. items_done={items_done}, samples={samples_emitted}, ' + f'failed={items_failed}\n') + + +if __name__ == '__main__': + main() diff --git a/cookbook/exp/reannotate_groundtruth.py b/cookbook/exp/reannotate_groundtruth.py new file mode 100644 index 00000000..137ebb4b --- /dev/null +++ b/cookbook/exp/reannotate_groundtruth.py @@ -0,0 +1,389 @@ +"""Re-annotate HotpotQA ground truth using a super-LLM to ensure correctness. + +The original HotpotQA dataset has annotation issues: + - GT doesn't match the question type (asks "where", GT gives a name) + - Partial/incomplete answers for multi-hop questions + - Single form when multiple valid forms exist (e.g. "2" vs "two") + - Question itself malformed (wrong question word, truncation, presupposition + mismatch with the answer type) + +This script: + 1. Loads HotpotQA fullwiki train split. + 2. By default (--only-forced), re-annotates ONLY the IDs listed in + wrong_ids.txt (the 340 known-bad cases). + Pass --no-only-forced to fall back to stratified 3000-per-level sampling + with wrong_ids force-included. + 3. For each row, sends question + full context + original GT to a super-LLM. + 4. The LLM emits one of four verdicts and (when applicable) a multi-form + answer list and/or a repaired question: + - keep: original Q + A are both correct + - fix_answer: Q is fine; A is wrong/incomplete + - fix_question: Q is malformed but repairable into a well-formed Q + that the same passages answer with the same gold facts + - drop: Q cannot be repaired without changing the fact, OR + passages do not support any answer + 5. Outputs ONE JSONL file containing all rows (including drop). Each row has + verdict, question, question_fixed, answers, reasoning. Downstream filters + by verdict. + +Run (re-clean wrong_ids.txt only, default): + python reannotate_groundtruth.py \ + --model qwen-max --api-key $OPENAI_API_KEY \ + --base-url https://dashscope.aliyuncs.com/compatible-mode/v1 \ + --output hotpotqa_reannotated_wrong.jsonl --concurrency 16 +""" +import argparse +import json +import os +import random +import re +import sys +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from twinkle.data_format.sampling import SamplingParams +from twinkle_agentic.protocol.openai import OpenAI + + +VERIFY_SYSTEM = """You are a dataset quality auditor for a multi-hop QA benchmark (HotpotQA). + +Given a Question, supporting Context passages, and the dataset's Original Answer, output ONE of four verdicts and a multi-form answer list grounded in the passages. + +VERDICTS +- "keep": original question + original answer are both correct. +- "fix_answer": question is fine; original answer is wrong/incomplete. +- "fix_question": question is malformed (wrong question word, broken grammar, truncated, or presupposition mismatch with the answer type) but can be REPAIRED into a well-formed question that the SAME passages answer with the SAME gold facts. +- "drop": question cannot be repaired without changing the underlying fact, OR the passages do not support any answer. + +MULTI-FORM ANSWER RULES (apply to keep / fix_answer / fix_question) +1. Output ALL acceptable surface forms whenever applicable: + - Number variants: arabic + english word + hyphen-prefix form (e.g. "3", "three", "three-door", "3-door") + - Range variants: start, end, and full range string (e.g. "1901", "1902", "1901-1902", "1901-2") + - Location variants: city / state-or-province / country (e.g. "Everett", "Washington", "WA", "United States") + - Person variants: legal name / nickname / full name (e.g. "Allan", "Heywood", "Allan Stewart Konigsberg") + - Entity-role pairs for role-of-X questions: BOTH the role AND the entity (e.g. "chauffeur", "Hitler's chauffeur") + - Show-vs-character pairs for best-known-for questions: BOTH the show AND the character (e.g. "M*A*S*H", "Major Frank Burns") + - Common abbreviations (e.g. "NYC", "New York City", "New York") + - With/without titles (e.g. "Dr. Smith", "Smith") + - Different date formats if applicable (e.g. "July 4, 1776", "4 July 1776") +2. Each answer is SHORT (a name, entity, number, date, or yes/no). +3. yes/no answers MUST be lowercase ["yes"] or ["no"]. +4. Do NOT hallucinate. Every answer must be grounded in the provided passages. + +QUESTION REWRITE RULES (verdict = fix_question) +1. question_fixed MUST be answerable by the SAME passages and yield the SAME factual answer as the original gold facts. +2. Allowed edits: swap question word (Where -> Did / Who / What), repair grammar, complete truncation, align question word with the answer type. +3. FORBIDDEN: changing intent, injecting the answer into the question, adding facts not in the passages. +4. If you cannot satisfy these constraints, downgrade to "drop". + +DROP RULES (verdict = drop) +- answers MUST be [] and question_fixed MUST be null. + +OUTPUT FORMAT (JSON only, no markdown fence, no explanation) +{"verdict": "keep|fix_answer|fix_question|drop", "question_fixed": "..." | null, "answers": ["..."], "reasoning": "one sentence"}""" + +VERIFY_USER = """## Question +{question} + +## Original Answer (may be wrong) +{original_answer} + +## Supporting Passages +{context} + +## Task +Audit the row per the system rules. Pick exactly one verdict (keep / fix_answer / fix_question / drop), produce the multi-form answers list (or [] for drop), and write a one-sentence reasoning. If verdict=fix_question, also produce question_fixed; otherwise set it to null. +Return a single JSON object only.""" + + +LEVELS: Tuple[str, str, str] = ('easy', 'medium', 'hard') + + +def _format_context(context: Dict[str, Any]) -> str: + titles = context.get('title', []) or [] + sentences = context.get('sentences', []) or [] + lines = [] + for i, (title, sents) in enumerate(zip(titles, sentences), start=1): + if isinstance(sents, list): + body = ' '.join(s.strip() for s in sents if s and s.strip()) + else: + body = str(sents).strip() + lines.append(f'[{i}] {title}: {body}') + return '\n\n'.join(lines) + + +_JSON_RE = re.compile(r'\{[^{}]*"verdict"\s*:\s*"[^"]+"[^{}]*"answers"\s*:\s*\[.*?\][^{}]*\}', re.DOTALL) + +_VALID_VERDICTS = ('keep', 'fix_answer', 'fix_question', 'drop') + + +def _parse_response(text: str) -> Optional[Dict[str, Any]]: + text = text.strip() + if text.startswith('```'): + first_nl = text.find('\n') + last_fence = text.rfind('```') + if first_nl != -1 and last_fence > first_nl: + text = text[first_nl + 1:last_fence].strip() + try: + obj = json.loads(text) + if isinstance(obj, dict) and 'answers' in obj: + return obj + except json.JSONDecodeError: + pass + m = _JSON_RE.search(text) + if m: + try: + return json.loads(m.group(0)) + except json.JSONDecodeError: + pass + return None + + +def _validate_verdict( + verdict: Optional[str], answers: List[str], + qfix: Optional[str], original_question: str, +) -> bool: + if verdict not in _VALID_VERDICTS: + return False + if verdict == 'drop': + return not answers and qfix is None + if not answers: + return False + if verdict == 'fix_question': + return bool(qfix) and qfix.strip() != original_question.strip() + return qfix is None + + +def verify_answer( + api: OpenAI, model: str, row: Dict[str, Any], +) -> Optional[Dict[str, Any]]: + question = row['question'] + original_answer = row.get('answer', '') or '' + context_str = _format_context(row.get('context', {}) or {}) + + user_content = VERIFY_USER.format( + question=question, + original_answer=original_answer, + context=context_str) + + trajectory = { + 'messages': [ + {'role': 'system', 'content': VERIFY_SYSTEM}, + {'role': 'user', 'content': user_content}, + ] + } + sp = SamplingParams(temperature=0.1, max_tokens=512) + + for attempt in range(3): + try: + reply = api(trajectory, sp, extra_body={'enable_thinking': True}) + except Exception as exc: + sys.stderr.write(f'[verify] {row["id"]}: API error: {exc}\n') + if attempt < 2: + continue + return None + + content = reply.get('content') or '' + parsed = _parse_response(content) + if parsed: + verdict = parsed.get('verdict') + answers_raw = parsed.get('answers') + answers = ( + [str(a).strip() for a in answers_raw if str(a).strip()] + if isinstance(answers_raw, list) else []) + qfix_raw = parsed.get('question_fixed') + qfix = (qfix_raw.strip() or None) if isinstance(qfix_raw, str) else None + if _validate_verdict(verdict, answers, qfix, question): + return { + 'id': row['id'], + 'verdict': verdict, + 'question': question, + 'question_fixed': qfix, + 'original_answer': original_answer, + 'answers': answers, + 'reasoning': parsed.get('reasoning', ''), + 'level': row.get('level', ''), + 'type': row.get('type', ''), + 'context': row.get('context', {}), + 'supporting_facts': row.get('supporting_facts', {}), + } + sys.stderr.write( + f'[verify retry {attempt+1}] {row["id"]}: ' + f'parse failed, content={content[:200]!r}\n') + + sys.stderr.write(f'[verify drop] {row["id"]}: all attempts failed\n') + return None + + +def stratified_sample_with_forced( + ds, per_level: Dict[str, int], forced_ids: frozenset, seed: int, +) -> List[Dict[str, Any]]: + rng = random.Random(seed) + buckets: Dict[str, List[int]] = {lv: [] for lv in LEVELS} + forced_indices: List[int] = [] + forced_levels: Dict[str, int] = {lv: 0 for lv in LEVELS} + + for i in range(len(ds)): + row_id = ds[i]['id'] + level = (ds[i].get('level') or '').strip().lower() + if row_id in forced_ids: + forced_indices.append(i) + if level in forced_levels: + forced_levels[level] += 1 + elif level in buckets: + buckets[level].append(i) + + picked_set = set(forced_indices) + for lv in LEVELS: + need = max(0, per_level[lv] - forced_levels[lv]) + pool = [idx for idx in buckets[lv] if idx not in picked_set] + if len(pool) < need: + sys.stderr.write( + f'Warning: level={lv} has {len(pool)} available, need {need}\n') + need = len(pool) + sampled = rng.sample(pool, need) + picked_set.update(sampled) + + picked = sorted(picked_set) + rng.shuffle(picked) + return [ds[int(i)] for i in picked] + + +def select_forced_only(ds, forced_ids: frozenset, seed: int) -> List[Dict[str, Any]]: + """Pick exactly the rows whose id is in forced_ids; warn on missing.""" + indices: List[int] = [] + found: set = set() + for i in range(len(ds)): + rid = ds[i]['id'] + if rid in forced_ids: + indices.append(i) + found.add(rid) + missing = forced_ids - found + if missing: + sys.stderr.write( + f'Warning: {len(missing)} forced ids not found in dataset, ' + f'e.g. {sorted(missing)[:5]}\n') + rng = random.Random(seed) + rng.shuffle(indices) + return [ds[int(i)] for i in indices] + + +def load_done_ids(path: str) -> set: + if not os.path.exists(path): + return set() + done = set() + with open(path, 'r', encoding='utf-8') as fh: + for line in fh: + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + rid = obj.get('id') + if rid: + done.add(rid) + return done + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument('--output', required=True) + parser.add_argument('--model', required=True) + parser.add_argument('--api-key', default=os.environ.get('OPENAI_API_KEY')) + parser.add_argument('--base-url', default=os.environ.get('OPENAI_BASE_URL')) + parser.add_argument('--total', type=int, default=12000) + parser.add_argument('--easy', type=int, default=2000) + parser.add_argument('--medium', type=int, default=4000) + parser.add_argument('--hard', type=int, default=6000) + parser.add_argument('--concurrency', type=int, default=16) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--wrong-ids', default='cookbook/rl/wrong_ids.txt') + parser.add_argument('--hf-subset', default='fullwiki') + parser.add_argument('--hf-split', default='train') + parser.add_argument( + '--only-forced', action=argparse.BooleanOptionalAction, default=False, + help='If set, re-annotate ONLY IDs in --wrong-ids; default is stratified sampling with wrong_ids force-included.') + args = parser.parse_args() + + forced_ids: frozenset = frozenset() + if args.wrong_ids and os.path.exists(args.wrong_ids): + with open(args.wrong_ids, 'r', encoding='utf-8') as fh: + forced_ids = frozenset(ln.strip() for ln in fh if ln.strip()) + sys.stderr.write(f'Forced IDs loaded: {len(forced_ids)}\n') + + if args.only_forced and not forced_ids: + raise ValueError( + f'--only-forced is set but no IDs loaded from {args.wrong_ids!r}') + + sys.stderr.write( + f'Loading hotpotqa/hotpot_qa:{args.hf_subset}:{args.hf_split}...\n') + ds = load_dataset( + 'hotpotqa/hotpot_qa', args.hf_subset, split=args.hf_split) + + if args.only_forced: + rows = select_forced_only(ds, forced_ids=forced_ids, seed=args.seed) + sys.stderr.write( + f'Selected {len(rows)} rows (only-forced mode, ' + f'requested={len(forced_ids)})\n') + else: + if args.easy + args.medium + args.hard != args.total: + raise ValueError( + f'--easy + --medium + --hard ({args.easy + args.medium + args.hard}) ' + f'must equal --total ({args.total})') + per_level = {'easy': args.easy, 'medium': args.medium, 'hard': args.hard} + rows = stratified_sample_with_forced( + ds, per_level=per_level, forced_ids=forced_ids, seed=args.seed) + sys.stderr.write( + f'Selected {len(rows)} rows (stratified per_level={per_level}, ' + f'forced={len(forced_ids)})\n') + + done = load_done_ids(args.output) + sys.stderr.write(f'Resume: {len(done)} rows already done, skipping.\n') + pending = [row for row in rows if row['id'] not in done] + sys.stderr.write(f'Pending: {len(pending)} / {len(rows)}\n') + + api = OpenAI( + model=args.model, api_key=args.api_key, base_url=args.base_url) + + write_lock = threading.Lock() + out_fh = open(args.output, 'a', encoding='utf-8') + rows_done = 0 + rows_failed = 0 + try: + with ThreadPoolExecutor(max_workers=args.concurrency) as ex: + futures = { + ex.submit(verify_answer, api, args.model, row): row['id'] + for row in pending + } + for fut in as_completed(futures): + rid = futures[fut] + try: + result = fut.result() + except Exception as exc: + sys.stderr.write(f'[row {rid}] crashed: {exc}\n') + rows_failed += 1 + continue + if result is None: + rows_failed += 1 + continue + with write_lock: + out_fh.write( + json.dumps(result, ensure_ascii=False) + '\n') + out_fh.flush() + rows_done += 1 + if rows_done % 100 == 0: + sys.stderr.write( + f'[progress] done={rows_done} ' + f'failed={rows_failed}\n') + finally: + out_fh.close() + + sys.stderr.write( + f'Done. rows_done={rows_done}, failed={rows_failed}, ' + f'total_pending={len(pending)}\n') + + +if __name__ == '__main__': + main() diff --git a/cookbook/exp/train_condenser_ddp.py b/cookbook/exp/train_condenser_ddp.py new file mode 100644 index 00000000..99723578 --- /dev/null +++ b/cookbook/exp/train_condenser_ddp.py @@ -0,0 +1,100 @@ +"""Ray LoRA SFT for the condenser model on condense_300K. + +Launch: + python cookbook/exp/train_condenser_ddp.py +""" +from pathlib import Path + +from peft import LoraConfig +from tqdm import tqdm + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import Preprocessor + +logger = get_logger() + +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = 'ms://twinkle-kit/condense_300K' +TEMPLATE_NAME = 'Qwen3_5Template' + +DP_SIZE = 8 +BATCH_SIZE = 8 +LEARNING_RATE = 1e-5 +GRADIENT_ACCUMULATION_STEPS = 8 +LOG_INTERVAL = 20 +EVAL_INTERVAL = 200 +EVAL_SAMPLES = 100 +NUM_EPOCHS = 1 + +OUTPUT_DIR = './output/condenser_ddp' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False +ADAPTER_NAME = 'default' + +class LegacySectionRenameProcessor(Preprocessor): + """Rewrite legacy `## Read inline` / `## Call extract_compressed for` headers to `## Summary` / `## More`.""" + + _REPLACEMENTS = ( + ('## Read inline', '## Summary'), + ('## Call extract_compressed for', '## More'), + ) + + def __call__(self, batch): + new_messages = [] + for msgs in batch['messages']: + patched = [] + for m in msgs: + content = m.get('content', '') or '' + for old, new in self._REPLACEMENTS: + content = content.replace(old, new) + patched.append({**m, 'content': content}) + new_messages.append(patched) + return {'messages': new_messages} + + +def build_dataset() -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta('/mnt/workspace/yzhao/tastelikefeet/condense_300K/train.jsonl')) + dataset.map(LegacySectionRenameProcessor(), remove_columns=[], num_proc=16) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID, max_length=40000, enable_thinking=False, truncation_strategy='delete') + dataset.encode(load_from_cache_file=True, num_proc=64) + return dataset + + +def train(): + device_groups = [DeviceGroup(name='model', ranks=DP_SIZE, device_type='GPU')] + model_mesh = DeviceMesh.from_sizes(world_size=DP_SIZE, dp_size=4, fsdp_size=2) + twinkle.initialize(mode='ray', nproc_per_node=DP_SIZE, groups=device_groups, global_device_mesh=model_mesh) + + dataset = build_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) + + model = TransformersModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model') + + model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) + total_optim_steps = (len(dataloader) * NUM_EPOCHS) // GRADIENT_ACCUMULATION_STEPS + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', num_warmup_steps=50, num_training_steps=total_optim_steps) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total micro-steps: {len(dataloader) * NUM_EPOCHS}, optim steps: {total_optim_steps}') + + for i in range(NUM_EPOCHS): + for cur_step, batch in enumerate(dataloader): + model.forward_backward(inputs=batch) + model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + if cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Step {cur_step}/{len(dataloader) * NUM_EPOCHS}, metric: {metric}') + if cur_step % 4000 == 0: + model.save(f'step_{cur_step}', output_dir=OUTPUT_DIR) + model.save('last_checkpoint', output_dir=OUTPUT_DIR) + + +if __name__ == '__main__': + train() diff --git a/cookbook/exp/train_embedding_full_ddp.py b/cookbook/exp/train_embedding_full_ddp.py new file mode 100644 index 00000000..46c53bf3 --- /dev/null +++ b/cookbook/exp/train_embedding_full_ddp.py @@ -0,0 +1,725 @@ +"""LoRA embedding training with online condenser self-improvement. + +Architecture (8 GPUs total): + - Ranks 0-3 (``model``): Trainable embedding model with LoRA, InfoNCE loss. + - Ranks 4-5 (``condenser_sampler``): Frozen vLLM condenser for online compression. + - Ranks 6-7 (``condenser_model``): Trainable condenser with LoRA for self-improvement. + +When the condenser sampler truncates (stop_reason='length'), an external OpenAI- +compatible API produces the correct compression. The failure is logged as SFT +training data. A background thread retrains the condenser on accumulated failures +mixed with condense_300K, then syncs weights back to the sampler. + +Launch: + python cookbook/exp/train_embedding_lora_ddp.py +""" +import hashlib +import json +import os +import random +import re +import sys +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional + +import swanlab +import torch + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.data_format import InputFeature, SamplingParams +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import InfonceLoss +from twinkle.metric import EmbeddingMetric +from twinkle.model import TransformersModel +from twinkle.processor import InputProcessor +from twinkle.sampler import vLLMSampler +from twinkle.template import Template +from twinkle.utils.parallel import PosixFileLock +from twinkle_agentic.protocol.openai import OpenAI as OpenAIClient + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from dataset_think import get_dataset # noqa: E402 + +logger = get_logger() + +# -- Backend selection -------------------------------------------------------- +BACKEND: Literal['transformers', 'megatron'] = 'transformers' + +# Condenser (online compression + LoRA self-improvement); embedding model trains LoRA on top of MODEL_ID. +CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2') +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +TEMPLATE_NAME = 'Qwen3_5Template' + +# -- GPU placement (8 total) -------------------------------------------------- +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +CONDENSER_SAMPLER_GPUS = int(os.environ.get('CONDENSER_SAMPLER_GPUS', 2)) +CONDENSER_MODEL_GPUS = int(os.environ.get('CONDENSER_MODEL_GPUS', 2)) +NUM_GPUS = MODEL_GPUS + CONDENSER_SAMPLER_GPUS + CONDENSER_MODEL_GPUS + +# -- Embedding training hyper-params ------------------------------------------ +EMB_MAX_LENGTH = 8192 +HARD_NEGATIVES = None +TEMPERATURE = 0.03 + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32)) +LEARNING_RATE = 1e-5 +GRADIENT_ACCUMULATION_STEPS = 1 +LOG_INTERVAL = 2 +SAVE_INTERVAL = 4000 +NUM_EPOCHS = 2 + +TOTAL_SAMPLES: Optional[int] = None + +# -- Online-compression knobs ------------------------------------------------- +# Below this length, condenser fabricates content for open-ended short prompts; +# query passes through as qr verbatim and cot rows are dropped from training. +MIN_TEXT_CHARS = 256 +DATASET_MAX_TOKENS = 32768 +COMPRESS_TEMPERATURE = 0.2 +COMPRESS_TOP_P = 0.5 +COMPRESS_MAX_MODEL_LEN = 32768 + +# -- OpenAI API fallback for truncated compressions --------------------------- +COMPRESS_API_KEY = os.environ.get('COMPRESS_API_KEY', '') +COMPRESS_BASE_URL = os.environ.get('COMPRESS_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1') +COMPRESS_MODEL = os.environ.get('COMPRESS_MODEL', 'qwen3.7-max') + +# -- Condenser retraining knobs ----------------------------------------------- +CONDENSER_DATASET_ID = 'ms://twinkle-kit/condense_300K' +CONDENSER_RETRAIN_SAMPLES = 128 +CONDENSER_RETRAIN_EPOCHS = 3 +CONDENSER_RETRAIN_LR = 1e-5 + +# -- Output paths ------------------------------------------------------------- +OUTPUT_DIR = f'./output/embedding_lora_{BACKEND}' +RESPONSE_LOG = os.environ.get('RESPONSE_LOG', f'./output/embedding_lora_{BACKEND}/responses.jsonl') +FAILURE_LOG = os.environ.get('FAILURE_LOG', f'./output/embedding_lora_{BACKEND}/failures.jsonl') + + +# ============================================================================= +# Prompts (from make_condenser_dataset.py — "## Summary" format) +# ============================================================================= + +COMPRESS_SYSTEM = """\ +You are a compression and summary assistant. For the (query, source) pair, emit a Markdown \ +answer with TWO sections, designed to pair with the `extract_compressed` tool: \ +the reader absorbs `## Summary` directly, then calls `extract_compressed` \ +on any topic-key listed under `## More` to recover its \ +fuller content. + + `## Summary` — extreme-density text the reader reads directly. + `## More` — a topic index whose keys are valid arguments \ +to `extract_compressed` for recovering material not captured inline. + +Together the two sections must form a COMPLETE, NON-DISTORTING inventory of the \ +source for the query — nothing essential lost, nothing implied that the source \ +does not support. NO preamble, NO meta-commentary, NO code fences wrapping the \ +whole output. + +Output skeleton: + +## Summary +Topic: <what the source is about + scope, one line> +<dense body answering the query> + +## More +- <topic-key>: <one-line hint of what is revealed when expanded> +- ... + +Format selection for the inline body (pick the MOST COMPACT form per query, mix \ +when helpful): +- Interface / signature → code notation directly: `func(a:int)->str` +- Factual / entity → telegraphic prose; drop function words; ":" for "is", "," \ +for "has" +- Skill / how-to / usage → lead with `Use when: <trigger>`; numbered telegraphic \ +steps `1.do X 2.then Y`; close with `Output: <result>` when relevant +- Procedural → numbered short steps +- Analytical / design → hierarchical bullets with abbreviations + +`## Summary` rules: +1. TOPIC LINE — line 1 is ALWAYS `Topic: <subject — scope>`, even when the \ +query is narrow. Anchors both the reader and the tool. +2. DENSITY — every token in the body carries query-relevant signal; cut filler. +3. PRIMARY-COMPLETE — never silently drop a fact essential to answering the \ +query. Anything cut for length MUST appear as a key under \ +`## More`. +4. NON-MISLEADING — phrasing must not let the reader infer anything the source \ +does not support; partial truths that mislead are worse than honest omissions \ +flagged in the index. +5. SELF-CONTAINED — the reader can act on the answer without re-opening the source. +6. FAITHFUL — only content the source supports; no fabrication, no extrapolation. +7. LANGUAGE — match the source language. +8. NO outer code fences around the whole answer; no meta-commentary. + +`## More` rules (MANDATORY — this section is never omitted): +1. FORMAT — each bullet is `- <topic-key>: <one-line hint>`: + • topic-key — short, unambiguous, grounded in source vocabulary so the \ +`extract_compressed` tool can locate the aspect (e.g. `decorators`, \ +`error handling`, `pitfalls`). + • hint — tells WHAT the reader gains by expanding (concrete numbers, code \ +listings, secondary cases, edge details, related context, …); do NOT restate \ +the inline answer. +2. CRITERION — each bullet names an aspect that EXISTS in the source but is \ +NOT fully captured inline. Material that genuinely fits inline without \ +distortion MUST NOT be duplicated here. +3. FAITHFUL — hints must be grounded in the source; never speculate or invent. +4. ORDER — by relevance to the query, then by importance. +5. EMPTY CASE — if the source is so short / single-purpose that everything \ +fits inline, write a single line `- (none)`. + +Now begin.\ +""" + +COMPRESS_USER = ( + 'Downstream model will read your compressed block to decide whether to ' + 'expand it. Compress faithfully: preserve the passage topic + core facts. ' + 'Do NOT invent facts. Do NOT drop major facts. Do NOT write meta-commentary ' + 'about the Query (never write "Query info: absent", "no X mention", etc.); ' + 'if the passage does not address the Query, still summarize the passage. ' + 'CRITICAL LANGUAGE RULE: detect the dominant language of the Passage ' + '(NOT the Query, NOT this instruction) and write the ENTIRE output in that ' + 'same language; English passage → English output, Chinese passage → ' + 'Chinese output, Japanese passage → Japanese output. NEVER translate, ' + 'NEVER mix languages, NEVER copy these instructions into the output.\n\n' + '## Query (ordering hint only — still summarize the whole passage)\n{query}\n\n' + '## Passage\n{text}') + + +# ============================================================================= +# Logging helpers +# ============================================================================= + +_response_lock: Optional[PosixFileLock] = None +_failure_lock: Optional[PosixFileLock] = None + +# Monotonic global sample id; per-batch index would alias across batches. +_sample_counter = 0 +_sample_counter_lock = threading.Lock() + + +def _next_sample_id() -> int: + global _sample_counter + with _sample_counter_lock: + sid = _sample_counter + _sample_counter += 1 + return sid + + +def _log_responses(query_resp_text: str, cot_resp_text: str, idx: int, + query_raw: str = '', cot_raw: str = ''): + global _response_lock + if _response_lock is None: + os.makedirs(os.path.dirname(RESPONSE_LOG) or '.', exist_ok=True) + _response_lock = PosixFileLock(RESPONSE_LOG + '.lock') + + record = { + 'idx': idx, + 'query_raw': query_raw, + 'cot_raw': cot_raw, + 'query_compressed': query_resp_text, + 'cot_compressed': cot_resp_text, + } + line = json.dumps(record, ensure_ascii=False, default=str) + '\n' + with _response_lock: + with open(RESPONSE_LOG, 'a', encoding='utf-8') as f: + f.write(line) + + +def _log_failure(source_text: str, query: str, compressed: str, batch_idx: int): + global _failure_lock + if _failure_lock is None: + os.makedirs(os.path.dirname(FAILURE_LOG) or '.', exist_ok=True) + _failure_lock = PosixFileLock(FAILURE_LOG + '.lock') + + qhash = hashlib.md5(query.strip().encode('utf-8')).hexdigest()[:8] + record = { + 'id': f'{batch_idx}__{qhash}', + 'source': 'online_failure', + 'query': query, + 'original_len': len(source_text), + 'compressed_len': len(compressed), + 'messages': [ + {'role': 'system', 'content': COMPRESS_SYSTEM}, + {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=source_text)}, + {'role': 'assistant', 'content': compressed}, + ], + } + line = json.dumps(record, ensure_ascii=False, default=str) + '\n' + with _failure_lock: + with open(FAILURE_LOG, 'a', encoding='utf-8') as f: + f.write(line) + + +# ============================================================================= +# Model builders +# ============================================================================= + +def build_model(device_mesh: DeviceMesh): + if BACKEND == 'transformers': + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=device_mesh, + remote_group='model', + ddp_config={'find_unused_parameters': True}, + ) + from twinkle.patch.no_split_modules import NoSplitModulesPatch + model.apply_patch(NoSplitModulesPatch({'Qwen3_5DecoderLayer'})) + return model + if BACKEND == 'megatron': + from twinkle.model import MegatronModel + return MegatronModel( + model_id=MODEL_ID, + device_mesh=device_mesh, + remote_group='model', + mixed_precision='bf16', + variable_seq_lengths=True, + ) + raise ValueError(f'Unknown BACKEND={BACKEND!r}') + + +def setup_optimizer(model, total_steps: int): + if BACKEND == 'transformers': + model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=200, + num_training_steps=total_steps, + ) + return + if BACKEND == 'megatron': + model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE) + model.set_lr_scheduler( + scheduler_cls='default', + lr_warmup_steps=50, + lr_decay_steps=total_steps, + ) + return + raise ValueError(f'Unknown BACKEND={BACKEND!r}') + + +def save_checkpoint(model, name: str): + model.save(name, output_dir=OUTPUT_DIR) + + +# ============================================================================= +# Compression prompt building +# ============================================================================= + +EMBED_QUERY_Q = ( + 'What problem does this passage address, and what skill or method is needed? ' + 'Topic must name the specific pattern, never generic labels. ' + 'Compress into a retrieval-friendly need description.') +EMBED_QUERY_COT = ( + 'Extract the reusable skill: trigger conditions, key steps, and expected output. ' + 'Topic names the method/pattern; format as "Use when: ...", numbered steps, ' + '"Output: ...". Compress into a standardized procedure for retrieval.') + + +def _extract_query_cot(row: Dict[str, Any]): + messages = row.get('messages') or [] + query, cot = '', '' + for m in messages: + if not isinstance(m, dict): + continue + role = m.get('role') or '' + if role == 'user' and not query: + query = (m.get('content') or '').strip() + elif role == 'assistant': + cot = (m.get('reasoning_content') or '').strip() + break + return query, cot + + +def _build_compress_prompts(rows: List[Dict[str, Any]]) -> tuple: + """Build prompts for compressing both query and cot per row. + + Returns (prompts, valid_indices, raw_pairs, prompt_queries, passthrough) where: + - prompts: flat-interleaved [query_0, cot_0, query_1, cot_1, ...]; ``None`` means + passthrough (use raw text directly, do not call sampler) + - valid_indices: which rows passed the min-length filter + - raw_pairs: [(query, cot), ...] + - prompt_queries: the query string used for each prompt (for failure logging) + - passthrough: parallel to prompts; non-None text means "use this verbatim as qc" + """ + prompts: List[Optional[Dict[str, Any]]] = [] + valid_indices: List[int] = [] + raw_pairs: List[tuple] = [] + prompt_queries: List[str] = [] + passthrough: List[Optional[str]] = [] + for i, row in enumerate(rows): + query, cot = _extract_query_cot(row) + if not query or len(cot) < MIN_TEXT_CHARS: + continue + valid_indices.append(i) + raw_pairs.append((query, cot)) + # Short query bypasses condenser to avoid skeleton-induced hallucination. + if len(query) < MIN_TEXT_CHARS: + prompts.append(None) + passthrough.append(query) + else: + user = COMPRESS_USER.format(query=EMBED_QUERY_Q, text=query) + prompts.append({'messages': [ + {'role': 'system', 'content': COMPRESS_SYSTEM}, + {'role': 'user', 'content': user}, + ]}) + passthrough.append(None) + prompt_queries.append(EMBED_QUERY_Q) + user = COMPRESS_USER.format(query=EMBED_QUERY_COT, text=cot) + prompts.append({'messages': [ + {'role': 'system', 'content': COMPRESS_SYSTEM}, + {'role': 'user', 'content': user}, + ]}) + prompt_queries.append(EMBED_QUERY_COT) + passthrough.append(None) + return prompts, valid_indices, raw_pairs, prompt_queries, passthrough + + +def _get_first_feature(decoded_text: str, template: Template, role: str) -> Optional[Dict[str, Any]]: + if not decoded_text: + return None + if role == 'anchor': + feat = template.encode({'messages': [ + {'role': 'user', 'content': decoded_text}, + {'role': 'assistant', 'content': 'Match the correct response here.'}, + ]}) + feat['labels'] = [1] + else: + feat = template.encode({'messages': [ + {'role': 'user', 'content': 'Match the correct query here.'}, + {'role': 'assistant', 'content': decoded_text}, + ]}) + feat['labels'] = [0] + return feat + + +# ============================================================================= +# OpenAI API fallback +# ============================================================================= + +def _is_truncated_compression(text: str) -> bool: + """Detect structurally incomplete output that vLLM may report as stop_reason='stop'. + + The condenser sometimes emits a chat-template token mid-skeleton (which we then + strip), so the visible text ends mid-sentence even though stop_reason!='length'. + The COMPRESS_SYSTEM skeleton mandates a `## More` section ending in a bullet list; + its absence is an unambiguous truncation signal. + """ + if not text or not text.strip(): + return True + if '## More' not in text or '## Summary' not in text: + return True + after_more = text.split('## More', 1)[1].strip() + if not after_more: + return True + last_line = after_more.splitlines()[-1].strip() + if not (last_line.startswith('-') or last_line.endswith(')')): + return True + return False + + +def _api_compress(api_client: OpenAIClient, prompt: Dict[str, Any]) -> Optional[str]: + """Call external API to compress when vLLM truncates.""" + trajectory = {'messages': prompt['messages']} + # Cap max_tokens to leave ample prompt headroom inside the API model context. + sp = SamplingParams(temperature=0.2, max_tokens=8192) + try: + reply = api_client(trajectory, sp, extra_body={'enable_thinking': False}) + except Exception as exc: + logger.warning(f'[api_fallback] error: {exc}') + return None + content = (reply.get('content') or '').strip() + if not content: + return None + # Strip outer code fence if present + m = re.match(r'^```[a-zA-Z]*\n(.*?)\n```\s*$', content, re.DOTALL) + if m: + content = m.group(1).strip() + return content + + +# ============================================================================= +# Condenser Retrainer (background thread) +# ============================================================================= + +class CondenserRetrainer: + """Async condenser self-improvement: retrains from failures, syncs to sampler.""" + + def __init__(self, condenser_model, ckpt_manager: CheckpointEngineManager, + condenser_sampler): + self._model = condenser_model + self._ckpt_manager = ckpt_manager + self._sampler = condenser_sampler + self._signal = threading.Event() + self._stop = threading.Event() + self._thread = threading.Thread(target=self._loop, daemon=True) + self._condense_300k_cache = None + self._retrain_count = 0 + # Prevents sample() and sync_weights() from running concurrently + self.sampler_lock = threading.Lock() + + def start(self): + self._thread.start() + + def stop(self): + self._stop.set() + self._signal.set() + self._thread.join(timeout=10) + + def notify_failure(self): + self._signal.set() + + def _loop(self): + while not self._stop.is_set(): + self._signal.wait(timeout=60) + if self._stop.is_set(): + break + if not self._signal.is_set(): + continue + self._signal.clear() + try: + self._retrain_and_sync() + except Exception as exc: + logger.error(f'[condenser_retrain] crashed: {exc}') + + def _retrain_and_sync(self): + # Retrain + sync temporarily disabled; failures.jsonl is written directly by _log_failure. + pass + + +# ============================================================================= +# Main training +# ============================================================================= + +def train(): + # -------- Device groups (3 groups) ---------------------------------------- + device_groups = [ + DeviceGroup(name='model', + ranks=list(range(MODEL_GPUS)), + device_type='GPU'), + DeviceGroup(name='condenser_sampler', + ranks=list(range(MODEL_GPUS, MODEL_GPUS + CONDENSER_SAMPLER_GPUS)), + device_type='GPU'), + DeviceGroup(name='condenser_model', + ranks=list(range(MODEL_GPUS + CONDENSER_SAMPLER_GPUS, NUM_GPUS)), + device_type='GPU'), + ] + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + condenser_sampler_mesh = DeviceMesh.from_sizes( + world_size=CONDENSER_SAMPLER_GPUS, dp_size=CONDENSER_SAMPLER_GPUS) + condenser_model_mesh = DeviceMesh.from_sizes( + world_size=CONDENSER_MODEL_GPUS, dp_size=1, fsdp_size=CONDENSER_MODEL_GPUS) + + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) + + # -------- Data ----------------------------------------------------------- + dataset = get_dataset(total=TOTAL_SAMPLES, load_from_cache_file=True) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) + total_forward_steps = len(dataloader) * NUM_EPOCHS + optimizer_steps = total_forward_steps // GRADIENT_ACCUMULATION_STEPS + + # -------- Embedding model (4 GPU) ---------------------------------------- + model = build_model(model_mesh) + model.set_processor(InputProcessor) + model.set_loss(InfonceLoss, temperature=TEMPERATURE, use_batch=True, + hard_negatives=HARD_NEGATIVES) + setup_optimizer(model, optimizer_steps) + model.add_metric(EmbeddingMetric, is_training=True) + + # -------- Condenser sampler (2 GPU, vLLM) -------------------------------- + emb_template = Template(model_id=MODEL_ID, max_length=EMB_MAX_LENGTH, enable_thinking=False) + # Special tokens come from the condenser tokenizer because the leak we strip is in its decoded output. + condenser_template = Template(model_id=CONDENSE_MODEL_ID, max_length=DATASET_MAX_TOKENS, + enable_thinking=False) + _special_tokens = set(condenser_template.processor.all_special_tokens) + condenser_sampler = vLLMSampler( + model_id=CONDENSE_MODEL_ID, + engine_args={ + 'gpu_memory_utilization': 0.8, + 'max_model_len': COMPRESS_MAX_MODEL_LEN, + }, + device_mesh=condenser_sampler_mesh, + remote_group='condenser_sampler', + ) + condenser_sampler.set_template( + TEMPLATE_NAME, model_id=CONDENSE_MODEL_ID, enable_thinking=False, + truncation_strategy='delete', max_length=DATASET_MAX_TOKENS) + compress_params = SamplingParams( + max_tokens=8192, + temperature=COMPRESS_TEMPERATURE, + top_p=COMPRESS_TOP_P, + num_samples=1, + ) + + # -------- Condenser model (2 GPU, trainable full-param) ------------------- + condenser_model = TransformersModel( + model_id=CONDENSE_MODEL_ID, + device_mesh=condenser_model_mesh, + remote_group='condenser_model', + ) + condenser_model.set_optimizer(optimizer_cls='AdamW', lr=CONDENSER_RETRAIN_LR) + + # -------- CheckpointEngineManager: condenser_model → condenser_sampler --- + condenser_ckpt_manager = CheckpointEngineManager( + model=condenser_model, sampler=condenser_sampler) + condenser_ckpt_manager.sync_weights() + + # -------- Background retrainer ------------------------------------------- + retrainer = CondenserRetrainer(condenser_model, condenser_ckpt_manager, + condenser_sampler) + retrainer.start() + + # -------- OpenAI API client for fallback --------------------------------- + api_client = OpenAIClient( + model=COMPRESS_MODEL, + api_key=COMPRESS_API_KEY, + base_url=COMPRESS_BASE_URL, + ) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total forward steps: {total_forward_steps}, optimizer steps: {optimizer_steps}') + + swanlab.init(project='twinkle', config={ + 'backend': BACKEND, + 'model_id': MODEL_ID, + 'condense_model_id': CONDENSE_MODEL_ID, + 'batch_size': BATCH_SIZE, + 'lr': LEARNING_RATE, + 'temperature': TEMPERATURE, + 'emb_max_length': EMB_MAX_LENGTH, + 'DATASET_MAX_TOKENS': DATASET_MAX_TOKENS, + }) + + # -------- Train loop ----------------------------------------------------- + def _sample_batch(raw_batch): + """Compress via vLLM sampler; fall back to API on truncation.""" + compress_prompts, valid_indices, raw_pairs, prompt_queries, passthrough = \ + _build_compress_prompts(raw_batch) + if not compress_prompts: + return None + + # Only submit non-passthrough prompts to the sampler. + sampler_input = [p for p in compress_prompts if p is not None] + sampler_pos = [ri for ri, p in enumerate(compress_prompts) if p is not None] + if sampler_input: + with retrainer.sampler_lock: + sampler_responses = condenser_sampler.sample(sampler_input, compress_params) + else: + sampler_responses = [] + responses = [None] * len(compress_prompts) + for resp, pos in zip(sampler_responses, sampler_pos): + responses[pos] = resp + + # Extract decoded texts; detect truncations and fall back to API + decoded_texts: List[str] = [] + for ri in range(len(compress_prompts)): + if passthrough[ri] is not None: + decoded_texts.append(passthrough[ri]) + continue + resp = responses[ri] + seq = resp.sequences[0] if resp and resp.sequences else None + text = '' + if seq and seq.stop_reason != 'length' and seq.decoded: + text = seq.decoded + for tok in _special_tokens: + text = text.replace(tok, '') + text = text.rstrip() + + # Premature-EOS: model emits chat-template token mid-skeleton, vLLM reports + # stop_reason='stop' but the stripped text is structurally incomplete. + needs_fallback = (not seq or seq.stop_reason == 'length' + or _is_truncated_compression(text)) + if not needs_fallback: + decoded_texts.append(text) + continue + + api_result = _api_compress(api_client, compress_prompts[ri]) + # Skip logging when the API itself produced truncated output: an incomplete + # gold answer would teach the condenser to imitate broken outputs. + if api_result and not _is_truncated_compression(api_result): + decoded_texts.append(api_result) + pair_idx = ri // 2 + q_raw, c_raw = raw_pairs[pair_idx] + source_text = q_raw if ri % 2 == 0 else c_raw + _log_failure(source_text, prompt_queries[ri], api_result, + valid_indices[pair_idx]) + retrainer.notify_failure() + else: + decoded_texts.append('') + + # Build embedding features from decoded texts + emb_features: List[Dict[str, Any]] = [] + for i in range(0, len(decoded_texts), 2): + q_text = decoded_texts[i] + c_text = decoded_texts[i + 1] + q_raw, c_raw = raw_pairs[i // 2] + _log_responses(q_text, c_text, _next_sample_id(), + query_raw=q_raw, cot_raw=c_raw) + feat_q = _get_first_feature(q_text, emb_template, role='anchor') + feat_c = _get_first_feature(c_text, emb_template, role='positive') + if feat_q and feat_c: + emb_features.append(feat_q) + emb_features.append(feat_c) + + if len(emb_features) < 4: + return None + return emb_features + + cur_step = 0 + prefetch_executor = ThreadPoolExecutor(max_workers=1) + for epoch in range(NUM_EPOCHS): + batch_iter = iter(dataloader) + prefetch_future = None + first_batch = next(batch_iter, None) + if first_batch is not None: + prefetch_future = prefetch_executor.submit(_sample_batch, first_batch) + + for raw_batch in batch_iter: + emb_features = prefetch_future.result() if prefetch_future else None + prefetch_future = prefetch_executor.submit(_sample_batch, raw_batch) + + if emb_features is None: + continue + + model.forward_backward(inputs=emb_features, task='embedding') + model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + cur_step += 1 + + if cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + logger.info( + f'Epoch {epoch} Step {cur_step}/{total_forward_steps}, metric: {metric}') + log_dict = {} + for k, v in metric.items(): + if not v: + continue + try: + log_dict[k] = float(v) + except (ValueError, TypeError): + pass + log_dict['epoch'] = epoch + swanlab.log(log_dict, step=cur_step) + if cur_step % SAVE_INTERVAL == 0: + save_checkpoint(model, f'step_{cur_step}') + + # # Drain last prefetched batch + # if prefetch_future is not None: + # emb_features = prefetch_future.result() + # if emb_features is not None: + # model.forward_backward(inputs=emb_features, task='embedding') + # model.clip_grad_and_step() + # cur_step += 1 + + prefetch_executor.shutdown(wait=False) + retrainer.stop() + save_checkpoint(model, 'last-checkpoint') + + +if __name__ == '__main__': + train() diff --git a/cookbook/exp/train_extract_ddp.py b/cookbook/exp/train_extract_ddp.py new file mode 100644 index 00000000..38d3c1f5 --- /dev/null +++ b/cookbook/exp/train_extract_ddp.py @@ -0,0 +1,119 @@ +"""DDP LoRA SFT for the policy on hotpotqa_distractor_reannotated_sft_12k.jsonl. + +The JSONL is the output of ``cookbook/rl/make_condensed_sft.py``: each row +already carries ``messages`` (system / user / assistant with textual +``<tool_call>`` blocks / tool) plus an OpenAI-shape ``tools`` schema, ready +for ``Qwen3_5Template`` to render. ``enable_thinking=False`` matches the +RL runtime contract. + +Launch: + torchrun --nproc_per_node=8 cookbook/rl/train_condensed_sft_ddp.py +""" +from pathlib import Path + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel + +logger = get_logger() + +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_PATH = str( + Path(__file__).resolve().parent.parent.parent + / 'hotpotqa_distractor_reannotated_sft_12k.jsonl') +TEMPLATE_NAME = 'Qwen3_5Template' +# Multi-hop with compressed context + multi-turn extract_condensed CoT; +# raw audit: most samples land well under 16k after condensation. +MAX_LENGTH = 32000 + +DP_SIZE = 8 +BATCH_SIZE = 16 +LEARNING_RATE = 1e-4 +GRADIENT_ACCUMULATION_STEPS = 2 +LOG_INTERVAL = 20 +NUM_EPOCHS = 2 + +OUTPUT_DIR = './output/condensed_sft_ddp' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False +ADAPTER_NAME = 'default' + +device_mesh = DeviceMesh.from_sizes(dp_size=DP_SIZE) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def build_dataset(num_samples: int = None) -> Dataset: + meta_kwargs = {} + if num_samples is not None: + meta_kwargs['data_slice'] = range(num_samples) + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, **meta_kwargs)) + # ``truncation_strategy='delete'`` drops overlong rows instead of slicing — + # a sliced multi-turn trajectory would lose `\boxed{}` and break SFT signal. + dataset.set_template( + TEMPLATE_NAME, + model_id=MODEL_ID, + max_length=MAX_LENGTH, + truncation_strategy='delete', + enable_thinking=False) + dataset.encode(load_from_cache_file=True, num_proc=16) + return dataset + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def train(): + dataset = build_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + model = TransformersModel(model_id=MODEL_ID, ddp_config={'find_unused_parameters': True}) + model.model._no_split_modules = {'Qwen3_5DecoderLayer'} + + lora_config = LoraConfig(r=16, lora_alpha=32, target_modules='all-linear') + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=50, + num_training_steps=len(dataloader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {'adapter_name': ADAPTER_NAME} if ADAPTER_NAME else {} + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader) * NUM_EPOCHS}') + + optimizer_group = model.optimizer_group[ADAPTER_NAME] + + for epoch in range(NUM_EPOCHS): + for batch in dataloader: + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + cur_step = optimizer_group.cur_step + if cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Epoch {epoch} Step {cur_step}/{len(dataloader) * NUM_EPOCHS}, metric: {metric}') + save_checkpoint(model, f'epoch-{epoch}', dataloader) + save_checkpoint(model, 'last-checkpoint', dataloader) + + +if __name__ == '__main__': + train() diff --git a/cookbook/exp/train_streaming_sft.py b/cookbook/exp/train_streaming_sft.py new file mode 100644 index 00000000..2bdc58ff --- /dev/null +++ b/cookbook/exp/train_streaming_sft.py @@ -0,0 +1,383 @@ +"""Streaming SFT with QualityPreprocessor on a streaming IterableDataset (Ray mode). + +Architecture (8 GPUs single-node): + GPU 0-3: LoRA SFT training (4x DP) + GPU 4-7: vLLMSampler Ray actor (same model, for QualityPreprocessor) + +QualityPreprocessor phases (intent, IFD, refine) use SamplerBackend +which calls vLLMSampler directly via Ray (no HTTP overhead). + +Two output files are produced: + - trained_data.jsonl: write-through of rows actually consumed by training + - dropped_data.jsonl: rows dropped by QualityPreprocessor (with step annotation) + +Launch: + python cookbook/exp/train_streaming_sft.py +""" +import json +import os +from pathlib import Path +from typing import Any, Dict, Iterator, List +from functools import partial +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset +from twinkle.dataset.base import DatasetMeta +from twinkle.model import TransformersModel +from twinkle.sampler import vLLMSampler +from twinkle.template import Qwen3_5Template +from twinkle_agentic.preprocessor import ( + QualityPreprocessor, SamplerBackend, + IntentClassifier, ResponseRefiner, ScoreFilter, + HardFilter, RefuseFilter, AgentTraceFilter, DeadLoopFilter, TokenSoupFilter, MessageSanityFilter, + WordRepeatFilter, CharRepeatFilter, SpecialCharsFilter, AlphanumericFilter, + FlaggedWordsFilter, MinHashDedupFilter, PIIPresidioFilter, +) +from twinkle_agentic.preprocessor.score_filter import ( + ChrMinScorer, PassNScorer, ParaphraseScorer, +) + +logger = get_logger() + +# ── Model ──────────────────────────────────────────────────────────────────── +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +MODEL_LOCAL_PATH = os.environ.get('MODEL_LOCAL_PATH', 'Qwen/Qwen3.5-4B') +TEMPLATE_NAME = 'Qwen3_5Template' +MAX_LENGTH = 40000 + +# ── GPU allocation ─────────────────────────────────────────────────────────── +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +# ── Training ───────────────────────────────────────────────────────────────── +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) +LEARNING_RATE = float(os.environ.get('LR', 1e-4)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRAD_ACCUM', 8)) +LOG_INTERVAL = 1 +SAVE_INTERVAL = 500 +NUM_STEPS = int(os.environ.get('NUM_STEPS', 5000)) + +# ── Output ─────────────────────────────────────────────────────────────────── +OUTPUT_DIR = './output/streaming_sft' +TRAINED_DATA_PATH = os.path.join(OUTPUT_DIR, 'trained_data.jsonl') +DROPPED_DATA_PATH = os.path.join(OUTPUT_DIR, 'dropped_data.jsonl') +ADAPTER_NAME = 'default' + +# ── Data source ────────────────────────────────────────────────────────────── +CSV_PATH = os.environ.get( + 'CSV_PATH', '/mnt/workspace/yzhao/tastelikefeet/bc/ds_csv/data/20260531.csv') +DATASET_TOTAL = int(os.environ.get('DATASET_TOTAL', 1000)) # 0 = full materialized dataset +# Worker count for HF Dataset.map(num_proc=N); spawn start method is forced in twinkle.dataset.base. +MAP_NUM_PROC = int(os.environ.get('MAP_NUM_PROC', 1)) + + +def _canonicalize_tool_call(tc: Any) -> Dict[str, Any]: + """Coerce ``tool_calls[i]`` to a fixed-schema dict for stable Arrow inference. + + Keeps ``function.arguments`` as the OpenAI-native JSON string so every row + sees a uniform ``string`` field; any string→dict decoding is the + chat_template's concern (see ``Template._apply_chat_template``). + + The decoded form is enforced to be a JSON object so the chat_template's + ``|items`` filter never receives list/scalar/null — those originate from + dirty CSV rows and are coerced to ``{}`` here, the ingestion boundary. + """ + tc = tc if isinstance(tc, dict) else {} + fn = tc.get('function') if isinstance(tc.get('function'), dict) else {} + args = fn.get('arguments') + if isinstance(args, dict): + args_str = json.dumps(args, ensure_ascii=False) + elif isinstance(args, str) and args.strip(): + try: + decoded = json.loads(args) + except json.JSONDecodeError: + decoded = {} + if not isinstance(decoded, dict): + decoded = {} + args_str = json.dumps(decoded, ensure_ascii=False) + else: + args_str = '{}' + return { + 'id': str(tc.get('id') or ''), + 'type': str(tc.get('type') or 'function'), + 'function': { + 'name': str(fn.get('name') or ''), + 'arguments': args_str, + }, + } + + +def _stream_csv_rows(csv_path: str, max_rows: int = 0) -> Iterator[Dict[str, Any]]: + """Stream the custom CSV: each line is `ts,model,req_id,messages_json` (no quoting). + + The first 3 fields are scalar; the remainder of the line is a JSON array of + chat messages, possibly containing commas — so we split on the first 3 commas only. + ``max_rows`` caps the yielded rows at ingestion time so Arrow never materializes + the unused tail. + """ + emitted = 0 + with open(csv_path, 'rb') as f: + bad_bytes = 0 + for raw in f: + try: + line = raw.decode('utf-8').rstrip('\n').rstrip('\r') + except UnicodeDecodeError: + bad_bytes += 1 + continue + if not line: + continue + parts = line.split(',', 3) + if len(parts) < 4: + continue + ts, _model, req_id, msgs_raw = parts + try: + raw_msgs = json.loads(msgs_raw) + except json.JSONDecodeError: + continue + messages: List[Dict[str, Any]] = [] + for m in raw_msgs: + role = m.get('role', '') + content = m.get('content') + # User content arrives as [{'type':'text','text':...}, ...]; flatten to plain string. + if isinstance(content, list): + content = ''.join( + p.get('text', '') for p in content + if isinstance(p, dict) and p.get('type') == 'text') + if content is None: + content = '' + if not isinstance(content, str): + continue + raw_tcs = m.get('tool_calls') if role == 'assistant' else None + tc_list = [_canonicalize_tool_call(tc) for tc in raw_tcs] if raw_tcs else [] + if role == 'assistant': + if not content and not tc_list: + continue + if m.get('reasoning_content'): + content = f"<think>{m['reasoning_content']}</think>{content}" + elif role == 'tool': + pass + elif not content: + continue + # tool_calls stored as JSON string (empty -> ''): keeps Arrow schema as a + # stable Value(string) regardless of empty-list / heterogeneous-struct shards. + # Template._apply_chat_template decodes it back to list before jinja render. + messages.append({ + 'role': role, + 'content': content, + 'tool_calls': json.dumps(tc_list, ensure_ascii=False) if tc_list else '', + 'tool_call_id': str(m.get('tool_call_id') or '') if role == 'tool' else '', + }) + if not messages: + continue + yield { + 'id': f'csv__{ts}__{req_id}', + 'source': Path(csv_path).stem, + 'messages': messages, + 'user_data': {}, + } + emitted += 1 + if max_rows and emitted >= max_rows: + break + + +# ── QualityPreprocessor config ─────────────────────────────────────────────── +SENSITIVE_WORDS_FILE = str( + Path(__file__).resolve().parent.parent.parent / 'sensitive_words.txt') +# chr_min cutoff: keep round if chr_min < threshold (low chr_min = hard). +CHR_MIN_THRESHOLD = float(os.environ.get('CHR_MIN_THRESHOLD', 0.5)) +REFINE_TEMPERATURE = float(os.environ.get('REFINE_TEMPERATURE', 0.6)) +REFINE_MAX_TOKENS = int(os.environ.get('REFINE_MAX_TOKENS', 4096)) + +# ── Pass@4 LLM-as-judge (grades each diagnostic rollout vs GT) ─────────────── +# Set JUDGE_MODEL='' to disable; otherwise judge runs over every diagnostic round. +JUDGE_MODEL = os.environ.get('JUDGE_MODEL', 'qwen3.7-max') +JUDGE_BASE_URL = os.environ.get('JUDGE_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1') +JUDGE_API_KEY = os.environ.get('JUDGE_API_KEY', 'EMPTY') +JUDGE_TEMPERATURE = float(os.environ.get('JUDGE_TEMPERATURE', 0.3)) +JUDGE_MAX_TOKENS = int(os.environ.get('JUDGE_MAX_TOKENS', 32000)) +JUDGE_MAX_WORKERS = int(os.environ.get('JUDGE_MAX_WORKERS', 16)) + + +def build_dataset(backend: SamplerBackend) -> Dataset: + """Materialize the local CSV, convert to SFT messages format, run QualityPreprocessor. + + Switched from streaming IterableDataset to in-memory Dataset so HF + `Dataset.map(num_proc=N)` can parallelize the QualityPreprocessor pipeline. + """ + os.makedirs(OUTPUT_DIR, exist_ok=True) + + # Custom CSV format (commas inside JSON) — feed framework via callable, not csv loader. + meta = DatasetMeta( + dataset_id=Path(CSV_PATH).stem, + data=partial(_stream_csv_rows, csv_path=CSV_PATH, max_rows=DATASET_TOTAL), + ) + dataset = Dataset(meta) + # template kept for future re-enablement of ScoreFilter; unused in current pipeline. + _ = Qwen3_5Template(model_id=MODEL_ID, max_length=MAX_LENGTH, + truncation_strategy='delete', + enable_thinking=False) + + qp = QualityPreprocessor( + pipeline=[ + # Phase 1-5: deterministic structural filters + HardFilter(min_user_chars_cjk=14, min_user_chars=24), + RefuseFilter(), + # Tag agent rollouts (Cline / OpenClaw / Claude Code) so DeadLoop + # / sanity rules can adapt instead of mass-dropping them. + AgentTraceFilter(), + DeadLoopFilter(), + MessageSanityFilter(sensitive_words_file='.temp/sensitive_words.txt'), + # Phase 8-10: repetition & character quality + WordRepeatFilter(), + CharRepeatFilter(), + SpecialCharsFilter(max_ratio=0.6), + # TokenSoupFilter samples head only — signals are uniform/statistical, no need to scan multi-MB tool payloads. + TokenSoupFilter(max_chars=8000), + AlphanumericFilter(), + FlaggedWordsFilter(), + # MinHashDedupFilter(), + IntentClassifier(), + # ScoreFilter temporarily disabled — reuses Ray vLLMSampler backend + # which is incompatible with HF Dataset.map(num_proc>1) workers. + # ScoreFilter( + # template=template, + # backend=backend, + # scorers=[ + # ChrMinScorer(), + # # PassNScorer( + # # backend=backend, + # # judge_model=JUDGE_MODEL or None, + # # judge_base_url=JUDGE_BASE_URL, + # # judge_api_key=JUDGE_API_KEY, + # # n=4, + # # min_pass=0, + # # sample_temperature=0.7, + # # sample_max_tokens=4096, + # # judge_temperature=JUDGE_TEMPERATURE, + # # judge_max_tokens=JUDGE_MAX_TOKENS, + # # judge_max_workers=JUDGE_MAX_WORKERS, + # # ), + # # ParaphraseScorer( + # # backend=backend, + # # template=template, + # # ), + # ], + # # trace_dir=os.path.join(OUTPUT_DIR, 'score_traces'), + # ), + # PIIPresidioFilter(languages=('en', 'zh')), + # Phase 13: response refinement + # ResponseRefiner( + # backend=backend, + # temperature=REFINE_TEMPERATURE, + # max_tokens=REFINE_MAX_TOKENS, + # max_workers=8, + # ), + ], + dropped_log_path=DROPPED_DATA_PATH, + ) + dataset.map(qp, num_proc=MAP_NUM_PROC, load_from_cache_file=False) + + dataset.set_template( + TEMPLATE_NAME, + model_id=MODEL_ID, + max_length=MAX_LENGTH, + truncation_strategy='delete', + enable_thinking=False, + ) + dataset.encode(num_proc=MAP_NUM_PROC, load_from_cache_file=False) + + return dataset + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def train(): + # ── Ray mode: GPUs 0-3 for training, GPUs 4-7 for vLLMSampler ──────────── + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=2), + ] + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS // 2, fsdp_size=2) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS // 2, tp_size=2) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + + # ── vLLMSampler on GPUs 4-7 (Ray actor, no HTTP overhead) ──────────────── + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'gpu_memory_utilization': 0.6, + 'max_model_len': MAX_LENGTH, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + backend = SamplerBackend(sampler) + logger.info(f'vLLMSampler ready on GPUs {MODEL_GPUS}-{NUM_GPUS - 1}') + + # ── Dataset with full QualityPreprocessor (uses SamplerBackend) ─────────── + dataset = build_dataset(backend) + dataloader = DataLoader( + dataset=dataset, + batch_size=BATCH_SIZE, + ) + + # ── Model (LoRA on 4 GPUs) ──────────────────────────────────────────────── + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + + lora_config = LoraConfig(r=16, lora_alpha=32, target_modules='all-linear') + model.add_adapter_to_model( + ADAPTER_NAME, lora_config, + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=50, + num_training_steps=NUM_STEPS) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {NUM_STEPS}, model GPUs: {MODEL_GPUS}, sampler GPUs: {SAMPLER_GPUS}') + + for cur_step, batch in enumerate(dataloader): + + print([len(m['input_ids']) for m in batch]) + if cur_step == 17: + print() + + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + + if cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Step {cur_step}/{NUM_STEPS}, metric: {metric}') + + if cur_step % SAVE_INTERVAL == 0: + save_checkpoint(model, f'step-{cur_step}', dataloader) + + if cur_step >= NUM_STEPS: + break + + save_checkpoint(model, 'last-checkpoint', dataloader) + logger.info(f'Training complete. Trained data saved to: {TRAINED_DATA_PATH}') + logger.info(f'Dropped data saved to: {DROPPED_DATA_PATH}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/short_math_grpo.py index 5e107b0a..17c16349 100644 --- a/cookbook/rl/short_math_grpo.py +++ b/cookbook/rl/short_math_grpo.py @@ -4,10 +4,12 @@ Uses short reasoning format: shorter thinking gets higher format reward. Answer extracted from \\boxed{} or #### format. """ +import math import os import re -from typing import List, Tuple, Dict, Any +from typing import List, Tuple, Dict, Any, Optional +import swanlab from peft import LoraConfig import twinkle @@ -23,6 +25,7 @@ from twinkle.reward import GSM8KAccuracyReward from twinkle.reward.base import Reward from twinkle.sampler import vLLMSampler +from twinkle.template import Qwen3_5Template from twinkle.preprocessor.llm import GSM8KProcessor logger = get_logger() @@ -47,9 +50,18 @@ SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) LORA_RANK = int(os.environ.get('LORA_RANK', 16)) +GSM8K_MAX_LENGTH = int(os.environ.get('GSM8K_MAX_LENGTH', 4096)) + +KL_BETA = float(os.environ.get('KL_BETA', 0.0)) +ENTROPY_COEF = float(os.environ.get('ENTROPY_COEF', 0.0)) +CISPO_EPS_LOW = float(os.environ.get('CISPO_EPS_LOW', 0.2)) +CISPO_EPS_HIGH = float(os.environ.get('CISPO_EPS_HIGH', 0.2)) +HIGH_KL_TOPK = int(os.environ.get('HIGH_KL_TOPK', 0)) + SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' 'and put your final answer within \\boxed{}.') + # ========== Reward Functions ========== class GSM8KBrevityReward(Reward): """Brevity reward: rewards shorter completions that contain a valid answer. @@ -88,7 +100,8 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: def create_gsm8k_dataset(): dataset = Dataset() dataset.add_dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=False) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=GSM8K_MAX_LENGTH, + truncation_strategy='delete', enable_thinking=False) dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) dataset.encode(add_generation_prompt=True) return dataset @@ -106,8 +119,52 @@ def compute_rewards( return total_rewards, brevity_rewards, accuracy_rewards +# ========== Diagnostics ========== +_LEADING_NUMBER_RE = re.compile(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?') + + +def _coerce_for_swanlab(log_dict: Dict[str, Any]) -> Dict[str, Any]: + """Cast string-valued metrics to float for swanlab line charts.""" + coerced: Dict[str, Any] = {} + for k, v in log_dict.items(): + if isinstance(v, bool) or isinstance(v, (int, float)): + coerced[k] = v + continue + if isinstance(v, str): + m = _LEADING_NUMBER_RE.search(v) + if m: + try: + coerced[k] = float(m.group()) + continue + except ValueError: + pass + coerced[k] = v + return coerced + + +def _logp_split_diagnostics( + accuracy_rewards: List[float], + old_logps: List[List[float]], +) -> Dict[str, float]: + """Split mean old-logp by accuracy outcome (pos vs zero).""" + out: Dict[str, float] = {} + if not accuracy_rewards or not old_logps: + return out + per_traj_mean = [(sum(lp) / len(lp)) if lp else 0.0 for lp in old_logps] + pos_logp = [m for m, a in zip(per_traj_mean, accuracy_rewards) if a > 0] + zero_logp = [m for m, a in zip(per_traj_mean, accuracy_rewards) if a <= 0] + out['acc_correct_rate'] = len(pos_logp) / len(accuracy_rewards) + out['mean_old_logp_acc_pos'] = (sum(pos_logp) / len(pos_logp)) if pos_logp else 0.0 + out['mean_old_logp_acc_zero'] = (sum(zero_logp) / len(zero_logp)) if zero_logp else 0.0 + out['policy_confidence_acc_pos'] = math.exp(out['mean_old_logp_acc_pos']) + out['policy_confidence_acc_zero'] = math.exp(out['mean_old_logp_acc_zero']) + return out + + # ========== Main ========== def main(): + swanlab.init(project='twinkle') + device_groups = [ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), @@ -117,7 +174,6 @@ def main(): sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) - # Since we are training on text-only data, we avoid using 'all-linear' which would include the ViT layers. lora_config = LoraConfig( target_modules='all-linear', r=LORA_RANK, @@ -149,16 +205,21 @@ def main(): model.set_optimizer('AdamW', lr=LEARNING_RATE) model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) - model.set_loss('GRPOLoss', epsilon=0.2) + model.set_loss('GRPOLoss', epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH, + beta=KL_BETA, entropy_coef=ENTROPY_COEF) model.set_processor(InputProcessor, padding_free=True) model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) + model.add_metric('GRPOMetric', is_training=True, + epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH, + top_k_kl=HIGH_KL_TOPK) + sampler = vLLMSampler( model_id=MODEL_ID, engine_args={ 'gpu_memory_utilization': 0.8, 'max_model_len': 8192, - 'max_lora_rank': 32, # save as lora_config + 'max_lora_rank': 32, 'enable_lora': True, 'enable_tower_connector_lora': True, }, @@ -166,6 +227,7 @@ def main(): remote_group='sampler', ) sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) + rollout_template = Qwen3_5Template(MODEL_ID, max_length=GSM8K_MAX_LENGTH, enable_thinking=False) ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) @@ -180,7 +242,10 @@ def main(): advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95) + sampling_params = SamplingParams( + max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, + temperature=1.0, top_p=0.95, + include_stop_str_in_output=True) optim_step = 0 logger.info('Starting GSM8K GRPO training (short reasoning)') @@ -190,21 +255,17 @@ def main(): if optim_step >= MAX_STEPS: break + batch_step = optim_step + metrics.reset() expand_prompts = [] for prompt in batch: expand_prompts.extend([prompt] * NUM_GENERATIONS) - # enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False) - # meaning only sync lora weights, if merge_and_sync=True, - # lora will be merged into the base model and sync all weights to vLLM ckpt_manager.sync_weights(merge_and_sync=False) sampler.reset_prefix_cache() - sample_responses = sampler.sample( - expand_prompts, - sampling_params, - ) + sample_responses = sampler.sample(expand_prompts, sampling_params) all_input_data: List[Dict[str, Any]] = [] all_old_logps: List[List[float]] = [] @@ -218,6 +279,15 @@ def main(): total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data) + rollout_advantages = advantage_fn( + total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() + + all_acc_labels: List[bool] = [a > 0 for a in accuracy_rewards] + n_pos = sum(1 for p in all_acc_labels if p) + n_neg = sum(1 for p in all_acc_labels if not p) + pos_with_neg_adv = sum(1 for p, a in zip(all_acc_labels, rollout_advantages) if p and a < 0) + neg_with_pos_adv = sum(1 for p, a in zip(all_acc_labels, rollout_advantages) if not p and a > 0) + metrics.accumulate( completion_lengths=all_completion_lengths, rewards={ @@ -227,19 +297,32 @@ def main(): }, ) - advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() - total_completions = len(all_input_data) - for mb_start in range(0, total_completions, MINI_BATCH_SIZE): - mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) + aligned_completions = (total_completions // MODEL_GPUS) * MODEL_GPUS + if aligned_completions < total_completions: + logger.info( + '[dp-align] dropping %d tail sample(s): total=%d -> aligned=%d (dp=%d)', + total_completions - aligned_completions, + total_completions, aligned_completions, MODEL_GPUS) + + for mb_start in range(0, aligned_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, aligned_completions) mb_inputs = all_input_data[mb_start:mb_end] mb_old_logps = all_old_logps[mb_start:mb_end] - mb_advantages = advantages[mb_start:mb_end] + mb_advantages = rollout_advantages[mb_start:mb_end] + mb_pos_mask = all_acc_labels[mb_start:mb_end] + + ref_logps = None + if KL_BETA > 0.0: + ref_outputs = model.forward_only(inputs=mb_inputs, disable_lora=True) + ref_logps = ref_outputs.get('logps') if isinstance(ref_outputs, dict) else getattr(ref_outputs, 'logps', None) model.forward_backward( inputs=mb_inputs, old_logps=mb_old_logps, advantages=mb_advantages, + ref_logps=ref_logps, + positive_mask=mb_pos_mask, micro_batch_size=MICRO_BATCH_SIZE, ) model.clip_grad_and_step() @@ -252,8 +335,29 @@ def main(): log_dict = metrics.calculate() log_dict.update(model.calculate_metric(is_training=True)) + log_dict.update(_logp_split_diagnostics(accuracy_rewards, all_old_logps)) + log_dict['pos_neg_adv_rate'] = pos_with_neg_adv / n_pos if n_pos else 0.0 + log_dict['neg_pos_adv_rate'] = neg_with_pos_adv / n_neg if n_neg else 0.0 + log_dict['adv_max'] = max(rollout_advantages) if rollout_advantages else 0.0 + log_dict['adv_min'] = min(rollout_advantages) if rollout_advantages else 0.0 + + _hk = log_dict.pop('_high_kl_records', None) + if _hk: + _tok = rollout_template.tokenizer + for r in _hk: + gsi = r.get('gsi') + try: + tok_text = _tok.decode([r['token_id']]) + except Exception: + tok_text = None + logger.info( + '[high-kl] step=%d gsi=%s pos=%s tok=%r kl=%.4f r=%.4f lp_new=%.4f lp_old=%.4f', + batch_step, gsi, r.get('pos'), tok_text, + r.get('kl'), r.get('ratio'), r.get('logp_new'), r.get('logp_old')) + + swanlab.log(_coerce_for_swanlab(log_dict), step=batch_step) metrics.reset() - logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') + logger.info(f'[Step {batch_step}/{MAX_STEPS}] {log_dict}') logger.info(f'Training completed. optim_steps={optim_step}') model.save('math-grpo-final') diff --git a/cookbook/sample/sample.py b/cookbook/sample/sample.py index b56460ea..da27efda 100644 --- a/cookbook/sample/sample.py +++ b/cookbook/sample/sample.py @@ -1,108 +1,802 @@ -""" -Standalone inference example using Ray + vLLMSampler with LoRA adapter. +"""使用 Qwen3.5-4B-Condenser LoRA 对三类原始数据进行压缩的示例。 -This script demonstrates how to: -1. Initialize Twinkle with Ray for distributed inference -2. Create a vLLMSampler with LoRA enabled on dedicated GPUs -3. Load a LoRA adapter from a local checkpoint path -4. Send prompts (Trajectory format) and collect generated responses +三个场景: + 1. Python 代码(短) + 2. 长度约 5120 字符的中文新闻文本 + 3. 含混杂字符的网页 HTML 代码 -Usage: - # Single GPU inference - SAMPLER_GPUS=1 python sample.py +除代码外的所有自然语言均为中文。压缩 LoRA 默认指向 ModelScope 上的 +``ms://twinkle-kit/Qwen3.5-4B-Condenser``,即与 ``cookbook/exp/grpo_condensed.py`` +所用 condenser 一致;可通过环境变量 ``LORA_PATH`` 覆盖。 - # Multi-GPU inference (tensor parallel) - SAMPLER_GPUS=2 python sample.py +启动方式:: - # Use a different model / adapter - MODEL_ID=/path/to/model LORA_PATH=/path/to/adapter SAMPLER_GPUS=1 python sample.py + SAMPLER_GPUS=1 python cookbook/sample/sample.py + SAMPLER_GPUS=2 python cookbook/sample/sample.py # 张量并行 """ import os -from typing import List, Dict, Any +from typing import Any, Dict, List import twinkle -from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger from twinkle.data_format import SamplingParams from twinkle.sampler import vLLMSampler logger = get_logger() -MODEL_ID = os.environ.get('MODEL_ID', 'Qwen/Qwen3.5-4B') -LORA_PATH = os.environ.get('LORA_PATH', '/path/to/lora') +# MODEL_ID = os.environ.get('MODEL_ID', 'output/condenser_ddp/step_44000') +MODEL_ID = 'Qwen/Qwen3.5-4B' +LORA_PATH = os.environ.get('LORA_PATH', 'ms://twinkle-kit/Qwen3.5-4B-Condenser') SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 1)) +# ────────────────────────────────────────────────────────────────────── +# Condenser 提示词(与训练时严格对齐,保留英文原文以匹配 LoRA 训练分布) +# ────────────────────────────────────────────────────────────────────── +CONDENSER_SYSTEM = """You are a text compression assistant. A downstream model will read your compressed output to decide whether the detail it needs is inside this block; if yes, it will fetch and read the original passage. + +Downstream model workflow: +Read your compressed output -> Decide whether needed info is in this block -> If yes -> Fetch original. + +Therefore your compression MUST NOT lose major information from the source. + +Output format: + +```text +## Summary +Overview plus facts STRONGLY RELATED to the Query, stated explicitly. + +## More +A collapsed index; expansion required to see specific information. +``` + +Rules: +1. Telegraphic style — drop function words ("the", "a", "is", "are", "of", ...); colons and commas mean "is" / "has". +2. Summary MUST contain the passage's primary topic + 2–4 concrete core facts drawn from the source (entities, numbers, dates, relations). If a Query is given, order Query-relevant facts first, but STILL include other core facts within the budget. A Query is an ORDERING HINT, NOT a filter. +3. Summary MUST NOT be meta-commentary about the Query. Forbidden patterns: "no X mention", "Query info: absent", "passage covers Y only", "does not contain ...", "no relevant info", or summaries that are only abstract category words like "structure/order/usage" with no facts. If the passage is unrelated to the Query, you still summarize the passage normally. +4. More is an INDEX of category keywords, NOT inline data. Enumerate what CAN be recovered from the source (e.g. "birthplace, death place, age"); do NOT paste dates/numbers/names inline. Make sure all category of useful facts are introduced here. +5. Output language MUST match the source language. +6. Do NOT fabricate. Do NOT omit major information. Any fact not in the source MUST NOT appear in your output. + +Now begin. +""" + +CONDENSER_USER = ( + 'Downstream model will read your compressed block to decide whether to ' + 'expand it. Compress faithfully: preserve the passage topic + core facts. ' + 'Do NOT invent facts. Do NOT drop major facts. Do NOT write meta-commentary ' + 'about the Query (never write "Query info: absent", "no X mention", etc.); ' + 'if the passage does not address the Query, still summarize the passage.\n\n' + '## Query (ordering hint only — still summarize the whole passage)\n{query}\n\n' + '## Target length\n' + 'Compress AS MUCH AS faithfully possible. HARD CEILING: {budget} chars ' + '(~50% of the source). If core facts fit in far fewer chars, output fewer. ' + 'Never exceed the ceiling.\n\n' + '## Passage\n{text}') + + +# ────────────────────────────────────────────────────────────────────── +# 场景 1:Python 代码片段(Dijkstra 单源最短路) +# ────────────────────────────────────────────────────────────────────── +PY_QUERY = '这段代码的功能、方法名、出入参是什么?其他人如何调用?' +PY_PASSAGE = '''import heapq +from typing import Dict, List, Tuple + + +def dijkstra(graph: Dict[str, List[Tuple[str, float]]], src: str) -> Dict[str, float]: + """Single-source shortest path on a non-negative weighted graph. + + Args: + graph: adjacency list, ``graph[u] = [(v, w), ...]`` with ``w >= 0``. + src: source node id; must be a key of ``graph``. + + Returns: + Mapping from node id to its shortest distance from ``src``; + unreachable nodes get ``math.inf``. + + Time: O((V + E) log V) via a binary heap. + Space: O(V) for the distance map and the priority queue. + """ + dist: Dict[str, float] = {node: float('inf') for node in graph} + dist[src] = 0.0 + heap: List[Tuple[float, str]] = [(0.0, src)] + visited: set = set() + while heap: + d, u = heapq.heappop(heap) + if u in visited: + continue + visited.add(u) + if d > dist[u]: + continue + for v, w in graph.get(u, []): + if w < 0: + raise ValueError(f'negative weight on edge {u}->{v}: {w}') + alt = d + w + if alt < dist[v]: + dist[v] = alt + heapq.heappush(heap, (alt, v)) + return dist + + +if __name__ == '__main__': + g = { + 'A': [('B', 1.0), ('C', 4.0)], + 'B': [('C', 2.0), ('D', 5.0)], + 'C': [('D', 1.0)], + 'D': [], + } + print(dijkstra(g, 'A')) +''' + + +# ────────────────────────────────────────────────────────────────────── +# 场景 2:长篇中文新闻(约 5120 字符) +# ────────────────────────────────────────────────────────────────────── +NEWS_QUERY = '本次峰会可以学习到什么,总结出什么经验?' +NEWS_PASSAGE = """2026年5月10日上午,为期三天的“全球低空经济创新峰会暨城市级低空示范走廊启用仪式”在合肥滨湖国际会展中心闭幕。会议由国家发展改革委、工业和信息化部、中国民用航空局共同主办,安徽省人民政府承办,吸引了来自三十六个国家和地区的一千二百余名代表,其中包括十七位省部级官员、四十二家飞行器整机企业代表、九十一家产业链上下游企业、二十三家科研院所及七家国际行业协会。会议公布了《低空经济创新发展指数(2026)》、《城市低空运行规则白皮书(试行版)》和《低空安全能力评估通用框架》等三份核心文件,明确将合肥、深圳、苏州、广州、成都、青岛六城列为首批“低空经济综合改革试验区”,并宣布合肥滨湖至庐阳总长四十六公里的“环城低空示范走廊”当日正式投入运行。 + +按照规划,环城低空示范走廊由九条主干航线和十六条支线组成,主干航线最低离地高度一百二十米、最高三百米,支线最低六十米、最高一百八十米;全网部署一百八十二个固定起降点和六十座可移动塔台,覆盖政务、医疗、应急、物流、低空旅游、城市巡检等十类典型业务场景。走廊采用“一张网、两套链、三层防”的运行架构,统一接入安徽省低空运行管理平台,平台部署三百台分布式边缘节点和两套异地灾备数据中心,单日峰值并发架次设计能力为六千架,支持十秒级动态空域调整与三十毫秒级冲突告警。运行首日上午即完成首班医疗血液配送、首班跨区低空通勤、首班高速公路应急救援与首班低空观光飞行等四项标志性任务。 + +国家发展改革委副主任周楠在主旨演讲中表示,低空经济作为我国正在加快培育的战略性新兴产业,2025年市场规模已突破六千八百亿元,年复合增长率连续三年保持在百分之三十二以上;按照《低空经济创新发展指数(2026)》预测,到2030年市场规模有望突破三万亿元,将带动直接就业岗位约二百二十万个、间接就业八百万个。她强调,下一阶段的政策重点将集中在三件事上:一是推动空域分类改革落地,将三百米以下空域使用审批权限下放至省级;二是建立全国统一的低空飞行身份认证体系,以“一码通飞”形式整合飞行器编号、运营资质、保险信息;三是加快建设“低空气象-通信-导航-监视”四张网,2027年前在三十座中心城市完成基础设施全覆盖。 + +中国民用航空局副局长邵岩晖介绍,新版《城市低空运行规则白皮书(试行版)》对运行主体提出了五项硬性要求:飞行器须取得型号合格证或试飞许可、运行人须建立安全管理体系并通过年度审核、机长须持有相应等级有效执照、第三方责任险保额单架不得低于人民币五百万元、城市核心区运行须接入城市低空数据共享平台。白皮书首次明确了无人机与有人机融合运行规则,规定融合空域内电子围栏精度优于五米、上传频率不低于每秒十次、应急避让响应时间不大于八百毫秒。试行版将在合肥、深圳、苏州先行实施六个月,2027年1月起在六个综合改革试验区全面推广。 + +本次峰会期间,共有四十八家整机与零部件企业进行了集中签约,签约总金额三百一十七亿元人民币。其中,亿航智能与合肥市政府就eVTOL航空枢纽建设达成战略合作,未来三年将在合肥落地两座垂直起降中心、一座飞行器维修工厂;峰飞航空V2000CG无人货运飞机宣布与京东物流、顺丰速运组建“低空干线物流联盟”,2026年内开通合肥-武汉、合肥-南京两条三百公里级日常货运航线;时的科技E20 eVTOL正式获得中国民用航空局型号合格审定(TC)受理通知,成为国内第二个进入TC审定阶段的国产载人eVTOL机型;中国电信、中国移动联合发布“低空通信定制网络”,提供基于5G-A的厘米级定位与十毫秒级时延切片服务,首批接入合肥、深圳、苏州三市示范走廊。 + +中国科学技术大学、北京航空航天大学、南京航空航天大学、中国航发湖南动力机械研究所等四家单位联合发布了五项关键技术成果。其中,中科大研制的“星臂II号”分布式电推进系统单机连续可靠工作时长突破六千小时,能量密度达到每千克四百二十瓦时;北航团队公布国内首套适用于城市楼宇间复杂气流环境的“激流-3”自主感知与避障算法,已在合肥CBD连续完成八千架次实飞验证,避障成功率达百分之九十九点九七;南航发布的“穹顶”机载多源融合定位单元在GPS拒止环境下可实现亚米级定位,精度优于现有民用产品三倍。 + +国际合作方面,中欧双方在峰会上签署《低空运行互认合作备忘录》,约定2027年起对各自认证的两吨级以下载人电动飞行器互相承认型号合格证基础部分,争议技术指标通过联合评审解决。中国与阿联酋、新加坡、巴西、德国、日本五国民航主管部门签署双边谅解备忘录,覆盖低空气象数据互通、跨境物流走廊试点、飞行员资质互认三个方向。世界经济论坛代表在致辞中评价,合肥示范走廊是“迄今为止全球规模最大、运营规则最完整的城市级低空融合试验场”。 + +为保障示范走廊安全运行,安徽省专门组建了“低空安全联合运行中心”,由民航华东空管局、安徽省公安厅、应急管理厅、气象局以及合肥市政府五方常态派员,实行7×24小时值守。运行中心配备六十四套全向相控阵雷达、九十二套低空ADS-B接收机和一百二十组光电跟踪设备,可对覆盖空域内大于0.05平方米的目标进行毫秒级追踪;同时部署了五十架次自动巡查无人机和两架有人机巡查直升机,对低慢小目标实行混合编队拦截。运行首日,中心累计处置告警事件二十三起,其中误闯入九起、设备失联六起、超高飞行四起、外部气象突变三起、未授权改航一起,全部在三分钟内完成处置。 + +针对普通市民关心的应用场景,主办方在滨湖国际会展中心外侧搭建了占地约一万二千平方米的“低空生活体验区”。市民可通过现场或“合肥低空”小程序预约低空观光(合肥环城线,单程二十分钟,票价人民币二百九十八元)、低空通勤(滨湖至合肥南站,单程八分钟,票价九十八元)、低空配送(三公里内三十分钟达,订单费十二元)和低空应急医疗演示等四项体验。仪式当日预约平台一上线即满负荷运转,截至当天下午五点,累计提交订单超过一万一千笔,其中观光类占百分之六十四、通勤类占百分之二十二、配送类占百分之十三。 + +投融资方面,峰会同期举行的“低空产业投资人之夜”披露:2025年我国低空领域股权融资总额突破七百八十亿元,同比增长百分之七十二,融资轮次主要集中在A轮至C轮,平均单笔金额一点二亿元;其中飞行器整机、电池电机电控、空管软件三类标的占比分别为百分之四十一、百分之二十三、百分之十八。安徽省产业投资集团联合中国国新基金、中信产业基金、深创投、合肥兴泰金融等十家机构发起设立总规模二百亿元的“低空经济母基金”,首期规模六十亿元,重点投向飞行器适航取证、低空通信导航、城市运营平台与高端材料四个方向,单项目投资上限三亿元,预计三年内完成对外投放。 + +人才培养方面,国家民航局、教育部、人力资源和社会保障部联合发布《低空飞行人员培养行动计划(2026-2030)》,明确到2030年累计培养低空领域专业人才四十万人,其中eVTOL机长六万人、地面运行控制员八万人、无人机系统工程师十二万人、低空气象与运行支持人员六万人、产业链高端研发人员八万人。中国民航大学、中国民用航空飞行学院、合肥工业大学、深圳职业技术大学等十六所院校将于2026年秋季学期同步开设“低空运行与管理”本科专业和“无人飞行器系统工程”研究生方向,前两年招生总规模约六千八百人,并实行校企双导师制。 + +法治保障方面,《中华人民共和国低空空域使用管理条例(草案)》已于4月底完成第二轮社会公开征求意见,预计2026年下半年提交全国人大常委会审议。条例(草案)首次以法律形式确立“分类分级、责任清晰、动态管理”的空域使用原则,明确300米以下非管制空域备案准入、300米至1000米管制空域许可准入;规定运行人对所造成的人身、财产损害承担无过错责任,强制责任险最低保额按机型分为单架人民币三百万、五百万、一千万三档;对违规飞行的行政处罚上限从原《民用航空法》的人民币十万元提高至人民币一百万元,构成犯罪的依法追究刑事责任。 + +民众体验环节中,本报记者亲自试乘了由亿航EH216-S执飞的合肥环城观光航线。从滨湖国际会展中心垂直起降平台起飞,飞行器在二十秒内攀升至一百八十米高度,随后沿环城西线向北巡航,途经合肥南站、政务区、合肥植物园等地标,全程巡航速度九十公里每小时,最高速度一百一十公里每小时,舱内噪音实测六十六分贝、相当于普通会议室水平;地面起降阶段振动幅度小于零点二G,乘坐感受平稳。值得一提的是,飞行器全程由地面无人值守,机舱内仅有四枚乘客座位与一台显示飞行参数和航线进度的十英寸触控屏,乘客可一键切换中文、英文、日文三种语音解说。 + +技术展望部分,多位专家在分论坛中表达共识:未来五年制约低空经济规模化的关键不是飞行器性能,而是“运行密度天花板”——即在城市核心区如何把单位空域、单位时间内的安全飞行架次密度从当前的十架次每平方公里每小时提升到二百架次。中国工程院院士王建宇指出,要突破这一瓶颈必须解决三个核心问题:一是低慢小目标的全天候、全气象、全城域感知;二是冲突探测与解脱算法在高密度场景下的实时性,目标响应时间需压缩到二十毫秒以内;三是空地一体化通信网络的可用性,必须达到5个9(99.999%)的可靠度。预计这些核心技术将在“十四五”末取得阶段性突破,并在“十五五”实现产业化推广。 + +区域协同与产业布局方面,本次峰会同期发布了《中国低空经济产业布局白皮书(2026)》,首次以全国六十三个重点城市为样本,对产业链上下游进行了详尽画像。白皮书揭示,现阶段我国低空产业已初步形成“三极多点”的空间格局:长三角以合肥、南京、苏州、上海、杭州五市为核心,重点发展eVTOL整机、高端重载无人机与运营平台,产业营收占全国百分之三十五;珠三角以深圳、广州为核心,重点发展消费级与商业级无人机、低空物流,产业营收占百分之二十九;成渝地区以成都、重庆为核心,重点发展航空发动机、错复材料与航电系统,占百分之十三;其余千亿级“多点”包括青岛、沈阳、西安、武汉、长沙五个区域中心。白皮书同时提示,中西部低空产业发展仍存在不平衡问题,需重点加强边境州、边境县及山区、草原地区的低空基础设施补齐,预计这些区域有3200多个乡镇需新增低空起降点。为此,发改委将联合农业农村部、国务院应急管理部启动专项补助,中央财政三年内安排专项资金一百五十亿元。 + +闭幕仪式上,安徽省政府宣布将在2026年内追加投资八十亿元用于二期走廊建设,二期工程将向北延伸至淮南、向西延伸至六安,总长度从四十六公里扩展到一百九十六公里,预计2027年底前贯通。下一步合肥还将牵头编制《长三角低空一体化运行总体方案》,推动沪苏浙皖四省市在2028年前实现“一码通飞、一卡结算、一平台调度”的跨省域低空运行体系。中国民用航空局表示,将在2026年第三季度发布《国家低空经济发展中长期规划(2026-2035)》,明确未来十年的总体目标、重点任务、保障措施与考核机制。本次峰会的全部技术文件、签约项目清单和示范走廊运行实时数据将在“中国低空经济服务网”同步公开。据主办方最后公布的统计,为期三天的峰会共举办主论坛一场、高端对话三场、专题分论坛二十三场、企业发布会十六场,现场展示飞行器与装备总计一百八十八架(套),其中eVTOL三十二架、重载无人货运机二十七架、高端作业型无人机六十五架、空管与低空助航装备五十四套;累计进场观众超过二十三万人次,现场达成预订订单十二万五千余笔,为合肥市仪式期间酒店入住率、餐饮营业额分别带来同比百分之四十二与百分之三十六的增长。与会代表普遍评价,本届峰会首次实现了“会议、试点、产业、民生”四者同场并进,将原本存在于不同部门、不同会议、不同后续跨年的不同工作压缩为一个集中阶段,明显提高了这轮低空经济发展的政策准备度与社会可见度。据估计,在后续三个月内,首批六个“低空经济综合改革试验区”均将起步运行走廊与试点业务,预计2027年上半年可被看到首轮可复制、可推广的运行经验与产业样本。 +""" + + +# ────────────────────────────────────────────────────────────────────── +# 场景 3:含混杂字符的网页 HTML(电商商品详情页) +# ────────────────────────────────────────────────────────────────────── +HTML_QUERY = '这段html代码的结构如何?如何使用js如何对接?' +HTML_PASSAGE = """<!DOCTYPE html> +<html lang="zh-CN" data-spm="product-detail"> +<head> + <meta charset="UTF-8"> + <meta name="viewport" content="width=device-width,initial-scale=1.0"> + <title>云端Air Pro 13 笔记本电脑(2026款)| TechMart 数码旗舰店 + + + + + + + + + +
+

云端Air Pro 13 2026款 · 全网首发

+

品牌:云端 (Yunduan) | 型号:YDA-PRO-13-2026 | 颜色:星空银 / 深空灰 / 沙漠金

+
+ ¥12,499.00 + ¥9,999.00 + 立省 ¥2,500,限时48小时 +
+
    +
  • 处理器:自研 M3-Pro,14核CPU @ 3.6GHz / 18核GPU / 16核NPU(35 TOPS)
  • +
  • 内存:18GB LPDDR5X-7500(统一内存架构)
  • +
  • 存储:512GB / 1TB / 2TB NVMe SSD(最高读取 7,400 MB/s)
  • +
  • 屏幕:13.6″ Liquid Retina,2,560×1,664,600 nits 峰值,DCI-P3 100%
  • +
  • 电池:72Wh,续航最长 18h(本地视频播放)
  • +
  • 重量:1.24 kg | 厚度:11.3 mm
  • +
  • 接口:2× Thunderbolt 5、1× HDMI 2.1、1× MagSafe 3、1× 3.5mm 耳机
  • +
  • 无线:Wi-Fi 7(802.11be)、蓝牙 5.4、UWB
  • +
+

赠品(前 100 名下单):原装 65W GaN 电源、Type-C→HDMI 2.1 转换线、防泼溅键盘膜、AppleCare+ 1 年延保。

+
!! 限时优惠:叠加云端校园券 ¥500 + 以旧换新最高补贴 ¥1,200 !!
+
+

常见问题

+

Q: 是否支持 Windows 11 ARM 双系统? A: 不支持,但可通过 Parallels Desktop 19 虚拟运行。

+

Q: 发货时效? A: 现货 24 小时内发出,安徽/江苏/浙江/上海次日达。

+
+ +
+
© 2026 TechMart Inc. 沪ICP备2021xxxx号 · 京公网安备 31010102xxxxxx号
+ + +""" + + +# ────────────────────────────────────────────────────────────────────── +# 场景 4:复杂异常处理 Python 代码(支付下单处理器) +# ────────────────────────────────────────────────────────────────────── +# 故意混入多种异常处理风格:自定义异常树、链式抛出、bare except 反模式、 +# 资源未关闭、重试与回退、上下文管理器、suppress、finally 重写返回值等。 +EXCEPTIONS_QUERY = ( + '这段支付下单代码的异常处理设计了哪些模式、踩了哪些反模式坑、' + '可以总结出哪些最佳实践和教训?') +EXCEPTIONS_PASSAGE = '''import json +import logging +import time +from contextlib import suppress +from typing import Optional + +import requests + +logger = logging.getLogger(__name__) + + +# ---- Domain exception hierarchy ---- +class PaymentError(Exception): + """Base for all payment-domain errors.""" + + +class TransientPaymentError(PaymentError): + """Retryable: timeout, 5xx, network flap.""" + + +class PermanentPaymentError(PaymentError): + """Non-retryable: 4xx, invalid card, fraud block.""" + + +class IdempotencyConflict(PermanentPaymentError): + """Idempotency-Key reused with a different request body.""" + + +class OrderRepository: + def __init__(self, conn): + self.conn = conn # NOTE: caller owns the connection lifetime. + + def begin(self): + self.conn.execute('BEGIN') + + def commit(self): + self.conn.execute('COMMIT') + + def rollback(self): + # Anti-pattern guard: swallow rollback errors to not mask the original. + with suppress(Exception): + self.conn.execute('ROLLBACK') + + def mark_paid(self, order_id: str, txn_id: str): + self.conn.execute( + 'UPDATE orders SET status=?, txn_id=? WHERE id=?', + ('PAID', txn_id, order_id)) + + +def _call_gateway(url: str, body: dict, idem_key: str, timeout: float = 3.0): + """Single HTTP call. Translates transport errors into the domain hierarchy.""" + try: + resp = requests.post( + url, json=body, timeout=timeout, + headers={'Idempotency-Key': idem_key}) + except requests.Timeout as e: + raise TransientPaymentError(f'gateway timeout: {url}') from e + except requests.ConnectionError as e: + raise TransientPaymentError(f'gateway unreachable: {url}') from e + + if 500 <= resp.status_code < 600: + raise TransientPaymentError(f'gateway 5xx: {resp.status_code}') + if resp.status_code == 409: + # Same key, different body — caller bug, never retry. + raise IdempotencyConflict(f'idem-key reused: {idem_key}') + if 400 <= resp.status_code < 500: + raise PermanentPaymentError( + f'gateway 4xx: {resp.status_code} body={resp.text[:200]}') + + try: + return resp.json() + except json.JSONDecodeError as e: + # Server claimed 2xx but body is junk; treat as transient — gateway bug. + raise TransientPaymentError('gateway returned non-JSON 2xx') from e + + +def charge_with_retry(url: str, body: dict, idem_key: str, + max_attempts: int = 4) -> dict: + """Exponential backoff. ONLY retries TransientPaymentError.""" + last_exc: Optional[BaseException] = None + for attempt in range(1, max_attempts + 1): + try: + return _call_gateway(url, body, idem_key) + except TransientPaymentError as e: + last_exc = e + sleep_s = min(2 ** (attempt - 1), 8) + logger.warning('charge attempt %d/%d failed: %s; sleep %ss', + attempt, max_attempts, e, sleep_s) + time.sleep(sleep_s) + # PermanentPaymentError intentionally propagates immediately. + assert last_exc is not None # for type-checker; loop guarantees this. + raise TransientPaymentError( + f'exhausted {max_attempts} attempts') from last_exc + + +def place_order(repo: OrderRepository, order_id: str, body: dict, + gateway_url: str) -> bool: + """End-to-end order placement. Returns True on success. + + Lessons embedded in this body: + - Idempotency-Key is derived from order_id (NOT a random uuid per attempt) + so retries hit the gateway as the same logical request. + - Catch broad exceptions ONLY at the outermost trust boundary, never + inside the loop. + - The bare-except below (legacy debugger pattern) IS A BUG — it suppresses + KeyboardInterrupt and SystemExit; left here intentionally to be flagged. + """ + idem_key = f'order:{order_id}' + repo.begin() + try: + receipt = charge_with_retry(gateway_url, body, idem_key) + repo.mark_paid(order_id, receipt['txn_id']) + repo.commit() + return True + except IdempotencyConflict: + # Loud: indicates a programming error upstream. + repo.rollback() + logger.exception('idempotency conflict on %s', order_id) + raise + except PermanentPaymentError as e: + # Expected business failure: rollback and surface a typed error. + repo.rollback() + logger.warning('order %s rejected by gateway: %s', order_id, e) + return False + except TransientPaymentError as e: + # Retries already exhausted; do not swallow. + repo.rollback() + logger.error('order %s transient failure: %s', order_id, e) + raise + except: # noqa: E722 -- ANTI-PATTERN, intentionally left for review. + # Catches KeyboardInterrupt / SystemExit / MemoryError too. Bad. + repo.rollback() + logger.exception('unexpected failure on %s', order_id) + return False + finally: + # Anti-pattern: returning from finally would swallow exceptions; we DO NOT + # return here. Only release locks / log timing. + logger.debug('place_order(%s) finished', order_id) +''' + + +# ────────────────────────────────────────────────────────────────────── +# 场景 5:混合服务日志(正常 / 不规则 / 异常 三类掺杂) +# ────────────────────────────────────────────────────────────────────── +# 目标:考察压缩模型能否在大量噪声中突出真正的故障信号。 +# Summary 应聚焦异常(ERROR/FATAL/堆栈),常规心跳/健康检查应被压成索引词。 +LOGS_QUERY = ( + '这堆服务日志里发生了哪些独立故障?要求把每一条 ERROR/FATAL/异常都作为独立条目列出来,' + '附上其专属标识(订单号 ORD-xxx / 退款键 refund:xxx / Pod 名 / 主机名 / 证书 CN / ' + 'PagerDuty 单号 / Kafka topic+partition / trace_id 等),同名不同实例必须分别列出,不得合并;' + '正常心跳和健康检查只需在末尾用一两句索引带过,不要展开。') +LOGS_PASSAGE = '''2026-05-28T03:14:00.001Z INFO [api-gw-7f9c] heartbeat ok rss=412MB cpu=3.1% +2026-05-28T03:14:00.118Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:00.402Z INFO [order-svc-12] POST /orders 200 27ms user=u_88231 amount=199.00 +2026-05-28T03:14:00.512Z INFO [api-gw-7f9c] GET /v1/products?cat=3 200 9ms +2026-05-28T03:14:00.690Z DEBUG [cache-3a1] redis.get key=sess:9ab miss=false ttl=512s +2026-05-28T03:14:00.731Z INFO [order-svc-09] POST /orders 200 18ms user=u_88232 amount=12.49 +May 28 03:14:00 host-edge-03 kernel: [13929847.221] TCP: request_sock_TCP: Possible SYN flooding on port 443. Sending cookies. +2026-05-28T03:14:00.812Z INFO [search-svc] query took=7ms hits=104 q="鼠标" +{"ts":"2026-05-28T03:14:00.901Z","lvl":"info","svc":"recom","msg":"model v17 served","qps":3120,"p99_ms":42} +2026-05-28T03:14:01.005Z DEBUG [cache-3a1] redis.get key=sess:abc miss=false ttl=287s +2026-05-28T03:14:01.122Z INFO [api-gw-7f9c] GET /v1/products?cat=12 200 14ms +172.18.4.21 - - [28/May/2026:03:14:01 +0000] "GET /static/app.js HTTP/1.1" 200 84217 "-" "Mozilla/5.0" rt=0.003 +172.18.4.22 - - [28/May/2026:03:14:01 +0000] "GET /static/app.css HTTP/1.1" 304 0 "-" "Mozilla/5.0" rt=0.001 +2026-05-28T03:14:01.480Z DEBUG [order-svc-12] cart.compute total=199.00 items=3 user=u_88231 promo=SPRING10 +2026-05-28T03:14:01.611Z INFO [audit] write event=login user=u_88245 ip=203.0.113.44 ua=ios/9.2.1 +2026-05-28T03:14:01.799Z WARN [order-svc-12] payment.gateway latency=812ms threshold=500ms attempt=1/3 +2026-05-28T03:14:01.901Z WARN [payment-svc] retry-budget remain=87/100 window=60s +03:14:02 (??) [????] partial frame: \x7f\x45\x4c\x46\x02\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00 ... <2174 bytes dropped, recovery=skip> +[bgworker] !! 一中一英插错位 !! sync orders snapshot 开始, shard=7 records=124882 +2026-05-28T03:14:02.044Z INFO [order-svc-12] POST /orders 200 31ms user=u_88245 amount=49.90 +2026-05-28T03:14:02.211Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:02.310Z WARN [payment-svc] gateway timeout retry_after=2s url=https://pay.acme.io/charge idem=order:ORD-44871 +2026-05-28T03:14:02.402Z DEBUG [search-svc] parsed query bm25_terms=[laptop,gaming] filters={"price":[null,2000]} +2026-05-28T03:14:02.510Z INFO [recom] shard-warmup done shard=5 took=88ms +2026-05-28T03:14:02.612Z ERROR [payment-svc] charge failed: TransientPaymentError: gateway timeout: https://pay.acme.io/charge + Traceback (most recent call last): + File "payments/client.py", line 88, in _call_gateway + resp = requests.post(url, json=body, timeout=3.0, ...) + File "requests/api.py", line 115, in post + raise Timeout("HTTPSConnectionPool: Read timed out.") + requests.exceptions.Timeout: HTTPSConnectionPool(host=\'pay.acme.io\', port=443): Read timed out. (read timeout=3.0) +2026-05-28T03:14:02.613Z INFO [payment-svc] retry attempt=2/3 backoff=2s key=order:ORD-44871 +2026-05-28T03:14:02.701Z INFO [api-gw-7f9c] GET /v1/products/9001 200 11ms +2026-05-28T03:14:02.812Z INFO [order-svc-09] POST /orders 200 22ms user=u_88251 amount=320.00 +2026-05-28T03:14:02.901Z DEBUG [cache-3a1] redis.scan cursor=0 count=200 match=sess:* +2026-05-28T03:14:03.000Z INFO [api-gw-7f9c] heartbeat ok rss=414MB cpu=3.4% +2026-05-28T03:14:03.044Z INFO [order-svc-12] GET /orders/ORD-44870 200 4ms +???GARBLED??? ½üÏí hé shì ?? mq.consumer offset=49281234 lag=??? 中间被截断 +May 28 03:14:03 host-db-master postgres[2174]: LOG: checkpoint starting: time +2026-05-28T03:14:03.244Z INFO [recom] a/b experiment exp_id=ex-991 traffic=10% bucket=v17 +[2026/05/28 03:14:03.488] log格式错位 -- 商品库存同步开始 batch=512 +2026-05-28T03:14:03.601Z INFO [order-svc-12] POST /orders 200 19ms user=u_88260 amount=4.99 +2026-05-28T03:14:03.812Z INFO [api-gw-7f9c] GET /v1/products?cat=8 200 7ms +2026-05-28T03:14:03.901Z DEBUG [cache-3a1] redis.get key=sess:def miss=false ttl=600s +2026-05-28T03:14:04.117Z INFO [inventory] sync done batch=512 ok=512 fail=0 took=629ms +2026-05-28T03:14:04.230Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:04.401Z INFO [order-svc-12] POST /orders 200 24ms user=u_88263 amount=88.00 +2026-05-28T03:14:04.602Z ERROR [payment-svc] charge failed: PermanentPaymentError: gateway 4xx: 402 body={"code":"INSUFFICIENT_FUNDS","order":"ORD-44871","user":"u_88251","reason":"card balance below required amount","trace_id":"pgw-c1a2b3d4"} +2026-05-28T03:14:04.604Z WARN [order-svc-12] order ORD-44871 rejected, status=PAYMENT_FAILED user=u_88251 +2026-05-28T03:14:04.707Z INFO [audit] write event=order.rejected order=ORD-44871 reason=insufficient_funds +2026-05-28T03:14:04.811Z INFO [recom] model v17 served qps=3140 p99_ms=44 +2026-05-28T03:14:04.901Z INFO [api-gw-7f9c] POST /v1/cart 200 12ms user=u_88277 +2026-05-28T03:14:05.001Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:05.103Z DEBUG [search-svc] parsed query bm25_terms=[laptop] filters={} +2026-05-28T03:14:05.220Z INFO [order-svc-09] POST /orders 200 16ms user=u_88281 amount=33.30 +2026-05-28T03:14:05.330Z INFO [search-svc] query took=12ms hits=37 q="laptop" +2026-05-28T03:14:05.501Z INFO [api-gw-7f9c] GET /v1/products?cat=5 200 9ms +2026-05-28T03:14:05.612Z DEBUG [cache-3a1] redis.del key=sess:xyz removed=1 +2026-05-28T03:14:05.778Z DEBUG [cache-3a1] redis.set key=sess:def ttl=600s size=384B +>>>>> RAW FRAME @ tcp://10.0.6.4:9092 - kafka consumer rebalance event - members=[c1,c2,c3,c4] generation=812 +2026-05-28T03:14:05.901Z INFO [audit] write event=cart.update user=u_88290 items=4 +2026-05-28T03:14:06.012Z INFO [order-svc-12] POST /orders 200 22ms user=u_88260 amount=12.00 +=== unstructured dump @ 03:14:06.4 === conn_pool active=48/64 idle=11 waiting=5 (warn>40) db=orders_master +2026-05-28T03:14:06.412Z WARN [db-pool] pool nearly exhausted: active=48 max=64 waiting=5 db=orders_master +2026-05-28T03:14:06.500Z WARN [db-pool] slow query detected took=812ms sql="UPDATE orders SET status=$1 WHERE shard=$2 AND ts<$3" rows=4188 +2026-05-28T03:14:06.612Z INFO [recom] shard-warmup done shard=6 took=92ms +2026-05-28T03:14:06.711Z INFO [search-svc] query took=8ms hits=21 q="keyboard" +2026-05-28T03:14:06.901Z DEBUG [cache-3a1] redis.get key=sess:0a1 miss=true ttl=0s +2026-05-28T03:14:07.001Z INFO [api-gw-7f9c] heartbeat ok rss=418MB cpu=4.0% +2026-05-28T03:14:07.118Z INFO [order-svc-12] POST /orders 200 28ms user=u_88277 amount=88.50 +2026-05-28T03:14:07.232Z INFO [api-gw-7f9c] GET /v1/orders/ORD-44871 200 5ms +2026-05-28T03:14:07.402Z WARN [order-svc-09] retry payment user=u_88251 attempt=2 idem=order:ORD-44871 +2026-05-28T03:14:07.612Z ERROR [order-svc-09] retry blocked: idempotency-conflict, charge already permanently failed +2026-05-28T03:14:07.910Z ERROR [order-svc-12] DB write failed: psycopg2.errors.UniqueViolation: duplicate key value violates unique constraint "orders_pkey" + DETAIL: Key (id)=(ORD-44875) already exists. + CONTEXT: COPY orders, line 1 + Traceback (most recent call last): + File "order/repo.py", line 142, in insert + cur.execute(SQL_INSERT, payload) + File "psycopg2/cursor.py", line 234, in execute + self._execute_impl(query, vars) + psycopg2.errors.UniqueViolation: duplicate key value violates unique constraint "orders_pkey" + -- query: INSERT INTO orders(id, user_id, amount, status, shard, ts) VALUES ($1,$2,$3,$4,$5,now()) + -- params: ('ORD-44875', 'u_88278', 88.50, 'NEW', 7) +2026-05-28T03:14:07.912Z ERROR [order-svc-12] POST /orders 500 187ms user=u_88278 trace_id=4f1c9a2b +2026-05-28T03:14:07.998Z INFO [audit] write event=order.failed order=ORD-44875 reason=duplicate_key +???? 5月28日 03:14:08 中间件告警: db-pool active=51 (中文告警) ?? 原始编码 GBK? +2026-05-28T03:14:08.044Z INFO [search-svc] query took=14ms hits=66 q="mouse" +2026-05-28T03:14:08.117Z INFO [api-gw-7f9c] POST /v1/login 200 33ms user=u_88290 +2026-05-28T03:14:08.232Z DEBUG [cache-3a1] redis.expire key=sess:5cd ttl=900s ok=1 +2026-05-28T03:14:08.330Z INFO [api-gw-7f9c] GET /v1/products/12 200 6ms +2026-05-28T03:14:08.444Z WARN [k8s-probe] livenessProbe pod=worker-22 statusCode=200 took=18ms (slow_threshold=10ms) +2026-05-28T03:14:08.601Z INFO [recom] model v17 served qps=3155 p99_ms=46 +2026-05-28T03:14:08.722Z DEBUG [order-svc-12] cart.compute total=88.50 items=2 user=u_88278 promo=- +2026-05-28T03:14:08.811Z INFO [api-gw-7f9c] GET /v1/products?cat=2 200 8ms +2026-05-28T03:14:08.991Z WARN [jvm-gc] G1 Old Gen pause=412ms heap_after=6.8G/8G # crosses 80% headroom, full-gc=0 +2026-05-28T03:14:09.001Z INFO [api-gw-7f9c] heartbeat ok rss=421MB cpu=4.2% +2026-05-28T03:14:09.122Z WARN [jvm-gc] G1 Old Gen pause=508ms heap_after=7.2G/8G # heap pressure increasing +????-??-??T??:??:??.???Z ??? [????] (timestamp parse failed) raw="flush queue depth=131072 lag=4.2s svc=worker-22" +2026-05-28T03:14:09.401Z WARN [worker-22] allocation slow: requested=512MB available=178MB triggering full GC +2026-05-28T03:14:09.601Z WARN [worker-22] full GC initiated: heap=7.8G/8G live=7.6G +2026-05-28T03:14:09.700Z FATAL [worker-22] java.lang.OutOfMemoryError: Java heap space + at com.acme.order.Aggregator.fold(Aggregator.java:88) + at com.acme.order.Aggregator.fold(Aggregator.java:71) + at com.acme.order.Aggregator.run(Aggregator.java:42) + at com.acme.metric.WindowSink.flush(WindowSink.java:204) + at com.acme.metric.WindowSink$Worker.runOnce(WindowSink.java:158) + at com.acme.metric.WindowSink$Worker.run(WindowSink.java:121) + at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) + at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) + at java.base/java.lang.Thread.run(Thread.java:840) + Suppressed: java.lang.OutOfMemoryError: GC overhead limit exceeded + at com.acme.metric.RollupBuffer.append(RollupBuffer.java:312) + ... 8 more +2026-05-28T03:14:09.702Z FATAL [worker-22] process exiting, dumping heap to /var/log/heap/worker-22-1748400849.hprof (size≈6.7GB) +May 28 03:14:09 host-edge-03 kernel: [13929856.881] Out of memory: Killed process 39087 (java) total-vm:9421248kB, anon-rss:7842316kB, file-rss:412kB, shmem-rss:0kB, UID:1000 pgtables:16448kB oom_score_adj:0 +2026-05-28T03:14:10.001Z ERROR [supervisor] child worker-22 exited code=137 (OOMKilled) restart_in=5s +2026-05-28T03:14:10.122Z WARN [api-gw-7f9c] upstream worker-22 marked DOWN, routing to worker-{19,20,21,23} +2026-05-28T03:14:10.220Z INFO [k8s-probe] pod worker-22 phase=Failed reason=OOMKilled lastTerm="node memory pressure" +2026-05-28T03:14:10.301Z WARN [k8s-probe] node node-edge-03 conditions: MemoryPressure=true (since 2026-05-28T03:14:08Z) +2026-05-28T03:14:10.422Z INFO [audit] write event=worker.crashed worker=worker-22 reason=oom +2026-05-28T03:14:10.611Z INFO [recom] shard-rebalance triggered cause=worker-22-down shards-moved=2 +2026-05-28T03:14:10.812Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:11.000Z INFO [api-gw-7f9c] heartbeat ok rss=423MB cpu=4.5% +2026-05-28T03:14:11.118Z INFO [order-svc-12] POST /orders 200 33ms user=u_88290 amount=320.00 +2026-05-28T03:14:11.244Z DEBUG [cache-3a1] redis.get key=sess:7ef miss=false ttl=120s +2026-05-28T03:14:11.401Z INFO [order-svc-09] POST /orders 200 19ms user=u_88299 amount=7.20 +2026-05-28T03:14:11.512Z INFO [api-gw-7f9c] POST /v1/cart 200 11ms user=u_88301 +=========================== TRUNCATED LOG SECTION (~14KB removed: 217 routine entries, 0 ERROR/FATAL) =========================== +2026-05-28T03:14:11.660Z INFO [search-svc] query took=9ms hits=12 q="keyboard" +2026-05-28T03:14:11.802Z DEBUG [order-svc-12] cart.compute total=320.00 items=5 user=u_88290 promo=VIP20 +2026-05-28T03:14:11.901Z INFO [recom] model v17 served qps=3120 p99_ms=43 +2026-05-28T03:14:12.001Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:12.122Z INFO [api-gw-7f9c] GET /v1/products?cat=1 200 7ms +2026-05-28T03:14:12.244Z WARN [auth-svc] kafka producer in-flight=80/100 approaching limit +2026-05-28T03:14:12.402Z ERROR [auth-svc] kafka publish failed topic=user.login partition=3 err=NotLeaderForPartitionException broker=broker-2:9092 retry=1/5 + java.util.concurrent.ExecutionException: org.apache.kafka.common.errors.NotLeaderForPartitionException: This server is not the leader for that topic-partition. + at org.apache.kafka.clients.producer.internals.FutureRecordMetadata.valueOrError(FutureRecordMetadata.java:101) + at org.apache.kafka.clients.producer.KafkaProducer$FutureFailure.(KafkaProducer.java:1356) +2026-05-28T03:14:12.404Z WARN [auth-svc] fallback to broker-1:9092, in-flight=84 will be retried +2026-05-28T03:14:12.522Z INFO [auth-svc] metadata refresh ok partitions=24 leaders={0:b1, 1:b1, 2:b3, 3:b1, ...} +2026-05-28T03:14:12.611Z INFO [auth-svc] kafka publish ok topic=user.login partition=3 offset=49281999 took=43ms +2026-05-28T03:14:12.722Z INFO [audit] write event=login user=u_88310 ip=198.51.100.7 ua=android/12.0 +2026-05-28T03:14:12.901Z DEBUG [cache-3a1] redis.get key=sess:bbf miss=true ttl=0s +2026-05-28T03:14:13.001Z INFO [api-gw-7f9c] heartbeat ok rss=419MB cpu=4.0% +2026-05-28T03:14:13.140Z INFO [supervisor] spawn child cmd="/opt/acme/bin/worker --id=22" cwd=/opt/acme env_count=87 +2026-05-28T03:14:13.220Z INFO [order-svc-12] POST /orders 200 21ms user=u_88311 amount=6.50 +2026-05-28T03:14:13.401Z WARN [tls] certificate expiring soon cn=*.acme.io days_remaining=11 +2026-05-28T03:14:13.500Z INFO [supervisor] worker-22 restarted pid=39112 took=3.4s +2026-05-28T03:14:13.612Z INFO [k8s-probe] pod worker-22 phase=Running readiness=true +2026-05-28T03:14:13.722Z INFO [recom] shard-rebalance done shards-moved=2 took=2.9s +2026-05-28T03:14:13.811Z INFO [api-gw-7f9c] upstream worker-22 marked UP, weight=0.2 (warm-up) +2026-05-28T03:14:13.902Z INFO [search-svc] query took=10ms hits=4 q="mechanical keyboard rgb" +2026-05-28T03:14:14.001Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:14.118Z INFO [order-svc-12] POST /orders 200 26ms user=u_88301 amount=15.00 +2026-05-28T03:14:14.244Z DEBUG [cache-3a1] redis.set key=sess:ccc ttl=600s size=412B +2026-05-28T03:14:14.401Z INFO [audit] write event=order.created order=ORD-44888 user=u_88301 +2026-05-28T03:14:14.512Z WARN [s3-uploader] partial upload: file=invoice/INV-44888.pdf parts=3/5 retrying +2026-05-28T03:14:14.602Z INFO [payment-svc] circuit-breaker pay.acme.io state=HALF_OPEN probes=1 +2026-05-28T03:14:14.701Z INFO [payment-svc] probe ok latency=121ms status=200 +2026-05-28T03:14:14.812Z INFO [s3-uploader] upload complete file=invoice/INV-44888.pdf parts=5/5 took=298ms +2026-05-28T03:14:14.901Z INFO [payment-svc] probe ok latency=98ms status=200 +2026-05-28T03:14:15.001Z INFO [api-gw-7f9c] heartbeat ok rss=420MB cpu=3.9% +2026-05-28T03:14:15.122Z INFO [payment-svc] circuit-breaker pay.acme.io state=CLOSED probes_ok=3/3 reset +2026-05-28T03:14:15.244Z INFO [order-svc-09] POST /orders 200 17ms user=u_88322 amount=58.80 +May 28 03:14:15 host-db-master postgres[2174]: LOG: checkpoint complete: wrote 4188 buffers (10.2%); 0 WAL file(s) added, 0 removed, 0 recycled; write=12.108 s, sync=0.044 s, total=12.169 s +2026-05-28T03:14:15.401Z WARN [k8s-probe] node node-edge-03 conditions: MemoryPressure=false (recovered) +2026-05-28T03:14:15.512Z INFO [recom] model v17 served qps=3105 p99_ms=41 +2026-05-28T03:14:15.611Z INFO [audit] write event=worker.restored worker=worker-22 took=5.4s +2026-05-28T03:14:15.722Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:15.811Z INFO [order-svc-12] POST /orders 200 23ms user=u_88330 amount=42.10 +2026-05-28T03:14:15.901Z DEBUG [cache-3a1] redis.get key=sess:9ab miss=false ttl=480s +--- begin frontend / edge / cdn segment --- +2026-05-28T03:14:16.012Z INFO [cdn-edge-hkg] HIT https://static.acme.io/app.abc123.js status=200 bytes=84217 cache=HIT pop=hkg1 colo=HKG client_ip=203.0.113.55 +2026-05-28T03:14:16.044Z INFO [cdn-edge-hkg] MISS https://static.acme.io/chunk-44ad.js status=404 bytes=0 cache=MISS pop=hkg1 origin_status=404 client_ip=203.0.113.55 referer=https://shop.acme.io/p/yda-pro-13-2026 +[browser/chrome 124] [Console] GET https://static.acme.io/chunk-44ad.js net::ERR_ABORTED 404 (Not Found) +[browser/chrome 124] [Console] ChunkLoadError: Loading chunk 44ad failed. + (error: Error: Loading chunk 44ad failed. + at HTMLScriptElement.l (https://static.acme.io/app.abc123.js:1:14082) + at Object.next (https://static.acme.io/app.abc123.js:1:9011) + at https://static.acme.io/app.abc123.js:1:9119) + (missing: https://static.acme.io/chunk-44ad.js) + trigger: page=https://shop.acme.io/p/yda-pro-13-2026 user-agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 14_4_1) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.4.1 Safari/605.1.15" +2026-05-28T03:14:16.118Z INFO [api-gw-7f9c] GET /v1/products?cat=4 200 8ms +[browser/chrome 124] [Warning] [Violation] Forced reflow while executing JavaScript took 51ms +[browser/chrome 124] [Warning] ResizeObserver loop completed with undelivered notifications. (occurs 17x in last 200ms) +{"csp-report":{"document-uri":"https://shop.acme.io/p/yda-pro-13-2026","referrer":"","violated-directive":"script-src-elem","effective-directive":"script-src-elem","original-policy":"default-src 'self'; script-src 'self' https://*.acme.io; img-src 'self' data: https://*.acmecdn.net; report-uri /csp-report","disposition":"enforce","blocked-uri":"https://evil-tracker.example.cn/beacon.js","line-number":1,"column-number":1,"source-file":"https://shop.acme.io/p/yda-pro-13-2026","status-code":200,"script-sample":""}} +2026-05-28T03:14:16.260Z WARN [waf-edge] blocked rule="OWASP CRS 941100 XSS" client_ip=198.51.100.91 url=https://shop.acme.io/search?q=%3Cscript%3Ealert(1)%3C/script%3E action=block ja3=t13d1516h2_8daaf6152771_b186095e22b6 +2026-05-28T03:14:16.330Z INFO [cdn-edge-hkg] HIT https://static.acme.io/app.abc123.css status=200 bytes=18234 cache=HIT pop=hkg1 +<<>> \x00\x00\x00\x10heartbeat\x00ack\x00 ... (binary noise, length=4096) +2026-05-28T03:14:16.401Z ERROR [frontend-bff] React hydration mismatch: server rendered "$199.00" client computed "$249.00" + component: + Stack: at ProductPrice (https://static.acme.io/chunks/product.js:1:18840) + at section (https://static.acme.io/chunks/product.js:1:18012) + at ProductPage (https://static.acme.io/chunks/product.js:1:23104) + at Suspense (https://static.acme.io/chunks/react-dom.production.min.js:21:4128) + cause: stale ISR snapshot served while promo SPRING10 toggled to VIP20 + page: /p/yda-pro-13-2026 rev_id=2026052800917 +2026-05-28T03:14:16.502Z WARN [frontend-bff] cache-revalidate triggered key=page:/p/yda-pro-13-2026 reason=hydration-mismatch +May 28 03:14:16 host-edge-03 nginx[8120]: 198.51.100.7 - - [28/May/2026:03:14:16 +0000] "GET /static/sourcemaps/app.abc123.js.map HTTP/1.1" 404 153 "-" "Mozilla/5.0" +May 28 03:14:16 host-edge-03 nginx[8120]: 198.51.100.7 - - [28/May/2026:03:14:16 +0000] "GET /assets/logo-v3.svg HTTP/1.1" 200 4218 "-" "Mozilla/5.0" +2026-05-28T03:14:16.701Z INFO [api-gw-7f9c] heartbeat ok rss=425MB cpu=4.6% +2026-05-28T03:14:16.812Z DEBUG [cache-3a1] redis.mget keys=12 hit=11 miss=1 +2026-05-28T03:14:16.901Z INFO [search-svc] query took=11ms hits=58 q="耳机" +500 Internal Server Error

500 Internal Server Error


nginx/1.25.4
--(sourced from upstream POST /v1/recommend captured by edge probe)-- +2026-05-28T03:14:17.001Z ERROR [recom] POST /v1/recommend 500 12ms user=u_88301 trace_id=8e4c1a7d cause="ConnectionResetError(104, 'Connection reset by peer')" + Traceback (most recent call last): + File "recom/server.py", line 211, in handle + feats = self.feat_store.fetch(user_id) + File "recom/feat_store.py", line 88, in fetch + resp = self._sess.post(f'{self.url}/batch', json={'uid':user_id}, timeout=0.8) + ConnectionResetError: [Errno 104] Connection reset by peer + http_url: http://feat-store-internal.acme.local:7780/batch + upstream_pod: feat-store-7d8c9b9b9c-xq2pl ip=10.42.4.219 zone=us-east-1b +2026-05-28T03:14:17.044Z WARN [api-gw-7f9c] upstream 502 from recom POST /v1/recommend latency=13ms client_ip=203.0.113.55 user=u_88301 -> served stale fallback +2026-05-28T03:14:17.118Z INFO [order-svc-12] POST /orders 200 22ms user=u_88340 amount=199.00 +2026-05-28T03:14:17.220Z INFO [audit] write event=feed.stale-served reason=recom-5xx user=u_88301 +{"ts":"2026-05-28T03:14:17.330Z","lvl":"warn","svc":"sentry-relay","event":{"type":"transaction","transaction":"GET /p/[sku]","contexts":{"trace":{"trace_id":"4f1c9a2b...","span_id":"a1b2c3d4","status":"internal_error"}},"tags":{"release":"shop@2026.05.28-rc4","environment":"prod","runtime":"node:20.11"},"breadcrumbs":[{"cat":"navigation","msg":"/->/p/yda-pro-13-2026"},{"cat":"console","level":"error","msg":"ChunkLoadError: Loading chunk 44ad failed."},{"cat":"fetch","msg":"GET /v1/recommend 502"}],"truncated":true}} +2026-05-28T03:14:17.401Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:17.512Z INFO [supervisor] spawn child cmd="/opt/acme/bin/recom --shard=4" pid=39220 +2026-05-28T03:14:17.611Z ERROR [worker-19] PaymentError: refund failed: idem-key reused with mismatched body, key=refund:ORD-44801 prev_amount=199.00 new_amount=189.00 user=u_87990 + Traceback (most recent call last): + File "refund/handler.py", line 47, in run + receipt = client.refund(order_id=oid, amount=amt, idempotency_key=key) + File "payments/client.py", line 191, in refund + raise IdempotencyConflict(f'idem-key reused: {key}') + payments.exceptions.IdempotencyConflict: idem-key reused: refund:ORD-44801 +2026-05-28T03:14:17.701Z WARN [audit] write event=refund.conflict order=ORD-44801 prev=199.00 new=189.00 op=u_87990 -- requires manual reconciliation +[grafana-alert] firing: name="recom_5xx_rate" labels={service="recom", env="prod"} value=0.071 threshold=0.02 since="2026-05-28T03:14:17Z" runbook="https://wiki.acme.io/runbook/recom-5xx" +[slack-webhook] POST https://hooks.slack.com/services/T0000/B0000/REDACTED 200 142ms payload={"channel":"#prod-alerts","text":":fire: recom 5xx 7.1% (>2%)","attachments":[{"fields":[{"title":"trace","value":"4f1c9a2b"}]}]} +2026-05-28T03:14:17.812Z INFO [api-gw-7f9c] POST /v1/cart 200 12ms user=u_88340 +2026-05-28T03:14:17.901Z DEBUG [cache-3a1] redis.get key=sess:7ef miss=false ttl=80s +2026-05-28T03:14:18.001Z INFO [api-gw-7f9c] heartbeat ok rss=427MB cpu=4.7% +2026-05-28T03:14:18.117Z INFO [order-svc-09] POST /orders 200 19ms user=u_88345 amount=22.00 +[browser/chrome 124] [Console] Mixed Content: The page at 'https://shop.acme.io/p/yda-pro-13-2026' was loaded over HTTPS, but requested an insecure XMLHttpRequest endpoint 'http://legacy-pixel.acme.cn/track'. This request has been blocked; the content must be served over HTTPS. +[browser/safari 17] [Console] [Error] Failed to load resource: The certificate for this server is invalid. You might be connecting to a server that is pretending to be "img-cdn.acme.cn" which could put your confidential information at risk. (asset/hero-banner-2026.jpg, line 0) +2026-05-28T03:14:18.244Z ERROR [edge-mtls] handshake failed: x509: certificate has expired or is not yet valid: current time 2026-05-28T03:14:18Z is after 2026-05-25T00:00:00Z host=img-cdn.acme.cn client=cdn-edge-hkg +2026-05-28T03:14:18.330Z WARN [tls] rotate-now task triggered for cn=img-cdn.acme.cn expired_3d_ago=true (escalating: PagerDuty P2 page-id=PD-2026-0589) +2026-05-28T03:14:18.402Z INFO [api-gw-7f9c] GET /v1/products?cat=7 200 9ms +2026-05-28T03:14:18.512Z DEBUG [search-svc] parsed query bm25_terms=[耳机,默认服购] filters={"price":[0,500]} +2026-05-28T03:14:18.601Z INFO [order-svc-12] POST /orders 200 24ms user=u_88349 amount=88.00 +2026-05-28T03:14:18.712Z WARN [auth-svc] jwt verification slow: kid=ks-2025-09 took=412ms (jwks fetch fallback) +2026-05-28T03:14:18.811Z ERROR [auth-svc] jwks fetch failed: dial tcp: lookup auth.acme.io on 169.254.20.10:53: read udp 169.254.20.10:53: i/o timeout retry=2/3 +2026-05-28T03:14:18.901Z INFO [auth-svc] jwks fetch ok from cache (stale=true age=311s) keys=4 +2026-05-28T03:14:19.001Z INFO [api-gw-7f9c] heartbeat ok rss=429MB cpu=4.9% +2026-05-28T03:14:19.118Z INFO [audit] write event=login user=u_88349 ip=192.0.2.18 ua=chrome/124.0 +[prometheus/scrape] target=http://recom-7:9090/metrics state=DOWN err="context deadline exceeded" (15s) consecutive_failures=3 -> marking unhealthy +[prometheus/scrape] target=http://order-svc-12:9090/metrics state=UP scrape_duration=78ms samples=3144 +2026-05-28T03:14:19.244Z WARN [recom] shard-rebalance still in progress moves=1/2 elapsed=8.6s expected=<5s +2026-05-28T03:14:19.330Z ERROR [recom] shard-rebalance stalled shard=4 reason="primary candidate worker-19 over budget rss=5.2G/4G" +2026-05-28T03:14:19.401Z INFO [supervisor] admission decision: deny worker-19 promotion, fallback to worker-21 +2026-05-28T03:14:19.512Z INFO [recom] shard-rebalance promote shard=4 -> worker-21 +2026-05-28T03:14:19.611Z INFO [recom] shard-rebalance done shards-moved=2 took=11.2s (slow) +??base64?? eyJlbnYiOiJwcm9kIiwic3ZjIjoiYW5hbHl0aWNzIiwiYmF0Y2giOlsiZTEiLCJlMiIsImUzIiwiZTQiXX0=??end?? +2026-05-28T03:14:19.812Z INFO [search-svc] query took=15ms hits=2 q=" " (empty after trim) +2026-05-28T03:14:19.901Z DEBUG [cache-3a1] redis.get key=sess:newuser miss=true ttl=0s -> initialize +2026-05-28T03:14:20.001Z INFO [api-gw-7f9c] heartbeat ok rss=430MB cpu=4.5% +2026-05-28T03:14:20.122Z WARN [websocket] conn closed code=1011 reason="internal server error" path=/ws/notify user=u_88301 dur=43s msgs_sent=7 msgs_recv=2 +2026-05-28T03:14:20.244Z INFO [websocket] reconnect from u_88301 backoff=1s attempt=1 +2026-05-28T03:14:20.330Z ERROR [api-gw-7f9c] client TLS abort: tls: client offered only unsupported versions: [301 302] client_ip=185.220.101.34 ja3=00000000000000000000000000000000 (likely scanner) +2026-05-28T03:14:20.401Z WARN [waf-edge] rate-limit bucket exceeded ip=185.220.101.34 rule=ip-burst limit=120/min observed=482 action=block ttl=600s +2026-05-28T03:14:20.512Z INFO [api-gw-7f9c] POST /v1/cart 200 14ms user=u_88349 +2026-05-28T03:14:20.611Z INFO [order-svc-12] POST /orders 200 21ms user=u_88349 amount=12.50 +2026-05-28T03:14:20.701Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:20.812Z DEBUG [search-svc] parsed query bm25_terms=[键盘] filters={} +[browser/chrome 124] [Console] Uncaught (in promise) TypeError: Cannot read properties of undefined (reading 'price') + at PriceTag (https://static.acme.io/chunks/cart.js:1:9211) + at renderWithHooks (https://static.acme.io/chunks/react-dom.production.min.js:14:7332) + at updateFunctionComponent (https://static.acme.io/chunks/react-dom.production.min.js:14:9818) + caused by: backend returned items[2] without `price` field (cart=u_88349, order_draft=ORD-DRAFT-993) + reported via window.onerror -> /csp-report (envelope_id=env-2026052803-7791) +2026-05-28T03:14:21.001Z INFO [api-gw-7f9c] heartbeat ok rss=431MB cpu=4.6% +2026-05-28T03:14:21.118Z WARN [cart-svc] defensive: missing field items[2].price in draft ORD-DRAFT-993, defaulting to 0.00 (will be rejected at checkout) +2026-05-28T03:14:21.220Z ERROR [cart-svc] checkout blocked: 1 item has price=0.00, order=ORD-DRAFT-993 user=u_88349 trace_id=ab12cd34 +2026-05-28T03:14:21.330Z INFO [audit] write event=checkout.blocked order=ORD-DRAFT-993 reason=zero_price_item user=u_88349 +2026-05-28T03:14:21.401Z INFO [api-gw-7f9c] GET /v1/products?cat=11 200 9ms +2026-05-28T03:14:21.512Z DEBUG [cache-3a1] redis.set key=draft:ORD-DRAFT-993 ttl=1800s size=1208B +2026-05-28T03:14:21.611Z INFO [recom] model v17 served qps=3088 p99_ms=47 +2026-05-28T03:14:21.722Z INFO [order-svc-09] POST /orders 200 16ms user=u_88353 amount=6.00 +2026-05-28T03:14:21.812Z WARN [s3-uploader] 503 from s3 bucket=invoice-prod, signed-url=https://s3.amazonaws.com/invoice-prod/INV-44889.pdf?...(redacted) retry=1/5 +2026-05-28T03:14:21.901Z INFO [s3-uploader] upload ok bucket=invoice-prod INV-44889.pdf parts=5/5 took=412ms retry=2 +2026-05-28T03:14:22.001Z INFO [api-gw-7f9c] heartbeat ok rss=429MB cpu=4.5% +========================== END EDGE/FRONTEND SEGMENT ============================ +2026-05-28T03:14:22.122Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:22.244Z INFO [order-svc-12] POST /orders 200 23ms user=u_88360 amount=15.00 +2026-05-28T03:14:22.330Z DEBUG [cache-3a1] redis.get key=sess:9ab miss=false ttl=440s +2026-05-28T03:14:22.401Z INFO [search-svc] query took=9ms hits=14 q="不响应的鼠标" +2026-05-28T03:14:22.512Z INFO [api-gw-7f9c] POST /v1/login 200 31ms user=u_88361 +2026-05-28T03:14:22.611Z ERROR [serverless] cold-start fn=invoice-render runtime=node20 init_ms=1820 (budget=400) -> emit hot-pool warm-up +2026-05-28T03:14:22.711Z WARN [serverless] fn=invoice-render concurrency=128 throttled=4 region=cn-shanghai +2026-05-28T03:14:22.812Z INFO [api-gw-7f9c] GET /v1/products?cat=2 200 8ms +2026-05-28T03:14:22.911Z INFO [audit] write event=invoice.requested order=ORD-44888 user=u_88301 +curl --trace - http://feat-store-internal.acme.local:7780/batch ## debug capture +== 0000: 50 4f 53 54 20 2f 62 61 74 63 68 20 48 54 54 50 POST /batch HTTP +== 0010: 2f 31 2e 31 0d 0a 48 6f 73 74 3a 20 66 65 61 74 /1.1..Host: feat +== ... handshake stuck for 3s, peer RST +== CONN-RESET errno=104 +2026-05-28T03:14:23.001Z INFO [api-gw-7f9c] heartbeat ok rss=430MB cpu=4.7% +2026-05-28T03:14:23.122Z ERROR [recom] POST /v1/recommend 500 9ms cause=ConnectionResetError trace_id=ce9f8123 user=u_88361 (2nd in 6s) +2026-05-28T03:14:23.244Z WARN [supervisor] feat-store: 3 RST in 10s -> isolate pod feat-store-7d8c9b9b9c-xq2pl ip=10.42.4.219 +2026-05-28T03:14:23.330Z INFO [supervisor] drain pod feat-store-7d8c9b9b9c-xq2pl grace=15s +2026-05-28T03:14:23.401Z INFO [api-gw-7f9c] GET /v1/health 200 1ms +2026-05-28T03:14:23.512Z INFO [order-svc-12] POST /orders 200 22ms user=u_88370 amount=199.00 +2026-05-28T03:14:23.611Z INFO [recom] feat-store fallback to local-cache hit_rate=0.91 (degraded) +2026-05-28T03:14:23.722Z INFO [recom] POST /v1/recommend 200 12ms fallback=local-cache user=u_88370 +2026-05-28T03:14:23.812Z DEBUG [cache-3a1] redis.get key=sess:abc miss=false ttl=210s +2026-05-28T03:14:23.901Z INFO [search-svc] query took=12ms hits=44 q="laptop bag" +2026-05-28T03:14:24.001Z INFO [api-gw-7f9c] heartbeat ok rss=431MB cpu=4.5% +2026-05-28T03:14:24.118Z INFO [supervisor] drain complete pod=feat-store-7d8c9b9b9c-xq2pl, scheduling replacement +2026-05-28T03:14:24.244Z INFO [supervisor] spawn pod feat-store-7d8c9b9b9d-rx7nm ip=10.42.4.230 +2026-05-28T03:14:24.330Z INFO [k8s-probe] readinessProbe pod=feat-store-7d8c9b9b9d-rx7nm statusCode=200 took=12ms +2026-05-28T03:14:24.401Z INFO [recom] feat-store endpoint refreshed, fallback off +2026-05-28T03:14:24.512Z INFO [recom] POST /v1/recommend 200 8ms user=u_88370 (recovered) +2026-05-28T03:14:24.611Z INFO [audit] write event=feat-store.replaced old=xq2pl new=rx7nm took=1.4s +2026-05-28T03:14:24.722Z WARN [grafana-alert] resolved: name="recom_5xx_rate" labels={service="recom", env="prod"} value=0.004 since_resolve="2026-05-28T03:14:24Z" duration=7s +2026-05-28T03:14:24.812Z INFO [api-gw-7f9c] POST /v1/cart 200 11ms user=u_88370 +2026-05-28T03:14:24.901Z DEBUG [cache-3a1] redis.set key=sess:newuser ttl=600s size=256B +2026-05-28T03:14:25.001Z INFO [api-gw-7f9c] heartbeat ok rss=429MB cpu=4.4% +''' + + +# ────────────────────────────────────────────────────────────────────── +# 组装 prompts +# ────────────────────────────────────────────────────────────────────── def build_prompts() -> List[Dict[str, Any]]: - """Build a list of Trajectory dicts (messages format) as prompts.""" - prompts = [ - { - 'messages': [ - {'role': 'system', 'content': 'You are a helpful assistant.'}, - {'role': 'user', 'content': 'What is the capital of France?'}, - ] - }, - { - 'messages': [ - {'role': 'system', 'content': 'You are a helpful assistant.'}, - {'role': 'user', 'content': 'Write a short poem about the moon.'}, - ] - }, - { - 'messages': [ - {'role': 'user', 'content': 'Solve: 2x + 3 = 11. What is x?'}, - ] - }, + """构造五个场景的 Trajectory dict 列表。""" + cases = [ + ('Python 代码', PY_QUERY, PY_PASSAGE), + ('中文长篇新闻', NEWS_QUERY, NEWS_PASSAGE), + ('网页 HTML', HTML_QUERY, HTML_PASSAGE), + ('Python 异常处理', EXCEPTIONS_QUERY, EXCEPTIONS_PASSAGE), + ('混合服务日志', LOGS_QUERY, LOGS_PASSAGE), ] + prompts: List[Dict[str, Any]] = [] + for tag, query, passage in cases: + # 50% 硬上限,与训练时一致 + budget = max(1, len(passage) // 2) + user_msg = CONDENSER_USER.format(query=query, budget=budget, text=passage) + prompts.append({ + 'tag': tag, + 'src_len': len(passage), + 'budget': budget, + 'messages': [ + {'role': 'system', 'content': CONDENSER_SYSTEM}, + {'role': 'user', 'content': user_msg}, + ], + }) return prompts def main(): - # ── 1. Initialize Twinkle with Ray ────────────────────────────────── + # 1. 初始化 Twinkle + Ray device_groups = [ - DeviceGroup(name='sampler', ranks=list(range(SAMPLER_GPUS)), device_type='GPU', gpus_per_worker=SAMPLER_GPUS), + DeviceGroup(name='sampler', + ranks=list(range(SAMPLER_GPUS)), + device_type='GPU', + gpus_per_worker=SAMPLER_GPUS), ] sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, tp_size=SAMPLER_GPUS) twinkle.initialize(mode='ray', nproc_per_node=SAMPLER_GPUS, groups=device_groups) - # ── 2. Create vLLMSampler with LoRA enabled ──────────────────────── + # 2. 构造 vLLMSampler,max_model_len 需容纳 5120 字符级原文 + 系统提示 + 输出 sampler = vLLMSampler( model_id=MODEL_ID, engine_args={ 'gpu_memory_utilization': 0.7, - 'max_model_len': 4096, - 'enable_lora': True, + 'max_model_len': 32768, + 'enable_lora': False, 'max_loras': 1, 'max_lora_rank': 32, - 'enable_tower_connector_lora': True, + # 'enable_tower_connector_lora': True, }, device_mesh=sampler_mesh, remote_group='sampler', ) - sampler.set_template('Qwen3_5Template', model_id=MODEL_ID) + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=32768) logger.info(get_device_placement()) - # ── 3. Configure sampling parameters ──────────────────────────────── + # 3. 采样参数:压缩任务用偏低温度,避免幻觉 sampling_params = SamplingParams( - max_tokens=2018, - temperature=0.7, + max_tokens=32768, + temperature=0.1, top_p=0.9, num_samples=1, ) - # ── 4. Run inference ──────────────────────────────────────────────── + # 4. 推理 prompts = build_prompts() - logger.info(f'Sampling {len(prompts)} prompts with model {MODEL_ID} ...') + logger.info(f'共 {len(prompts)} 个压缩场景,模型 {MODEL_ID},LoRA {LORA_PATH} ...') - responses = sampler.sample(prompts, sampling_params, adapter_path=LORA_PATH) + responses = sampler.sample( + [{'messages': p['messages']} for p in prompts], + sampling_params, + # adapter_path=LORA_PATH, + ) - # ── 5. Print results ──────────────────────────────────────────────── + # 5. 输出结果 for i, response in enumerate(responses): + meta = prompts[i] for seq in response.sequences: - text = sampler.template.tokenizer.decode(seq.tokens, skip_special_tokens=True) - logger.info(f'\n{"="*60}\nPrompt {i}: {prompts[i]["messages"][-1]["content"]}\n{"─"*60}\n{text}\n') + # strip chat-template close tag that leaks through decode + text = seq.decoded.replace('<|im_end|>', '').rstrip() + logger.info( + f'\n{"=" * 60}\n' + f'场景 {i + 1}:{meta["tag"]}(原文 {meta["src_len"]} 字符,硬上限 {meta["budget"]} 字符)\n' + f'{"-" * 60}\n' + f'压缩结果({len(text)} 字符):\n{text}\n') - logger.info('Done.') + logger.info('全部场景压缩完成。') if __name__ == '__main__': diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 450906c5..2dfcb276 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -5,44 +5,26 @@ import twinkle from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() +args = CLI.from_args() -MODEL_ID = 'ms://Qwen/Qwen3.5-4B' -DATASET_ID = 'ms://swift/self-cognition' -TEMPLATE_NAME = 'Qwen3_5Template' -MODEL_NAME = 'twinkle大模型' -MODEL_AUTHOR = 'ModelScope社区' -FSDP_SIZE = 2 -DP_SIZE = 4 -BATCH_SIZE = 8 -LEARNING_RATE = 1e-4 -GRADIENT_ACCUMULATION_STEPS = 2 -LOG_INTERVAL = 20 -EVAL_INTERVAL = 40 -EVAL_SAMPLES = 100 -TRAIN_SAMPLES = 1000 - -OUTPUT_DIR = './output/fsdp2' -RESUME_FROM_CHECKPOINT = None -RESUME_ONLY_MODEL = False -IGNORE_DATA_SKIP = False -ADAPTER_NAME = 'default' - -# Construct a device_mesh -device_mesh = DeviceMesh.from_sizes(fsdp_size=FSDP_SIZE, dp_size=DP_SIZE) -# use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +device_mesh = DeviceMesh.from_sizes(fsdp_size=args.infra.fsdp_size, dp_size=args.infra.dp_size) +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) def build_dataset(num_samples: int) -> Dataset: - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) - dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(num_samples))) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) + dataset.map(SelfCognitionProcessor( + args.extra.get('model_name', 'twinkle大模型'), + args.extra.get('model_author', 'ModelScope社区'), + )) dataset.encode() return dataset @@ -50,15 +32,16 @@ def build_dataset(num_samples: int) -> Dataset: def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): model.save( checkpoint_name, - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, + output_dir=args.training.output_dir, + adapter_name=args.lora.adapter_name, save_optimizer=True, consumed_train_samples=dataloader.get_state()['consumed_train_samples'], ) def evaluate(model): - dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + eval_samples = args.training.eval_samples or 100 + dataloader = DataLoader(dataset=build_dataset(eval_samples), batch_size=args.training.batch_size) for batch in tqdm(dataloader): model.forward_only(inputs=batch) model.calculate_loss() @@ -66,52 +49,45 @@ def evaluate(model): def train(): - dataset = build_dataset(TRAIN_SAMPLES) - # Global batch size = 8, for GPUs, so 1 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - # Use a TransformersModel - model = TransformersModel(model_id=MODEL_ID) + train_samples = int(args.extra.get('train_samples', 1000)) + dataset = build_dataset(train_samples) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) + model = TransformersModel(model_id=args.model.model_id) model.model._no_split_modules = {'Qwen3_5DecoderLayer'} - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') - - # Add a lora to model, with name `default` - model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) - # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) - # Add LRScheduler for lora `default` + lora_config = LoraConfig(**args.get_lora_args()) + model.add_adapter_to_model( + args.lora.adapter_name, lora_config, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) + model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + scheduler_cls=args.scheduler.scheduler_cls, + num_warmup_steps=args.scheduler.num_warmup_steps, + num_training_steps=len(dataloader)) - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() - kwargs = {} - if ADAPTER_NAME: - kwargs['adapter_name'] = ADAPTER_NAME + if args.training.resume_from_checkpoint: + checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve() progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - if not IGNORE_DATA_SKIP: + str(checkpoint_path), + resume_only_model=args.training.resume_only_model, + adapter_name=args.lora.adapter_name) + if not args.training.ignore_data_skip: dataloader.resume_from_checkpoint(progress['consumed_train_samples']) logger.info(get_device_placement()) - # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') - optimizer_group = model.optimizer_group[ADAPTER_NAME] + optimizer_group = model.optimizer_group[args.lora.adapter_name] best_loss = float('inf') - # lora: 8G * 8 - # full: 18G * 8 + eval_interval = args.training.eval_interval or 40 for batch in dataloader: - # Do forward and backward model.forward_backward(inputs=batch) - # Step model.clip_grad_and_step() cur_step = optimizer_group.cur_step - if cur_step % LOG_INTERVAL == 0: - # Print metric + if cur_step % args.training.log_interval == 0: metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') - if cur_step > 0 and cur_step % EVAL_INTERVAL == 0: + if cur_step > 0 and cur_step % eval_interval == 0: metrics = evaluate(model) logger.info(f'Eval metric: {metrics}') metrics['step'] = cur_step diff --git a/cookbook/transformers/fsdp2.sh b/cookbook/transformers/fsdp2.sh index 93c531a9..bbe26962 100644 --- a/cookbook/transformers/fsdp2.sh +++ b/cookbook/transformers/fsdp2.sh @@ -1 +1,25 @@ -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2.py +#!/usr/bin/env bash +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash fsdp2.sh --batch-size 16 --lr 5e-5 + +CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} \ + torchrun --nproc_per_node=8 fsdp2.py \ + --model-id ms://Qwen/Qwen3.5-4B \ + --dataset-id ms://swift/self-cognition \ + --template-cls Qwen3_5Template \ + --fsdp-size 2 \ + --dp-size 4 \ + --batch-size 8 \ + --lr 1e-4 \ + --gradient-accumulation-steps 2 \ + --log-interval 20 \ + --eval-interval 40 \ + --eval-samples 100 \ + --output-dir ./output/fsdp2 \ + --adapter-name default \ + --scheduler-cls CosineWarmupScheduler \ + --num-warmup-steps 5 \ + --train-samples 1000 \ + --model-name twinkle大模型 \ + --model-author ModelScope社区 \ + "$@" diff --git a/pyproject.toml b/pyproject.toml index 964a7548..26a7db55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ kernels = ["kernels"] megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]", "mcore_bridge"] vllm = ["vllm>=0.11"] ray = ["ray[serve]"] +pyodps = ["pyodps"] +datajuicer = ["py-data-juicer"] tinker = ["tinker==0.14.0"] docs = [ "sphinx>=5.3.0,<6.0.0", diff --git a/src/twinkle/__init__.py b/src/twinkle/__init__.py index f64917a5..79e8f89d 100644 --- a/src/twinkle/__init__.py +++ b/src/twinkle/__init__.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from twinkle_client import init_tinker_client, init_twinkle_client - from .infra import get_device_placement, initialize, is_master, remote_class, remote_function + from .infra import get_device_placement, initialize, is_master, remote_class, remote_function, remote_generator from .utils import (GPU, NPU, DeviceGroup, DeviceMesh, Platform, Plugin, check_unsafe, exists, find_free_port, find_node_ip, framework_util, get_logger, requires, torch_util, trust_remote_code) from .version import __release_datetime__, __version__ @@ -16,7 +16,7 @@ 'framework_util', 'torch_util', 'exists', 'requires', 'Platform', 'GPU', 'NPU', 'find_node_ip', 'find_free_port', 'trust_remote_code', 'check_unsafe', 'DeviceMesh', 'Plugin', 'DeviceGroup', 'get_logger' ], - 'infra': ['initialize', 'remote_class', 'remote_function', 'get_device_placement', 'is_master'], + 'infra': ['initialize', 'remote_class', 'remote_function', 'remote_generator', 'get_device_placement', 'is_master'], } import sys diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index cde5c519..3866cfb4 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -122,6 +122,9 @@ def sync_weights(self, merge_and_sync=True): if self._model_keys is None: if hasattr(self.sampler, 'get_state_keys'): self._model_keys = self.sampler.get_state_keys() + # remote_function with lazy_collect returns a callable + if callable(self._model_keys): + self._model_keys = self._model_keys() if self._model_keys is None: self._model_keys = [] diff --git a/src/twinkle/cli/__init__.py b/src/twinkle/cli/__init__.py new file mode 100644 index 00000000..03eadd72 --- /dev/null +++ b/src/twinkle/cli/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .cli import ( + CLI, + Args, + CheckpointArgs, + CLISource, + ConfigResolver, + ConfigSource, + DatasetArgs, + DotEnvSource, + EnvVarSource, + InfraArgs, + LoraArgs, + LossArgs, + ModelArgs, + OptimizerArgs, + RLArgs, + SamplerArgs, + SamplingArgs, + SchedulerArgs, + ServerArgs, + TemplateArgs, + TrainingArgs, + ValueCaster, + YamlSource, +) + +__all__ = [ + 'CLI', + 'Args', + 'ConfigSource', + 'ConfigResolver', + 'ValueCaster', + 'DotEnvSource', + 'EnvVarSource', + 'YamlSource', + 'CLISource', + 'ModelArgs', + 'LoraArgs', + 'DatasetArgs', + 'TemplateArgs', + 'TrainingArgs', + 'OptimizerArgs', + 'SchedulerArgs', + 'LossArgs', + 'SamplerArgs', + 'SamplingArgs', + 'InfraArgs', + 'ServerArgs', + 'RLArgs', + 'CheckpointArgs', +] diff --git a/src/twinkle/cli/cli.py b/src/twinkle/cli/cli.py new file mode 100644 index 00000000..7c51acc0 --- /dev/null +++ b/src/twinkle/cli/cli.py @@ -0,0 +1,615 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import os +import sys +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, fields +from pathlib import Path +from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union + + +# ──────────────────────────────────────────────────────────────────────────────── +# Arg group dataclasses +# ──────────────────────────────────────────────────────────────────────────────── + + +@dataclass +class ModelArgs: + model_id: Optional[str] = field(default=None, metadata={'primary': True}) + model_cls: Optional[str] = None + tokenizer_id: Optional[str] = None + mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16' + strategy: Literal['accelerate', 'native_fsdp'] = field( + default='accelerate', metadata={'aliases': ('use_megatron',)}) + memory_efficient_init: bool = False + gradient_checkpointing: bool = True + trust_remote_code: bool = True + ddp_config: Optional[Dict[str, Any]] = None + fsdp_config: Optional[Dict[str, Any]] = None + grad_scaler_config: Optional[Dict[str, Any]] = None + + +@dataclass +class LoraArgs: + use_lora: bool = False + lora_r: int = 8 + lora_alpha: int = 32 + lora_dropout: float = 0.05 + lora_target_modules: Optional[List[str]] = None + adapter_name: str = 'default' + + +@dataclass +class DatasetArgs: + dataset_id: str = '' + subset_name: str = 'default' + split: str = 'train' + streaming: bool = False + num_proc: Optional[int] = None + data_slice: Optional[str] = None + revision: Optional[str] = None + + +@dataclass +class TemplateArgs: + template_cls: Optional[str] = None + model_id: Optional[str] = None + max_length: int = 8192 + truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise' + use_chat_template: bool = True + enable_thinking: bool = True + default_system: Optional[str] = None + + +@dataclass +class TrainingArgs: + max_steps: int = 200 + num_train_epochs: Optional[int] = None + batch_size: int = 8 + mini_batch_size: Optional[int] = None + micro_batch_size: int = 2 + gradient_accumulation_steps: int = 1 + output_dir: str = './output' + save_steps: int = 50 + save_total_limit: Optional[int] = None + log_interval: int = 10 + eval_interval: Optional[int] = None + eval_samples: Optional[int] = None + resume_from_checkpoint: Optional[str] = None + resume_only_model: bool = False + ignore_data_skip: bool = False + seed: int = field(default=42, metadata={'primary': True}) + full_determinism: bool = False + padding_free: bool = False + + +@dataclass +class OptimizerArgs: + optimizer_cls: str = 'AdamW' + learning_rate: float = field(default=1e-5, metadata={'aliases': ('lr',)}) + weight_decay: float = 0.0 + adam_beta1: float = 0.9 + adam_beta2: float = 0.999 + adam_epsilon: float = 1e-8 + max_grad_norm: float = 1.0 + + +@dataclass +class SchedulerArgs: + scheduler_cls: str = 'CosineAnnealingLR' + num_warmup_steps: int = 0 + num_training_steps: Optional[int] = None + t_max: Optional[int] = None + eta_min: float = 0.0 + lr_decay_steps: Optional[int] = None + max_lr: Optional[float] = None + + +@dataclass +class LossArgs: + loss_cls: str = 'GRPOLoss' + epsilon: float = 0.2 + epsilon_high: Optional[float] = None + beta: float = 0.0 + entropy_coef: float = 0.0 + ignore_index: int = -100 + + +@dataclass +class SamplerArgs: + sampler_type: str = 'vLLMSampler' + gpu_memory_utilization: float = 0.8 + max_model_len: Optional[int] = None + tensor_parallel_size: Optional[int] = None + enable_lora: bool = False + max_lora_rank: int = 32 + enforce_eager: bool = False + + +@dataclass +class SamplingArgs: + max_tokens: Optional[int] = field(default=None, metadata={'aliases': ('max_new_tokens',)}) + temperature: float = 1.0 + top_k: int = -1 + top_p: float = 1.0 + repetition_penalty: float = 1.0 + num_samples: int = 1 + logprobs: Optional[int] = None + seed: Optional[int] = None + stop: Optional[str] = None + + +@dataclass +class InfraArgs: + mode: Literal['local', 'ray'] = 'local' + nproc_per_node: int = field(default=8, metadata={'aliases': ('num_gpus',)}) + ncpu_proc_per_node: int = 8 + model_gpus: Optional[int] = None + sampler_gpus: Optional[int] = None + dp_size: Optional[int] = None + fsdp_size: Optional[int] = None + tp_size: Optional[int] = None + cp_size: Optional[int] = None + ep_size: Optional[int] = None + ulysses_size: Optional[int] = None + lazy_collect: bool = True + + +@dataclass +class ServerArgs: + config: Optional[str] = None + ray_namespace: str = 'twinkle_cluster' + host: str = '0.0.0.0' + port: int = 8000 + log_level: str = 'INFO' + + +@dataclass +class RLArgs: + num_generations: int = 8 + advantage_type: str = 'GRPOAdvantage' + advantage_scale: Literal['group', 'batch', 'none'] = 'group' + reward_fns: Optional[List[str]] = None + + +@dataclass +class CheckpointArgs: + save_optimizer: bool = True + merge_and_sync: bool = True + platform: str = 'GPU' + + +# ──────────────────────────────────────────────────────────────────────────────── +# ConfigSource hierarchy +# ──────────────────────────────────────────────────────────────────────────────── + + +class ConfigSource(ABC): + """Base class for all configuration sources.""" + + @abstractmethod + def load(self) -> Dict[str, Any]: + """Return raw key-value pairs from this source.""" + ... + + +class DotEnvSource(ConfigSource): + + def __init__(self, path: Optional[Union[str, Path]] = None): + self._path = path + + def load(self) -> Dict[str, str]: + path = self._resolve_path() + if path is None: + return {} + result: Dict[str, str] = {} + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith('#'): + continue + if '=' not in line: + continue + key, _, value = line.partition('=') + result[key.strip()] = value.strip().strip('"').strip("'") + return result + + def _resolve_path(self) -> Optional[Path]: + if self._path is not None: + p = Path(self._path) + return p if p.is_file() else None + for name in ('.env', '.env.local'): + p = Path.cwd() / name + if p.is_file(): + return p + return None + + +class EnvVarSource(ConfigSource): + """Reads os.environ; recognizes TWINKLE_ prefix and any key known to the registry.""" + + def __init__(self, registry: 'ConfigRegistry'): + self._registry = registry + + def load(self) -> Dict[str, str]: + result: Dict[str, str] = {} + for key, value in os.environ.items(): + if key.startswith('TWINKLE_'): + result[key[8:]] = value + elif self._registry.resolve(key) is not None: + result[key] = value + return result + + +class YamlSource(ConfigSource): + + def __init__(self, path: Union[str, Path]): + self._path = Path(path) + + def load(self) -> Dict[str, Any]: + from omegaconf import OmegaConf + if not self._path.is_file(): + raise FileNotFoundError(f'Config file not found: {self._path}') + cfg = OmegaConf.load(self._path) + return OmegaConf.to_container(cfg, resolve=True) + + +class CLISource(ConfigSource): + + def __init__(self, argv: Optional[List[str]] = None): + self._argv = argv if argv is not None else sys.argv[1:] + + def load(self) -> Dict[str, Any]: + result: Dict[str, Any] = {} + i = 0 + argv = self._argv + while i < len(argv): + token = argv[i] + if not token.startswith('--'): + i += 1 + continue + token = token[2:] + if token.startswith('no_') or token.startswith('no-'): + result[token[3:]] = False + i += 1 + continue + if '=' in token: + key, _, value = token.partition('=') + result[key] = value + i += 1 + continue + if i + 1 < len(argv) and not argv[i + 1].startswith('--'): + result[token] = argv[i + 1] + i += 2 + else: + result[token] = True + i += 1 + return result + + +# ──────────────────────────────────────────────────────────────────────────────── +# ConfigRegistry: maps normalized keys to (group_name, field_name) +# ──────────────────────────────────────────────────────────────────────────────── + + +class ConfigRegistry: + """Introspects Args dataclass groups to build a case-insensitive key→field map.""" + + # Same field name in 2+ groups — the winning group must declare metadata={'primary': True} + + def __init__(self, groups: Dict[str, Any]): + self._field_map: Dict[str, Tuple[str, str]] = {} + self._alias_map: Dict[str, str] = {} + self._groups = groups + self._build(groups) + + def _build(self, groups: Dict[str, Any]) -> None: + owners: Dict[str, List[Tuple[str, bool]]] = {} + for group_name, group_obj in groups.items(): + for f in fields(group_obj): + is_primary = f.metadata.get('primary', False) + owners.setdefault(f.name.lower(), []).append((group_name, is_primary)) + for alias in f.metadata.get('aliases', ()): # field-local aliases + self._alias_map[alias.lower()] = f.name.lower() + for key, owner_list in owners.items(): + if len(owner_list) == 1: + self._field_map[key] = (owner_list[0][0], key) + continue + primaries = [g for g, p in owner_list if p] + if len(primaries) != 1: + all_groups = [g for g, _ in owner_list] + raise ValueError( + f'Field {key!r} exists in groups {all_groups}; ' + f"exactly one must declare metadata={{'primary': True}}, found {len(primaries)}") + self._field_map[key] = (primaries[0], key) + + def resolve(self, key: str) -> Optional[Tuple[str, str]]: + normalized = key.lower().replace('-', '_') + canonical = self._alias_map.get(normalized, normalized) + if canonical in self._field_map: + return self._field_map[canonical] + # prefix-based fallback: model_xxx → group=model, field=xxx + for group_name in self._groups: + prefix = group_name + '_' + if canonical.startswith(prefix): + stripped = canonical[len(prefix):] + if stripped and (group_name, stripped) in ( + (g, f.name) for g, obj in self._groups.items() for f in fields(obj) + ): + return (group_name, stripped) + return None + + def all_keys(self) -> Iterator[str]: + return iter(self._field_map) + + +# ──────────────────────────────────────────────────────────────────────────────── +# Args: unified container +# ──────────────────────────────────────────────────────────────────────────────── + + +@dataclass +class Args: + """Unified argument container. Access groups directly or via get_*_args() dicts.""" + + model: ModelArgs = field(default_factory=ModelArgs) + lora: LoraArgs = field(default_factory=LoraArgs) + dataset: DatasetArgs = field(default_factory=DatasetArgs) + template: TemplateArgs = field(default_factory=TemplateArgs) + training: TrainingArgs = field(default_factory=TrainingArgs) + optimizer: OptimizerArgs = field(default_factory=OptimizerArgs) + scheduler: SchedulerArgs = field(default_factory=SchedulerArgs) + loss: LossArgs = field(default_factory=LossArgs) + sampler: SamplerArgs = field(default_factory=SamplerArgs) + sampling: SamplingArgs = field(default_factory=SamplingArgs) + infra: InfraArgs = field(default_factory=InfraArgs) + server: ServerArgs = field(default_factory=ServerArgs) + rl: RLArgs = field(default_factory=RLArgs) + checkpoint: CheckpointArgs = field(default_factory=CheckpointArgs) + extra: Dict[str, Any] = field(default_factory=dict) + + def get_model_args(self) -> Dict[str, Any]: + d = self._to_dict(self.model) + if not d.get('model_id') and self.template.model_id: + d['model_id'] = self.template.model_id + return d + + def get_lora_args(self) -> Dict[str, Any]: + return { + 'target_modules': self.lora.lora_target_modules or 'all-linear', + 'r': self.lora.lora_r, + 'lora_alpha': self.lora.lora_alpha, + 'lora_dropout': self.lora.lora_dropout, + } + + def get_dataset_args(self) -> Dict[str, Any]: + return self._to_dict(self.dataset) + + def get_template_args(self) -> Dict[str, Any]: + d = self._to_dict(self.template) + if not d.get('model_id') and self.model.model_id: + d['model_id'] = self.model.model_id + return d + + def get_training_args(self) -> Dict[str, Any]: + return self._to_dict(self.training) + + def get_optimizer_args(self) -> Dict[str, Any]: + d = self._to_dict(self.optimizer) + d['lr'] = d.pop('learning_rate', 1e-5) + return d + + def get_scheduler_args(self) -> Dict[str, Any]: + return self._to_dict(self.scheduler) + + def get_loss_args(self) -> Dict[str, Any]: + return self._to_dict(self.loss) + + def get_sampler_args(self) -> Dict[str, Any]: + return self._to_dict(self.sampler) + + def get_sampling_args(self) -> Dict[str, Any]: + return self._to_dict(self.sampling) + + def get_infra_args(self) -> Dict[str, Any]: + return self._to_dict(self.infra) + + def get_server_args(self) -> Dict[str, Any]: + return self._to_dict(self.server) + + def get_rl_args(self) -> Dict[str, Any]: + return self._to_dict(self.rl) + + def get_checkpoint_args(self) -> Dict[str, Any]: + return self._to_dict(self.checkpoint) + + def get(self, key: str, default: Any = None) -> Any: + for f in fields(self): + if f.name == 'extra': + continue + group = getattr(self, f.name) + if hasattr(group, key): + return getattr(group, key) + return self.extra.get(key, default) + + def __getitem__(self, key: str) -> Any: + val = self.get(key, _SENTINEL) + if val is _SENTINEL: + raise KeyError(key) + return val + + def to_dict(self) -> Dict[str, Any]: + result = {} + for f in fields(self): + if f.name == 'extra': + continue + result.update(self._to_dict(getattr(self, f.name))) + result.update(self.extra) + return result + + @staticmethod + def _to_dict(obj: Any) -> Dict[str, Any]: + return {f.name: getattr(obj, f.name) for f in fields(obj) if getattr(obj, f.name) is not None} + + +_SENTINEL = object() + + +# ──────────────────────────────────────────────────────────────────────────────── +# ValueCaster: type coercion +# ──────────────────────────────────────────────────────────────────────────────── + + +class ValueCaster: + + @staticmethod + def auto_cast(value: Any) -> Any: + if not isinstance(value, str): + return value + low = value.lower() + if low in ('true', 'yes', 'on'): + return True + if low in ('false', 'no', 'off'): + return False + if low in ('none', 'null', '~'): + return None + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + if ',' in value: + return [ValueCaster.auto_cast(v.strip()) for v in value.split(',')] + return value + + @staticmethod + def coerce_to_field(obj: Any, field_name: str, value: Any) -> Any: + current = getattr(obj, field_name, None) + if current is None or value is None: + return value + target_type = type(current) + if target_type is bool: + if isinstance(value, bool): + return value + return ValueCaster.auto_cast(str(value)) + if target_type is int and not isinstance(value, int): + try: + return int(float(value)) if isinstance(value, str) else int(value) + except (ValueError, TypeError): + return value + if target_type is float and not isinstance(value, (int, float)): + try: + return float(value) + except (ValueError, TypeError): + return value + if target_type is list and isinstance(value, str): + return [v.strip() for v in value.split(',')] + return value + + +# ──────────────────────────────────────────────────────────────────────────────── +# ConfigResolver: merges sources +# ──────────────────────────────────────────────────────────────────────────────── + + +class ConfigResolver: + + def __init__(self, args: Args): + self._args = args + self._groups = { + f.name: getattr(args, f.name) + for f in fields(args) + if f.name != 'extra' + } + self._registry = ConfigRegistry(self._groups) + + @property + def registry(self) -> 'ConfigRegistry': + return self._registry + + def apply(self, source: Dict[str, Any], cast_strings: bool = False) -> None: + flat = self._flatten(source) + for raw_key, raw_value in flat.items(): + key = raw_key.lower().replace('-', '_') + value = ValueCaster.auto_cast(raw_value) if cast_strings else raw_value + # handle use_megatron alias + if key == 'use_megatron': + if ValueCaster.auto_cast(str(value)): + self._set('model', 'strategy', 'native_fsdp') + continue + resolved = self._registry.resolve(key) + if resolved: + group_name, field_name = resolved + group = self._groups[group_name] + coerced = ValueCaster.coerce_to_field(group, field_name, value) + setattr(group, field_name, coerced) + else: + self._args.extra[key] = value + + def _set(self, group_name: str, field_name: str, value: Any) -> None: + group = self._groups[group_name] + setattr(group, field_name, value) + + def _flatten(self, d: Any, prefix: str = '') -> Dict[str, Any]: + if not isinstance(d, dict): + return {prefix: d} if prefix else {} + result: Dict[str, Any] = {} + for key, value in d.items(): + full_key = f'{prefix}_{key}' if prefix else key + if isinstance(value, dict): + result.update(self._flatten(value, full_key)) + else: + result[full_key] = value + return result + + +# ──────────────────────────────────────────────────────────────────────────────── +# CLI: top-level entry point +# ──────────────────────────────────────────────────────────────────────────────── + + +class CLI: + """Unified configuration parser. + + Resolution order (later wins): + 1. Dataclass defaults + 2. .env file + 3. Environment variables (TWINKLE_ prefix or bare) + 4. YAML config file (--config / explicit) + 5. CLI overrides (--key value) + + All keys are case-insensitive and dash/underscore equivalent: + --model-id, MODEL_ID, TWINKLE_MODEL_ID, model_id: in .yaml all resolve the same. + """ + + @staticmethod + def from_args( + argv: Optional[List[str]] = None, + env_file: Optional[Union[str, Path]] = None, + config_file: Optional[Union[str, Path]] = None, + ) -> Args: + args = Args() + resolver = ConfigResolver(args) + + # 1. .env + resolver.apply(DotEnvSource(env_file).load(), cast_strings=True) + + # 2. Environment variables + resolver.apply(EnvVarSource(resolver.registry).load(), cast_strings=True) + + # 3. CLI (first pass to extract --config) + cli_data = CLISource(argv).load() + yaml_path = config_file or cli_data.pop('config', None) + + # 4. YAML + if yaml_path: + resolver.apply(YamlSource(yaml_path).load(), cast_strings=False) + + # 5. CLI overrides (highest priority, values are strings from argv) + resolver.apply(cli_data, cast_strings=True) + + return args diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index d44856b7..90c8c3d8 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -1,11 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json as _json import os.path from collections.abc import Iterable, Mapping from dataclasses import dataclass from datasets import DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset -from torch.utils.data import Dataset as TorchDataset -from typing import Any, Callable, Dict, Type, Union - +from torch.utils.data import Dataset as TorchDataset, IterableDataset as TorchIterableDataset +from typing import Any, Callable, Dict, List, Optional, Type, Union +import threading +from queue import Queue +from twinkle.utils.parallel import PosixFileLock import twinkle from twinkle import preprocessor from twinkle.hub import HubOperation @@ -27,20 +30,34 @@ class DatasetMeta: The dataset meta-information, used to describe a dataset. """ # The dataset id or local path - dataset_id: str + dataset_id: str = '' # The subset name subset_name: str = 'default' # The split split: str = 'train' # Pick a data slice data_slice: Iterable = None + # In-memory / in-process data source. Supports: + # - List[Dict] (row-oriented, eager) + # - Dict[str, List] (column-oriented, eager) + # - Callable (generator function; routed to HF from_generator, + # streaming vs eager picked from `streaming` kwarg. + # Bind args via functools.partial.) + # - HFDataset / HFIterableDataset (already-constructed, passed through) + data: Any = None def get_id(self): + if self.data is not None: + return f'__memory_{self._uid}__:' + self.subset_name + ':' + self.split return self.dataset_id.replace(os.sep, '_').replace('.', '_') + ':' + self.subset_name + ':' + self.split def __post_init__(self): + import uuid + self._uid = uuid.uuid4().hex[:8] if self.data_slice is not None and not isinstance(self.data_slice, Iterable): raise ValueError('data_slice must be an iterable') + if not self.dataset_id and self.data is None: + raise ValueError('Either dataset_id or data must be provided') @remote_class(execute='first') @@ -58,6 +75,7 @@ class Dataset(TorchDataset): def __init__(self, dataset_meta: DatasetMeta = None, **kwargs): self.template = None + self._mixed = False if dataset_meta is None: self.datasets = {} self.dataset = None @@ -79,6 +97,17 @@ def set_template(self, template_func: Union[Template, Type[Template], str], **kw """ self.template = construct_class(template_func, Template, twinkle.template, **kwargs) + @staticmethod + def _normalize_cache_kwargs(target, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Strip/inject load_from_cache_file based on whether target supports HF cache.""" + kw = dict(kwargs) + # Streaming datasets (HF IterableDataset / torch IterableDataset wrappers) reject load_from_cache_file. + if isinstance(target, (IterableDataset, TorchIterableDataset)): + kw.pop('load_from_cache_file', None) + else: + kw.setdefault('load_from_cache_file', False) + return kw + @remote_function() def encode(self, add_generation_prompt: bool = False, **kwargs): """An inplace operation to encode the dataset. @@ -90,18 +119,16 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): **kwargs: The mapping and filter kwargs of the `datasets.map`. """ kwargs['batched'] = True # Only supported batched, because a single row may explode to several rows - if 'load_from_cache_file' not in kwargs: - # By default, we don't use load_from_cache_file, because read cache will not consider - # the changes in the same file, - # which will cause unexpected behaviors. - kwargs['load_from_cache_file'] = False + kwargs = self._normalize_cache_kwargs(self.dataset, kwargs) from functools import partial encode_fn = partial(self.template.batch_encode, add_generation_prompt=add_generation_prompt) + # Dataset.filter() does not accept map-only kwargs (e.g. remove_columns); split them off. + filter_kwargs = {k: v for k, v in kwargs.items() if k != 'remove_columns'} with processing_lock('dataset'): # use a default lock because encode is to all datasets self.dataset = self.dataset.map(encode_fn, **kwargs).filter( lambda batch: [True] * len(next(iter(batch.values()))) - if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **kwargs) + if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **filter_kwargs) @remote_function() def check(self, **kwargs): @@ -111,9 +138,7 @@ def check(self, **kwargs): **kwargs: The mapping and filter kwargs of the `datasets.map`. """ kwargs['batched'] = True # Only supported batched, because a single row may explode to several rows - # check depends on template/tokenizer behavior; cached filter results can keep old empty outputs. - # Disable cache here to avoid the "silent stop" caused by stale empty cache. - kwargs.setdefault('load_from_cache_file', False) + kwargs = self._normalize_cache_kwargs(self.dataset, kwargs) with processing_lock('dataset'): # use a default lock because check is to all datasets def _check_batch(batch): @@ -126,6 +151,24 @@ def _check_batch(batch): @staticmethod def _load_dataset(dataset_meta: DatasetMeta, **kwargs): + # In-memory / in-process data path + if dataset_meta.data is not None: + from datasets import Dataset as HFDataset + from datasets import IterableDataset as HFIterableDataset + d = dataset_meta.data + if isinstance(d, (HFDataset, HFIterableDataset)): + return d + if isinstance(d, list): + return HFDataset.from_list(d) + if isinstance(d, dict): + return HFDataset.from_dict(d) + if callable(d): + cls = HFIterableDataset if kwargs.get('streaming') else HFDataset + return cls.from_generator(d) + raise ValueError( + f'DatasetMeta.data must be list, dict, callable, or HF Dataset/IterableDataset, ' + f'got {type(d).__name__}') + dataset_id = dataset_meta.dataset_id subset_name = dataset_meta.subset_name split = dataset_meta.split @@ -155,7 +198,7 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs): dataset = load_dataset(file_type, **load_kwargs, **kwargs) else: dataset = HubOperation.load_dataset(dataset_id, subset_name, split, **kwargs) - + # fix: Some dataset sources return DatasetDict instead of Dataset, which breaks downstream select/map calls. # fix: Normalize split resolution here (target split first, then train) and fail early with a clear error. if isinstance(dataset, DatasetDict): @@ -168,6 +211,9 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs): raise KeyError(f"Split '{split}' not found for dataset '{dataset_id}'. " f'Available splits: {available_splits}') + if hasattr(dataset, 'to_hf_dataset'): + dataset = dataset.to_hf_dataset() + if isinstance(dataset_meta.data_slice, Iterable) and hasattr(dataset, '__len__'): iter_list = [] @@ -209,22 +255,27 @@ def map(self, **kwargs: The kwargs of the `datasets.map`. """ init_args = init_args or {} - if 'load_from_cache_file' not in kwargs: - # By default, we don't use load_from_cache_file, because read cache will not consider - # the changes in the same file, - # which will cause unexpected behaviors. - kwargs['load_from_cache_file'] = False preprocess_func = construct_class(preprocess_func, Preprocessor, twinkle.preprocessor, **init_args) - if dataset_meta is None: - assert len(self.datasets) == 1 - key = next(iter(self.datasets.keys())) - else: - key = dataset_meta.get_id() kwargs['batched'] = True - with processing_lock(key): - self.datasets[key] = self.datasets[key].map(preprocess_func, **kwargs) - if len(self.datasets) == 1: - self.dataset = self.datasets[key] + + if self._mixed: + self.dataset = self.dataset.map( + preprocess_func, **self._normalize_cache_kwargs(self.dataset, kwargs)) + else: + if dataset_meta is None: + assert len(self.datasets) == 1 + key = next(iter(self.datasets.keys())) + else: + key = dataset_meta.get_id() + with processing_lock(key): + kw = self._normalize_cache_kwargs(self.datasets[key], kwargs) + if 'remove_columns' not in kw: + features = getattr(self.datasets[key], 'features', None) + if features is not None: + kw['remove_columns'] = list(features.keys()) + self.datasets[key] = self.datasets[key].map(preprocess_func, **kw) + if len(self.datasets) == 1: + self.dataset = self.datasets[key] @remote_function() def filter(self, @@ -242,16 +293,20 @@ def filter(self, """ init_args = init_args or {} filter_func = construct_class(filter_func, DataFilter, twinkle.preprocessor, **init_args) - if dataset_meta is None: - assert len(self.datasets) == 1 - key = next(iter(self.datasets.keys())) + if self._mixed: + kwargs['batched'] = False + self.dataset = self.dataset.filter(filter_func, **kwargs) else: - key = dataset_meta.get_id() - kwargs['batched'] = False - with processing_lock(key): - self.datasets[key] = self.datasets[key].filter(filter_func, **kwargs) - if len(self.datasets) == 1: - self.dataset = self.datasets[key] + if dataset_meta is None: + assert len(self.datasets) == 1 + key = next(iter(self.datasets.keys())) + else: + key = dataset_meta.get_id() + kwargs['batched'] = False + with processing_lock(key): + self.datasets[key] = self.datasets[key].filter(filter_func, **kwargs) + if len(self.datasets) == 1: + self.dataset = self.datasets[key] @remote_function() def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): @@ -279,15 +334,296 @@ def mix_dataset(self, interleave=True): dataset_types = [isinstance(ds, IterableDataset) for ds in self.datasets] assert all( dataset_types) or not any(dataset_types), 'All datasets must be all streaming=True or streaming=False' + # Align features: cast large_string → string to avoid concatenation type mismatch + if not any(dataset_types): + from datasets import Features, Value, Sequence + dsets = list(self.datasets.values()) + ref_features = dsets[0].features + aligned = [] + for ds in dsets: + if ds.features != ref_features: + ds = ds.cast(ref_features) + aligned.append(ds) + else: + aligned = list(self.datasets.values()) if interleave: - self.dataset = interleave_datasets(list(self.datasets.values())) + self.dataset = interleave_datasets(aligned) else: - self.dataset = concatenate_datasets(list(self.datasets.values())) + self.dataset = concatenate_datasets(aligned) + self._mixed = True + + @remote_function() + def save_as(self, output_path: str, format: Optional[str] = None, + batch_size: int = 1000, mode: str = 'immediate', **kwargs) -> None: + """Save the merged dataset to a local file. + + Args: + output_path: Target file path. Extension determines format if `format` is None. + format: One of 'jsonl', 'json', 'csv', 'parquet'. Auto-detected from extension if None. + batch_size: Batch size for buffered writing. + mode: 'immediate' to save all data now; 'training' to write-through as data is + consumed by __iter__/__getitem__ — call flush_save() when training ends. + **kwargs: Extra args passed to the underlying HF export method (immediate bulk only). + """ + if self.dataset is None: + raise ValueError('No dataset to save.') + if len(self.datasets) > 1 and any(self.dataset is v for v in self.datasets.values()): + raise ValueError('Call mix_dataset() before save_as() when multiple datasets are loaded.') + + fmt = format or self._infer_format(output_path) + if fmt not in ('jsonl', 'json', 'csv', 'parquet'): + raise ValueError(f"Unsupported format: '{fmt}'. Use jsonl/json/csv/parquet.") + + dir_path = os.path.dirname(os.path.abspath(output_path)) + os.makedirs(dir_path, exist_ok=True) + + if mode == 'training': + self._save_state = _SaveState(output_path, fmt, batch_size) + return + + if self._should_materialize(): + self._save_incremental(output_path, fmt, batch_size) + else: + self._save_bulk(output_path, fmt, **kwargs) + + @remote_function() + def flush_save(self) -> None: + """Finalize and close the training-mode writer opened by save_as(mode='training').""" + state = getattr(self, '_save_state', None) + if state is not None: + state.close() + self._save_state = None + + def _write_through(self, row): + """If training-mode save is active, persist the row.""" + state = getattr(self, '_save_state', None) + if state is not None: + state.write(row) + return row + + @staticmethod + def _infer_format(path: str) -> str: + ext = os.path.splitext(path)[1].lstrip('.').lower() + return {'jsonl': 'jsonl', 'json': 'jsonl', 'csv': 'csv', + 'parquet': 'parquet', 'pq': 'parquet'}.get(ext, 'jsonl') + + def _should_materialize(self) -> bool: + if isinstance(self.dataset, IterableDataset): + return True + if hasattr(self, 'do_encode') and self.do_encode: + return True + if getattr(self, '_lazy_map_ops', None) or getattr(self, '_global_map_ops', None): + return True + return False + + def _save_bulk(self, path: str, fmt: str, **kwargs) -> None: + if fmt in ('jsonl', 'json'): + self.dataset.to_json(path, **kwargs) + elif fmt == 'csv': + self.dataset.to_csv(path, **kwargs) + elif fmt == 'parquet': + self.dataset.to_parquet(path, **kwargs) + + def _save_incremental(self, path: str, fmt: str, batch_size: int) -> None: + iterator = self._row_iterator() + if fmt in ('jsonl', 'json'): + self._write_jsonl(path, iterator) + elif fmt == 'csv': + self._write_csv(path, iterator, batch_size) + elif fmt == 'parquet': + self._write_parquet(path, iterator, batch_size) + + def _row_iterator(self): + if isinstance(self.dataset, IterableDataset): + yield from self.dataset + else: + for i in range(len(self)): + yield self[i] + + @staticmethod + def _write_jsonl(path: str, iterator) -> None: + with open(path, 'w', encoding='utf-8') as f: + for row in iterator: + f.write(_json.dumps(row, ensure_ascii=False, default=_default_serializer) + '\n') + + @staticmethod + def _write_csv(path: str, iterator, batch_size: int) -> None: + import pandas as pd + first = True + batch: List[Dict] = [] + for row in iterator: + batch.append(row) + if len(batch) >= batch_size: + pd.DataFrame(batch).to_csv(path, mode='a', header=first, index=False) + first = False + batch = [] + if batch: + pd.DataFrame(batch).to_csv(path, mode='a', header=first, index=False) + + @staticmethod + def _write_parquet(path: str, iterator, batch_size: int) -> None: + import pyarrow as pa + import pyarrow.parquet as pq + writer = None + batch: List[Dict] = [] + for row in iterator: + batch.append(row) + if len(batch) >= batch_size: + table = pa.Table.from_pylist(batch) + if writer is None: + writer = pq.ParquetWriter(path, table.schema) + writer.write_table(table) + batch = [] + if batch: + table = pa.Table.from_pylist(batch) + if writer is None: + writer = pq.ParquetWriter(path, table.schema) + writer.write_table(table) + if writer: + writer.close() @remote_function() def __getitem__(self, idx): - return self.dataset[idx] + item = self.dataset[idx] + self._write_through(item) + return item @remote_function() def __len__(self): return len(self.dataset) + + +def _default_serializer(obj): + """Handle numpy types in JSON serialization.""" + import numpy as np + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + raise TypeError(f'Object of type {type(obj).__name__} is not JSON serializable') + + +_SENTINEL = object() + + +class _SaveState: + """Async persistent writer for training-mode save_as. + + Writes happen on a background daemon thread so the training loop is never blocked. + Uses fcntl file-lock for cross-process safety when multiple ranks write one file. + """ + + def __init__(self, path: str, fmt: str, batch_size: int): + + self._path = path + self._fmt = fmt + self._batch_size = batch_size + self._queue: Queue = Queue(maxsize=batch_size * 4) + self._lock = PosixFileLock(path + '.lock') + self._error = None + + self._thread = threading.Thread(target=self._writer_loop, daemon=True) + self._thread.start() + + def write(self, row: Dict) -> None: + self._queue.put(row) + + def close(self) -> None: + self._queue.put(_SENTINEL) + self._thread.join() + self._lock.close() + if self._error: + raise self._error + + def _writer_loop(self) -> None: + try: + if self._fmt in ('jsonl', 'json'): + self._loop_jsonl() + elif self._fmt == 'csv': + self._loop_csv() + elif self._fmt == 'parquet': + self._loop_parquet() + except Exception as e: + self._error = e + + def _acquire_lock(self): + self._lock.acquire() + + def _release_lock(self): + self._lock.release() + + def _loop_jsonl(self) -> None: + with open(self._path, 'a', encoding='utf-8') as f: + while True: + item = self._queue.get() + if item is _SENTINEL: + return + line = _json.dumps(item, ensure_ascii=False, default=_default_serializer) + '\n' + self._acquire_lock() + try: + f.write(line) + f.flush() + finally: + self._release_lock() + + def _loop_csv(self) -> None: + import pandas as pd + header_written = False + buffer: List[Dict] = [] + while True: + item = self._queue.get() + if item is _SENTINEL: + if buffer: + self._acquire_lock() + try: + pd.DataFrame(buffer).to_csv( + self._path, mode='a', header=not header_written, index=False) + finally: + self._release_lock() + return + buffer.append(item) + if len(buffer) >= self._batch_size: + self._acquire_lock() + try: + pd.DataFrame(buffer).to_csv( + self._path, mode='a', header=not header_written, index=False) + header_written = True + finally: + self._release_lock() + buffer = [] + + def _loop_parquet(self) -> None: + import pyarrow as pa + import pyarrow.parquet as pq + writer = None + buffer: List[Dict] = [] + try: + while True: + item = self._queue.get() + if item is _SENTINEL: + if buffer: + table = pa.Table.from_pylist(buffer) + if writer is None: + writer = pq.ParquetWriter(self._path, table.schema) + self._acquire_lock() + try: + writer.write_table(table) + finally: + self._release_lock() + return + buffer.append(item) + if len(buffer) >= self._batch_size: + table = pa.Table.from_pylist(buffer) + if writer is None: + writer = pq.ParquetWriter(self._path, table.schema) + self._acquire_lock() + try: + writer.write_table(table) + finally: + self._release_lock() + buffer = [] + finally: + if writer: + writer.close() diff --git a/src/twinkle/dataset/iterable_dataset.py b/src/twinkle/dataset/iterable_dataset.py index 21ae82f8..b985d83e 100644 --- a/src/twinkle/dataset/iterable_dataset.py +++ b/src/twinkle/dataset/iterable_dataset.py @@ -29,6 +29,6 @@ def __getitem__(self, idx): @remote_function() def __iter__(self): - # TODO if this class passed through actor handler, an error will occur: - # a global single dataset, multiple dataloaders, the self._iter will cover each other - return self.dataset.__iter__() + for row in self.dataset: + self._write_through(row) + yield row diff --git a/src/twinkle/dataset/lazy_dataset.py b/src/twinkle/dataset/lazy_dataset.py index 29f8f678..383f85d7 100644 --- a/src/twinkle/dataset/lazy_dataset.py +++ b/src/twinkle/dataset/lazy_dataset.py @@ -186,6 +186,7 @@ def __getitem__(self, idx): elif self.do_check: item = self.template.check(item) + self._write_through(item) return item @remote_function() diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 83158a28..5027cf32 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -1,10 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import functools import inspect +import itertools import json import numpy as np import os -from typing import Any, Callable, List, Literal, Optional, TypeVar, Union +import random +from typing import Any, AsyncIterator, Callable, List, Literal, Optional, TypeVar, Union from twinkle.notifier import Notifier, notify_exception from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, get_logger, requires @@ -36,6 +38,54 @@ _TWINKLE_NOTIFIER_ENV = 'TWINKLE_NOTIFIER' +def _capture_caller() -> Optional[str]: + """Return ``file:line`` of the first frame outside this module, or ``None``.""" + this_file = __file__ + frame = inspect.currentframe() + if frame is None: + return None + frame = frame.f_back # skip _capture_caller itself + while frame is not None and frame.f_code.co_filename == this_file: + frame = frame.f_back + if frame is None: + return None + return f'{frame.f_code.co_filename}:{frame.f_lineno}' + + +def _attach_caller_note(exc: BaseException, caller: Optional[str]) -> None: + """Append a driver-caller note to ``exc`` so it surfaces in traceback dumps (PY3.11+).""" + if not caller: + return + try: + marker = f'[twinkle] driver caller: {caller}' + notes = getattr(exc, '__notes__', None) or [] + if marker not in notes: + exc.add_note(marker) + except Exception: # noqa: BLE001 + pass + + +def _augment_exc_with_caller(exc: BaseException, caller: Optional[str]) -> None: + """Prepend driver caller to ``exc.args[0]`` so ``f'{exc}'`` / ``str(exc)`` surfaces it. + + ``add_note`` only shows up in ``traceback.format_exception``; downstream code that + logs via ``f'{e}'`` (e.g. ``SamplerBackend.prompt_logprobs``) would otherwise drop + the caller hint. Idempotent via a sentinel attribute so repeated re-raises in nested + wrappers don't stack the prefix. + """ + if not caller or getattr(exc, '_twinkle_caller_augmented', False): + return + try: + prefix = f'[twinkle driver caller: {caller}] ' + if exc.args: + exc.args = (prefix + str(exc.args[0]), *exc.args[1:]) + else: + exc.args = (prefix.rstrip(),) + setattr(exc, '_twinkle_caller_augmented', True) + except Exception: # noqa: BLE001 + pass + + def _maybe_load_worker_notifier() -> None: """Lazily reconstruct notifier + name on ray workers from inherited env vars.""" global _notifier, _name @@ -384,13 +434,17 @@ def dispatch_func(arg, n): # Comment this because remote_class supports `first`` # assert device_mesh.world_size == len(workers) length = len(workers) + # Map actor index to global_rank: with gpus_per_worker>1, consecutive + # global ranks belong to the same actor (TP peers). + _mesh_world = device_mesh.world_size if device_mesh is not None else length + _rank_stride = max(1, _mesh_world // length) def dispatch_func(arg, n): import torch if isinstance(arg, list) or isinstance(arg, torch.Tensor): _args = [] for i in range(n): - _args.append(arg[device_mesh.get_slice(len(arg), device_mesh.get_data_rank_from_global_rank(i))]) + _args.append(arg[device_mesh.get_slice(len(arg), device_mesh.get_data_rank_from_global_rank(i * _rank_stride))]) return _args elif isinstance(arg, dict): _args = [{} for _ in range(n)] @@ -487,15 +541,19 @@ def decorator(cls): @functools.wraps(init_method) def new_init(self, *args, **kwargs): + _caller = _capture_caller() _ctx = f'{cls.__name__}.__init__' + if _caller: + _ctx = f'{_ctx} <- {_caller}' try: _maybe_load_worker_notifier() - _new_init_body(self, *args, **kwargs) + _new_init_body(self, _caller, *args, **kwargs) except Exception as _e: # noqa: BLE001 + _attach_caller_note(_e, _caller) notify_exception(_notifier, _ctx, _e, _name) raise - def _new_init_body(self, *args, **kwargs): + def _new_init_body(self, _caller, *args, **kwargs): if _mode == 'local': # Get the actual device_mesh device_mesh = _get_device_mesh_param(args, kwargs) @@ -519,10 +577,16 @@ def _new_init_body(self, *args, **kwargs): from ._ray import RayHelper # In case the same class created twice in the same device group - # Try to get the caller's line - frame = inspect.currentframe().f_back - caller_file = frame.f_code.co_filename.replace(os.sep, '_').replace('.', '_') - caller_line = frame.f_lineno + # Try to get the caller's line (resolved in ``new_init`` so it points + # at user code, not at the wrapper itself). + if _caller: + _cf, _, _cl = _caller.rpartition(':') + caller_file = _cf.replace(os.sep, '_').replace('.', '_') + caller_line = _cl + else: + frame = inspect.currentframe().f_back + caller_file = frame.f_code.co_filename.replace(os.sep, '_').replace('.', '_') + caller_line = frame.f_lineno # Pass an instance_id is recommended instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}' remote_group = kwargs.get('remote_group') @@ -688,6 +752,10 @@ def decorator(func: Callable[..., T1]) -> Callable[..., T1]: @functools.wraps(func) def wrapper(self, *args, **kwargs) -> T1: _ctx = f'{type(self).__name__}.{func.__name__}' + # Only capture caller on driver side; worker frames are Ray internals + _caller = _capture_caller() if hasattr(self, '_actors') else None + if _caller: + _ctx = f'{_ctx} <- {_caller}' try: device_mesh = getattr(self, 'device_mesh', None) if _mode == 'local': @@ -766,6 +834,8 @@ def _notifying_result_func(*rargs, **rkwargs): try: return _orig_result_func(*rargs, **rkwargs) except Exception as _e: # noqa: BLE001 + _attach_caller_note(_e, _caller) + _augment_exc_with_caller(_e, _caller) notify_exception(_notifier, _ctx, _e, _name) raise @@ -779,6 +849,8 @@ def _notifying_result_func(*rargs, **rkwargs): except StopIteration: raise except Exception as _e: # noqa: BLE001 + _attach_caller_note(_e, _caller) + _augment_exc_with_caller(_e, _caller) notify_exception(_notifier, _ctx, _e, _name) raise @@ -790,3 +862,99 @@ def _notifying_result_func(*rargs, **rkwargs): return wrapper return decorator + + +async def _wrap_async_iter_with_notify(gen: AsyncIterator, ctx: str, caller: Optional[str] = None) -> AsyncIterator: + """Re-emit chunks from a local async generator and forward exceptions to the notifier.""" + try: + async for chunk in gen: + yield chunk + except Exception as _e: # noqa: BLE001 + _attach_caller_note(_e, caller) + _augment_exc_with_caller(_e, caller) + notify_exception(_notifier, ctx, _e, _name) + raise + + +async def _wrap_objrefgen_with_notify(ref_gen: Any, ctx: str, caller: Optional[str] = None) -> AsyncIterator: + """Drain a Ray ObjectRefGenerator chunk-by-chunk; forward exceptions to the notifier.""" + import ray + try: + async for ref in ref_gen: + yield await ref + except Exception as _e: # noqa: BLE001 + _attach_caller_note(_e, caller) + _augment_exc_with_caller(_e, caller) + notify_exception(_notifier, ctx, _e, _name) + raise + + +def remote_generator(execute: Literal['first', 'balanced', 'random'] = 'balanced'): + """Streaming counterpart of ``remote_function`` for async-generator methods. + + The decorated method must be ``async def`` with ``yield``. Driver-side + returns an async iterator that yields each chunk as soon as the worker + emits it; under Ray this is backed by ``ObjectRefGenerator``. + + Args: + execute: How to pick the actor for a given call. Streaming is single-rank + inference (no NCCL collective), so we route the whole call to ONE actor. + + - 'first': always ``_actors[0]``. Useful for debugging or when + a particular rank holds privileged state. + - 'balanced': round-robin across ``_actors`` (DEFAULT). Each + decorated method owns an independent counter. + - 'random': uniform random pick. + + Notes: + - Bypasses ``_dispatch_args`` entirely (no ``slice_dp`` "batch too small" + guard fires for streaming). + - On the worker side the decorator is a transparent passthrough; Ray + turns the actor's ``async def + yield`` method into a streaming + generator handle automatically. + """ + + def decorator(func: Callable[..., AsyncIterator[T1]]) -> Callable[..., AsyncIterator[T1]]: + + # Per-method counter, isolated from any other @remote_generator call site. + _rr_counter = itertools.count() + + @functools.wraps(func) + def wrapper(self, *args, **kwargs) -> AsyncIterator[T1]: + _ctx = f'{type(self).__name__}.{func.__name__}' + _caller = _capture_caller() if hasattr(self, '_actors') else None + if _caller: + _ctx = f'{_ctx} <- {_caller}' + try: + if _mode == 'local' or not hasattr(self, '_actors'): + # Worker-side OR pure local mode: just invoke the async generator. + return _wrap_async_iter_with_notify(func(self, *args, **kwargs), _ctx, _caller) + if _mode != 'ray': + raise NotImplementedError(f'Unsupported mode {_mode}') + + check_unsafe(*args, **kwargs) + actors = self._actors + if not actors: + raise RuntimeError(f'{_ctx}: no actors available for streaming dispatch') + if execute == 'first': + actor = actors[0] + elif execute == 'random': + actor = random.choice(actors) + elif execute == 'balanced': + actor = actors[next(_rr_counter) % len(actors)] + else: + raise ValueError(f'Unsupported execute mode for remote_generator: {execute}') + + ref_gen = getattr(actor, func.__name__).remote(*args, **kwargs) + return _wrap_objrefgen_with_notify(ref_gen, _ctx, _caller) + except Exception as _e: # noqa: BLE001 + _attach_caller_note(_e, _caller) + _augment_exc_with_caller(_e, _caller) + notify_exception(_notifier, _ctx, _e, _name) + raise + + wrapper._execute = execute + wrapper._is_generator = True + return wrapper + + return decorator diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index 4e4d0e82..8e1d0e2a 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -5,6 +5,7 @@ from .dpo import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from .gkd import GKDLoss from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss +from .infonce import InfonceLoss from .mse import MSELoss torch_loss_mapping = { @@ -25,4 +26,6 @@ 'simpo': SimPOLoss, 'cpo': CPOLoss, 'orpo': ORPOLoss, + # Embedding / contrastive losses + 'infonce': InfonceLoss, } diff --git a/src/twinkle/loss/base.py b/src/twinkle/loss/base.py index 334d5edd..5fd046ae 100644 --- a/src/twinkle/loss/base.py +++ b/src/twinkle/loss/base.py @@ -6,6 +6,7 @@ class Loss: require_logits = False require_entropy = False + require_logps = True def __call__(self, inputs: InputFeature, outputs: ModelOutput, **kwargs) -> LossOutput: ... diff --git a/src/twinkle/loss/cross_entropy.py b/src/twinkle/loss/cross_entropy.py index abcc9591..c1b5225d 100644 --- a/src/twinkle/loss/cross_entropy.py +++ b/src/twinkle/loss/cross_entropy.py @@ -4,37 +4,28 @@ class CrossEntropyLoss(Loss): - """Calculate CE from logps""" + """Calculate CE from logps, with optional DFT (arxiv 2508.05629) entropy weighting.""" - def __init__(self, ignore_index: int = -100, reduction='mean', **kwargs): + def __init__(self, ignore_index: int = -100, reduction='mean', dft: bool = False, **kwargs): super().__init__() self.ignore_index = ignore_index self.reduction = reduction + self.dft = dft def __call__(self, inputs, outputs, **kwargs): labels = inputs['labels'] logps = outputs.get('logps') - logits = outputs.get('logits') - if logps is not None: - loss_mask = (labels != self.ignore_index).float() - if self.reduction != 'sum': - return LossOutput( - loss=(-logps * loss_mask).sum() / loss_mask.sum().clamp(min=1), - num_tokens=0, - ) - else: - return LossOutput( - loss=(-logps * loss_mask).sum(), - num_tokens=loss_mask.sum().clamp(min=1), - ) - else: - import torch - assert logits is not None - logits = logits.view(-1, logits.shape[-1]) + if logps is None: + import torch.nn.functional as F + logits = outputs['logits'].view(-1, outputs['logits'].shape[-1]) labels = labels.view(-1) - loss = torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels) - if self.reduction != 'sum': - return LossOutput(loss=loss, num_tokens=0) - else: - return LossOutput(loss=loss, num_tokens=(labels != self.ignore_index).sum()) + logps = F.log_softmax(logits, dim=-1).gather(-1, labels.clamp(min=0).unsqueeze(-1)).squeeze(-1) + + mask = (labels != self.ignore_index).float() + # DFT: -p·log(p) instead of -log(p) + per_token = -logps * logps.exp() if self.dft else -logps + + if self.reduction != 'sum': + return LossOutput(loss=(per_token * mask).sum() / mask.sum().clamp(min=1), num_tokens=0) + return LossOutput(loss=(per_token * mask).sum(), num_tokens=mask.sum().clamp(min=1)) diff --git a/src/twinkle/loss/infonce.py b/src/twinkle/loss/infonce.py new file mode 100644 index 00000000..44d2e48d --- /dev/null +++ b/src/twinkle/loss/infonce.py @@ -0,0 +1,274 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Embedding / contrastive losses for Twinkle. + +Inputs convention: + inputs['labels']: pair / multi-negative grouping labels (see each class docstring). + outputs['embeddings']: sentence embeddings produced by the model + (shape ``[B, D]``). Falls back to ``outputs['logits']`` for + backward-compatibility with the legacy hook-side pooling layout. + +All classes return :class:`LossOutput` with ``num_tokens=0`` (no per-token +normalization, matching the convention used by ``DPOLoss``/``GRPOLoss``). +""" +from enum import Enum +from typing import Optional + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +from twinkle.data_format import LossOutput + +from .base import Loss + + +# Borrowed from sentence_transformers. +class SiameseDistanceMetric(Enum): + """Distance metrics available to the pairwise contrastive losses.""" + + EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa + MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa + COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa + + +def _extract_sentences(outputs) -> torch.Tensor: + """Return [B, D] sentence embeddings from postprocess_tensor_sp output. + + Prefers the canonical ``embeddings`` key (post-pooling); falls back to + ``logits`` (legacy hook-side pooling) and applies CLS pooling for 3-D. + """ + sentences = outputs.get('embeddings') + if sentences is None: + sentences = outputs['logits'] + if sentences.dim() == 3: + sentences = sentences[:, 0] + return sentences + + +def _parse_pair_sentence(outputs): + """Split an interleaved [s1_0, s2_0, s1_1, s2_1, ...] tensor into (s1, s2).""" + sentences = _extract_sentences(outputs) + return sentences[0::2], sentences[1::2] + + +def _parse_multi_negative_sentences(sentences: torch.Tensor, + labels: torch.Tensor, + hard_negatives: Optional[int] = None): + """Split a flat embedding tensor into per-sample groups. + + ``labels`` is a 1-D mask where ``1`` marks the start of a new + ``anchor(1)+positive(1)+negatives(n)`` group; the inserted offsets account for + the anchor sitting immediately before each positive in the flat layout. + """ + split_indices = torch.nonzero(labels, as_tuple=False).squeeze().tolist() + if isinstance(split_indices, int): + split_indices = [split_indices] + split_indices.append(len(labels)) + split_tensors = [] + for i in range(len(split_indices) - 1): + start, end = split_indices[i], split_indices[i + 1] + split_part = sentences[start:end] + if hard_negatives is not None: + negatives = len(split_part) - 2 + assert negatives > 0 + if negatives > hard_negatives: + split_part = split_part[:hard_negatives + 2] + elif negatives < hard_negatives: + # upsample negatives with replacement; skip index 0 (positive) + selected = np.random.choice(list(range(negatives)), size=hard_negatives - negatives, replace=True) + 1 + split_part = torch.cat((split_part, split_part[selected]), dim=0) + split_tensors.append(split_part) + return split_tensors + + +class InfonceLoss(Loss): + """InfoNCE contrastive loss with optional cross-DP gathering. + + Each sample is laid out as ``anchor(1) + positive(1) + negatives(n)``; + ``inputs['labels']`` is a 1-D mask where ``1`` marks the start of every + such group. Setting ``use_batch=True`` enables in-batch negatives and, + when distributed is initialized, gathers embeddings from all DP ranks + (only the local shard keeps gradients). + + Args: + temperature: Logit scaling factor. + use_batch: Include cross-sample (and cross-rank) in-batch negatives. + hard_negatives: Fix the per-sample negative count via truncation/upsampling. + ``None`` keeps the original variable counts. + mask_fake_negative: Mask any logit greater than ``positive + fake_neg_margin``. + fake_neg_margin: Threshold offset above the positive logit when masking. + include_qq: Append the query-query similarity block (self diagonal masked). + include_dd: Append the positive-doc to all-docs block (self positive masked). + process_group: Distributed process group used for the all-gather. + When ``None``, the default group (``dist.group.WORLD``) is used. + """ + + require_logits = True + require_entropy = False + require_logps = False + + def __init__( + self, + temperature: float = 0.1, + use_batch: bool = True, + hard_negatives: Optional[int] = None, + mask_fake_negative: bool = False, + fake_neg_margin: float = 0.1, + include_qq: bool = False, + include_dd: bool = False, + process_group=None, + **kwargs, + ): + self.temperature = temperature + self.use_batch = use_batch + self.hard_negatives = hard_negatives + self.mask_fake_negative = mask_fake_negative + self.fake_neg_margin = fake_neg_margin + self.include_qq = include_qq + self.include_dd = include_dd + self.process_group = process_group + + def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor): + """All-gather embeddings & labels across DP ranks; only local shard keeps grad.""" + if not (dist.is_available() and dist.is_initialized()): + return sentences, labels + world_size = dist.get_world_size(group=self.process_group) + if world_size <= 1: + return sentences, labels + rank = dist.get_rank(group=self.process_group) + + # variable per-rank shapes require communicating shape first + local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) + shapes = [torch.empty_like(local_shape) for _ in range(world_size)] + dist.all_gather(shapes, local_shape, group=self.process_group) + all_sentences = [sentences.new_empty(shape.tolist()) for shape in shapes] + dist.all_gather(all_sentences, sentences.contiguous(), group=self.process_group) + + local_label_shape = labels.new_tensor(labels.shape, dtype=torch.long) + label_shapes = [torch.empty_like(local_label_shape) for _ in range(world_size)] + dist.all_gather(label_shapes, local_label_shape, group=self.process_group) + all_labels = [labels.new_empty(shape.tolist()) for shape in label_shapes] + dist.all_gather(all_labels, labels.contiguous(), group=self.process_group) + + # keep the local shard differentiable; detach others + all_sentences[rank] = sentences + for idx in range(world_size): + if idx != rank: + all_sentences[idx] = all_sentences[idx].detach() + return torch.cat(all_sentences, dim=0), torch.cat(all_labels, dim=0) + + def __call__(self, inputs, outputs, **kwargs) -> LossOutput: + labels = inputs['labels'].view(-1) + sentences = _extract_sentences(outputs) + + if self.use_batch: + sentences, labels = self._gather_across_dp(sentences, labels) + + split_tensors = _parse_multi_negative_sentences(sentences, labels, self.hard_negatives) + can_batched = self.hard_negatives is not None or len({s.shape[0] for s in split_tensors}) == 1 + + if not self.use_batch: + loss = self._intra_sample_loss(split_tensors, can_batched) + else: + loss = self._in_batch_loss(split_tensors, can_batched) + return LossOutput(loss=loss, num_tokens=0) + + def _intra_sample_loss(self, split_tensors, can_batched) -> torch.Tensor: + """InfoNCE with only the per-sample negatives (no cross-sample sharing).""" + if can_batched: + sentences = torch.stack(split_tensors, dim=0) # [B, neg+2, D] + similarity_matrix = torch.matmul(sentences[:, 0:1], sentences[:, 1:].transpose(1, 2)) / self.temperature + labels = torch.zeros(len(split_tensors), dtype=torch.int64, device=sentences.device) + return nn.CrossEntropyLoss()(similarity_matrix.squeeze(1), labels) + + loss = 0 + for tensor in split_tensors: + similarity_matrix = torch.matmul(tensor[0], tensor[1:].T) / self.temperature + labels = torch.tensor(0, device=tensor.device) + loss = loss + nn.CrossEntropyLoss()(similarity_matrix, labels) + return loss / len(split_tensors) + + def _in_batch_loss(self, split_tensors, can_batched) -> torch.Tensor: + """InfoNCE with cross-sample (and optionally cross-rank) negatives.""" + if can_batched: + return self._in_batch_loss_batched(split_tensors) + return self._in_batch_loss_unbatched(split_tensors) + + def _in_batch_loss_batched(self, split_tensors) -> torch.Tensor: + sentences = torch.stack(split_tensors, dim=0) # [B, neg+2, D] + queries = sentences[:, 0] # [B, D] + docs_all = sentences[:, 1:].reshape(-1, sentences.size(2)) # [B*(neg+1), D] + qd_matrix = torch.matmul(queries, docs_all.T) # [B, B*(neg+1)] + # each row's positive sits at column row_idx * (neg+1) + block = sentences.size(1) - 1 + labels = torch.arange(0, sentences.size(0) * block, block, device=sentences.device) + + logits_list = [qd_matrix] + + if self.include_qq: + qq_matrix = torch.matmul(queries, queries.T).clone() + qq_matrix.fill_diagonal_(float('-inf')) + logits_list.append(qq_matrix) + + if self.include_dd: + pos_docs = sentences[:, 1] # [B, D] + dd_matrix = torch.matmul(pos_docs, docs_all.T) # [B, B*(neg+1)] + if block > 0: + row_idx = torch.arange(dd_matrix.size(0), device=dd_matrix.device) + dd_matrix[row_idx, row_idx * block] = float('-inf') + logits_list.append(dd_matrix) + + if self.mask_fake_negative: + row_idx = torch.arange(qd_matrix.size(0), device=qd_matrix.device) + thresholds = (qd_matrix[row_idx, labels].view(-1, 1).detach() + self.fake_neg_margin) + + qd_block = qd_matrix.clone() + qd_block[qd_block > thresholds] = float('-inf') + components = [qd_block] + if self.include_qq: + qq_block = logits_list[1].clone() + qq_block[qq_block > thresholds] = float('-inf') + components.append(qq_block) + if self.include_dd: + # align with Qwen3-Embedding: no threshold masking on d-d block + components.append(logits_list[-1]) + similarity_matrix = torch.cat(components, dim=1) + else: + similarity_matrix = torch.cat(logits_list, dim=1) + + return nn.CrossEntropyLoss()(similarity_matrix / self.temperature, labels) + + def _in_batch_loss_unbatched(self, split_tensors) -> torch.Tensor: + # docs from every sample concatenated as a shared negative bank + docs_bank = torch.cat([t[1:] for t in split_tensors], dim=0) + queries_all = torch.stack([t[0] for t in split_tensors], dim=0) if self.include_qq else None + + loss = 0 + length = 0 + for idx, tensor in enumerate(split_tensors): + qd_vec = torch.matmul(tensor[0], docs_bank.T) + target = torch.tensor(length, device=tensor.device) + threshold = qd_vec[target].detach() + self.fake_neg_margin + + qd_masked = torch.where(qd_vec > threshold, qd_vec.new_full((), float('-inf')), + qd_vec) if self.mask_fake_negative else qd_vec + logits_parts = [qd_masked] + + if self.include_qq: + qq_vec = torch.matmul(tensor[0], queries_all.T).clone() + qq_vec[idx] = float('-inf') + if self.mask_fake_negative: + qq_vec = torch.where(qq_vec > threshold, qq_vec.new_full((), float('-inf')), qq_vec) + logits_parts.append(qq_vec) + + if self.include_dd: + dd_vec = torch.matmul(tensor[1], docs_bank.T) + dd_vec[length] = float('-inf') + logits_parts.append(dd_vec) + + logits_row = torch.cat(logits_parts, dim=-1) / self.temperature + loss = loss + nn.CrossEntropyLoss()(logits_row.unsqueeze(0), target.unsqueeze(0)) + length += tensor.size(0) - 1 + return loss / len(split_tensors) diff --git a/src/twinkle/metric/__init__.py b/src/twinkle/metric/__init__.py index ad244e1d..baeb6c1c 100644 --- a/src/twinkle/metric/__init__.py +++ b/src/twinkle/metric/__init__.py @@ -3,6 +3,7 @@ from .base import Metric from .completion_and_reward import CompletionRewardMetric from .dpo import DPOMetric +from .embedding import EmbeddingMetric from .grpo import CISPOMetric, GRPOMetric, GSPOMetric from .loss import LossMetric from .train_metric import TrainMetric diff --git a/src/twinkle/metric/embedding.py b/src/twinkle/metric/embedding.py new file mode 100644 index 00000000..543380d6 --- /dev/null +++ b/src/twinkle/metric/embedding.py @@ -0,0 +1,109 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from typing import List, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from twinkle.data_format import InputFeature, ModelOutput + +from .base import Metric + + +class EmbeddingMetric(Metric): + """Embedding similarity metric for InfoNCE training. + + Reports anchor-positive cosine similarity stats (mean/min/max) and + average anchor-to-other-positives (in-batch negative) similarity. + Performs an extra all_gather to compute cross-rank statistics. + """ + + def __init__(self, device_mesh, process_group, **kwargs): + super().__init__(device_mesh, process_group, **kwargs) + self.reset() + + def reset(self): + self.pos_sim_sum = 0.0 + self.pos_sim_min = float('inf') + self.pos_sim_max = float('-inf') + self.pos_count = 0 + self.neg_sim_sum = 0.0 + self.neg_count = 0 + self.total_loss = 0.0 + self.total_count = 0 + self.grad_norm = 0.0 + + def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): + sentences = outputs.get('embeddings') + if sentences is None: + sentences = outputs.get('logits') + if sentences is None: + return + if sentences.dim() == 3: + sentences = sentences[:, 0] + + if not isinstance(inputs, list): + inputs = [inputs] + labels = torch.cat([inp['labels'].view(-1) for inp in inputs], dim=0) + + # Gather embeddings and labels across DP for in-batch stats + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + world_size = dist.get_world_size() + local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) + shapes = [torch.empty_like(local_shape) for _ in range(world_size)] + dist.all_gather(shapes, local_shape) + all_sentences = [sentences.new_empty(s.tolist()) for s in shapes] + dist.all_gather(all_sentences, sentences.contiguous()) + sentences = torch.cat(all_sentences, dim=0) + + local_lshape = labels.new_tensor(labels.shape, dtype=torch.long) + lshapes = [torch.empty_like(local_lshape) for _ in range(world_size)] + dist.all_gather(lshapes, local_lshape) + all_labels = [labels.new_empty(s.tolist()) for s in lshapes] + dist.all_gather(all_labels, labels.contiguous()) + labels = torch.cat(all_labels, dim=0) + + anchor_idx = torch.nonzero(labels, as_tuple=False).squeeze(-1) + if anchor_idx.numel() == 0: + return + + anchors = sentences[anchor_idx] + positives = sentences[anchor_idx + 1] + + # Anchor-positive cosine similarity + pos_cos = F.cosine_similarity(anchors, positives, dim=1) + self.pos_sim_sum += pos_cos.sum().item() + self.pos_sim_min = min(self.pos_sim_min, pos_cos.min().item()) + self.pos_sim_max = max(self.pos_sim_max, pos_cos.max().item()) + self.pos_count += pos_cos.numel() + + # Anchor vs all other positives (in-batch negatives) + if anchors.size(0) > 1: + sim_matrix = torch.matmul(anchors, positives.T) + mask = ~torch.eye(sim_matrix.size(0), dtype=torch.bool, device=sim_matrix.device) + neg_sims = sim_matrix[mask] + self.neg_sim_sum += neg_sims.sum().item() + self.neg_count += neg_sims.numel() + + loss = outputs.get('loss') + if loss is not None: + self.total_loss += loss.item() if hasattr(loss, 'item') else loss + self.total_count += 1 + grad_norm = kwargs.get('grad_norm') + if grad_norm is not None: + self.grad_norm = grad_norm + + def calculate(self): + results = {} + if self.pos_count > 0: + results['pos_sim'] = f'{self.pos_sim_sum / self.pos_count:.4f}' + results['pos_sim_min'] = f'{self.pos_sim_min:.4f}' + results['pos_sim_max'] = f'{self.pos_sim_max:.4f}' + if self.neg_count > 0: + results['neg_sim'] = f'{self.neg_sim_sum / self.neg_count:.4f}' + if self.total_count > 0: + results['loss'] = f'{self.total_loss / self.total_count:.4f}' + if self.grad_norm > 0: + results['grad_norm'] = f'{self.grad_norm:.6f}' + self.reset() + return results diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 96dace65..61075588 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -10,6 +10,7 @@ import torch import torch.distributed as dist import torch.nn as nn +import contextlib from contextlib import contextmanager from dataclasses import dataclass from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model @@ -31,7 +32,7 @@ from twinkle.metric import LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus -from twinkle.patch import Patch, apply_patch +from twinkle.patch import Patch, apply_context, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template from twinkle.utils import construct_class, get_logger, selective_log_softmax @@ -41,6 +42,22 @@ logger = get_logger() +def _resolve_task_context(model, task): + """Return a context manager that applies the right per-forward Patch for ``task``. + + Mirrors the transformers backend: 'causal_lm' (default) is a no-op, while + 'embedding' installs :class:`MegatronEmbeddingPatch` which swaps the + ``output_layer`` for identity (with TP/SP gather) and registers a hook that + handles CP gather + last-token pooling, returning ``[n_seqs, hidden]``. + """ + if task in (None, 'causal_lm'): + return contextlib.nullcontext() + if task == 'embedding': + from twinkle.patch.megatron_emb import MegatronEmbeddingPatch + return apply_context(model, MegatronEmbeddingPatch()) + raise ValueError(f'Unknown task={task!r}; expected one of: causal_lm, embedding.') + + @dataclass class MegatronOptimizerGroup(BaseOptimizerGroup): """Optimizer group for Megatron training. @@ -286,6 +303,7 @@ def forward_backward(self, temperature = float(kwargs.pop('temperature', 1.0)) forward_only = kwargs.pop('forward_only', False) return_logits = kwargs.pop('return_logits', False) + task = kwargs.pop('task', 'causal_lm') optimizer_config = self.optimizer_group[adapter_name] loss_instance = self.optimizer_group[adapter_name].loss_instance if not inputs: @@ -349,14 +367,18 @@ def forward_backward(self, _mb_counter = [0] # mutable counter for closure - def post_loss_function(output_tensor, inputs, logps, unpacked_logits=None, entropies=None): + def post_loss_function(output_tensor, inputs, logps, unpacked_logits=None, entropies=None, + embeddings=None): mb_idx = _mb_counter[0] _mb_counter[0] += 1 current_kwargs = loss_extra_kwargs_per_mb[mb_idx % len(loss_extra_kwargs_per_mb)] - logits = unpacked_logits if unpacked_logits is not None else output_tensor - outputs = ModelOutput(logits=logits, logps=logps) - if entropies is not None: - outputs['entropies'] = entropies + if embeddings is not None: + outputs = ModelOutput(embeddings=embeddings) + else: + logits = unpacked_logits if unpacked_logits is not None else output_tensor + outputs = ModelOutput(logits=logits, logps=logps) + if entropies is not None: + outputs['entropies'] = entropies result = loss_instance(inputs, outputs, **current_kwargs) if unpacked_logits is not None: outputs.pop('logits', None) @@ -390,21 +412,29 @@ def forward_step_func(data_iterator, model): logps = None unpacked_logits = None entropies = None + embeddings = None _loss_instance = loss_instance - if labels is not None and mpu.is_pipeline_last_stage(False, unwrapped_model.vp_stage): - loss_mask = (labels != -100).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - output_tensor.div_(temperature) + is_last_pp = mpu.is_pipeline_last_stage(False, unwrapped_model.vp_stage) + if task == 'embedding': + # MegatronEmbeddingPatch already pooled output to [n_seqs, hidden] on last PP stage. + if is_last_pp: + embeddings = output_tensor + elif labels is not None and is_last_pp: + _loss_require_logps = getattr(_loss_instance, 'require_logps', True) _loss_require_entropy = (hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy) - if _loss_require_entropy: - logps, entropies = selective_log_softmax(output_tensor, masked_labels, return_entropy=True) - else: - logps = selective_log_softmax(output_tensor, masked_labels) - # Reconstruct full-length tensors from CP-split shards - logps = processor.postprocess_tensor_cp(logps) - if entropies is not None: - entropies = processor.postprocess_tensor_cp(entropies) + if _loss_require_logps: + loss_mask = (labels != -100).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + output_tensor.div_(temperature) + if _loss_require_entropy: + logps, entropies = selective_log_softmax(output_tensor, masked_labels, return_entropy=True) + else: + logps = selective_log_softmax(output_tensor, masked_labels) + # Reconstruct full-length tensors from CP-split shards + logps = processor.postprocess_tensor_cp(logps) + if entropies is not None: + entropies = processor.postprocess_tensor_cp(entropies) batch['labels'] = processor.postprocess_tensor_cp(labels) if 'position_ids' in batch: pos = batch['position_ids'] @@ -427,6 +457,7 @@ def forward_step_func(data_iterator, model): logps=logps, unpacked_logits=unpacked_logits, entropies=entropies, + embeddings=embeddings, ) # Get Megatron's forward-backward function @@ -446,15 +477,16 @@ def forward_step_func(data_iterator, model): # Run forward-backward with Megatron's scheduler # Megatron handles all communication internally using proper process groups - losses = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iter, - model=self.model, - num_microbatches=len(inputs), - seq_length=seq_length, - micro_batch_size=micro_batch_size, - forward_only=forward_only, - ) + with _resolve_task_context(self.model, task): + losses = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iter, + model=self.model, + num_microbatches=len(inputs), + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=forward_only, + ) # Extract loss from results (only last PP stage returns non-empty) loss = torch.tensor(0.0).to(Platform.get_local_device()) @@ -920,12 +952,11 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): ) else: bridge = self.strategy.bridge - for _model in self.strategy.unwrap_model(self.model): - bridge.load_weights( - _model, - checkpoint_dir, - peft_format=(adapter_name != _default_adapter_name), - ) + bridge.load_weights( + self.strategy.unwrap_model(self.model), + checkpoint_dir, + peft_format=(adapter_name != _default_adapter_name), + ) if dist.is_initialized(): dist.barrier() diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 51a28015..1f0547e6 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -923,6 +923,21 @@ def _trim_gathered_sequence_padding(tensor: torch.Tensor, real_position_ids: tor return torch.cat(pieces, dim=1).contiguous() if pieces else tensor[:, :0].contiguous() return tensor[:, :real_position_ids.shape[-1]].contiguous() + def gather_features(self, features: torch.Tensor) -> torch.Tensor: + """All-gather SP-sharded per-token features ``[B, T_local, H]`` -> ``[B, T_real, H]``. + + Mirrors the gather + trim path used for logps but operates directly on + hidden_states, so embedding pooling can run on the full sequence with + the same ``real_position_ids`` source of truth. + """ + if features is None or not torch.is_tensor(features): + return features + if not self.enabled or self.ulysses_size <= 1: + return features + real_position_ids = sequence_parallel.real_position_ids + gathered, _ = GatherLoss.apply(features, None, 1, real_position_ids) + return self._trim_gathered_sequence_padding(gathered, real_position_ids) + def gather_loss_tensors( self, inputs: Dict[str, Any], diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index c2bf8c7e..58acba95 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -36,7 +36,7 @@ from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus from twinkle.model.transformers.moe import apply_expert_parallel from twinkle.model.transformers.strategy import AccelerateStrategy, NativeFSDPStrategy -from twinkle.patch import Patch, apply_patch +from twinkle.patch import Patch, apply_context, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util @@ -48,6 +48,22 @@ logger = get_logger() +def _resolve_task_context(model, task): + """Return a context manager that applies the right per-forward Patch for ``task``. + + 'causal_lm' (default) keeps the model untouched (returns ``nullcontext``). + 'embedding' swaps lm_head for identity + installs a feature-extraction hook so + downstream pooling can run inside + ``InputProcessor.postprocess_tensor_sp(task='embedding', ...)``. + """ + if task in (None, 'causal_lm'): + return contextlib.nullcontext() + if task == 'embedding': + from twinkle.patch.transformers_emb import TransformersEmbeddingPatch + return apply_context(model, TransformersEmbeddingPatch()) + raise ValueError(f'Unknown task={task!r}; expected one of: causal_lm, embedding.') + + @dataclass class OptimizerGroup(BaseOptimizerGroup): """Optimizer group for Transformers training.""" @@ -107,6 +123,8 @@ def accumulate_metrics(self, is_training): self._ensure_dp_group() status = self.train_status if is_training else self.eval_status if len(status.metrics) > 0 and status.inputs is not None and status.outputs is not None: + forward_kwargs = copy(status.forward_kwargs) + forward_kwargs.pop('gradient_accumulation_steps', None) for metric in status.metrics: metric.accumulate( status.inputs, @@ -116,7 +134,7 @@ def accumulate_metrics(self, is_training): gradient_accumulation_steps=self.gradient_accumulation_steps, grad_norm=self._last_grad_norm, loss_reduction=getattr(self.loss_instance, 'reduction', 'mean'), - **status.forward_kwargs) + **forward_kwargs) _default_adapter_name = '' @@ -380,6 +398,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec adapter_name = kwargs.pop('adapter_name', self._get_default_group()) temperature = float(kwargs.pop('temperature', 1.0)) return_logits = kwargs.pop('return_logits', False) + task = kwargs.pop('task', 'causal_lm') optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() if not inputs: @@ -397,6 +416,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec loss_instance = optimizer_config.loss_instance loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logps = getattr(loss_instance, 'require_logps', True) assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding' inputs: Dict[str, Any] = processor( inputs, @@ -407,9 +427,10 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec ) labels: torch.Tensor = inputs.pop('labels', None) optimizer_config.accumulate_metrics(True) - outputs = self.model(**inputs) + with _resolve_task_context(self.model, task): + outputs = self.model(**inputs) inputs['labels'] = labels - if labels is not None: + if labels is not None and loss_require_logps: loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 @@ -424,8 +445,9 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec outputs['past_key_values'] = None if not (return_logits or loss_require_logits): outputs['logits'] = None - inputs, outputs = processor.postprocess_tensor_sp(inputs, outputs, sp_strategy=self.sp_strategy) - inputs, outputs = processor.unpack_packed_sequences(inputs, outputs) + inputs, outputs = processor.postprocess_tensor_sp( + inputs, outputs, sp_strategy=self.sp_strategy, task=task) + inputs, outputs = processor.unpack_packed_sequences(inputs, outputs, task=task) optimizer_config.train_status.inputs = inputs optimizer_config.train_status.outputs = outputs optimizer_config.train_status.forward_kwargs = kwargs @@ -451,6 +473,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T disable_lora = kwargs.pop('disable_lora', False) temperature = float(kwargs.pop('temperature', 1.0)) return_logits = kwargs.pop('return_logits', False) + task = kwargs.pop('task', 'causal_lm') optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() if not inputs: @@ -470,6 +493,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T loss_instance = optimizer_config.loss_instance loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logps = getattr(loss_instance, 'require_logps', True) inputs: Dict[str, Any] = processor( inputs, sp_strategy=self.sp_strategy, @@ -480,13 +504,13 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T labels = inputs.pop('labels', None) optimizer_config.accumulate_metrics(False) unwrapped_model = self.strategy.unwrap_model(self.model) - if disable_lora and isinstance(unwrapped_model, PeftModel): - with unwrapped_model.disable_adapter(): - outputs = self.model(**inputs) - else: + lora_ctx = (unwrapped_model.disable_adapter() + if disable_lora and isinstance(unwrapped_model, PeftModel) + else contextlib.nullcontext()) + with _resolve_task_context(self.model, task), lora_ctx: outputs = self.model(**inputs) inputs['labels'] = labels - if labels is not None: + if labels is not None and loss_require_logps: loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 @@ -501,8 +525,9 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T outputs['past_key_values'] = None if not (return_logits or loss_require_logits): outputs['logits'] = None - inputs, outputs = processor.postprocess_tensor_sp(inputs, outputs, sp_strategy=self.sp_strategy) - inputs, outputs = processor.unpack_packed_sequences(inputs, outputs) + inputs, outputs = processor.postprocess_tensor_sp( + inputs, outputs, sp_strategy=self.sp_strategy, task=task) + inputs, outputs = processor.unpack_packed_sequences(inputs, outputs, task=task) optimizer_config.eval_status.inputs = inputs optimizer_config.eval_status.outputs = outputs optimizer_config.eval_status.forward_kwargs = kwargs @@ -582,7 +607,7 @@ def backward(self, **kwargs): scaler = optimizer_config.scaler optimizer_config.cur_step += 1 - should_sync = optimizer_config.do_grad_sync() + should_sync = optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')) import contextlib no_sync_ctx = contextlib.nullcontext() diff --git a/src/twinkle/patch/__init__.py b/src/twinkle/patch/__init__.py index 76d42eb9..da7a0165 100644 --- a/src/twinkle/patch/__init__.py +++ b/src/twinkle/patch/__init__.py @@ -1,14 +1,30 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import sys +from contextlib import contextmanager from typing import Any, Type, Union from .base import Patch +def _resolve(patch_cls: Union[Patch, Type[Patch], str]) -> Patch: + from twinkle.utils import construct_class + return construct_class(patch_cls, Patch, sys.modules[__name__]) + + def apply_patch(module: Any, patch_cls: Union[Patch, Type[Patch], str], *args, **kwargs): - from ..utils import construct_class - patch_ins = construct_class(patch_cls, Patch, sys.modules[__name__]) + patch_ins = _resolve(patch_cls) return patch_ins(module, *args, **kwargs) -__all__ = ['apply_patch', 'Patch'] +@contextmanager +def apply_context(module: Any, patch_cls: Union[Patch, Type[Patch], str], *args, **kwargs): + # Apply patch on enter; revert via subclass-implemented unpatch on exit (even on exception). + patch_ins = _resolve(patch_cls) + result = patch_ins(module, *args, **kwargs) + try: + yield result + finally: + patch_ins.unpatch(module, *args, **kwargs) + + +__all__ = ['apply_patch', 'apply_context', 'Patch'] diff --git a/src/twinkle/patch/base.py b/src/twinkle/patch/base.py index 08982ba9..3a9b8c07 100644 --- a/src/twinkle/patch/base.py +++ b/src/twinkle/patch/base.py @@ -9,3 +9,6 @@ class Patch: def __call__(self, module: Union['torch.nn.Module', List['torch.nn.Module'], Any], *args, **kwargs): ... + + def unpatch(self, module: Union['torch.nn.Module', List['torch.nn.Module'], Any], *args, **kwargs): + raise NotImplementedError() diff --git a/src/twinkle/patch/megatron_emb.py b/src/twinkle/patch/megatron_emb.py new file mode 100644 index 00000000..3779feb2 --- /dev/null +++ b/src/twinkle/patch/megatron_emb.py @@ -0,0 +1,138 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Patch a Megatron causal LM into a sentence-embedding model. + +Two mutations applied to every pipeline-last-stage chunk (``post_process=True``): + +1. ``output_layer.forward`` (a ``ColumnParallelLinear``) is replaced with an + identity that returns ``(hidden_states, None)``. When ``sequence_parallel`` + is enabled, the gather across the TP group that ``ColumnParallelLinear`` + normally performs is mirrored, so the chunk's forward hook always sees a + full-length ``[s, b, h]`` tensor. +2. A forward hook on the chunk gathers across CP (when ``cp_size > 1``), + pools the last valid token (per-segment via ``packed_seq_params.cu_seqlens_q`` + for padding-free batches; per-row via ``position_ids`` for padded batches), + L2-normalises and returns ``[n_seqs, hidden]`` embeddings. + +Intermediate PP stages (``post_process=False``) are left untouched. + +Both mutations are reverted by ``unpatch``. +""" +from types import MethodType +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from twinkle.patch import Patch +from twinkle.utils.torch_utils import gather_cp_load_balanced + + +def _last_valid_from_position_ids(position_ids: torch.Tensor) -> torch.Tensor: + if position_ids.dim() == 3: + position_ids = position_ids[0] + valid = (position_ids >= 0).int() + seq_len = valid.shape[-1] + return seq_len - 1 - torch.fliplr(valid).argmax(dim=-1) + + +def _last_valid_from_attention_mask(attention_mask: torch.Tensor) -> torch.Tensor: + seq_len = attention_mask.shape[1] + return seq_len - 1 - torch.fliplr(attention_mask).argmax(dim=1) + + +def _resolve_cp_group(module) -> Optional[object]: + cp_group = getattr(module, 'cp_group', None) + if cp_group is None: + pg = getattr(module, 'pg_collection', None) + cp_group = getattr(pg, 'cp', None) if pg is not None else None + return cp_group + + +def _output_embedding_hook(module, args, kwargs, output): + if not torch.is_tensor(output) or output.dim() != 3: + return output + + cp_group = _resolve_cp_group(module) + if cp_group is not None and cp_group.size() > 1: + output = gather_cp_load_balanced(output, cp_group, seq_dim=1) + + packed_seq_params = kwargs.get('packed_seq_params', None) + if packed_seq_params is not None: + cu = getattr(packed_seq_params, 'cu_seqlens_q', None) + if cu is not None and cu.numel() >= 2: + # cu is full-seq based (built before CP split), so it indexes the gathered output directly. + last_idx = (cu[1:].long() - 1).to(output.device) + embeddings = output[0, last_idx] + return F.normalize(embeddings, p=2, dim=1).contiguous() + + position_ids = kwargs.get('position_ids', None) + attention_mask = kwargs.get('attention_mask', None) + if position_ids is not None and cp_group is not None and cp_group.size() > 1: + position_ids = gather_cp_load_balanced( + position_ids if position_ids.dim() >= 2 else position_ids.unsqueeze(0), + cp_group, + seq_dim=1, + ) + + if position_ids is not None: + last_idx = _last_valid_from_position_ids(position_ids) + elif attention_mask is not None and attention_mask.dim() == 2: + last_idx = _last_valid_from_attention_mask(attention_mask) + else: + last_idx = torch.full((output.shape[0],), output.shape[1] - 1, device=output.device, dtype=torch.long) + + last_idx = last_idx.to(device=output.device, dtype=torch.long) + embeddings = output[torch.arange(output.shape[0], device=output.device), last_idx] + return F.normalize(embeddings, p=2, dim=1).contiguous() + + +def _identity_output_layer(self, hidden_states, weight=None, runtime_gather_output=None, **kwargs): + # Mirror ColumnParallelLinear's seq-parallel gather so the hook sees full [s, b, h]. + if getattr(self, 'sequence_parallel', False): + from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region + hidden_states = gather_from_sequence_parallel_region( + hidden_states, tensor_parallel_output_grad=True, group=self.tp_group) + return hidden_states, None + + +def _iter_chunks(module) -> List[torch.nn.Module]: + if isinstance(module, (list, tuple)): + return [m for m in module if isinstance(m, torch.nn.Module)] + return [module] + + +def _find_post_process_owner(chunk: torch.nn.Module) -> Optional[torch.nn.Module]: + """Locate the GPTModel-like owner of ``output_layer`` inside a chunk. + + Walks all submodules so it transparently handles DDP/Float16Module/PeftModel wrappers. + """ + for sub in chunk.modules(): + layer = getattr(sub, 'output_layer', None) + post_process = getattr(sub, 'post_process', None) + if isinstance(layer, torch.nn.Module) and (post_process is None or post_process): + return sub + return None + + +class MegatronEmbeddingPatch(Patch): + """Convert a Megatron causal LM into a sentence-embedding model. Reversible via ``unpatch``.""" + + def __call__(self, module, *args, **kwargs): + self._patched = [] + for chunk in _iter_chunks(module): + owner = _find_post_process_owner(chunk) + if owner is None: + continue + output_layer = owner.output_layer + origin_forward = output_layer.forward + output_layer.forward = MethodType(_identity_output_layer, output_layer) + hook_handle = owner.register_forward_hook(_output_embedding_hook, with_kwargs=True) + self._patched.append((output_layer, origin_forward, hook_handle)) + return module + + def unpatch(self, module, *args, **kwargs): + for output_layer, origin_forward, hook_handle in self._patched: + hook_handle.remove() + output_layer.forward = origin_forward + self._patched = [] + return module diff --git a/src/twinkle/patch/no_split_modules.py b/src/twinkle/patch/no_split_modules.py new file mode 100644 index 00000000..7d8aee58 --- /dev/null +++ b/src/twinkle/patch/no_split_modules.py @@ -0,0 +1,21 @@ +from typing import Set, Union + +from twinkle.patch import Patch + + +class NoSplitModulesPatch(Patch): + """Set _no_split_modules on a model so FSDP2 respects layer boundaries.""" + + def __init__(self, module_names: Union[Set[str], str] = frozenset({'Qwen3_5DecoderLayer'})): + if isinstance(module_names, str): + module_names = {module_names} + self._names = set(module_names) + + def __call__(self, module, *args, **kwargs): + module._no_split_modules = self._names + return module + + def unpatch(self, module, *args, **kwargs): + if hasattr(module, '_no_split_modules'): + del module._no_split_modules + return module diff --git a/src/twinkle/patch/qwen3_chat_template.py b/src/twinkle/patch/qwen3_chat_template.py index 822f8e8e..b2aa3e1c 100644 --- a/src/twinkle/patch/qwen3_chat_template.py +++ b/src/twinkle/patch/qwen3_chat_template.py @@ -51,6 +51,17 @@ ' {%- endif %}') +_OLD_TAIL = ( + '{%- if ns.multi_step_tool %}\n' + " {{- raise_exception('No user query found in messages.') }}\n" + '{%- endif %}') + +_NEW_TAIL = ( + '{%- if ns.multi_step_tool %}\n' + ' {#- patched: tool-tail prefix allowed (Qwen3AllowToolTailTemplate) -#}\n' + '{%- endif %}') + + class Qwen3ChatTemplate(Patch): """Patch tokenizer.chat_template in-place to fix Qwen3.x parse defects. @@ -81,3 +92,34 @@ def __call__(self, tokenizer, *args, **kwargs): return False tokenizer.chat_template = tmpl.replace(_OLD, _NEW, 1) return True + + +class Qwen3AllowToolTailTemplate(Patch): + """Relax Qwen3.x ``multi_step_tool`` check so prefixes ending in ``tool`` + (or whose only user messages are ```` wrappers) render + instead of raising ``No user query found in messages``. + + Required by ScoreFilter when scoring intermediate assistant turns of + multi-turn agent rollouts: the slice ``messages[:asst_idx]`` legitimately + ends with a ``tool`` message, and skipping such rounds would silently + discard exactly the turns where tool-call accuracy lives. + """ + + def __call__(self, tokenizer, *args, **kwargs): + tmpl = getattr(tokenizer, 'chat_template', None) + if not tmpl or not isinstance(tmpl, str): + return False + if _NEW_TAIL in tmpl: + return False + if _OLD_TAIL not in tmpl: + warnings.warn( + 'Qwen3AllowToolTailTemplate patch: expected OLD multi_step_tool ' + 'block not found in tokenizer.chat_template. Upstream template ' + 'may have diverged; skipping patch. ScoreFilter on multi-turn ' + 'agent prefixes will likely raise TemplateError.', + RuntimeWarning, + stacklevel=2, + ) + return False + tokenizer.chat_template = tmpl.replace(_OLD_TAIL, _NEW_TAIL, 1) + return True diff --git a/src/twinkle/patch/transformers_emb.py b/src/twinkle/patch/transformers_emb.py new file mode 100644 index 00000000..0e10da76 --- /dev/null +++ b/src/twinkle/patch/transformers_emb.py @@ -0,0 +1,85 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Patch a HF transformers causal LM into a sentence-embedding model. + +Two mutations applied to the model: + +1. ``lm_head.forward`` is replaced with identity, so the wrapped model returns + the final hidden states under ``output.logits``. +2. A forward hook on the lm-head-bearing submodule L2-normalizes per-token + hidden states and stores them under ``outputs['features']`` (shape + ``[B, T, H]`` or ``[B, T_local, H]`` under SP). + +Last-token pooling (incl. padding-free, SP gather) is **deferred** to +``InputProcessor.postprocess_tensor_sp(task='embedding', ...)`` so this patch +stays SP/CP/packed-agnostic and the dispatch sits in one place. + +Both mutations are reverted by ``unpatch``. +""" +from types import MethodType +from typing import TYPE_CHECKING, Optional +from twinkle.patch import Patch +if TYPE_CHECKING: + import torch + +_LM_HEADS = ['lm_head', 'output', 'embed_out', 'output_layer'] + + +def get_lm_head_model(module, lm_heads=None): + from peft import PeftModel + import torch + if isinstance(module, PeftModel): + module = module.model + if lm_heads is None: + lm_heads = _LM_HEADS + for sub in module.modules(): + for name in lm_heads: + child = getattr(sub, name, None) + if isinstance(child, torch.nn.Module): + return sub + return module + + +def _output_features_hook(module, args, kwargs, output): + import torch.nn.functional as F + hidden_states = output.logits + return {'features': F.normalize(hidden_states, p=2, dim=-1).contiguous()} + + +def _identity_forward(self, hidden_states): + return hidden_states + + +class TransformersEmbeddingPatch(Patch): + """Convert a causal LM into a sentence-embedding feature extractor. Reversible via ``unpatch``.""" + + def __call__(self, module, *args, **kwargs): + import torch + lm_head_model = get_lm_head_model(module, lm_heads=_LM_HEADS) + + head: Optional[torch.nn.Module] = None + for name in _LM_HEADS: + if hasattr(lm_head_model, name): + head = getattr(lm_head_model, name) + break + assert head is not None, 'Cannot find the proper lm_head name' + + # Save originals BEFORE mutation so unpatch can restore them verbatim. + self._head = head + self._origin_forward = head.forward + head.forward = MethodType(_identity_forward, head) + self._hook_handle = lm_head_model.register_forward_hook(_output_features_hook, with_kwargs=True) + return module + + def unpatch(self, module, *args, **kwargs): + handle = getattr(self, '_hook_handle', None) + if handle is not None: + handle.remove() + self._hook_handle = None + + head = getattr(self, '_head', None) + origin = getattr(self, '_origin_forward', None) + if head is not None and origin is not None: + head.forward = origin + self._origin_forward = None + self._head = None + return module diff --git a/src/twinkle/preprocessor/base.py b/src/twinkle/preprocessor/base.py index 06ad06ba..0225d3c1 100644 --- a/src/twinkle/preprocessor/base.py +++ b/src/twinkle/preprocessor/base.py @@ -7,7 +7,9 @@ class Preprocessor: @staticmethod - def map_col_to_row(rows: Dict[str, List[Any]]) -> List[Dict[str, Any]]: + def map_col_to_row(rows) -> List[Dict[str, Any]]: + if isinstance(rows, list): + return rows if not rows: return [] _new_rows = [] @@ -20,12 +22,14 @@ def map_col_to_row(rows: Dict[str, List[Any]]) -> List[Dict[str, Any]]: return _new_rows @staticmethod - def map_row_to_col(rows: List[Dict[str, Any]]) -> Dict[str, List[Any]]: + def map_row_to_col(rows, keys: List[str] = None) -> Dict[str, List[Any]]: + if isinstance(rows, dict): + return rows if not rows: - return {} + return {k: [] for k in keys} if keys else {} columns: Dict[str, List[Any]] = {} - keys = rows[0].keys() + keys = keys or rows[0].keys() for key in keys: columns[key] = [row[key] for row in rows] diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index d6e1eed9..c27f42f9 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -142,12 +142,96 @@ def postprocess_tensor_sp(self, inputs: Dict[str, Any], outputs: Dict[str, Any], After this call, logps and labels are in per-sequence batch format ``[num_sequences, max_seq_len]`` when the input was packed, or left unchanged for normal (non-packed) batches. + + For ``task='embedding'`` this also performs the last-valid-token + pooling (with padding-free / SP gather awareness) and writes the + pooled ``[n_seqs, H]`` tensor to ``outputs['embeddings']``; the raw + per-token ``outputs['features']`` is consumed and removed. """ sp_strategy = kwargs.get('sp_strategy') + task = kwargs.get('task', 'causal_lm') + if task == 'embedding': + return self._postprocess_embedding(inputs, outputs, sp_strategy=sp_strategy) if self.framework == 'transformers' and sp_strategy is not None: return sp_strategy.gather_loss_tensors(inputs, outputs) return inputs, outputs + @staticmethod + def _packed_last_indices(position_ids: torch.Tensor, total_len: int) -> torch.Tensor: + """For padding-free batches: per-segment last-token indices into a [1, total] sequence.""" + flat = position_ids.squeeze(0) if position_ids.dim() == 2 else position_ids + starts = (flat == 0).nonzero(as_tuple=False).squeeze(-1) + end_anchor = torch.tensor([total_len], device=flat.device, dtype=starts.dtype) + boundaries = torch.cat([starts, end_anchor]) + return (boundaries[1:] - 1).long() + + def _postprocess_embedding(self, inputs: Dict[str, Any], outputs: Dict[str, Any], + sp_strategy=None) -> tuple[Dict[str, Any], Dict[str, Any]]: + """Pool per-token features to per-sequence embeddings (last-valid-token). + + Build a one-hot end-token mask in the un-padded global frame, route it + through the same pad+split as ``input_ids`` so it aligns with local + features, pool locally, then ``all_reduce`` only the ``[n_seqs, H]`` + tensor across SP × RP. No feature gather; uniform across + DP / Ulysses / zigzag-ring / padding-free. + """ + from copy import copy + import torch.distributed as dist + + features = outputs.get('features') + assert features is not None + + sp_enabled = (self.framework == 'transformers' and sp_strategy is not None + and getattr(sp_strategy, 'enabled', False) + and getattr(sp_strategy, 'world_size', 1) > 1) + + ref_pos = sp_strategy.real_position_ids if sp_enabled else inputs['position_ids'] + if ref_pos.dim() == 3: + ref_pos = ref_pos[0] + cu_seq_lens_q = inputs.get('cu_seq_lens_q') + + is_packed = ( + features.shape[0] == 1 + and (cu_seq_lens_q is not None or int((ref_pos.reshape(-1) == 0).sum()) > 1)) + + device, dtype = features.device, features.dtype + T_real = ref_pos.shape[-1] + + if is_packed: + if torch.is_tensor(cu_seq_lens_q) and cu_seq_lens_q.numel() >= 2: + end_idx = (cu_seq_lens_q[1:].long() - 1).to(device) + else: + end_idx = self._packed_last_indices(ref_pos, T_real).to(device) + n_seqs = end_idx.shape[0] + mask = torch.zeros(1, T_real, n_seqs, dtype=dtype, device=device) + mask[0, end_idx, torch.arange(n_seqs, device=device)] = 1.0 + else: + B = ref_pos.shape[0] + end_idx = (ref_pos >= 0).long().sum(-1) - 1 + mask = torch.zeros(B, T_real, 1, dtype=dtype, device=device) + mask[torch.arange(B, device=device), end_idx, 0] = 1.0 + + if sp_enabled: + # Route mask through the same pad+split as input_ids to align with local features. + rp = sp_strategy.real_position_ids + rp_padded = sp_strategy.pad(rp, padding_value=-1, position_ids=rp, dim=-1) + mask = sp_strategy.pad(mask, padding_value=0, position_ids=rp, dim=1) + mask = sp_strategy.split(mask, dim=1, position_ids=rp_padded) + + embeddings = (torch.einsum('th,tn->nh', features.squeeze(0), mask.squeeze(0)) + if is_packed else (features * mask).sum(dim=1)) + + if sp_enabled and dist.is_available() and dist.is_initialized(): + for grp_attr, size_attr in (('_sp_group', 'sp_world_size'), ('_rp_group', 'rp_world_size')): + grp = getattr(sp_strategy, grp_attr, None) + if grp is not None and getattr(sp_strategy, size_attr, 1) > 1: + dist.all_reduce(embeddings, op=dist.ReduceOp.SUM, group=grp) + + outputs = copy(outputs) + outputs.pop('features', None) + outputs['embeddings'] = embeddings.contiguous() + return inputs, outputs + def pad_cp(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]: if self.device_mesh is None: @@ -468,6 +552,7 @@ def unpack_packed_sequences( self, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]] = None, + task: str = 'causal_lm', ) -> tuple[Dict[str, Any], Optional[Dict[str, Any]]]: """Unpack packed (padding_free) sequences into per-sequence batch format. @@ -475,7 +560,12 @@ def unpack_packed_sequences( Unpacks ``labels`` and any present output keys (``logps``, ``logits``) from ``[1, total_tokens, ...]`` to ``[num_sequences, max_seq_len, ...]``. Keys that are ``None`` are silently skipped. + + For ``task='embedding'`` the outputs are already pooled to ``[n_seqs, H]`` + by ``postprocess_tensor_sp``, so this is a no-op. """ + if task == 'embedding': + return inputs, outputs labels = inputs.get('labels') position_ids = inputs.get('position_ids') @@ -645,46 +735,13 @@ def collate_fn(self, return outputs def postprocess_tensor_cp(self, tensor): - """All-gather and reconstruct full sequence from CP-split tensor. - - Uses load-balanced split pattern: each CP rank holds chunks [rank] and - [2*cp_size - rank - 1] from the original 2*cp_size chunks. + """All-gather and reconstruct full sequence from a CP load-balanced shard. - Only the current rank's slice retains the original tensor (and its - gradient graph); other ranks' slices are plain copies. This means - backward through the reconstructed tensor only produces gradients for - the local chunk, naturally distributing the gradient across CP ranks - without extra scaling. - - Args: - tensor: [batch_size, seq_len/cp_size] CP-split tensor - - Returns: - [batch_size, full_seq_len] reconstructed full tensor + Thin wrapper over :func:`twinkle.utils.torch_utils.gather_cp_load_balanced` + that resolves the CP group via Megatron's ``parallel_state``. """ if self.device_mesh.cp_world_size <= 1: return tensor - from megatron.core import parallel_state as mpu - cp_size = mpu.get_context_parallel_world_size() - cp_rank = mpu.get_context_parallel_rank() - cp_group = mpu.get_context_parallel_group() - - gathered = [torch.empty_like(tensor) for _ in range(cp_size)] - torch.distributed.all_gather(gathered, tensor.contiguous(), group=cp_group) - gathered[cp_rank] = tensor - - batch_size = tensor.shape[0] - seq_len_per_cp = tensor.shape[1] - full_seq_len = seq_len_per_cp * cp_size - chunk_len = full_seq_len // (2 * cp_size) - half_len = seq_len_per_cp // 2 - - output = tensor.new_zeros(batch_size, full_seq_len) - for j in range(cp_size): - o = gathered[j] - output[:, j * chunk_len:(j + 1) * chunk_len] = o[:, :half_len] - reverse_idx = 2 * cp_size - j - 1 - output[:, reverse_idx * chunk_len:(reverse_idx + 1) * chunk_len] = o[:, half_len:] - - return output + from twinkle.utils.torch_utils import gather_cp_load_balanced + return gather_cp_load_balanced(tensor, mpu.get_context_parallel_group(), seq_dim=1) diff --git a/src/twinkle/sampler/base.py b/src/twinkle/sampler/base.py index d8222ead..e0c012a2 100644 --- a/src/twinkle/sampler/base.py +++ b/src/twinkle/sampler/base.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from abc import ABC, abstractmethod from peft import PeftConfig -from typing import Any, List, Optional, Type, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Type, Union import twinkle from twinkle import remote_function @@ -47,6 +47,25 @@ def sample( def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs) -> None: ... + def astream_one( + self, + trajectory: Trajectory, + sampling_params: Optional[SamplingParams] = None, + adapter_name: str = '', + adapter_path: Optional[str] = None, + *, + use_base_model: bool = False, + ) -> AsyncIterator[Dict[str, Any]]: + """Stream OpenAI-shape delta chunks for a single trajectory. + + Default implementation raises ``NotImplementedError``; backend samplers + opt in by overriding (e.g. ``vLLMSampler``). + + Yields: + Dicts shaped ``{'index': int, 'delta': {...}, 'finish_reason': ...}``. + """ + raise NotImplementedError(f'{type(self).__name__} does not support streaming') + @staticmethod def _not_encoded(inputs: Any) -> bool: """Check if inputs are not yet encoded (i.e., is Trajectory, not InputFeature). diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 4965f7b3..29b1a73c 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -337,6 +337,57 @@ async def sample(self, topk_prompt_logprobs=result_topk_prompt_logprobs, ) + async def astream(self, + prompt: Union[List[int], str], + sampling_params: Union[SamplingParams, Dict[str, Any]], + lora_request: Optional[Any] = None, + request_id: Optional[str] = None, + priority: int = 0, + *, + multi_modal_data: Optional[Dict[str, Any]] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + disable_lora: bool = False, + **kwargs): + """Streaming counterpart of :meth:`sample`. Yields raw vLLM ``RequestOutput`` + deltas as they arrive from the engine — no aggregation. + + Caller is responsible for diffing token_ids across frames. + """ + from vllm.inputs import TextPrompt, TokensPrompt + + if isinstance(sampling_params, dict): + sampling_params = SamplingParams.from_dict(sampling_params) + vllm_params = sampling_params.to_vllm(**kwargs) + + if request_id is None: + request_id = uuid.uuid4().hex + if isinstance(prompt, str): + prompt = TextPrompt(prompt=prompt) + else: + prompt = TokensPrompt(prompt_token_ids=prompt) + if multi_modal_data: + prompt['multi_modal_data'] = multi_modal_data + if mm_processor_kwargs: + prompt['mm_processor_kwargs'] = mm_processor_kwargs + + if lora_request is not None and not self.enable_lora: + logger.warning('lora_request provided but enable_lora is False — ignored') + lora_request = None + if disable_lora: + lora_request = None + elif lora_request is None and self._synced_lora_request is not None: + lora_request = self._synced_lora_request + + generator = self.engine.generate( + prompt=prompt, + sampling_params=vllm_params, + request_id=request_id, + lora_request=lora_request, + priority=priority, + ) + async for output in generator: + yield output + # ----------------------------------------------------------------- # RL-training synced LoRA helpers # ----------------------------------------------------------------- diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 79db15db..b5706530 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -25,9 +25,9 @@ import os import threading from copy import copy -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Type, Union -from twinkle import DeviceMesh, get_logger, remote_class, remote_function, requires +from twinkle import DeviceMesh, get_logger, remote_class, remote_function, remote_generator, requires from twinkle.checkpoint_engine import CheckpointEngineMixin from twinkle.data_format import InputFeature, SampledSequence, SampleResponse, SamplingParams, Trajectory from twinkle.hub import HubOperation @@ -251,6 +251,7 @@ async def _sample_single( else: feat['input_ids'] = response.prompt_token_ids feat['labels'] = [-100] * len(response.prompt_token_ids) + if not logprobs_only: # response.sequences contains num_samples sequences for this prompt sequences = [] @@ -333,13 +334,12 @@ def sample( sampling_params = copy(sampling_params) sampling_params.max_tokens = 1 logprobs_only = True - assert not is_trajectory, 'Logprobs only not supported for Trajectory inputs' multi_modal_data_list = [] for feat in inputs_list: multi_modal_data_list.append(self._extract_multi_modal_data(feat)) - if is_trajectory and not logprobs_only: + if is_trajectory: template = self.template assert template is not None, \ 'Use set_template to add a template when trying to input Trajectory' @@ -375,6 +375,121 @@ async def _sample_all(): sample_results = self._run_in_loop(_sample_all()) return sample_results + @remote_generator(execute='balanced') + async def astream_one( + self, + trajectory: Trajectory, + sampling_params: Optional[Union[SamplingParams, Dict[str, Any]]] = None, + adapter_name: str = '', + adapter_path: Optional[str] = None, + *, + use_base_model: bool = False, + ) -> AsyncIterator[Dict[str, Any]]: + """Stream OpenAI-shape deltas for a single trajectory. + + Single-trajectory only: routed to one DP actor by ``@remote_generator`` + (see decorator), so DP slicing / NCCL collective constraints do not apply. + + Yields dicts of shape:: + + {'index': int, 'delta': {...}, 'finish_reason': None | 'stop' | 'tool_calls' | 'length'} + + Where ``delta`` is one of ``{'role':'assistant'}``, ``{'content': str}``, + ``{'tool_calls': [{...}]}``, or ``{}`` (final frame). The handler layer + wraps these into ``chat.completion.chunk`` envelopes for SSE. + """ + if sampling_params is None: + sampling_params = SamplingParams() + elif isinstance(sampling_params, dict): + sampling_params = SamplingParams.from_dict(sampling_params) + + assert isinstance(trajectory, dict) and 'input_ids' not in trajectory, \ + 'astream_one accepts a single Trajectory (not InputFeature / not a list)' + assert self.template is not None, 'set_template must be called before streaming' + + multi_modal_data = self._extract_multi_modal_data(trajectory) + feat = self.encode_trajectory_for_vllm(trajectory, adapter_name, True) + + lora_request = None + if adapter_path is not None: + adapter_path = HubOperation.download_model(model_id_or_path=adapter_path) + lora_request = self._run_in_loop(self.engine._get_or_load_lora(adapter_path)) + if lora_request is None: + logger.warning(f'Failed to pre-load LoRA from {adapter_path}, streaming will run without LoRA') + + # vLLM AsyncLLM lives on self._async_loop (background thread); the actor + # method runs on Ray's actor loop. Bridge frames via a per-call queue. + ray_loop = asyncio.get_event_loop() + out_queue: asyncio.Queue = asyncio.Queue(maxsize=64) + _SENTINEL = object() + _ERR_KIND = '__err__' + + async def _producer(): + try: + async for output in self.engine.astream( + prompt=self.template.get_vllm_input_ids(feat['input_ids']), + sampling_params=sampling_params, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=feat.get('mm_processor_kwargs'), + disable_lora=use_base_model, + ): + asyncio.run_coroutine_threadsafe( + out_queue.put(('chunk', output)), ray_loop).result() + except BaseException as _e: # noqa: BLE001 + asyncio.run_coroutine_threadsafe( + out_queue.put((_ERR_KIND, _e)), ray_loop).result() + finally: + asyncio.run_coroutine_threadsafe( + out_queue.put((_SENTINEL, None)), ray_loop).result() + + asyncio.run_coroutine_threadsafe(_producer(), self._async_loop) + + seq_state: Dict[int, Dict[str, Any]] = {} + role_emitted: Dict[int, bool] = {} + finished: Dict[int, bool] = {} + had_tool_call: Dict[int, bool] = {} + template = self.template + + while True: + kind, payload = await out_queue.get() + if kind is _SENTINEL: + break + if kind == _ERR_KIND: + raise payload + request_output = payload + for seq_output in request_output.outputs: + idx = getattr(seq_output, 'index', 0) + # Sampler owns last_text_len; tc_state is opaque template state. + state = seq_state.setdefault(idx, {'last_text_len': 0, 'tc_state': {}}) + if not role_emitted.get(idx): + yield {'index': idx, 'delta': {'role': 'assistant'}, 'finish_reason': None} + role_emitted[idx] = True + + full_text = template.decode(list(seq_output.token_ids)) + delta_text = '' + if len(full_text) > state['last_text_len']: + delta_text = full_text[state['last_text_len']:] + state['last_text_len'] = len(full_text) + + is_finished = bool(seq_output.finish_reason) and not finished.get(idx) + if delta_text or is_finished: + for ev in template.parse_tool_call_stream( + state['tc_state'], delta_text, finished=is_finished): + if 'tool_calls' in ev: + had_tool_call[idx] = True + yield {'index': idx, 'delta': ev, 'finish_reason': None} + + if is_finished: + if seq_output.finish_reason == 'length': + fr = 'length' + elif had_tool_call.get(idx): + fr = 'tool_calls' + else: + fr = 'stop' + yield {'index': idx, 'delta': {}, 'finish_reason': fr} + finished[idx] = True + @remote_function(dispatch='all', collect='first') def sleep(self, level: int = 1) -> None: """ diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index b27ec23d..5f90e39b 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -6,9 +6,13 @@ """ from __future__ import annotations +import json +import time import traceback +import uuid from fastapi import Depends, FastAPI, HTTPException, Request -from typing import TYPE_CHECKING, Callable +from fastapi.responses import StreamingResponse +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple from twinkle_client.common.serialize import deserialize_object @@ -49,6 +53,130 @@ def _get_twinkle_sampler_adapter_name(request: Request, adapter_name: str | None return request.state.request_id + '-' + adapter_name +def _openai_body_to_trajectory_and_params( + body: Dict[str, Any]) -> Tuple[Trajectory, SamplingParams]: + """Map an OpenAI ``/v1/chat/completions`` body to (Trajectory, SamplingParams). + + Trajectory.messages / .tools are already OpenAI-shaped TypedDicts, so they + pass through verbatim — no field renaming needed. + """ + messages = body.get('messages') + if not messages: + raise HTTPException(status_code=400, detail='messages is required') + trajectory: Trajectory = {'messages': list(messages)} + if body.get('tools'): + trajectory['tools'] = list(body['tools']) + + sp_kwargs: Dict[str, Any] = {} + if body.get('temperature') is not None: + sp_kwargs['temperature'] = float(body['temperature']) + if body.get('top_p') is not None: + sp_kwargs['top_p'] = float(body['top_p']) + # max_completion_tokens supersedes max_tokens per the newer OpenAI spec + if body.get('max_completion_tokens') is not None: + sp_kwargs['max_tokens'] = int(body['max_completion_tokens']) + elif body.get('max_tokens') is not None: + sp_kwargs['max_tokens'] = int(body['max_tokens']) + if body.get('seed') is not None: + sp_kwargs['seed'] = int(body['seed']) + if body.get('n') is not None: + sp_kwargs['num_samples'] = int(body['n']) + if body.get('stop'): + sp_kwargs['stop'] = body['stop'] + if body.get('logprobs'): + sp_kwargs['logprobs'] = int(body.get('top_logprobs') or 0) + if body.get('prompt_logprobs') is not None: + sp_kwargs['prompt_logprobs'] = int(body['prompt_logprobs']) + fp = body.get('frequency_penalty') + if fp is not None and fp != 0: + # OpenAI frequency_penalty (-2..2, 0 == no penalty) -> repetition_penalty + sp_kwargs['repetition_penalty'] = 1.0 + float(fp) + return trajectory, SamplingParams(**sp_kwargs) + + +def _format_openai_choice(seq: Any, idx: int, template: Any) -> Dict[str, Any]: + """Build one ``choices[]`` entry from a SampledSequence.""" + decoded = seq.decoded or '' + tool_calls: List[Dict[str, Any]] = [] + if template is not None: + try: + parsed = template.parse_tool_call(decoded) + except Exception: + parsed = [] + for j, tc in enumerate(parsed or []): + fn = dict(tc.get('function') or {}) + args = fn.get('arguments') + # OpenAI wire format demands arguments as a JSON string, not a dict + if isinstance(args, dict): + fn['arguments'] = json.dumps(args, ensure_ascii=False) + tool_calls.append({ + 'id': tc.get('id') or f'call_{idx}_{j}', + 'type': tc.get('type') or 'function', + 'function': fn, + }) + if tool_calls: + try: + decoded = template.clean_tool_call(decoded) + except Exception: + pass + + finish_reason = 'length' if seq.stop_reason == 'length' else ( + 'tool_calls' if tool_calls else 'stop') + message: Dict[str, Any] = {'role': 'assistant', 'content': decoded} + if tool_calls: + message['tool_calls'] = tool_calls + choice: Dict[str, Any] = {'index': idx, 'message': message, 'finish_reason': finish_reason} + if seq.logprobs: + choice['logprobs'] = {'token_logprobs': [lp[0][1] if lp else None for lp in seq.logprobs]} + return choice + + +def _build_openai_completion( + response: Any, model_id: str, template: Any) -> Dict[str, Any]: + """Wrap a SampleResponse as an OpenAI ChatCompletion object.""" + choices = [ + _format_openai_choice(seq, i, template) + for i, seq in enumerate(response.sequences) + ] + completion_tokens = sum(len(seq.tokens) for seq in response.sequences) + result: Dict[str, Any] = { + 'id': f'chatcmpl-{uuid.uuid4().hex}', + 'object': 'chat.completion', + 'created': int(time.time()), + 'model': model_id, + 'choices': choices, + 'usage': { + 'prompt_tokens': len(response.prompt_token_ids or []), + 'completion_tokens': completion_tokens, + 'total_tokens': len(response.prompt_token_ids or []) + completion_tokens, + }, + } + if response.prompt_logprobs is not None: + result['prompt_logprobs'] = response.prompt_logprobs + return result + + +def _build_openai_chunk( + delta_event: Dict[str, Any], completion_id: str, created: int, + model_id: str) -> Dict[str, Any]: + """Wrap a sampler delta dict as an OpenAI ``chat.completion.chunk`` object. + + ``delta_event`` is one item yielded by ``Sampler.astream_one``, with keys + ``index``, ``delta``, ``finish_reason``. + """ + return { + 'id': completion_id, + 'object': 'chat.completion.chunk', + 'created': created, + 'model': model_id, + 'choices': [{ + 'index': delta_event.get('index', 0), + 'delta': delta_event.get('delta') or {}, + 'finish_reason': delta_event.get('finish_reason'), + }], + } + + def _register_twinkle_sampler_routes(app: FastAPI, self_fn: Callable[[], SamplerManagement]) -> None: """Register all /twinkle/* sampler routes on the given FastAPI app. @@ -157,6 +285,118 @@ async def _task(): task_type='sample', )) + @app.post('/v1/chat/completions') + async def chat_completions( + request: Request, + body: Dict[str, Any], + self: SamplerManagement = Depends(self_fn), + ): + """OpenAI-compatible chat completions endpoint. + + Accepts the standard ``/v1/chat/completions`` body (messages, tools, + temperature, top_p, max_tokens, n, seed, stop, frequency_penalty, + logprobs/top_logprobs, ...) and returns an OpenAI ``chat.completion`` + response. Twinkle-specific extensions: ``adapter_name`` and + ``adapter_uri`` for LoRA inference. When ``stream=true`` is set the + response is an SSE stream of ``chat.completion.chunk`` objects. + """ + # Flatten extra_body so Twinkle extras (adapter_name/adapter_uri/...) are + # accessible regardless of whether the OpenAI SDK already inlined them. + extra = body.pop('extra_body', None) + if isinstance(extra, dict): + for k, v in extra.items(): + body.setdefault(k, v) + + token = await self._on_request_start(request) + + # Resolve adapter (shared by stream / non-stream paths) + async def _resolve_adapter() -> Tuple[str, Any]: + adapter_path = None + adapter_name = body.get('adapter_name') or '' + full_adapter_name = _get_twinkle_sampler_adapter_name(request, adapter_name) or '' + adapter_uri = body.get('adapter_uri') + if adapter_uri: + from twinkle.server.common.checkpoint_factory import create_checkpoint_manager + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + _, adapter_path = checkpoint_manager.parse_adapter_uri(adapter_uri) + self.sampler.reset_prefix_cache() + return full_adapter_name, adapter_path + + if body.get('stream'): + # Streaming path: bypass the GPU serial queue entirely. Each request + # opens a single async generator on a balanced DP actor and pipes + # chat.completion.chunk events back as SSE. + full_adapter_name, adapter_path = await _resolve_adapter() + trajectory, params = _openai_body_to_trajectory_and_params(body) + model_id = body.get('model') or getattr(self, 'model_id', '') or '' + completion_id = f'chatcmpl-{uuid.uuid4().hex}' + created = int(time.time()) + + async def _sse_generator(): + try: + async for event in self.sampler.astream_one( + trajectory, + params, + adapter_name=full_adapter_name, + adapter_path=adapter_path, + ): + chunk = _build_openai_chunk(event, completion_id, created, model_id) + yield f'data: {json.dumps(chunk, ensure_ascii=False)}\n\n' + yield 'data: [DONE]\n\n' + except HTTPException: + raise + except Exception: + err_tb = traceback.format_exc() + logger.error(err_tb) + err_chunk = { + 'id': completion_id, + 'object': 'chat.completion.chunk', + 'created': created, + 'model': model_id, + 'error': {'message': err_tb, 'type': 'internal_error'}, + } + yield f'data: {json.dumps(err_chunk, ensure_ascii=False)}\n\n' + yield 'data: [DONE]\n\n' + + return StreamingResponse( + _sse_generator(), + media_type='text/event-stream', + headers={ + 'Cache-Control': 'no-cache', + 'X-Accel-Buffering': 'no', + }, + ) + + async def _task(): + full_adapter_name, adapter_path = await _resolve_adapter() + trajectory, params = _openai_body_to_trajectory_and_params(body) + + responses = self.sampler.sample( + [trajectory], + params, + adapter_name=full_adapter_name, + adapter_path=adapter_path, + ) + + return _build_openai_completion( + responses[0], + model_id=body.get('model') or getattr(self, 'model_id', '') or '', + template=getattr(self.sampler, 'template', None), + ) + + # Rough char-based estimate for queue scheduling; trajectory tokens are unknown pre-encode + rough_tokens = sum( + len(m.get('content') or '') if isinstance(m.get('content'), str) else 0 + for m in (body.get('messages') or []) + ) // 4 + return await run_task( + self.schedule_task_and_wait( + _task, + token=token, + input_tokens=rough_tokens, + task_type='sample', + )) + @app.post('/twinkle/set_template', response_model=types.SetTemplateResponse) async def set_template( request: Request, diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index b1ab1d21..6c4bdddd 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import Template from .deepseek_v4 import DeepseekV4Template -from .qwen import QwenTemplate from .qwen3_5_vl import Qwen3_5Template diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 50ba3e5e..0512c360 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -1,15 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import inspect +import json import numpy as np import os from collections.abc import Mapping from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Union from twinkle import remote_class from twinkle.data_format import InputFeature, Message, Trajectory from twinkle.hub import HubOperation from twinkle.utils import load_image, to_device +from .tools import ToolCallRegistry, trailing_prefix_of from .utils import TokenizeByRound, transfer_to_standard_message if TYPE_CHECKING: @@ -69,25 +71,139 @@ def __init__(self, def parse_tool_call(self, decoded: str) -> List[Dict[str, Any]]: """Parse tool calls from the assistant's decoded output. - Dispatches by model family on ``self.model_id``; the actual - wire-format logic lives in :mod:`.tool_call_parser`. + Polls registered :class:`ToolCallParser` in order; first parser whose + ``detect`` matches takes ownership and produces the result. Other + parsers are not invoked on the same text — prevents nested re-extraction. """ - mid = (self.model_id or '').lower() - if 'qwen' in mid: - from .qwen import QwenTemplate - return QwenTemplate.parse(self, decoded) - # TODO: Other models (Llama3, OpenAI JSON, …) — add a parser in - # ``tool_call_parser.py`` and extend this dispatch. - return [] + parser = ToolCallRegistry.detect_first(decoded or '') + return parser.parse(decoded) if parser else [] def clean_tool_call(self, decoded: str) -> str: - """Strip family-specific tool-call markup from assistant text.""" - mid = (self.model_id or '').lower() - if 'qwen' in mid: - from .qwen import QwenTemplate - return QwenTemplate.clean(self, decoded) - # TODO: Other models - return (decoded or '').rstrip() + """Strip tool-call markup using the same parser that ``parse_tool_call`` would pick.""" + parser = ToolCallRegistry.detect_first(decoded or '') + return parser.clean(decoded) if parser else (decoded or '').rstrip() + + def parse_tool_call_stream( + self, + state: Dict[str, Any], + new_text: str, + finished: bool = False, + ) -> List[Dict[str, Any]]: + """Convert incremental decoded text into OpenAI streaming ``delta`` parts. + + Selects a parser once (cached on ``state``) by ``model_id``. If that + parser declares ``open_marker``/``close_marker`` (e.g. Hermes/Qwen), + runs the generic block-buffer state machine: holds back partial + markers, parses each closed block via ``parser.parse``, emits one + ``tool_calls`` delta per parsed call. Otherwise streams plain content. + + Args: + state: Per-sequence opaque dict; caller allocates ``{}`` once. + new_text: Incremental decoded text since the previous call. + finished: True on the final call so partial buffers can flush. + + Returns: + List of delta dicts; each carries at most one of ``content`` / + ``tool_calls``. + """ + parser = state.get('parser') + if 'parser' not in state: + parser = ToolCallRegistry.select_for_model(self.model_id) + state['parser'] = parser + if parser is None or not parser.open_marker: + return [{'content': new_text}] if new_text else [] + return self._stream_marker_blocks(state, new_text, finished, parser) + + def _stream_marker_blocks( + self, + state: Dict[str, Any], + new_text: str, + finished: bool, + parser, + ) -> List[Dict[str, Any]]: + """Generic open/close marker streaming protocol. + + Buffers partial markup until ``parser.close_marker`` arrives, then + parses the block via ``parser.parse``. Used by Hermes/Qwen and any + future block-style format (Mistral ``[TOOL_CALLS]``, etc.). + """ + open_marker, close_marker = parser.open_marker, parser.close_marker + state.setdefault('pending', '') + state.setdefault('tc_count', 0) + if new_text: + state['pending'] += new_text + + events: List[Dict[str, Any]] = [] + while True: + buf = state['pending'] + if not buf: + break + open_idx = buf.find(open_marker) + if open_idx == -1: + partial = 0 if finished else trailing_prefix_of(buf, open_marker) + emit = buf[:-partial] if partial else buf + state['pending'] = buf[-partial:] if partial else '' + if emit: + events.append({'content': emit}) + break + if open_idx > 0: + events.append({'content': buf[:open_idx]}) + state['pending'] = buf[open_idx:] + continue + close_idx = buf.find(close_marker) + if close_idx == -1: + if finished: + # EOF with unclosed block — let parser.parse handle the truncation. + try: + parsed = parser.parse(buf) or [] + except Exception: + import logging + logging.getLogger(__name__).exception( + 'tool-call parse failed for unclosed streamed block; emitting as raw content') + events.append({'content': buf}) + state['pending'] = '' + break + if parsed: + for tc in parsed: + events.append({'tool_calls': [self._format_tc_delta(state, tc)]}) + else: + events.append({'content': buf}) + state['pending'] = '' + break + block_end = close_idx + len(close_marker) + block = buf[:block_end] + try: + parsed = parser.parse(block) or [] + except Exception: + logger.warn( + 'tool-call parse failed for streamed block; emitting as raw content') + events.append({'content': block}) + state['pending'] = buf[block_end:] + continue + for tc in parsed: + events.append({'tool_calls': [self._format_tc_delta(state, tc)]}) + state['pending'] = buf[block_end:] + return events + + @staticmethod + def _format_tc_delta(state: Dict[str, Any], tc: Dict[str, Any]) -> Dict[str, Any]: + """Format a parsed tool_call dict as an OpenAI streaming delta entry. + + ``arguments`` is encoded as JSON string for the wire format (OpenAI + streaming spec); ``index`` and ``id`` are auto-assigned from ``state``. + """ + fn = dict(tc.get('function') or {}) + args = fn.get('arguments') + if isinstance(args, dict): + fn['arguments'] = json.dumps(args, ensure_ascii=False) + delta = { + 'index': state['tc_count'], + 'id': tc.get('id') or f'call_{state["tc_count"]}', + 'type': tc.get('type') or 'function', + 'function': fn, + } + state['tc_count'] += 1 + return delta @property def tokenizer(self): @@ -239,6 +355,10 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]: message['reasoning_content'] = reasoning_content message['content'] = new_content + # Always emit string (never None/missing) — keeps PyArrow struct schema + # stable across shards; empty string renders identically to None in jinja. + if not isinstance(message.get('reasoning_content'), str): + message['reasoning_content'] = '' result.append(message) @@ -489,6 +609,20 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo k: v for k, v in b.items() if v is not None } for b in msg['content'] if isinstance(b, dict)] + for msg in messages: + tcs = msg.get('tool_calls') + if isinstance(tcs, str): + tcs = json.loads(tcs) if tcs else [] + msg['tool_calls'] = tcs + if not tcs: + continue + new_tcs = [] + for tc in tcs: + fn = tc['function'] + args = fn['arguments'] + decoded = json.loads(args) if args.strip() else {} + new_tcs.append({**tc, 'function': {**fn, 'arguments': decoded}}) + msg['tool_calls'] = new_tcs # ``tool_calls`` / ``tools`` are already OpenAI-shaped (see # :mod:`twinkle.data_format.message`); pass them through verbatim. tools = list(trajectory.get('tools') or []) @@ -544,10 +678,23 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo **kwargs) return inputs + @staticmethod + def _get_train_indices(trajectory: Trajectory) -> Optional[Set[int]]: + """Extract key-round assistant indices from trajectory's ``user_data``.""" + user_data = trajectory.get('user_data') + if not isinstance(user_data, dict): + return None + key_rounds = user_data.get('key_rounds') + if not isinstance(key_rounds, list) or not key_rounds: + return None + return set(key_rounds) or None + def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs) -> InputFeature: """Encode a single trajectory's messages into InputFeature.""" labels = None input_ids = None + # key-round selective training + train_indices = self._get_train_indices(trajectory) if not add_generation_prompt else None if self.use_chat_template: if add_generation_prompt: # For inference: just get input_ids with generation prompt, no labels needed @@ -557,6 +704,14 @@ def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = if hasattr(input_ids, 'squeeze'): input_ids = input_ids.squeeze(0) labels = np.full_like(input_ids, -100) # No labels for inference + elif train_indices is not None: + # key-round-only: always use TokenizeByRound with filtered indices + if kwargs.get('tokenize', True): + input_ids, labels, encoded = TokenizeByRound.tokenize_with_assistant_labels( + self.tokenizer, self._apply_chat_template, trajectory, + train_indices=train_indices, **kwargs) + else: + encoded = self._apply_chat_template(trajectory, **kwargs) elif self._template_support_assistant_tokens_mask: encoded = self._apply_chat_template( trajectory, return_assistant_tokens_mask=kwargs.get('tokenize', True), **kwargs) diff --git a/src/twinkle/template/qwen3_5_vl.py b/src/twinkle/template/qwen3_5_vl.py index c8332f49..e0a4487d 100644 --- a/src/twinkle/template/qwen3_5_vl.py +++ b/src/twinkle/template/qwen3_5_vl.py @@ -7,8 +7,7 @@ from twinkle import remote_class, requires from twinkle.data_format import InputFeature -from twinkle.template.base import ImageInput, VideoInput -from twinkle.template.qwen import QwenTemplate +from twinkle.template.base import ImageInput, Template, VideoInput from twinkle.template.utils import get_inputs_embeds_hf _ROPE_INDEX_CACHE: Dict[str, Callable] = {} @@ -31,7 +30,7 @@ def _build_rope_index_func(config) -> Callable: @remote_class() -class Qwen3_5Template(QwenTemplate): +class Qwen3_5Template(Template): """ Processor for Qwen VL series. @@ -44,8 +43,16 @@ def __init__(self, *args, **kwargs): # Fix upstream Qwen3 chat_template parse bugs (orphan
handling). # Deferred import to avoid cycles; idempotent across Ray actor re-init. from twinkle.patch import apply_patch - from twinkle.patch.qwen3_chat_template import Qwen3ChatTemplate + from twinkle.patch.qwen3_chat_template import ( + Qwen3AllowToolTailTemplate, Qwen3ChatTemplate) apply_patch(self.tokenizer, Qwen3ChatTemplate) + # Allow ScoreFilter to render multi-turn agent prefixes ending in `tool`. + apply_patch(self.tokenizer, Qwen3AllowToolTailTemplate) + # Qwen3VLProcessor carries its own chat_template; _apply_chat_template + # routes through self.processor, so the patch must be applied there too. + if self.processor is not self.tokenizer: + apply_patch(self.processor, Qwen3ChatTemplate) + apply_patch(self.processor, Qwen3AllowToolTailTemplate) self._patch_size: Optional[int] = None self._merge_size: Optional[int] = None self._init_vision_config() diff --git a/src/twinkle/template/tools/__init__.py b/src/twinkle/template/tools/__init__.py new file mode 100644 index 00000000..243774bd --- /dev/null +++ b/src/twinkle/template/tools/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tool-call parser registry. + +Importing this package auto-registers every parser. Order matters: +narrower / stronger formats first so round-robin detection prefers them +over weaker fallbacks. +""" +from .base import ToolCallParser, ToolCallRegistry, trailing_prefix_of +from .cline import ClineParser +from .qwen import HermesQwenParser +from .react import ReActParser +from .vcp import VCPParser + +# Order: strongest/most-specific markers first. Hermes owns ```` +# (also denied by Cline), so its detection wins for shared-XML inputs. +ToolCallRegistry.register(HermesQwenParser()) +ToolCallRegistry.register(ClineParser()) +ToolCallRegistry.register(VCPParser()) +ToolCallRegistry.register(ReActParser()) + +__all__ = [ + 'ToolCallParser', + 'ToolCallRegistry', + 'trailing_prefix_of', + 'HermesQwenParser', + 'ClineParser', + 'VCPParser', + 'ReActParser', +] diff --git a/src/twinkle/template/tools/base.py b/src/twinkle/template/tools/base.py new file mode 100644 index 00000000..fd94206c --- /dev/null +++ b/src/twinkle/template/tools/base.py @@ -0,0 +1,79 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + + +class ToolCallParser(ABC): + """Single-format tool-call parser.""" + + name: str = '' + open_marker: Optional[str] = None + close_marker: Optional[str] = None + + def matches_model(self, model_id: str) -> bool: + """Return True if this parser is the canonical choice for ``model_id``. + + Used for streaming where we must commit to a parser before any text + has arrived. Default False — parser is text-detection-only. + """ + return False + + @abstractmethod + def detect(self, text: str) -> bool: + """Cheap pre-check: does ``text`` carry this format's markup?""" + + @abstractmethod + def parse(self, text: str) -> List[Dict[str, Any]]: + """Return OpenAI-shape tool_calls. ``arguments`` is a dict (jinja-friendly).""" + + @abstractmethod + def clean(self, text: str) -> str: + """Strip parser-specific markup; return plain content text.""" + + +class ToolCallRegistry: + """Global ordered registry of :class:`ToolCallParser` instances.""" + + _parsers: List[ToolCallParser] = [] + + @classmethod + def register(cls, parser: ToolCallParser) -> ToolCallParser: + for p in cls._parsers: + if p.name == parser.name: + return p + cls._parsers.append(parser) + return parser + + @classmethod + def parsers(cls) -> List[ToolCallParser]: + return list(cls._parsers) + + @classmethod + def select_for_model(cls, model_id: Optional[str]) -> Optional[ToolCallParser]: + mid = (model_id or '').lower() + for p in cls._parsers: + if p.matches_model(mid): + return p + return None + + @classmethod + def detect_first(cls, text: str) -> Optional[ToolCallParser]: + if not text: + return None + for p in cls._parsers: + if p.detect(text): + return p + return None + + +def trailing_prefix_of(buf: str, marker: str) -> int: + """Length of trailing chars of ``buf`` that form a strict prefix of ``marker``. + + Used by streaming protocols to hold back the tail when it could be the + start of an upcoming open tag, preventing mid-marker splits. + """ + upper = min(len(marker) - 1, len(buf)) + for k in range(upper, 0, -1): + if buf.endswith(marker[:k]): + return k + return 0 diff --git a/src/twinkle/template/tools/cline.py b/src/twinkle/template/tools/cline.py new file mode 100644 index 00000000..e6273cfb --- /dev/null +++ b/src/twinkle/template/tools/cline.py @@ -0,0 +1,105 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Cline / OpenClaw text-embedded XML tool-call format. + +Wire format (Layer-B agent app protocol — lives in plain ``content``, +not in the OpenAI ``tool_calls`` field): + + src/foo.py + + ls -la + false + + +Detection is **structural** (no hardcoded tool-name whitelist): + +* outer tag is snake_case ``[a-z][a-z0-9_]*`` and not in :data:`_DENY` +* outer block contains at least one nested ``VAL`` child + +Streaming: ``open_marker``/``close_marker`` are ``None`` because the +outer tag varies per call. The base ``parse_tool_call_stream`` therefore +falls back to plain content passthrough; recognised blocks are extracted +only on full-text :meth:`parse` (e.g. by ``AgentTraceFilter`` after +trajectory assembly). +""" +from __future__ import annotations + +import re +from typing import Any, Dict, List + +from .base import ToolCallParser + +# Common HTML-like / template tags that are NOT Cline tool calls. Outer +# tags falling here are skipped to prevent false positives. +_DENY = frozenset({ + # twinkle-internal / model-internal markers + 'think', 'answer', 'tool_call', 'tool_response', 'function', 'parameter', + 'parameters', 'tools', 'tool', 'system', 'user', 'assistant', 'message', + 'messages', 'content', 'response', 'output', 'role', 'reasoning_content', + # html / markdown + 'p', 'a', 'b', 'i', 'em', 'strong', 'div', 'span', 'pre', 'code', 'br', + 'hr', 'ul', 'ol', 'li', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'table', + 'tr', 'td', 'th', 'tbody', 'thead', 'img', 'video', 'audio', +}) + +# Outer tool-call block: matched-pair via backreference. Body is non-greedy. +_BLOCK_RE = re.compile(r'<(?P[a-z][a-z0-9_]*)>(?P[\s\S]*?)') +# Inner parameter: matched-pair via backreference. +_PARAM_RE = re.compile(r'<(?P[a-z][a-z0-9_]*)>(?P[\s\S]*?)') + + +class ClineParser(ToolCallParser): + name = 'cline' + # Outer tag varies per tool — no fixed marker; streaming uses passthrough. + open_marker = None + close_marker = None + + def matches_model(self, model_id: str) -> bool: + # Cline is an app-level prompt protocol, not bound to any model family. + return False + + def detect(self, text: str) -> bool: + if not text or '<' not in text: + return False + for m in _BLOCK_RE.finditer(text): + if m.group('tool') in _DENY: + continue + if _PARAM_RE.search(m.group('body')): + return True + return False + + def parse(self, text: str) -> List[Dict[str, Any]]: + calls: List[Dict[str, Any]] = [] + for m in _BLOCK_RE.finditer(text or ''): + tool = m.group('tool') + if tool in _DENY: + continue + args: Dict[str, Any] = {} + for pm in _PARAM_RE.finditer(m.group('body')): + args[pm.group('key')] = pm.group('val').strip() + if not args: + continue + calls.append({ + 'type': 'function', + 'function': {'name': tool, 'arguments': args}, + }) + return calls + + def clean(self, text: str) -> str: + if not text: + return text or '' + spans: List[tuple] = [] + for m in _BLOCK_RE.finditer(text): + if m.group('tool') in _DENY: + continue + if not _PARAM_RE.search(m.group('body')): + continue + spans.append((m.start(), m.end())) + if not spans: + return text.rstrip() + out: List[str] = [] + last = 0 + for s, e in spans: + out.append(text[last:s]) + last = e + out.append(text[last:]) + return ''.join(out).rstrip() diff --git a/src/twinkle/template/qwen.py b/src/twinkle/template/tools/qwen.py similarity index 61% rename from src/twinkle/template/qwen.py rename to src/twinkle/template/tools/qwen.py index 4c68ab3a..12361b73 100644 --- a/src/twinkle/template/qwen.py +++ b/src/twinkle/template/tools/qwen.py @@ -3,21 +3,28 @@ import re from typing import Any, Dict, List -from twinkle import remote_class -from twinkle.template import Template +from .base import ToolCallParser -@remote_class() -class QwenTemplate(Template): +class HermesQwenParser(ToolCallParser): + name = 'hermes_qwen' + open_marker = '' + close_marker = '' _BLOCK_RE = re.compile(r'\s*([\s\S]*?)\s*(?:|\Z)') _FUNCTION_RE = re.compile(r']+)>([\s\S]*?)') _PARAMETER_RE = re.compile(r']+)>\s*([\s\S]*?)\s*') _STRIP_RE = re.compile(r'[\s\S]*?(?:|\Z)') - def parse(self, decoded: str) -> List[Dict[str, Any]]: + def matches_model(self, model_id: str) -> bool: + return 'qwen' in model_id + + def detect(self, text: str) -> bool: + return self.open_marker in text + + def parse(self, text: str) -> List[Dict[str, Any]]: calls: List[Dict[str, Any]] = [] - for block_m in self._BLOCK_RE.finditer(decoded or ''): + for block_m in self._BLOCK_RE.finditer(text or ''): block = block_m.group(1) func_m = self._FUNCTION_RE.search(block) if func_m: @@ -37,7 +44,6 @@ def parse(self, decoded: str) -> List[Dict[str, Any]]: }, }) continue - # JSON fallback: ``{"name": ..., "arguments": ...}`` inside the block. try: data = json.loads(block) except json.JSONDecodeError: @@ -60,26 +66,5 @@ def parse(self, decoded: str) -> List[Dict[str, Any]]: }) return calls - def clean(self, decoded: str) -> str: - return self._STRIP_RE.sub('', decoded or '').rstrip() - - def parse_tool_call(self, decoded: str) -> List[Dict[str, Any]]: - """Parse tool calls from the assistant's decoded output. - - Dispatches by model family on ``self.model_id``; the actual - wire-format logic lives in :mod:`.tool_call_parser`. - """ - mid = (self.model_id or '').lower() - if 'qwen' in mid: - return self.parse(decoded) - # TODO: Other models (Llama3, OpenAI JSON, …) — add a parser in - # ``tool_call_parser.py`` and extend this dispatch. - return [] - - def clean_tool_call(self, decoded: str) -> str: - """Strip family-specific tool-call markup from assistant text.""" - mid = (self.model_id or '').lower() - if 'qwen' in mid: - return self.clean(decoded) - # TODO: Other models - return (decoded or '').rstrip() + def clean(self, text: str) -> str: + return self._STRIP_RE.sub('', text or '').rstrip() diff --git a/src/twinkle/template/tools/react.py b/src/twinkle/template/tools/react.py new file mode 100644 index 00000000..774f7421 --- /dev/null +++ b/src/twinkle/template/tools/react.py @@ -0,0 +1,32 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import re +from typing import Any, Dict, List + +from .base import ToolCallParser + +_ACTION_RE = re.compile( + r'^\s*Action\s*:\s*(?P[\w\-./]+)\s*\[(?P.*?)\]\s*$', + re.MULTILINE, +) + + +class ReActParser(ToolCallParser): + name = 'react' + + def detect(self, text: str) -> bool: + return bool(_ACTION_RE.search(text or '')) + + def parse(self, text: str) -> List[Dict[str, Any]]: + calls: List[Dict[str, Any]] = [] + for m in _ACTION_RE.finditer(text or ''): + calls.append({ + 'type': 'function', + 'function': { + 'name': m.group('name'), + 'arguments': {'input': m.group('args')}, + }, + }) + return calls + + def clean(self, text: str) -> str: + return _ACTION_RE.sub('', text or '').rstrip() diff --git a/src/twinkle/template/tools/vcp.py b/src/twinkle/template/tools/vcp.py new file mode 100644 index 00000000..5e030f9d --- /dev/null +++ b/src/twinkle/template/tools/vcp.py @@ -0,0 +1,65 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import re +from typing import Any, Dict, List + +from .base import ToolCallParser + +_VCP_OPEN = '<<<[TOOL_REQUEST]>>>' +_VCP_CLOSE = '<<<[END_TOOL_REQUEST]>>>' + +_VCP_BLOCK_RE = re.compile( + r'<<<\[TOOL_REQUEST\]>>>(.*?)<<<\[END_TOOL_REQUEST\]>>>', + re.DOTALL, +) + +# `「始ESCAPE」...「末ESCAPE」` is the nesting-safe variant; pair them strictly +# so an escaped value is not closed by a bare `「末」` from an inner block. +_VCP_KV_RE = re.compile( + r'(?P[A-Za-z_]\w*)\s*:\s*' + r'(?:「始ESCAPE」(?P.*?)「末ESCAPE」' + r'|「始」(?P.*?)「末」)', + re.DOTALL, +) + + +class VCPParser(ToolCallParser): + """VCPChat / VCPSystem custom tool-call format. + + Outer markers ``<<<[TOOL_REQUEST]>>> ... <<<[END_TOOL_REQUEST]>>>`` wrap + one call; parameters use full-width brackets ``「始」value「末」`` (escape + variant ``「始ESCAPE」...「末ESCAPE」`` permits nested outer markers). + The canonical function name lives in the ``tool_name`` field. + """ + + name = 'vcp' + open_marker = _VCP_OPEN + close_marker = _VCP_CLOSE + + def detect(self, text: str) -> bool: + return _VCP_OPEN in (text or '') + + def parse(self, text: str) -> List[Dict[str, Any]]: + calls: List[Dict[str, Any]] = [] + for block in _VCP_BLOCK_RE.findall(text or ''): + args: Dict[str, Any] = {} + name = '' + for m in _VCP_KV_RE.finditer(block): + k = m.group('key') + v = m.group('val_esc') if m.group('val_esc') is not None else m.group('val') + if k == 'tool_name': + name = (v or '').strip() + else: + args[k] = v + if not name: + continue + calls.append({ + 'type': 'function', + 'function': { + 'name': name, + 'arguments': args, + }, + }) + return calls + + def clean(self, text: str) -> str: + return _VCP_BLOCK_RE.sub('', text or '').rstrip() diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index 72975d78..fe2b1e03 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -194,6 +194,7 @@ class TokenizeByRound: @staticmethod def tokenize_with_assistant_labels(tokenizer: 'PreTrainedTokenizer', encode_func: Callable, trajectory: Trajectory, + train_indices: Optional[set] = None, **kwargs) -> Tuple[List[int], List[int], Dict[str, Any]]: """Tokenize trajectory and generate labels for assistant turns. @@ -201,6 +202,8 @@ def tokenize_with_assistant_labels(tokenizer: 'PreTrainedTokenizer', encode_func tokenizer: The tokenizer (unused, kept for interface compatibility). encode_func: Function to encode a trajectory. Must support add_generation_prompt. trajectory: The trajectory containing messages. + train_indices: If provided, only label assistant messages whose + message index is in this set. ``None`` means label all. Returns: Tuple of (input_ids, labels, extra_encoded_fields). @@ -225,6 +228,8 @@ def tokenize_with_assistant_labels(tokenizer: 'PreTrainedTokenizer', encode_func for i, msg in enumerate(messages): if msg['role'] != 'assistant': continue + if train_indices is not None and i not in train_indices: + continue # Get position AFTER assistant prefix: # encode(messages[:i], add_generation_prompt=True) includes the prefix diff --git a/src/twinkle/utils/parallel.py b/src/twinkle/utils/parallel.py index ba3b63e3..a235c136 100644 --- a/src/twinkle/utils/parallel.py +++ b/src/twinkle/utils/parallel.py @@ -87,6 +87,55 @@ def _try_create_claim(path: str, session: str, payload: str) -> bool: return True +class PosixFileLock: + """POSIX advisory file lock with persistent fd for repeated acquire/release. + + Fork-safe: reopens its fd lazily when used from a child process so each + worker owns its own descriptor. + """ + + def __init__(self, path: str): + import fcntl + self._path = path + self._fcntl = fcntl + self._fd = open(path, 'w') + self._pid = os.getpid() + + def _ensure_fd(self): + # After fork, child must reopen so it doesn't share parent's fd state. + pid = os.getpid() + if pid != self._pid: + self._fd = open(self._path, 'w') + self._pid = pid + + def acquire(self): + self._ensure_fd() + self._fcntl.flock(self._fd, self._fcntl.LOCK_EX) + + def release(self): + self._fcntl.flock(self._fd, self._fcntl.LOCK_UN) + + def close(self): + self._fd.close() + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, *exc): + self.release() + + def __getstate__(self): + return {'_path': self._path} + + def __setstate__(self, state): + import fcntl + self._path = state['_path'] + self._fcntl = fcntl + self._fd = open(self._path, 'w') + self._pid = os.getpid() + + @contextmanager def processing_lock(lock_file: str): """A file lock to prevent parallel operations to one file. diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index deb788db..a2aa8ad9 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -268,6 +268,45 @@ def pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200 return torch.stack(padded_tensors, dim=0) +def gather_cp_load_balanced(tensor: 'torch.Tensor', cp_group, seq_dim: int = 1) -> 'torch.Tensor': + """All-gather a CP-load-balanced shard along ``seq_dim`` into the full sequence. + + Inverse of :func:`split_cp_inputs`: each CP rank ``r`` holds chunks ``[r, 2*cp - r - 1]`` + of the original ``2*cp`` sequence chunks. The local rank's slice keeps autograd; + other ranks' slices are detached copies, so backward through the gathered tensor + only produces gradients for the local chunk. + """ + import torch + cp_size = cp_group.size() + if cp_size <= 1: + return tensor + cp_rank = torch.distributed.get_rank(group=cp_group) + gathered = [torch.empty_like(tensor) for _ in range(cp_size)] + torch.distributed.all_gather(gathered, tensor.contiguous(), group=cp_group) + gathered[cp_rank] = tensor + seq_local = tensor.shape[seq_dim] + half_len = seq_local // 2 + full_seq = seq_local * cp_size + chunk_len = full_seq // (2 * cp_size) + out_shape = list(tensor.shape) + out_shape[seq_dim] = full_seq + output = tensor.new_zeros(*out_shape) + for j in range(cp_size): + o = gathered[j] + front = [slice(None)] * tensor.ndim + front[seq_dim] = slice(j * chunk_len, (j + 1) * chunk_len) + rev = 2 * cp_size - j - 1 + back = [slice(None)] * tensor.ndim + back[seq_dim] = slice(rev * chunk_len, (rev + 1) * chunk_len) + local_front = [slice(None)] * tensor.ndim + local_front[seq_dim] = slice(0, half_len) + local_back = [slice(None)] * tensor.ndim + local_back[seq_dim] = slice(half_len, seq_local) + output[tuple(front)] = o[tuple(local_front)] + output[tuple(back)] = o[tuple(local_back)] + return output + + def split_cp_inputs(inputs: 'torch.Tensor', cu_seqlens: Optional['torch.Tensor'], dim: int): import torch from megatron.core import mpu diff --git a/src/twinkle_agentic/data_format/__init__.py b/src/twinkle_agentic/data_format/__init__.py index 6298015c..35457599 100644 --- a/src/twinkle_agentic/data_format/__init__.py +++ b/src/twinkle_agentic/data_format/__init__.py @@ -1 +1,2 @@ from .chunks import Chunk, Chunks +from .score import RoundContext, ScoreResult, Scorer diff --git a/src/twinkle_agentic/data_format/score.py b/src/twinkle_agentic/data_format/score.py new file mode 100644 index 00000000..20550940 --- /dev/null +++ b/src/twinkle_agentic/data_format/score.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Protocol + + +@dataclass +class RoundContext: + """Per-round payload passed to scorers.""" + row_idx: int + rnd_idx: int + asst_idx: int + row: Dict[str, Any] + intent: Optional[str] + messages: List[Dict[str, Any]] + context_messages: List[Dict[str, Any]] + cond_ids: List[int] + n_prompt: int + asst_ids: List[int] + asst_text: str + user_prompt: str + features: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ScoreResult: + score: Optional[float] = None + passed: bool = True + extras: Dict[str, Any] = field(default_factory=dict) + + +class Scorer(Protocol): + name: str + requires_logprobs: bool + + def score(self, contexts: List[RoundContext]) -> List[ScoreResult]: + ... diff --git a/src/twinkle_agentic/preprocessor/__init__.py b/src/twinkle_agentic/preprocessor/__init__.py new file mode 100644 index 00000000..f351683c --- /dev/null +++ b/src/twinkle_agentic/preprocessor/__init__.py @@ -0,0 +1,105 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import json +import time +from typing import Any, Callable, Dict, List, Optional + +from twinkle.preprocessor import Preprocessor +from twinkle.utils import get_logger +from twinkle.utils.parallel import PosixFileLock + +from .consistency_filter import ConsistencyFilter +from .agent_trace_filter import AgentTraceFilter +from .data_juicer import ( + AlphanumericFilter, + CharRepeatFilter, + FlaggedWordsFilter, + FixUnicodeFilter, + KenLMFilter, + LanguageFilter, + LLMConditionFilter, + LLMDifficultyFilter, + LLMQualityFilter, + LLMTaskRelevanceFilter, + MinHashDedupFilter, + RemoveRepeatSentencesFilter, + SpecialCharsFilter, + StopwordsFilter, + TextActionFilter, + TokenNumFilter, + WordRepeatFilter, +) +from .dead_loop_filter import DeadLoopFilter +from .hard_filter import HardFilter +from .intent_classifier import IntentClassifier +from .llm_backend import LLMBackend, OpenAIBackend, SamplerBackend # noqa: F401 +from .majority_vote import MajorityVoteFilter +from .message_sanity import MessageSanityFilter +from .perplexity import PerplexityFilter +from .pii_presidio_filter import PIIPresidioFilter +from .refuse_filter import RefuseFilter +from .response_refiner import ResponseRefiner +from .score_filter import ScoreFilter +from .token_soup import TokenSoupFilter + +logger = get_logger(only_local_master=False) + + +class QualityPreprocessor(Preprocessor): + """Thin pipeline runner: accepts a list of callables, runs them in order. + + Each step must accept and return List[Dict[str, Any]]. + Per-step logging (before/after count) and optional dropped-row JSONL are provided. + """ + + def __init__(self, pipeline: List[Callable], dropped_log_path: str = ''): + import os + super().__init__() + self._pipelines = list(pipeline) + self._dropped_log_path = dropped_log_path + if dropped_log_path: + os.makedirs(os.path.dirname(os.path.abspath(dropped_log_path)), exist_ok=True) + self._lock: Optional[PosixFileLock] = ( + PosixFileLock(dropped_log_path + '.lock') if dropped_log_path else None) + if dropped_log_path and os.path.exists(dropped_log_path): + os.remove(dropped_log_path) + + def __call__(self, rows): + rows_list = self.map_col_to_row(rows) + for step in self._pipelines: + if not rows_list: + break + step_name = getattr(step, '__name__', None) or type(step).__name__ + before = len(rows_list) + prev = rows_list + t0 = time.perf_counter() + rows_list = self.map_col_to_row(step(rows_list)) + elapsed = time.perf_counter() - t0 + after = len(rows_list) + logger.info( + f'[QualityPreprocessor] {step_name}: {before} -> {after} ' + f'(dropped {before - after}, {elapsed:.3f}s)') + self._log_dropped(step_name, prev, rows_list) + return self.map_row_to_col(rows_list) + + def _log_dropped(self, step_name: str, prev: List[Dict[str, Any]], + kept: List[Dict[str, Any]]) -> None: + if not self._lock or len(kept) == len(prev): + return + # Use row 'id' field for matching; fall back to object id + kept_keys = set() + for r in kept: + rid = r.get('id') + kept_keys.add(rid if rid is not None else id(r)) + dropped = [] + for r in prev: + rid = r.get('id') + key = rid if rid is not None else id(r) + if key not in kept_keys: + dropped.append(r) + if not dropped: + return + with self._lock: + with open(self._dropped_log_path, 'a', encoding='utf-8') as f: + for r in dropped: + f.write(json.dumps({'step': step_name, 'row': r}, + ensure_ascii=False, default=str) + '\n') diff --git a/src/twinkle_agentic/preprocessor/agent_trace_filter.py b/src/twinkle_agentic/preprocessor/agent_trace_filter.py new file mode 100644 index 00000000..c223c02a --- /dev/null +++ b/src/twinkle_agentic/preprocessor/agent_trace_filter.py @@ -0,0 +1,65 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Detect agent-rollout data so downstream filters can adapt their rules. + +Agent SFT datasets (Cline / OpenClaw / Claude Code) carry trajectories whose +tool calls are encoded as text inside assistant content (e.g. +``foo``) rather than as the OpenAI +``tool_calls`` field, and whose tool execution results are ``role='tool'``. + +Two consequences this preprocessor exists to handle: + +1. ``MessageSanityFilter`` strict role-order rules reject these traces. +2. ``DeadLoopFilter`` over-fires on long agent trajectories whose phrasing + ("Let me read the file...") matches hesitation regexes designed for + short reasoning traces. + +Detection-only: rows are tagged ``is_agent=True`` and never dropped. +Downstream filters read the flag and adapt. +""" +from typing import Any, Dict, List + +from twinkle.preprocessor import Preprocessor +from twinkle.template.tools import ToolCallRegistry + +from .message_sanity import _normalize_tool_calls + + +def _msg_text(m: Dict[str, Any]) -> str: + c = m.get('content') + if isinstance(c, str): + return c + if isinstance(c, list): + return ' '.join(p.get('text', '') for p in c + if isinstance(p, dict) and p.get('type') == 'text') + return '' + + +def _is_agent_row(messages: Any) -> bool: + if not isinstance(messages, list): + return False + for m in messages: + if not isinstance(m, dict): + continue + role = m.get('role') + if role == 'tool': + return True + tcs = _normalize_tool_calls(m) + if tcs: + return True + # Text-embedded tool calls (Cline / OpenClaw / Claude-Code style): + # delegate detection to the parser registry — no hardcoded tag list. + if role == 'assistant' and ToolCallRegistry.detect_first(_msg_text(m)) is not None: + return True + return False + + +class AgentTraceFilter(Preprocessor): + """Tag rows that look like agent rollouts; never drops rows.""" + + def __call__(self, rows) -> List[Dict[str, Any]]: + # Set is_agent on every row (not just matches) so map_row_to_col sees a + # uniform schema; otherwise rows[0].keys() may miss 'is_agent' and KeyError later. + return [ + dict(row, is_agent=_is_agent_row(row.get('messages'))) + for row in rows + ] diff --git a/src/twinkle_agentic/preprocessor/consistency_filter.py b/src/twinkle_agentic/preprocessor/consistency_filter.py new file mode 100644 index 00000000..9d983fbf --- /dev/null +++ b/src/twinkle_agentic/preprocessor/consistency_filter.py @@ -0,0 +1,274 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional + +import numpy as np + +from twinkle.preprocessor import Preprocessor + +from .llm_backend import LLMBackend, OpenAIBackend + + +def _get_assistant_text(messages: List[Dict[str, Any]]) -> Optional[str]: + for m in reversed(messages): + if isinstance(m, dict) and m.get('role') == 'assistant': + return (m.get('content') or '').strip() + return None + + +def _get_prompt_messages(messages: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]: + """Return messages up to (not including) the last assistant turn.""" + for i in range(len(messages) - 1, -1, -1): + if isinstance(messages[i], dict) and messages[i].get('role') == 'assistant': + return messages[:i] + return None + + +def _cosine_sim(a: np.ndarray, b: np.ndarray) -> float: + denom = np.linalg.norm(a) * np.linalg.norm(b) + if denom < 1e-12: + return 0.0 + return float(np.dot(a, b) / denom) + + +def _pairwise_cosine_mean(embeddings: np.ndarray) -> float: + """Mean pairwise cosine similarity for N embeddings of shape (N, dim).""" + n = len(embeddings) + if n < 2: + return 1.0 + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + normed = embeddings / np.clip(norms, 1e-12, None) + sim_matrix = normed @ normed.T + return float(sim_matrix[np.triu_indices(n, k=1)].mean()) + + +def _generate_rollouts( + backend: LLMBackend, + prompt_messages: List[Dict[str, Any]], + n: int, + temperature: float, +) -> List[str]: + choices = backend.chat(prompt_messages, temperature=temperature, max_tokens=4096, n=n) + return [c.get('content', '') for c in choices] + + +def _embed_texts( + backend: LLMBackend, + texts: List[str], +) -> np.ndarray: + return backend.embeddings(texts) + + +def _process_row( + backend: LLMBackend, + embed_backend: LLMBackend, + messages: List[Dict[str, Any]], + n_rollouts: int, + temperature: float, +) -> Optional[Dict[str, Any]]: + """Returns {'C': float, 'D': float, 'best_rollout': str, 'best_d': float} or None.""" + prompt_msgs = _get_prompt_messages(messages) + if not prompt_msgs: + return None + + traj_text = _get_assistant_text(messages) + if not traj_text: + return None + + try: + rollout_texts = _generate_rollouts( + backend, + prompt_msgs, n_rollouts, temperature, + ) + except Exception: + return None + + rollout_texts = [t for t in rollout_texts if t.strip()] + if len(rollout_texts) < 2: + return None + + try: + embeddings = _embed_texts( + embed_backend, [traj_text] + rollout_texts) + except Exception: + return None + + if len(embeddings) != 1 + len(rollout_texts): + return None + + traj_emb = embeddings[0] + rollout_embs = embeddings[1:] + + c = _pairwise_cosine_mean(rollout_embs) + d = 1.0 - _cosine_sim(rollout_embs.mean(axis=0), traj_emb) + + # rollout closest to original traj + norms = np.linalg.norm(rollout_embs, axis=1, keepdims=True) + normed_r = rollout_embs / np.clip(norms, 1e-12, None) + traj_norm = traj_emb / max(np.linalg.norm(traj_emb), 1e-12) + sims = normed_r @ traj_norm + best_idx = int(np.argmax(sims)) + + return { + 'C': c, + 'D': d, + 'best_rollout': rollout_texts[best_idx], + 'best_d': 1.0 - float(sims[best_idx]), + } + + +class ConsistencyFilter(Preprocessor): + """2D consistency filter: rollout consistency (C) × deviation from original traj (D). + + Quadrants: + A (C>=thresh, D=thresh, D>=thresh): stable but drifted → source-dependent + C (C=thresh): unstable & off-target → filter + + Modes (combinable): + filter only: drop quadrant D (and B when source=self) + annotate=True: keep all, attach _quadrant/_diff_score/_consistency/_deviation + replace=True: replace assistant traj with best rollout where safe + """ + + def __init__( + self, + backend: LLMBackend = None, + embed_backend: LLMBackend = None, + n_rollouts: int = 8, + c_thresh: float = 0.7, + d_thresh: float = 0.3, + temperature: float = 0.7, + max_workers: int = 4, + source: str = 'auto', + annotate: bool = False, + replace: bool = False, + min_density_ratio: float = 0.4, + # Legacy params + sampler_endpoint: str = '', + embed_endpoint: str = '', + sampler_model: str = 'default', + embed_model: str = 'bge-m3', + ): + if backend is not None: + self._backend = backend + else: + self._backend = OpenAIBackend( + endpoint=sampler_endpoint, model=sampler_model, timeout=300.0) + if embed_backend is not None: + self._embed_backend = embed_backend + else: + self._embed_backend = OpenAIBackend( + endpoint=embed_endpoint, model=embed_model, timeout=300.0) + self._n_rollouts = n_rollouts + self._c_thresh = c_thresh + self._d_thresh = d_thresh + self._temperature = temperature + self._max_workers = max_workers + self._source = source + self._annotate = annotate + self._replace = replace + self._min_density_ratio = min_density_ratio + + def _assign_quadrant(self, c: float, d: float) -> str: + if c >= self._c_thresh: + return 'A' if d < self._d_thresh else 'B' + return 'C' if d < self._d_thresh else 'D' + + def _should_drop(self, quadrant: str, row: Dict[str, Any]) -> bool: + """Whether to remove the row entirely (only applies in non-annotate mode).""" + if quadrant == 'D': + return True + if quadrant == 'B': + if self._source == 'self': + return True + if self._source == 'auto' and row.get('_source') == 'self': + return True + return False + + def _try_replace(self, row: Dict[str, Any], metrics: Dict[str, Any], quadrant: str) -> None: + """Attempt in-place replacement of assistant content with best rollout.""" + original = _get_assistant_text(row.get('messages') or []) or '' + best = metrics['best_rollout'] + density = len(best) / max(len(original), 1) + + if quadrant == 'A': + if density >= self._min_density_ratio: + self._set_assistant_text(row, best) + row['_replaced'] = True + else: + row['_replaced'] = False + row['_needs_completion'] = True + elif quadrant == 'C' and metrics['best_d'] < self._d_thresh * 0.8: + if density >= self._min_density_ratio: + self._set_assistant_text(row, best) + row['_replaced'] = True + else: + row['_replaced'] = False + elif quadrant == 'B': + row['_replaced'] = False + row['_needs_verification'] = True + else: + row['_replaced'] = False + + @staticmethod + def _set_assistant_text(row: Dict[str, Any], text: str) -> None: + for m in reversed(row.get('messages') or []): + if isinstance(m, dict) and m.get('role') == 'assistant': + m['content'] = text + return + + def __call__(self, rows) -> List[Dict[str, Any]]: + if not rows: + return rows + + results: Dict[int, Optional[Dict[str, Any]]] = {} + n_workers = min(self._max_workers, len(rows)) + + with ThreadPoolExecutor(max_workers=n_workers) as pool: + future_to_idx = { + pool.submit( + _process_row, + self._backend, self._embed_backend, + row.get('messages') or [], self._n_rollouts, self._temperature, + ): i + for i, row in enumerate(rows) + } + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + try: + results[idx] = future.result() + except Exception: + results[idx] = None + + out = [] + for i, row in enumerate(rows): + metrics = results.get(i) + + if metrics is None: + if self._annotate: + row['_quadrant'] = 'unknown' + row['_diff_score'] = -1.0 + out.append(row) + continue + + c, d = metrics['C'], metrics['D'] + quadrant = self._assign_quadrant(c, d) + + # filter decision (skip in annotate mode — annotate keeps everything) + if not self._annotate and self._should_drop(quadrant, row): + continue + + if self._annotate: + row['_quadrant'] = quadrant + row['_diff_score'] = (1.0 - c) if d < self._d_thresh else 0.0 + row['_consistency'] = c + row['_deviation'] = d + + if self._replace: + self._try_replace(row, metrics, quadrant) + + out.append(row) + + return out diff --git a/src/twinkle_agentic/preprocessor/data_juicer.py b/src/twinkle_agentic/preprocessor/data_juicer.py new file mode 100644 index 00000000..2b447f20 --- /dev/null +++ b/src/twinkle_agentic/preprocessor/data_juicer.py @@ -0,0 +1,350 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Data-Juicer integration for trajectory quality filtering. +# +# Each class below is a standalone Preprocessor with __call__ interface. +# They share a module-level op cache for model/tokenizer reuse. +from typing import Any, Dict, List, Optional, Union + +from twinkle.preprocessor import Preprocessor + + +# ── Shared helpers ──────────────────────────────────────────────────────────── + +_OP_CACHE: Dict = {} + + +def _get_op(op_class, **kwargs): + key = (op_class, repr(tuple(sorted(kwargs.items())))) + if key not in _OP_CACHE: + _OP_CACHE[key] = op_class(**kwargs) + return _OP_CACHE[key] + + +def _get_tokenizer(hf_tokenizer: str): + key = ('_tokenizer', hf_tokenizer) + if key not in _OP_CACHE: + from modelscope import AutoTokenizer + _OP_CACHE[key] = AutoTokenizer.from_pretrained(hf_tokenizer, trust_remote_code=True) + return _OP_CACHE[key] + + +def _get_text(row: Dict[str, Any], role: str = 'assistant') -> str: + """Concatenate all turns for a given role from messages.""" + parts = [] + for msg in row.get('messages') or []: + if msg.get('role') == role: + content = msg.get('content') or '' + if isinstance(content, list): + content = ' '.join(b.get('text', '') for b in content if isinstance(b, dict)) + parts.append(str(content)) + return ' '.join(parts) + + +def _keep_mask(op, texts: List[str]) -> List[bool]: + """Run a DJ Filter op directly; no dataset/multiprocessing overhead.""" + from data_juicer.utils.constant import Fields + samples = {op.text_key: texts, Fields.stats: [{} for _ in texts], Fields.meta: [{} for _ in texts]} + samples = op.compute_stats_batched(samples) + return list(op.process_batched(samples)) + + +# ── Wrapper classes ─────────────────────────────────────────────────────────── + + +class FixUnicodeFilter(Preprocessor): + def __init__(self, normalization: str = 'NFC', role: str = 'assistant'): + self._normalization = normalization + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.mapper import FixUnicodeMapper + op = _get_op(FixUnicodeMapper, normalization=self._normalization) + indices, texts = [], [] + for ri, row in enumerate(rows): + for mi, msg in enumerate(row.get('messages') or []): + if msg.get('role') == self._role: + texts.append(msg.get('content') or '') + indices.append((ri, mi)) + if not texts: + return rows + result = op.process_batched({op.text_key: list(texts)}) + for (ri, mi), new_text in zip(indices, result[op.text_key]): + rows[ri]['messages'][mi]['content'] = new_text + return rows + + +class RemoveRepeatSentencesFilter(Preprocessor): + def __init__(self, lowercase: bool = False, ignore_special_character: bool = True, role: str = 'assistant'): + self._lowercase = lowercase + self._ignore = ignore_special_character + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.mapper import RemoveRepeatSentencesMapper + op = _get_op(RemoveRepeatSentencesMapper, lowercase=self._lowercase, ignore_special_character=self._ignore) + indices, texts = [], [] + for ri, row in enumerate(rows): + for mi, msg in enumerate(row.get('messages') or []): + if msg.get('role') == self._role: + texts.append(msg.get('content') or '') + indices.append((ri, mi)) + if not texts: + return rows + result = op.process_batched({op.text_key: list(texts)}) + for (ri, mi), new_text in zip(indices, result[op.text_key]): + rows[ri]['messages'][mi]['content'] = new_text + return rows + + +class WordRepeatFilter(Preprocessor): + def __init__(self, rep_len: int = 10, max_ratio: float = 0.4, role: str = 'assistant'): + self._rep_len = rep_len + self._max_ratio = max_ratio + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import WordRepetitionFilter + op = _get_op(WordRepetitionFilter, rep_len=self._rep_len, min_ratio=0.0, max_ratio=self._max_ratio) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class CharRepeatFilter(Preprocessor): + def __init__(self, rep_len: int = 10, max_ratio: float = 0.4, role: str = 'assistant'): + self._rep_len = rep_len + self._max_ratio = max_ratio + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import CharacterRepetitionFilter + op = _get_op(CharacterRepetitionFilter, rep_len=self._rep_len, min_ratio=0.0, max_ratio=self._max_ratio) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class SpecialCharsFilter(Preprocessor): + def __init__(self, max_ratio: float = 0.25, role: str = 'assistant'): + self._max_ratio = max_ratio + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import SpecialCharactersFilter + op = _get_op(SpecialCharactersFilter, min_ratio=0.0, max_ratio=self._max_ratio) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class AlphanumericFilter(Preprocessor): + def __init__(self, min_ratio: float = 0.25, role: str = 'assistant'): + self._min_ratio = min_ratio + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import AlphanumericFilter as DJAlphanumericFilter + op = _get_op(DJAlphanumericFilter, tokenization=False, min_ratio=self._min_ratio) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class TokenNumFilter(Preprocessor): + def __init__(self, hf_tokenizer: str = 'Qwen/Qwen2.5-0.5B', min_num: int = 10, max_num: int = 8192, role: str = 'assistant'): + self._hf_tokenizer = hf_tokenizer + self._min_num = min_num + self._max_num = max_num + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + tokenizer = _get_tokenizer(self._hf_tokenizer) + texts = [_get_text(r, self._role) for r in rows] + encoded = tokenizer(texts, add_special_tokens=False) + return [r for r, ids in zip(rows, encoded['input_ids']) if self._min_num <= len(ids) <= self._max_num] + + +class TextActionFilter(Preprocessor): + def __init__(self, lang: str = 'en', min_action_num: int = 1, role: str = 'assistant'): + self._lang = lang + self._min_action_num = min_action_num + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import TextActionFilter as DJTextActionFilter + op = _get_op(DJTextActionFilter, lang=self._lang, min_action_num=self._min_action_num) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class StopwordsFilter(Preprocessor): + def __init__(self, lang: str = 'en', min_ratio: float = 0.1, max_ratio: float = 1.0, role: str = 'assistant'): + self._lang = lang + self._min_ratio = min_ratio + self._max_ratio = max_ratio + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import StopWordsFilter + op = _get_op(StopWordsFilter, lang=self._lang, min_ratio=self._min_ratio, max_ratio=self._max_ratio) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class FlaggedWordsFilter(Preprocessor): + def __init__(self, lang: str = 'en', max_ratio: float = 0.045, role: str = 'assistant'): + self._lang = lang + self._max_ratio = max_ratio + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import FlaggedWordFilter + op = _get_op(FlaggedWordFilter, lang=self._lang, min_ratio=0.0, max_ratio=self._max_ratio) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class LanguageFilter(Preprocessor): + def __init__(self, lang: Union[str, List[str]] = '', min_score: float = 0.7, role: str = 'assistant'): + self._lang = lang + self._min_score = min_score + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import LanguageIDScoreFilter + op = _get_op(LanguageIDScoreFilter, lang=self._lang, min_score=self._min_score) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class KenLMFilter(Preprocessor): + def __init__(self, lang: str = 'en', min_ppl: float = 0, max_ppl: float = 1500, role: str = 'assistant'): + self._lang = lang + self._min_ppl = min_ppl + self._max_ppl = max_ppl + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import PerplexityFilter as KenLMPPLFilter + op = _get_op(KenLMPPLFilter, lang=self._lang, min_ppl=self._min_ppl, max_ppl=self._max_ppl) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class MinHashDedupFilter(Preprocessor): + def __init__(self, tokenization: str = 'character', window_size: int = 5, num_permutations: int = 256, jaccard_threshold: float = 0.7, role: str = 'assistant'): + self._tokenization = tokenization + self._window_size = window_size + self._num_permutations = num_permutations + self._jaccard_threshold = jaccard_threshold + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.deduplicator import DocumentMinhashDeduplicator + from data_juicer.core.data import NestedDataset + from data_juicer.utils.constant import Fields + import datasets + + texts = [_get_text(r, self._role) for r in rows] + ds = datasets.Dataset.from_dict({'text': texts}) + ds = ds.map(lambda x: {Fields.stats: {}, Fields.meta: {}}, batched=False) + nd = NestedDataset(ds) + + op = _get_op(DocumentMinhashDeduplicator, + tokenization=self._tokenization, + window_size=self._window_size, + num_permutations=self._num_permutations, + jaccard_threshold=self._jaccard_threshold, + ) + nd = op.run(nd) + keep_texts = set(nd['text']) + seen, result = set(), [] + for r, t in zip(rows, texts): + if t in keep_texts and t not in seen: + seen.add(t) + result.append(r) + return result + + +class LLMQualityFilter(Preprocessor): + def __init__(self, api_endpoint: str, model: str = 'default', min_score: float = 0.5, role: str = 'assistant'): + self._api_endpoint = api_endpoint + self._model = model + self._min_score = min_score + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import LLMQualityScoreFilter + op = _get_op(LLMQualityScoreFilter, api_or_hf_model=self._model, api_endpoint=self._api_endpoint, min_score=self._min_score) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class LLMDifficultyFilter(Preprocessor): + def __init__(self, api_endpoint: str, model: str = 'default', min_score: float = 0.4, max_score: float = 1.0, role: str = 'user'): + self._api_endpoint = api_endpoint + self._model = model + self._min_score = min_score + self._max_score = max_score + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import LLMDifficultyScoreFilter + op = _get_op(LLMDifficultyScoreFilter, api_or_hf_model=self._model, api_endpoint=self._api_endpoint, min_score=self._min_score, max_score=self._max_score) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class LLMConditionFilter(Preprocessor): + def __init__(self, condition: str, api_endpoint: str, model: str = 'default', role: str = 'assistant'): + self._condition = condition + self._api_endpoint = api_endpoint + self._model = model + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import LLMConditionFilter as DJLLMConditionFilter + op = _get_op(DJLLMConditionFilter, condition=self._condition, api_or_hf_model=self._model, api_endpoint=self._api_endpoint) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] + + +class LLMTaskRelevanceFilter(Preprocessor): + def __init__(self, api_endpoint: str, task_desc: str = '', model: str = 'default', min_score: float = 0.5, role: str = 'assistant'): + self._api_endpoint = api_endpoint + self._task_desc = task_desc + self._model = model + self._min_score = min_score + self._role = role + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + from data_juicer.ops.filter import LLMTaskRelevanceFilter as DJLLMTaskRelevanceFilter + op = _get_op(DJLLMTaskRelevanceFilter, api_or_hf_model=self._model, api_endpoint=self._api_endpoint, min_score=self._min_score, valid_dataset=None, task_desc=self._task_desc) + texts = [_get_text(r, self._role) for r in rows] + mask = _keep_mask(op, texts) + return [r for r, keep in zip(rows, mask) if keep] diff --git a/src/twinkle_agentic/preprocessor/dead_loop_filter.py b/src/twinkle_agentic/preprocessor/dead_loop_filter.py new file mode 100644 index 00000000..eb1bd5f6 --- /dev/null +++ b/src/twinkle_agentic/preprocessor/dead_loop_filter.py @@ -0,0 +1,216 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import re +from typing import List, Dict, Any + +from twinkle.preprocessor import Preprocessor + +# ── Hesitation-marker regexes ───────────────────────────────────────────────── +# +# Matches thinking-aloud / self-interruption signals. +# Each pattern intentionally targets SURFACE FORM, not semantic meaning, +# to avoid false positives on normal explanatory language. + +_EN_HESITATE = re.compile( + r'\b(' + # Direct hesitation tokens + r'wait[,\s]*\.{2,}|wait[,\s]+(wait|no|actually|hmm|let)|' + r'no\s+wait|oh\s+wait|but\s+wait|' + # Thinking aloud with self-doubt + r'hmm+[,\s]*\.{0,3}|uh+m*[,\s]*\.{0,3}|' + # Self-correction cascade starters + r'actually[,\s]+no|actually[,\s]+wait|actually[,\s]+i\s+was|' + r'no[,\s]+actually[,\s]+(that|this|i)|' + # Explicit restart / reconsideration + r'let\s+me\s+(re-?think|try\s+again|start\s+over|reconsider)|' + r'i\'?ll\s+(start\s+over|try\s+again|redo\s+this)|' + # Confusion / disorientation + r'i\'?m\s+(getting\s+confused|going\s+in\s+circles|lost\s+here|not\s+sure\s+where)|' + r'this\s+is\s+(getting|becoming)\s+(messy|complicated\s+fast|circular)|' + # Repeated-mistake acknowledgement + r'i\s+keep\s+(making|getting)\s+(the\s+same\s+)?error|' + r'i\s+(made|keep\s+making)\s+(the\s+same\s+)?(mistake|error)\s+again' + r')\b', + re.IGNORECASE, +) + +_ZH_HESITATE = re.compile( + r'(' + # Direct hesitation tokens. Note: '等一下' is excluded — it overwhelmingly + # appears as a polite '稍等一下' / '请等一下' rather than self-hesitation. + r'等等[,,。\s]*\.{0,3}|哦等等|不不不+|' + # Note: 哦 is excluded (95%+ sentence-final particle, e.g. "拍拍我哦"); 嗯 requires + # repetition (single 嗯 is often affirmation, e.g. "嗯,好的"). + r'嗯{2,}[,,。\s]*\.{0,3}|呃+[,,。\s]*\.{0,3}|' + # Self-correction + r'不对[,,。]?[,,\s]?(等等|重新|让我)|错了[,,。]?\s*让我|' + r'让我(重新|再次?)(想|试|来|考虑|计算)|' + r'我(再|重新)(想想|试试|来一次|考虑)|' + # Confusion / disorientation + r'我(越来越|有点)?(搞不清楚?|不确定|迷糊了?|乱了?)|' + r'这(变得|太|越来越)(复杂|乱|难以?理清)|' + # Repeated-mistake + r'我(好像|似乎|又)(搞|弄)错(了)?|我(又犯|再次犯)(了)?错|' + r'一直(出错|犯错|搞错)' + r')', + re.UNICODE, +) + +_JA_HESITATE = re.compile( + r'(' + r'ちょっと待って|待って待って|いや待って|えっと+[、。\s]*\.{0,3}|' + r'うーん+[、。\s]*\.{0,3}|あれ[、。]?[、。\s]*(また|もう一度)|' + r'もう一度考え直|やり直し|混乱してきた|わからなくなって' + r')', + re.UNICODE, +) + +_KO_HESITATE = re.compile( + r'(' + r'잠깐[,\s]*\.{0,3}|아\s*잠깐|잠깐만요?|' + r'음+[,\s]*\.{0,3}|어+[,\s]*\.{0,3}|' + r'다시\s*(생각|시작|해보|해야)|' + r'헷갈(리기|리네|려서)|' + r'계속\s*(틀리|실수|잘못)' + r')', + re.UNICODE, +) + +# Combined list for density scan +_HESITATE_PATTERNS = (_EN_HESITATE, _ZH_HESITATE, _JA_HESITATE, _KO_HESITATE) + +# Lightweight per-char cascade pattern (fast scan for dense clusters). +# 'let me' is excluded — it is the canonical agent-prelude phrasing +# ("Let me read the file...") and over-fires on long agent trajectories. +_CASCADE_RE = re.compile( + r'\b(wait|actually|hmm|no\s+wait|oh\s+wait|' + r'i\s+was\s+wrong|i\s+made\s+an?\s+(error|mistake))\b|' + r'(等等|不对|重新|错了|嗯{2,}|让我再)', + re.IGNORECASE | re.UNICODE, +) + + +# ── Detection helpers ───────────────────────────────────────────────────────── + +def _hesitation_density(text: str) -> float: + """Count hesitation markers per 1000 chars across all language patterns.""" + count = sum(len(p.findall(text)) for p in _HESITATE_PATTERNS) + return count / max(len(text), 1) * 1000 + + +def _has_correction_cascade_with_threshold(text: str, threshold: int, window: int = 800) -> bool: + matches = [m.start() for m in _CASCADE_RE.finditer(text)] + if len(matches) < threshold: + return False + for i in range(len(matches) - threshold + 1): + if matches[i + threshold - 1] - matches[i] <= window: + return True + return False + + +def _high_repetition_with_threshold(text: str, threshold: float, ngram_size: int = 8, ngram_min_words: int = 30) -> bool: + words = text.split() + if len(words) < ngram_min_words: + return False + ngrams = [' '.join(words[i:i + ngram_size]) for i in range(len(words) - ngram_size + 1)] + unique_ratio = len(set(ngrams)) / len(ngrams) + return (1.0 - unique_ratio) > threshold + + +def _is_stuck( + text: str, + hesitation_density_threshold: float = 7.0, + cascade_window: int = 800, + cascade_threshold: int = 5, + repetition_threshold: float = 0.45, + ngram_size: int = 8, + ngram_min_words: int = 30, + think_hesitation_density_threshold: float = 15.0, + think_cascade_threshold: int = 20, + think_repetition_threshold: float = 0.65, +) -> bool: + """Return True if the text exhibits signs of a hesitation / dead-loop.""" + import re as _re + think_match = _re.search(r'(.*?)', text, _re.DOTALL) + if think_match: + think_part = think_match.group(1) + response_part = text[think_match.end():] + think_stuck = ( + _hesitation_density(think_part) > think_hesitation_density_threshold + or _has_correction_cascade_with_threshold(think_part, think_cascade_threshold, cascade_window) + or _high_repetition_with_threshold(think_part, think_repetition_threshold, ngram_size, ngram_min_words) + ) + response_stuck = response_part.strip() and ( + _hesitation_density(response_part) > hesitation_density_threshold + or _has_correction_cascade_with_threshold(response_part, cascade_threshold, cascade_window) + or _high_repetition_with_threshold(response_part, repetition_threshold, ngram_size, ngram_min_words) + ) + return think_stuck or response_stuck + return ( + _hesitation_density(text) > hesitation_density_threshold + or _has_correction_cascade_with_threshold(text, cascade_threshold, cascade_window) + or _high_repetition_with_threshold(text, repetition_threshold, ngram_size, ngram_min_words) + ) + + +# ── Preprocessor ───────────────────────────────────────────────────────────── + +class DeadLoopFilter(Preprocessor): + + def __init__( + self, + hesitation_density_threshold: float = 7.0, + cascade_window: int = 800, + cascade_threshold: int = 5, + repetition_threshold: float = 0.45, + ngram_size: int = 8, + ngram_min_words: int = 30, + think_hesitation_density_threshold: float = 15.0, + think_cascade_threshold: int = 20, + think_repetition_threshold: float = 0.65, + ) -> None: + super().__init__() + self._hesitation_density_threshold = hesitation_density_threshold + self._cascade_window = cascade_window + self._cascade_threshold = cascade_threshold + self._repetition_threshold = repetition_threshold + self._ngram_size = ngram_size + self._ngram_min_words = ngram_min_words + self._think_hesitation_density_threshold = think_hesitation_density_threshold + self._think_cascade_threshold = think_cascade_threshold + self._think_repetition_threshold = think_repetition_threshold + + def __call__(self, rows) -> List[Dict[str, Any]]: + out = [] + for row in rows: + # Agent rollouts (Cline / OpenClaw / Claude Code) carry long + # trajectories whose phrasing legitimately matches our hesitation + # heuristics; trust the upstream AgentTraceFilter tag and skip. + if row.get('is_agent'): + out.append(row) + continue + messages = row.get('messages') or [] + asst_msgs = [ + m for m in messages + if isinstance(m, dict) and m.get('role') == 'assistant' + ] + if not asst_msgs: + out.append(row) + continue + stuck = any( + _is_stuck( + (m.get('content') or '').strip(), + hesitation_density_threshold=self._hesitation_density_threshold, + cascade_window=self._cascade_window, + cascade_threshold=self._cascade_threshold, + repetition_threshold=self._repetition_threshold, + ngram_size=self._ngram_size, + ngram_min_words=self._ngram_min_words, + think_hesitation_density_threshold=self._think_hesitation_density_threshold, + think_cascade_threshold=self._think_cascade_threshold, + think_repetition_threshold=self._think_repetition_threshold, + ) + for m in asst_msgs + ) + if not stuck: + out.append(row) + return out diff --git a/src/twinkle_agentic/preprocessor/hard_filter.py b/src/twinkle_agentic/preprocessor/hard_filter.py new file mode 100644 index 00000000..b6309098 --- /dev/null +++ b/src/twinkle_agentic/preprocessor/hard_filter.py @@ -0,0 +1,191 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import re +from typing import Any, Dict, List + +from twinkle.preprocessor import Preprocessor + +# ── Language detection ──────────────────────────────────────────────────────── + +_CJK_RE = re.compile( + r'[\u4e00-\u9fff' # CJK Unified Ideographs (Chinese) + r'\u3040-\u309f' # Hiragana + r'\u30a0-\u30ff' # Katakana + r'\uac00-\ud7a3]', # Hangul Syllables + re.UNICODE, +) + + +def _cjk_ratio(text: str) -> float: + return len(_CJK_RE.findall(text)) / max(len(text), 1) + + +# ── English simple-query patterns ───────────────────────────────────────────── + +_EN_GREETING_RE = re.compile( + r'^(h+e+l+l+o+|h+i+|hey+|yo+|howdy|greetings|' + r'good\s+(morning|afternoon|evening|night|day)|' + r'what\'?s\s+up|how\'?s\s+it\s+going|how\s+are\s+you)' + r'[\s,!.?]*$', + re.IGNORECASE, +) + +_EN_SIMPLE_RE = re.compile( + r'^(' + # bare wh-question: interrogative word + short tail + r'(what|who|where|when|why|how)\s+(is|are|was|were|does|do|did|has|have|can|could|would|should)\b.{0,30}|' + r'(what|who|where|when|why|how)\'s\b.{0,30}|' + # polar question opener + r'(is|are|was|were|do|does|did|can|could|would|should|may|might)\s+(it|this|that|you|there|they|he|she)\b.{0,30}|' + # imperative with no body + r'(tell\s+me(\s+(about|more))?|explain(\s+to\s+me)?|define|describe|list|summarize|give\s+me)\b.{0,20}|' + # help-me opener (no task detail) + r'(please\s+)?(help\s+me|assist\s+me)\b.{0,20}' + r')\s*[?!.]?$', + re.IGNORECASE | re.DOTALL, +) + +# ── Chinese simple-query patterns ───────────────────────────────────────────── + +_ZH_GREETING_RE = re.compile( + r'^(你好+|您好+|早上好|下午好|晚上好|大家好|嗨+|哈+喽+|哈+|喂+|hello+|hi+)' + r'[\s,,!!。.]*$', + re.UNICODE, +) + +_ZH_SIMPLE_RE = re.compile( + r'^(' + # "X是什么" / "什么是X" / "X怎么样" + r'.{0,7}(是什么|是啥|啥意思|是何|什么意思|怎么样|如何|为什么|为啥)[??。]?|' + r'(什么|啥|哪|谁|何|怎么|怎样|为什么|为啥|几|多少|何时|何地).{0,7}[??。]?|' + # single-verb imperative with no substantive object + r'(介绍|解释|说明|告诉我|帮我说说|请问|能说说|讲讲).{0,5}|' + # short open-ended knowledge prompt with no substantive body + r'(请\s*(给出|介绍|解释|说明|提供|列举|讲讲|阐述|描述|概述|举例|分析|说一下)|能否\s*(给出|设计|提供|介绍|解释|说明)).{0,10}' + r')\s*[??!!。]?$', + re.UNICODE, +) + +# ── Japanese simple-query patterns ──────────────────────────────────────────── + +_JA_GREETING_RE = re.compile( + r'^(こんにちは+|こんばんは+|おはよう(ございます)?|やあ+|どうも+|はじめまして|よろしく(おねがいします)?)' + r'[\s!!。.]*$', + re.UNICODE, +) + +_JA_SIMPLE_RE = re.compile( + r'^(' + r'.{0,7}(とは何ですか|って何|とはなんですか|について教えて(ください)?|はどうですか|ですか)[??]?|' + r'(何|なに|どこ|いつ|誰|だれ|なぜ|どうして|どう|どれ|どの).{0,7}[??。]?' + r')\s*[??!!。]?$', + re.UNICODE, +) + +# ── Korean simple-query patterns ────────────────────────────────────────────── + +_KO_GREETING_RE = re.compile( + r'^(안녕(하세요|하십니까)?|좋은\s*(아침|오후|저녁)|반갑습니다|여보세요)' + r'[\s!!.]*$', + re.UNICODE, +) + +_KO_SIMPLE_RE = re.compile( + r'^(' + r'.{0,7}(이?란\s*무엇|는\s*무엇|은\s*무엇|이?\s*뭐|가\s*뭐)[인가요까요]?[??]?|' + r'(무엇|뭐|어디|언제|누가|왜|어떻게).{0,7}[??]?|' + r'.{0,7}(에\s*대해|에\s*관해)\s*(알려주|설명해)[세요주십시오]?' + r')\s*[??!!]?$', + re.UNICODE, +) + + +# ── Core helpers ────────────────────────────────────────────────────────────── + +def _is_simple_query(text: str, min_user_chars: int = 10, min_user_chars_cjk: int = 6) -> bool: + """Return True if ``text`` is a greeting or trivially simple question.""" + t = text.strip() + if not t: + return True + + if _cjk_ratio(t) >= 0.3: + if len(t) < min_user_chars_cjk: + return True + return bool( + _ZH_GREETING_RE.match(t) or _ZH_SIMPLE_RE.match(t) or + _JA_GREETING_RE.match(t) or _JA_SIMPLE_RE.match(t) or + _KO_GREETING_RE.match(t) or _KO_SIMPLE_RE.match(t) + ) + + if len(t) < min_user_chars: + return True + return bool(_EN_GREETING_RE.match(t) or _EN_SIMPLE_RE.match(t)) + + +_MIN_THINKING_CHARS = 200 + + +def _has_thinking(msg: Dict[str, Any], min_chars: int = _MIN_THINKING_CHARS) -> bool: + """Return True if an assistant message carries a sufficiently long thinking chain.""" + thinking = msg.get('thinking') or msg.get('reasoning_content') or '' + if isinstance(thinking, str): + return len(thinking.strip()) >= min_chars + return bool(thinking) + + +# ── Preprocessor ───────────────────────────────────────────────────────────── + +class HardFilter(Preprocessor): + + def __init__( + self, + min_user_chars: int = 10, + min_user_chars_cjk: int = 6, + min_assistant_chars_2turn: int = 80, + allow_incomplete_role: bool = False, + ) -> None: + super().__init__() + self._min_user_chars = min_user_chars + self._min_user_chars_cjk = min_user_chars_cjk + self._min_assistant_chars_2turn = min_assistant_chars_2turn + self.allow_incomplete_role = allow_incomplete_role + + def __call__(self, rows) -> List[Dict[str, Any]]: + """Drop rows that are trivially low-quality by two rules: + + Rule 1 — Single-turn simple query: + Only one user message AND that message is a greeting or bare simple question. + + Rule 2 — Two-turn shallow assistant reply: + Exactly one user + one assistant turn, assistant reply is shorter than + _MIN_ASSISTANT_CHARS_2TURN, and the assistant message has no thinking chain. + """ + out = [] + for row in rows: + messages = row.get('messages') or [] + if not isinstance(messages, list): + continue + + user_msgs = [m for m in messages if isinstance(m, dict) and m.get('role') == 'user'] + asst_msgs = [m for m in messages if isinstance(m, dict) and m.get('role') == 'assistant'] + + if not user_msgs: + if self.allow_incomplete_role: + out.append(row) + continue + + # Rule 1: single-turn trivial query (skip if assistant has thinking) + if len(user_msgs) == 1: + user_text = (user_msgs[0].get('content') or '').strip() + if _is_simple_query(user_text, self._min_user_chars, self._min_user_chars_cjk): + if not asst_msgs or not _has_thinking(asst_msgs[0], _MIN_THINKING_CHARS): + continue + + # Rule 2: two-turn shallow reply without thinking + if len(user_msgs) == 1 and len(asst_msgs) == 1: + asst = asst_msgs[0] + asst_text = (asst.get('content') or '').strip() + if len(asst_text) < self._min_assistant_chars_2turn and not _has_thinking(asst): + continue + + out.append(row) + return out diff --git a/src/twinkle_agentic/preprocessor/intent_classifier.py b/src/twinkle_agentic/preprocessor/intent_classifier.py new file mode 100644 index 00000000..2706f417 --- /dev/null +++ b/src/twinkle_agentic/preprocessor/intent_classifier.py @@ -0,0 +1,391 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import re +from collections import Counter +from typing import Any, Dict, List, Optional + +from twinkle.preprocessor import Preprocessor +from twinkle.utils import get_logger + +logger = get_logger(only_local_master=False) + +# ── Intent categories ───────────────────────────────────────────────────────── +INTENT_TOOL_CALL = 'tool_call' +INTENT_CODE = 'code' +INTENT_MATH = 'math' +INTENT_COMPLEX_LOGIC = 'complex_logic' +INTENT_USER_DISSATISFACTION = 'user_dissatisfaction' +INTENT_OTHER = 'other' + +# ── Heuristic patterns ──────────────────────────────────────────────────────── +_CODE_BLOCK_RE = re.compile(r'```[\s\S]{10,}?```') +_CODE_KEYWORD_RE = re.compile( + r'\b(def |class |import |from |function |const |let |var |return |if \(|for \(|while \(|' + r'#include|public class|private |protected |async |await |yield |throw |throws |catch |' + r'switch |case |break |continue |void |struct |enum |interface |abstract |static |final |' + r'namespace |package |module |export |lambda |func |fn |println|console\.log)\b|' + # Symbolic call / arrow signatures occur even without the keywords above. + r'(?:[a-zA-Z_]\w*\([^)\n]*\)\s*\{|=>\s*\{|->\s*[A-Za-z_]\w*)' +) + +_MATH_LATEX_RE = re.compile( + r'(\$\$.+?\$\$|\$[^$\n]+?\$|' + r'\\frac|\\sum|\\int|\\lim|\\begin\{(equation|align|matrix)|' + r'\\mathbb|\\partial|\\nabla|\\sqrt|\\overline|' + r'\\boxed|\\text\{|\\mathrm|\\langle|\\rangle|\\cdot|' + r'\\times|\\div|\\pm|\\leq|\\geq|\\neq|\\approx|\\equiv|' + r'\\infty|\\pi|\\alpha|\\beta|\\gamma|\\theta|\\lambda|\\mu|\\sigma|\\prod|\\to|\\rightarrow|' + r'\\\[.+?\\\]|' + # R1-distill writes math in plain Unicode without $...$; catch operators, Greek, sub/super digits, fractions. + r'[×÷±°∑∏∫√∂∇∞∈∋⊂⊃⊆⊇≤≥≠≈≡≅∝⇒⇔]|' + r'[α-ωΔΘΛΞΠΣΦΨΩ]|' + r'[⁰¹²³⁴-⁹₀-₉]|' + r'[½⅓⅔¼¾⅛⅜⅝⅞]|' + # Arithmetic equation pattern catches '30 ÷ 6 = 5' even when other markers are absent. + r'\d+\s*[×÷\*/\+\-]\s*\d+\s*=\s*\d+|' + # ≥4 comma-separated integers — number-sequence pattern. + r'\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+|' + # 'x = 5' / 'a = -3' style assignment. + r'[a-zA-Z]\s*=\s*-?\d+|' + # Chinese math vocabulary (strong indicators; ≥2 hits required so single occurrences in non-math text are safe). + r'积分|微分|导数|求导|偏导|梯度|极限|矩阵|向量|行列式|特征值|特征向量|' + r'多项式|因式分解|不等式|方程组?|二次方程|线性方程|求解|解方程|未知数|化简|约分|通分|因式|代入|应用题|算式|算术|计算题|一元(?:一次|二次|三次|方程|不等式|多项式)|二元(?:一次|二次|方程)?|' + r'平方|立方|开方|根号|对数|指数函数|三角函数|正弦|余弦|正切|余切|反三角|' + r'概率|期望值?|方差|标准差|分布|随机变量|均值|中位数|众数|百分比|比例|比率|' + r'子集|并集|交集|空集|集合|映射|' + r'乘以|除以|平方根|立方根|平方米|立方米|' + r'系数|常数项|首项|项数|公差|公比|' + r'切线|法线|渐近线|对称轴|双曲线|抛物线|椭圆|' + # Geometry. + r'三角形|四边形|多边形|长方形|正方形|圆形|圆锥|圆柱|球体|平行四边形|梯形|菱形|' + r'半径|直径|周长|面积|体积|对角线|内角|外角|锐角|钝角|直角|平角|余角|补角|勾股|弧度|象限|坐标系|' + # Sequences / number theory / elementary math. + r'数列|数字序列|等差数列|等比数列|等差|等比|通项|递推公式|' + r'奇数(?:位|项)?|偶数(?:位|项)?|质数|素数|合数|整数|小数|分数|有理数|无理数|实数|' + r'因数|倍数|公因数|公倍数|最大公约数|最小公倍数|阶乘|排列组合|' + r'余数|商(?=是|为|等)|被除数|除数|被乘数|乘数|' + r'(?:加|减|乘|除)\d+|' + r'第\d+(?:位|项)|' + # English math vocabulary. + r'\b(integral|differential|derivative|gradient|polynomial|equation|inequality|' + r'matrix|vector|determinant|eigenvalue|eigenvector|coefficient|' + r'logarithm|exponential|sqrt|theorem|lemma|proof|qed|axiom|corollary|' + r'sine|cosine|tangent|cosecant|secant|cotangent|arcsin|arccos|arctan|' + r'probability|variance|expectation|distribution|stddev|deviation|median|mean|mode|' + r'subset|superset|union|intersection|multiply|divide|squared|cubed|factorial|' + r'radius|diameter|circumference|perimeter|hypotenuse|congruent|parallel|perpendicular)\b|' + r'\w_\{[^}]+\}|\w\^\{[^}]+\})', + re.DOTALL, +) + +# ── Complex logic patterns ──────────────────────────────────────────────────── +_LOGIC_STRUCTURE_RE = re.compile( + # Sequential reasoning markers (Chinese) + r'首先.{4,}其次|其次.{4,}最后|第一.{4,}第二.{4,}第三|' + r'一方面.{4,}另一方面|从.{1,6}角度|' + # Conditional / branching (Chinese) + r'如果.{2,30}那么|假设.{2,30}则|若.{2,20}则|' + r'分(为|成).{0,5}(种|类|个).{0,10}(情况|情形|场景|类型)|分情况讨论|' + # Causal chains (Chinese) + r'因为.{2,40}所以|由于.{2,40}因此|既然.{2,30}那么|' + r'导致.{2,30}进而|之所以.{2,30}是因为|' + # Synthesis / conclusion (Chinese) + r'综上(所述)?|综合(以上|来看|分析)|总[的而]言之|由此可[得见知]|' + # Comparison / trade-off (Chinese) + r'优缺点|利弊|优劣|权衡|对比分析|相比之下|' + # Multi-constraint reasoning (Chinese) + r'需要同时满足|同时考虑|兼顾|约束条件|' + # Sequential reasoning markers (English) + r'\b(first(ly)?|second(ly)?|third(ly)?|finally|furthermore|moreover|in addition|' # noqa: E501 + r'on (the )?one hand|on the other hand|' # noqa: E501 + r'as a result|consequently|therefore|hence|thus|accordingly)\b|' + # Conditional / branching (English) + r'\b(if .{5,30} then|assuming .{5,30} then|in (case|scenario) .{2,10}(A|B|1|2)|' # noqa: E501 + r'case \d|scenario \d)\b|' + # Synthesis (English) + r'\b(in (conclusion|summary)|to (summarize|conclude)|overall|all things considered|' # noqa: E501 + r'weighing .{3,20} against|pros and cons|trade-?offs?|advantages .{0,10} disadvantages)\b', + re.DOTALL | re.IGNORECASE, +) + +_DISSATISFACTION_ZH_RE = re.compile( + # Quality / correctness complaints. + r'不[满好对行准确靠谱严]|不太[行好对准]|不正确|不准确|不对劲|不靠谱|不严谨|' + # Severity intensifiers. + r'太(差|慢|烂|傻|笨|垃圾|菜|弱|水|差劲|low)|这(么)?(差|烂|垃圾|傻|破|low)|' + # Redo / retry. + r'重[做来新答试]|重新(回答|做|来|算|想|考虑|生成)|再(答|来|做|算|想|试)一(次|遍|回|下)|你再答|' + # Wrong / errors. + r'错了?|错误|又错|搞错|弄错|出错|完全错|全错|大错|根本不(对|是)|压根不(对|是)|' + # Off-topic / unhelpful. + r'有问题|没用|没帮助|答非所问|文不对题|牛头不对|风马牛|跑题|偏题|偏离|跑偏|' + # Stop talking nonsense. + r'别瞎|别乱|别胡|你在说(什么|啥)|这是什么|这都什么|' + r'离谱|搞什么|质量(太|很差)|胡(说|扯|言|乱|写|编|闹)|瞎(编|说|扯|写|想|猜|蒙|讲)|' + # Random / illogical. + r'莫名其妙|一塌糊涂|一派胡言|谬(论|误)|废话|屁话|没逻辑|没道理|说不通|不合逻辑|' + # Negative emotion. + r'不(满意|开心|高兴)|失望|让(我|人)失望|烦人|真烦|厌|气死|' + # Misunderstanding / model failure. + r'你(没|不)(懂|理解|明白|听懂)|理解错|抓不住重点|没get|没get到|' + r'我说的不是|我问的不是|这不是我(说|问|想|要)|你听(错|不懂)|没听懂|' + # Time / value waste. + r'浪费时间|没意义|没价值|垃圾|废物|' + # Generic anger. + r'什么(玩意|东西|鬼)|你这是|你这答', +) +_DISSATISFACTION_EN_RE = re.compile( + # Negative adjectives. + r'\b(wrong|incorrect|useless|terrible|awful|horrible|bad|poor|lousy|sloppy|stupid|dumb|' + r'idiotic|ridiculous|broken|misleading|infuriating|annoying|disappointing|disappointed|' + r'unacceptable|unhelpful|inaccurate|imprecise|sub[- ]?par|low[- ]?quality)\b|' + # "not X" complaints. + r'\bnot (correct|right|good|helpful|useful|accurate|relevant|making sense|' + r'what (i|I) (asked|wanted|meant|need|expected|requested))\b|' + # Negation phrasings. + r'(doesn\'?t|does not|didn\'?t|did not) (make sense|work|help|fit|match|address)|' + r'makes? (no|zero|little) sense|' + # Redo / retry. + r'\b(redo|try again|do (it|this|that) again|start over|start again|do over|do better|' + r'once more|again from scratch)\b|' + # Insults / bullshit. + r'\b(nonsense|garbage|trash|crap|bullshit|bs|baloney|hogwash|gibberish)\b|' + r'(low|poor|bad|terrible) quality|waste of (time|effort|energy)|' + # Misunderstanding. + r'you (misunderstood|don\'?t understand|didn\'?t (get it|understand|listen)|missed (the|my) point)|' + r'that\'?s (not what|wrong|incorrect|terrible|garbage|nonsense|useless)|' + # Profanity. + r'\b(WTF|wth|what the (heck|hell|fuck))\b|' + # Off-target. + r'\b(off[- ]topic|missed the mark|way off|completely off|totally wrong|nowhere near)\b|' + r'not (even|really|quite) (close|right|correct)|' + # Sarcasm / disbelief. + r'come on|are you (serious|kidding|joking|sure)|' + r'\bfrustrat\w+\b', + re.IGNORECASE, +) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _msg_text(msg: Dict[str, Any]) -> str: + c = msg.get('content') + if isinstance(c, str): + return c + if isinstance(c, list): + return ' '.join( + p.get('text', '') for p in c + if isinstance(p, dict) and p.get('type') == 'text' + ) + return '' + + +def _pair_assistant(messages: List[Dict[str, Any]], idx: int, role: str) -> Optional[int]: + """Resolve which assistant idx represents the round that owns a signal at (idx, role).""" + if role == 'assistant': + return idx + if role == 'user': + for j in range(idx + 1, len(messages)): + m = messages[j] + if isinstance(m, dict) and m.get('role') == 'assistant': + return j + return None + + +# ── Intent detectors (extensible pipeline) ──────────────────────────────────── + +class IntentDetector: + """Base class. Each subclass sets ``intent`` and implements ``__call__``. + + ``__call__(messages)`` returns a list of assistant indices (key rounds) that + match this intent within the given trajectory. An empty list means no match. + Set ``definitive = True`` so the pipeline short-circuits on this detector + (used for hard signals such as tool calls). + """ + + intent: str = '' + definitive: bool = False + + def __call__(self, messages: List[Dict[str, Any]]) -> List[int]: + raise NotImplementedError + + +class _RegexDetector(IntentDetector): + """Common scaffolding: scan messages, run ``_match`` on each text, pair to assistant.""" + + role_filter: Optional[str] = None + + def _match(self, text: str) -> bool: + return False + + def __call__(self, messages): + rounds = set() + for idx, m in enumerate(messages): + if not isinstance(m, dict): + continue + role = m.get('role') + # tool/system messages can never resolve to a key round (see _pair_assistant) + # and tool outputs are often multi-MB — skip to avoid wasted regex scans. + if role not in ('assistant', 'user'): + continue + if self.role_filter and role != self.role_filter: + continue + text = _msg_text(m) + if not text or not self._match(text): + continue + asst_idx = _pair_assistant(messages, idx, role) + if asst_idx is not None: + rounds.add(asst_idx) + return sorted(rounds) + + +class ToolCallDetector(IntentDetector): + """Mark every assistant turn that carries a ``tool_calls`` payload.""" + + intent = INTENT_TOOL_CALL + definitive = True + + def __call__(self, messages): + return [ + i for i, m in enumerate(messages) + if isinstance(m, dict) and m.get('role') == 'assistant' and m.get('tool_calls') + ] + + +class CodeDetector(_RegexDetector): + intent = INTENT_CODE + + def __init__(self, threshold: int = 3) -> None: + self.threshold = threshold + + def _match(self, text): + blocks = _CODE_BLOCK_RE.findall(text) + if blocks: + return True + return len(_CODE_KEYWORD_RE.findall(text)) >= self.threshold + + +class MathDetector(_RegexDetector): + intent = INTENT_MATH + + def __init__(self, threshold: int = 4) -> None: + self.threshold = threshold + + def _match(self, text): + return len(_MATH_LATEX_RE.findall(text)) >= self.threshold + + +class ComplexLogicDetector(_RegexDetector): + intent = INTENT_COMPLEX_LOGIC + role_filter = 'assistant' + + def __init__(self, threshold: int = 6) -> None: + self.threshold = threshold + + def _match(self, text): + return len(_LOGIC_STRUCTURE_RE.findall(text)) >= self.threshold + + +class UserDissatisfactionDetector(_RegexDetector): + intent = INTENT_USER_DISSATISFACTION + role_filter = 'user' + + def _match(self, text): + return bool(_DISSATISFACTION_ZH_RE.search(text) or _DISSATISFACTION_EN_RE.search(text)) + + def __call__(self, messages): + # Dissatisfaction is a reaction — require at least one prior assistant turn. + seen_assistant = False + rounds = set() + for idx, m in enumerate(messages): + if not isinstance(m, dict): + continue + role = m.get('role') + if role == 'assistant': + seen_assistant = True + continue + if role != 'user' or not seen_assistant: + continue + text = _msg_text(m) + if text and self._match(text): + asst_idx = _pair_assistant(messages, idx, role) + if asst_idx is not None: + rounds.add(asst_idx) + return sorted(rounds) + + +# ── Preprocessor ────────────────────────────────────────────────────────────── + +class IntentClassifier(Preprocessor): + """Annotate each trajectory with its primary intent and key-round indices. + + Pure-heuristic, no LLM. Each intent is a pluggable :class:`IntentDetector`; + pass ``detectors=[...]`` to extend or override. + + Annotates per row:: + + row['intent'] # primary intent string + row['user_data']['key_rounds'] # list[int] of assistant indices + row['user_data']['intents'] # dict[int, str] per-round intent + """ + + DEFAULT_DETECTORS: List[IntentDetector] = [ + ToolCallDetector(), + CodeDetector(), + MathDetector(), + ComplexLogicDetector(), + UserDissatisfactionDetector(), + ] + + def __init__( + self, + detectors: Optional[List[IntentDetector]] = None, + intent_field: str = 'intent', + drop_no_key_rounds: bool = True, + ) -> None: + super().__init__() + self._intent_field = intent_field + self._drop_no_key_rounds = drop_no_key_rounds + self._detectors = list(detectors) if detectors is not None else list(self.DEFAULT_DETECTORS) + + def _detect(self, messages: List[Dict[str, Any]]) -> Dict[int, str]: + """Run detector pipeline; later detectors never override earlier intent on the same round.""" + round_intents: Dict[int, str] = {} + for det in self._detectors: + rounds = det(messages) + if not rounds: + continue + for idx in rounds: + round_intents.setdefault(idx, det.intent) + if det.definitive: + break + return round_intents + + def __call__(self, rows) -> List[Dict[str, Any]]: + if not rows: + return rows + + out = [] + for row in rows: + row = dict(row) + messages = row.get('messages') + round_intents = ( + self._detect(messages) if isinstance(messages, list) and messages else {} + ) + + if round_intents: + primary = Counter(round_intents.values()).most_common(1)[0][0] + user_data = dict(row.get('user_data') or {}) + user_data['key_rounds'] = sorted(round_intents) + user_data['intents'] = {str(k): v for k, v in round_intents.items()} + row['user_data'] = user_data + else: + if self._drop_no_key_rounds: + continue + primary = INTENT_OTHER + + row[self._intent_field] = primary + out.append(row) + + dist = Counter(r[self._intent_field] for r in out) + logger.info(f'[IntentClassifier] distribution: {dict(dist)}') + return out diff --git a/src/twinkle_agentic/preprocessor/llm_backend.py b/src/twinkle_agentic/preprocessor/llm_backend.py new file mode 100644 index 00000000..106825db --- /dev/null +++ b/src/twinkle_agentic/preprocessor/llm_backend.py @@ -0,0 +1,341 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Abstract LLM backend for preprocessor pipeline. + +Supports two modes: + - OpenAIBackend: httpx-based calls to any OpenAI-compatible HTTP server + - SamplerBackend: direct calls to Twinkle vLLMSampler Ray actor (no HTTP) +""" +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple + +from twinkle.utils import get_logger + +logger = get_logger(only_local_master=False) + + +class LLMBackend(ABC): + """Abstract base for LLM inference used by QualityPreprocessor stages.""" + + @abstractmethod + def chat( + self, + messages: List[Dict[str, Any]], + *, + temperature: float = 0.0, + max_tokens: int = 16, + n: int = 1, + ) -> List[Dict[str, str]]: + """Chat completion. + + Returns: + List of n choices, each a dict with keys 'content' and 'reasoning_content'. + """ + + def chat_batch( + self, + messages_list: List[List[Dict[str, Any]]], + *, + temperature: float = 0.0, + max_tokens: int = 16, + n: int = 1, + ) -> List[List[Dict[str, str]]]: + """Batched chat completion. Returns one List[choice] per input messages list. + + Default impl loops over `chat`; backends should override to fan out concurrently + (HTTP) or pass the full list to the underlying sampler in a single call (vLLM DP). + """ + return [ + self.chat(m, temperature=temperature, max_tokens=max_tokens, n=n) + for m in messages_list + ] + + @abstractmethod + def prompt_logprobs(self, messages: List[Dict[str, Any]]) -> Optional[List]: + """Evaluate prompt tokens without generation. + + Returns: + List of per-token logprob entries (format varies by backend but + is compatible with _extract_logprob helpers), or None on failure. + """ + + @abstractmethod + def prompt_logprobs_ids(self, input_ids_list: List[List[int]]) -> List[List]: + """Batched: evaluate raw token-id prompts without chat template wrapping. + + Used for unconditional perplexity (e.g. IFD denominator). Caller MUST + supply a list of token-id sequences; for distributed backends the list + length must satisfy backend-specific batching constraints (e.g. + ``len >= dp_world_size`` for SamplerBackend). + """ + + def embeddings(self, texts: List[str]) -> Any: + """Compute text embeddings. Override in backends that support it.""" + raise NotImplementedError(f'{type(self).__name__} does not support embeddings') + + +class OpenAIBackend(LLMBackend): + """Backend wrapping any OpenAI-compatible HTTP endpoint.""" + + def __init__( + self, + endpoint: str, + model: str = 'default', + api_key: str = '', + timeout: float = 120.0, + ): + import httpx + headers = {'Content-Type': 'application/json'} + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + self._client = httpx.Client(timeout=timeout, headers=headers) + base = endpoint.rstrip('/') + self._chat_endpoint = f'{base}/v1/chat/completions' + self._embed_endpoint = f'{base}/v1/embeddings' + self._model = model + + @property + def model(self) -> str: + return self._model + + def chat( + self, + messages: List[Dict[str, Any]], + *, + temperature: float = 0.0, + max_tokens: int = 16, + n: int = 1, + ) -> List[Dict[str, str]]: + try: + resp = self._client.post(self._chat_endpoint, json={ + 'model': self._model, + 'messages': messages, + 'temperature': temperature, + 'max_tokens': max_tokens, + 'n': n, + }) + resp.raise_for_status() + choices = resp.json().get('choices', []) + results = [] + for c in choices: + msg = c.get('message') or {} + results.append({ + 'content': msg.get('content') or '', + 'reasoning_content': msg.get('reasoning_content') or '', + }) + return results + except Exception as e: + logger.warning(f'[OpenAIBackend] chat failed: {e}') + return [] + + def chat_batch( + self, + messages_list: List[List[Dict[str, Any]]], + *, + temperature: float = 0.0, + max_tokens: int = 16, + n: int = 1, + max_workers: int = 16, + ) -> List[List[Dict[str, str]]]: + """Concurrent chat: vLLM HTTP server multiplexes requests; httpx.Client is thread-safe.""" + from concurrent.futures import ThreadPoolExecutor + if not messages_list: + return [] + workers = max(1, min(max_workers, len(messages_list))) + results: List[List[Dict[str, str]]] = [[] for _ in messages_list] + with ThreadPoolExecutor(max_workers=workers) as ex: + futs = { + ex.submit(self.chat, m, temperature=temperature, max_tokens=max_tokens, n=n): i + for i, m in enumerate(messages_list) + } + for fut in futs: + results[futs[fut]] = fut.result() + return results + + def prompt_logprobs(self, messages: List[Dict[str, Any]]) -> Optional[List]: + try: + resp = self._client.post(self._chat_endpoint, json={ + 'model': self._model, + 'messages': messages, + 'max_tokens': 0, + 'prompt_logprobs': 1, + }) + resp.raise_for_status() + return resp.json().get('prompt_logprobs') + except Exception: + return None + + def prompt_logprobs_ids(self, input_ids_list: List[List[int]]) -> List[List]: + endpoint = self._chat_endpoint.rsplit('/', 2)[0] + '/v1/completions' + results: List[List] = [] + for input_ids in input_ids_list: + resp = self._client.post(endpoint, json={ + 'model': self._model, + 'prompt': list(input_ids), + 'max_tokens': 0, + 'echo': True, + 'prompt_logprobs': 1, + }) + resp.raise_for_status() + data = resp.json() + choices = data.get('choices') or [] + if choices and 'prompt_logprobs' in choices[0]: + results.append(choices[0]['prompt_logprobs']) + else: + results.append(data['prompt_logprobs']) + return results + + def embeddings(self, texts: List[str]): + import numpy as np + resp = self._client.post(self._embed_endpoint, json={ + 'model': self._model, + 'input': texts, + }) + resp.raise_for_status() + data = resp.json().get('data', []) + data_sorted = sorted(data, key=lambda x: x.get('index', 0)) + return np.array([d['embedding'] for d in data_sorted], dtype=np.float32) + + +class SamplerBackend(LLMBackend): + """Backend wrapping a Twinkle vLLMSampler (Ray actor, no HTTP overhead).""" + + def __init__( + self, + sampler, + embed_endpoint: str = '', + embed_model: str = 'bge-m3', + ): + """ + Args: + sampler: A vLLMSampler instance (with template already set). + embed_endpoint: Optional OpenAI-compatible endpoint for embeddings. + embed_model: Model name for embeddings. + """ + self._sampler = sampler + self._embed_endpoint = embed_endpoint + self._embed_model = embed_model + self._embed_client = None + if embed_endpoint: + import httpx + self._embed_client = httpx.Client(timeout=120.0) + self._embed_url = f'{embed_endpoint.rstrip("/")}/v1/embeddings' + + def chat( + self, + messages: List[Dict[str, Any]], + *, + temperature: float = 0.0, + max_tokens: int = 16, + n: int = 1, + ) -> List[Dict[str, str]]: + from twinkle.data_format import SamplingParams + trajectory = {'messages': messages} + params = SamplingParams( + temperature=temperature, + max_tokens=max_tokens, + num_samples=n, + ) + try: + responses = self._sampler.sample(trajectory, params) + results = [] + for resp in responses: + for seq in resp.sequences: + text = seq.decoded or '' + reasoning = '' + if '
' in text: + parts = text.split('', 1) + reasoning = parts[0].split('')[-1].strip() + text = parts[1].strip() + results.append({'content': text, 'reasoning_content': reasoning}) + return results + except Exception as e: + logger.warning(f'[SamplerBackend] chat failed: {e}') + return [] + + @staticmethod + def _split_think(text: str) -> Tuple[str, str]: + if '' in text: + parts = text.split('', 1) + return parts[1].strip(), parts[0].split('')[-1].strip() + return text, '' + + def chat_batch( + self, + messages_list: List[List[Dict[str, Any]]], + *, + temperature: float = 0.0, + max_tokens: int = 16, + n: int = 1, + ) -> List[List[Dict[str, str]]]: + """One sampler dispatch over the full list; lets vLLM DP workers stay saturated.""" + from twinkle.data_format import SamplingParams + if not messages_list: + return [] + device_mesh = getattr(self._sampler, 'device_mesh', None) + dp_world_size = getattr(device_mesh, 'dp_world_size', 1) or 1 + n_inputs = len(messages_list) + feats = [{'messages': m} for m in messages_list] + # Pad the dispatch so every DP worker has at least one item; trim duplicates after. + if n_inputs < dp_world_size: + feats = feats + [feats[-1]] * (dp_world_size - n_inputs) + params = SamplingParams(temperature=temperature, max_tokens=max_tokens, num_samples=n) + try: + responses = self._sampler.sample(feats, params) + except Exception as e: + logger.warning(f'[SamplerBackend] chat_batch failed: {e}') + return [[] for _ in range(n_inputs)] + responses = list(responses)[:n_inputs] + out: List[List[Dict[str, str]]] = [] + for resp in responses: + choices: List[Dict[str, str]] = [] + for seq in (getattr(resp, 'sequences', None) or []): + text, reasoning = self._split_think(seq.decoded or '') + choices.append({'content': text, 'reasoning_content': reasoning}) + out.append(choices) + while len(out) < n_inputs: + out.append([]) + return out + + def prompt_logprobs(self, messages: List[Dict[str, Any]]) -> Optional[List]: + from twinkle.data_format import SamplingParams + trajectory = {'messages': messages} + params = SamplingParams(max_tokens=0, prompt_logprobs=1) + try: + responses = self._sampler.sample(trajectory, params) + if responses and responses[0].prompt_logprobs is not None: + return responses[0].prompt_logprobs + return None + except Exception as e: + logger.warning(f'[SamplerBackend] prompt_logprobs failed: {e}') + return None + + def prompt_logprobs_ids(self, input_ids_list: List[List[int]]) -> List[List]: + from twinkle.data_format import SamplingParams + if not isinstance(input_ids_list, list) or not input_ids_list: + raise ValueError('prompt_logprobs_ids requires a non-empty List[List[int]].') + device_mesh = getattr(self._sampler, 'device_mesh', None) + dp_world_size = getattr(device_mesh, 'dp_world_size', 1) or 1 + if len(input_ids_list) < dp_world_size: + raise ValueError( + f'SamplerBackend.prompt_logprobs_ids requires at least ' + f'dp_world_size={dp_world_size} inputs to keep all DP workers busy, ' + f'got {len(input_ids_list)}. Batch upstream before calling.') + feats = [{'input_ids': list(ids)} for ids in input_ids_list] + params = SamplingParams(max_tokens=0, prompt_logprobs=1) + responses = self._sampler.sample(feats, params) + return [r.prompt_logprobs for r in responses] + + def embeddings(self, texts: List[str]): + if self._embed_client is None: + raise NotImplementedError( + 'SamplerBackend requires embed_endpoint for embeddings. ' + 'Pass embed_endpoint when constructing SamplerBackend.') + import numpy as np + resp = self._embed_client.post(self._embed_url, json={ + 'model': self._embed_model, + 'input': texts, + }) + resp.raise_for_status() + data = resp.json().get('data', []) + data_sorted = sorted(data, key=lambda x: x.get('index', 0)) + return np.array([d['embedding'] for d in data_sorted], dtype=np.float32) diff --git a/src/twinkle_agentic/preprocessor/majority_vote.py b/src/twinkle_agentic/preprocessor/majority_vote.py new file mode 100644 index 00000000..c0c5f583 --- /dev/null +++ b/src/twinkle_agentic/preprocessor/majority_vote.py @@ -0,0 +1,155 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional + +from twinkle.preprocessor import Preprocessor + +from .llm_backend import LLMBackend, OpenAIBackend + +_DEFAULT_SYSTEM_PROMPT = ( + 'You are a strict trajectory quality judge. ' + 'Given a multi-turn conversation, decide whether the assistant response is high-quality. ' + 'Criteria: factual accuracy, helpfulness, coherence, and completeness. ' + 'Reply with EXACTLY one word: PASS or FAIL.' +) + + +class JudgeSource: + """One LLM judge backend.""" + + def __init__( + self, + backend: LLMBackend = None, + api_endpoint: str = '', + model: str = 'default', + api_key: str = '', + timeout: float = 120.0, + ): + if backend is not None: + self.backend = backend + else: + self.backend = OpenAIBackend( + endpoint=api_endpoint, model=model, api_key=api_key, timeout=timeout) + + +def _build_judge_messages( + messages: List[Dict[str, Any]], + system_prompt: str, +) -> List[Dict[str, Any]]: + """Wrap the trajectory into a judge prompt.""" + conversation_text = [] + for m in messages: + if not isinstance(m, dict): + continue + role = m.get('role', 'unknown') + content = (m.get('content') or '').strip() + if content: + conversation_text.append(f'[{role}]: {content}') + joined = '\n'.join(conversation_text) + return [ + {'role': 'system', 'content': system_prompt}, + {'role': 'user', 'content': f'Please judge the following conversation:\n\n{joined}'}, + ] + + +def _vote_one( + source: JudgeSource, + judge_messages: List[Dict[str, Any]], + temperature: float, +) -> Optional[bool]: + """Send one judge request. Returns True=PASS, False=FAIL, None=error.""" + choices = source.backend.chat(judge_messages, temperature=temperature, max_tokens=16) + if not choices: + return None + text = choices[0].get('content', '').strip().upper() + if 'PASS' in text: + return True + if 'FAIL' in text: + return False + return None + + +class MajorityVoteFilter(Preprocessor): + """Multi-judge majority vote filter. + + Sends each trajectory to N independent OpenAI-compatible judges. + Keeps the row only if the majority votes PASS. + """ + + def __init__( + self, + sources: List[Dict[str, Any]], + system_prompt: str = _DEFAULT_SYSTEM_PROMPT, + pass_threshold: float = 0.5, + temperature: float = 0.0, + max_workers: int = 8, + skip_on_error: bool = True, + ): + """ + Args: + sources: List of judge source configs, each dict has keys: + api_endpoint (required), model, api_key, timeout. + system_prompt: Evaluation prompt sent to each judge. + pass_threshold: Fraction of votes needed to pass (> threshold keeps). + temperature: Sampling temperature for judges. + max_workers: Thread pool size for concurrent API calls. + skip_on_error: If True, keep rows where all judges failed. + """ + if not sources: + raise ValueError('At least one judge source is required') + self._sources = [JudgeSource(**s) for s in sources] + self._system_prompt = system_prompt + self._pass_threshold = pass_threshold + self._temperature = temperature + self._max_workers = max_workers + self._skip_on_error = skip_on_error + + def _judge_row(self, messages: List[Dict[str, Any]]) -> Optional[bool]: + """Collect votes from all sources for one row. Returns pass/fail/None.""" + judge_msgs = _build_judge_messages(messages, self._system_prompt) + + votes: List[bool] = [] + with ThreadPoolExecutor(max_workers=len(self._sources)) as pool: + futures = [ + pool.submit(_vote_one, src, judge_msgs, self._temperature) + for src in self._sources + ] + for f in as_completed(futures): + result = f.result() + if result is not None: + votes.append(result) + + if not votes: + return None + return sum(votes) / len(votes) > self._pass_threshold + + def __call__(self, rows) -> List[Dict[str, Any]]: + """Filter rows by majority vote across configured judge sources.""" + if not rows: + return rows + + results: Dict[int, Optional[bool]] = {} + n_workers = min(self._max_workers, len(rows)) + + with ThreadPoolExecutor(max_workers=n_workers) as pool: + future_to_idx = { + pool.submit(self._judge_row, row.get('messages') or []): i + for i, row in enumerate(rows) + } + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + try: + results[idx] = future.result() + except Exception: + results[idx] = None + + out = [] + for i, row in enumerate(rows): + verdict = results.get(i) + if verdict is None: + if self._skip_on_error: + out.append(row) + continue + if verdict: + out.append(row) + return out diff --git a/src/twinkle_agentic/preprocessor/message_sanity.py b/src/twinkle_agentic/preprocessor/message_sanity.py new file mode 100644 index 00000000..62f6436c --- /dev/null +++ b/src/twinkle_agentic/preprocessor/message_sanity.py @@ -0,0 +1,400 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import json +import os +import re +from typing import Any, Dict, List, Optional, Set + +from twinkle.preprocessor import Preprocessor + +# ── Valid role set ──────────────────────────────────────────────────────────── +_VALID_ROLES = {'system', 'user', 'assistant', 'tool'} + +_DEFAULT_SENSITIVE: Set[str] = set() + + +def _load_sensitive_words(path: Optional[str]) -> Set[str]: + """Load sensitive words from an external file (one word per line). + + Blank lines and #-comments are ignored. + """ + if not path or not os.path.isfile(path): + return set() + words: Set[str] = set() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#'): + words.add(line) + return words + + +def _build_sensitive_regex(words: Set[str]) -> Optional['re.Pattern']: + """Build a compiled regex from a set of words. Returns None if empty.""" + if not words: + return None + cjk_words = [] + latin_words = [] + cjk_re = re.compile(r'[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3]') + for w in words: + if cjk_re.search(w): + cjk_words.append(re.escape(w)) + else: + latin_words.append(re.escape(w)) + parts = [] + if latin_words: + parts.append(r'\b(' + '|'.join(latin_words) + r')\b') + if cjk_words: + parts.append('(' + '|'.join(cjk_words) + ')') + return re.compile('|'.join(parts), re.IGNORECASE) + + +def _msg_content_text(msg: Dict[str, Any]) -> str: + """Extract plain text from a message's content (str | list | dict).""" + c = msg.get('content') + if isinstance(c, str): + return c + if isinstance(c, list): + return ' '.join( + p.get('text', '') for p in c + if isinstance(p, dict) and p.get('type') == 'text' + ) + if isinstance(c, dict) and c.get('type') == 'text': + return c.get('text', '') + return '' + + +def _normalize_tool_calls(msg: Dict[str, Any]) -> Optional[List[Any]]: + """Return ``tool_calls`` as a list, decoding the JSON-string form used for + PyArrow schema stability. Returns ``None`` when absent / empty / malformed. + """ + tcs = msg.get('tool_calls') + if isinstance(tcs, str): + s = tcs.strip() + if not s: + return None + try: + decoded = json.loads(s) + except (json.JSONDecodeError, ValueError): + return None + return decoded if isinstance(decoded, list) and decoded else None + if isinstance(tcs, list) and tcs: + return tcs + return None + + +# ── Role order validation ──────────────────────────────────────────────────── + +def _consolidate_system_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Fold every ``role='system'`` message into one block at index 0. + + Multi-block agents (Claude Code skills/billing/tooling) emit several + system messages, sometimes interleaved with the conversation + (``[sys, user, sys, asst, ...]``). Chat templates expect at most one + system block at the start; we collect all system contents in original + order and concatenate them. Non-system messages keep their relative order. + + Returns the input list unchanged (identity-equal) when it is already + canonical (≤1 system, at index 0) so callers can use ``is`` for an O(1) + "changed?" check. + """ + sys_count = 0 + misplaced = False + for i, m in enumerate(messages): + if isinstance(m, dict) and m.get('role') == 'system': + sys_count += 1 + if i != 0: + misplaced = True + if sys_count <= 1 and not misplaced: + return messages + + sys_chunks: List[str] = [] + rest: List[Dict[str, Any]] = [] + template: Optional[Dict[str, Any]] = None + for m in messages: + if isinstance(m, dict) and m.get('role') == 'system': + if template is None: + template = m + text = _msg_content_text(m).strip() + if text: + sys_chunks.append(text) + else: + rest.append(m) + return [dict(template, content='\n\n'.join(sys_chunks))] + rest + + +def _validate_role_order(messages: List[Dict[str, Any]], is_agent: bool = False) -> bool: + """Check that message roles follow a sane conversational order. + + Strict rules (default): + - Every message has a valid role. + - system (if present) must be at index 0. + - The first non-system message must be ``user``. + - Every ``assistant`` has at least one ``user`` somewhere before it. + - tool messages immediately follow an assistant with ``tool_calls`` (or a + preceding tool, for parallel calls). + + Agent rules (``is_agent=True``, e.g. Cline / OpenClaw text-based tool calls): + - tool messages may follow any role as long as some assistant exists + earlier in the conversation (the structured ``tool_calls`` field is + absent because the call is encoded inside assistant text). + """ + if not messages: + return False + + seen_user = False + seen_assistant = False + saw_first_non_system = False + for i, m in enumerate(messages): + if not isinstance(m, dict): + return False + role = m.get('role') + if role not in _VALID_ROLES: + return False + if role == 'system': + if i != 0: + return False + continue + if not saw_first_non_system: + if role != 'user': + return False + saw_first_non_system = True + if role == 'user': + seen_user = True + elif role == 'assistant': + if not seen_user: + return False + seen_assistant = True + elif role == 'tool': + if is_agent: + if not seen_assistant: + return False + else: + prev = messages[i - 1] + if not isinstance(prev, dict): + return False + prev_role = prev.get('role') + if prev_role not in ('assistant', 'tool'): + return False + if prev_role == 'assistant' and not prev.get('tool_calls'): + return False + return True + + +_IDENTIFIER_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_.\-]*$') + + +def _validate_content_integrity( + messages: List[Dict[str, Any]], + min_turns: int = 2, + max_msg_chars: int = 50000, +) -> bool: + """Check content-level integrity of a conversation.""" + user_count = 0 + assistant_count = 0 + + for i, m in enumerate(messages): + if not isinstance(m, dict): + return False + role = m.get('role') + content = _msg_content_text(m) + + if role == 'user': + user_count += 1 + elif role == 'assistant': + assistant_count += 1 + # Assistant must have content or tool_calls + if not content.strip() and not _normalize_tool_calls(m): + return False + elif role == 'system': + if not content.strip(): + return False + + # Single message length bounds + if content and len(content) > max_msg_chars: + return False + + # tool_calls structural validity + norm_tcs = _normalize_tool_calls(m) + if norm_tcs is not None: + for tc in norm_tcs: + if not isinstance(tc, dict): + return False + func = tc.get('function') + if not isinstance(func, dict): + return False + name = func.get('name', '') + if not name or not _IDENTIFIER_RE.match(name): + return False + # arguments must be valid JSON string (or dict) + args = func.get('arguments') + if isinstance(args, str): + try: + json.loads(args) + except (json.JSONDecodeError, ValueError): + return False + + # Duplicate consecutive detection (skip tool — parallel calls may return same result) + if i > 0 and role != 'tool' and isinstance(messages[i - 1], dict): + prev = messages[i - 1] + if prev.get('role') == role and _msg_content_text(prev) == content and content: + return False + + # Minimum conversation depth + if user_count < 1 or assistant_count < 1: + return False + if (user_count + assistant_count) < min_turns: + return False + + return True + + +def _validate_tool_call_matching(messages: List[Dict[str, Any]]) -> bool: + """Verify tool_call_id bidirectional matching between assistant and tool messages.""" + i = 0 + while i < len(messages): + m = messages[i] + if not isinstance(m, dict): + i += 1 + continue + if m.get('role') == 'assistant': + norm_tcs = _normalize_tool_calls(m) + if norm_tcs: + # Collect expected IDs from this assistant's tool_calls + expected_ids = set() + for tc in norm_tcs: + if isinstance(tc, dict) and tc.get('id'): + expected_ids.add(tc['id']) + if not expected_ids: + i += 1 + continue + # Collect actual tool response IDs that follow + actual_ids = set() + j = i + 1 + while j < len(messages): + nxt = messages[j] + if not isinstance(nxt, dict) or nxt.get('role') != 'tool': + break + tid = nxt.get('tool_call_id') + if tid: + actual_ids.add(tid) + j += 1 + # Must have at least one matching response; all responses must reference valid calls + if not actual_ids or not actual_ids.issubset(expected_ids): + return False + i = j + else: + i += 1 + else: + i += 1 + return True + + +def _trim_to_last_assistant(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Trim trailing messages so the conversation ends with an assistant message. + + Returns the trimmed list, or empty list if no assistant message exists. + """ + last_asst = -1 + for i in range(len(messages) - 1, -1, -1): + if isinstance(messages[i], dict) and messages[i].get('role') == 'assistant': + last_asst = i + break + if last_asst < 0: + return [] + return messages[:last_asst + 1] + + +# ── Preprocessor ───────────────────────────────────────────────────────────── + +class MessageSanityFilter(Preprocessor): + """Structural and content sanity filter for messages-format datasets. + + 1. Role order validation (system at 0, tool after assistant, valid roles). + 2. Trim to last assistant (discard if no assistant remains). + 3. Sensitive word filtering (discard row if any message contains bad words). + + Sensitive words source: + - ``sensitive_words_file``: external text file (one word per line, # for comments) + - ``extra_sensitive_words``: additional words merged programmatically + """ + + def __init__( + self, + check_role_order: bool = True, + check_tool_matching: bool = True, + check_content_integrity: bool = True, + trim_to_assistant: bool = True, + filter_sensitive: bool = True, + sensitive_words_file: Optional[str] = None, + extra_sensitive_words: Optional[List[str]] = None, + min_turns: int = 2, + max_msg_chars: int = 50000, + ) -> None: + super().__init__() + self.check_role_order = check_role_order + self.check_tool_matching = check_tool_matching + self.check_content_integrity = check_content_integrity + self.trim_to_assistant = trim_to_assistant + self.filter_sensitive = filter_sensitive + self._min_turns = min_turns + self._max_msg_chars = max_msg_chars + + # Build unified sensitive word set + if sensitive_words_file: + all_words = _load_sensitive_words(sensitive_words_file) + else: + all_words = set(_DEFAULT_SENSITIVE) + if extra_sensitive_words: + all_words.update(w.strip() for w in extra_sensitive_words if w and w.strip()) + self._sensitive_re = _build_sensitive_regex(all_words) + + def __call__(self, rows) -> List[Dict[str, Any]]: + out = [] + for row in rows: + messages = row.get('messages') + if not isinstance(messages, list) or not messages: + continue + is_agent = bool(row.get('is_agent')) + + # Step 0: fold all system blocks into one at index 0 + normalized = _consolidate_system_messages(messages) + if normalized is not messages: + messages = normalized + row = dict(row, messages=messages) + + # Step 1: role order check + if self.check_role_order and not _validate_role_order(messages, is_agent=is_agent): + continue + + # Step 1.5: tool_call_id matching (skip for agent rows: text-based tool calls have no IDs) + if self.check_tool_matching and not is_agent and not _validate_tool_call_matching(messages): + continue + + # Step 2: trim to last assistant + if self.trim_to_assistant: + messages = _trim_to_last_assistant(messages) + if not messages: + continue + row = dict(row, messages=messages) + + # Step 2.5: content integrity (after trim so we validate the final sample) + if self.check_content_integrity and not _validate_content_integrity( + messages, + min_turns=self._min_turns, + max_msg_chars=self._max_msg_chars, + ): + continue + + # Step 3: sensitive word check + if self.filter_sensitive and self._sensitive_re: + has_bad = False + for m in messages: + text = _msg_content_text(m) + if self._sensitive_re.search(text): + has_bad = True + break + if has_bad: + continue + + out.append(row) + return out diff --git a/src/twinkle_agentic/preprocessor/perplexity.py b/src/twinkle_agentic/preprocessor/perplexity.py new file mode 100644 index 00000000..9fe4f84c --- /dev/null +++ b/src/twinkle_agentic/preprocessor/perplexity.py @@ -0,0 +1,146 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import math +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional, Tuple + +from twinkle.preprocessor import Preprocessor + +from .llm_backend import LLMBackend, OpenAIBackend + +_MIN_RESPONSE_TOKENS = 5 + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _encode_pair( + tokenizer, + messages: List[Dict[str, Any]], +) -> Optional[Tuple[List[Dict[str, Any]], int]]: + """Return (messages, n_prompt_tokens) or None.""" + last_asst = next( + (i for i in range(len(messages) - 1, -1, -1) + if isinstance(messages[i], dict) and messages[i].get('role') == 'assistant'), + None, + ) + if last_asst is None: + return None + + try: + prompt_text = tokenizer.apply_chat_template( + messages[:last_asst], tokenize=False, add_generation_prompt=True, + ) + full_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False, + ) + except Exception: + return None + + # Template already embeds special tokens as text; avoid double-adding them + n_prompt = len(tokenizer(prompt_text, add_special_tokens=False)['input_ids']) + n_full = len(tokenizer(full_text, add_special_tokens=False)['input_ids']) + if n_full - n_prompt < _MIN_RESPONSE_TOKENS: + return None + return messages, n_prompt + + +def _extract_logprob(lp) -> Optional[float]: + """Extract scalar log-prob from a vLLM prompt_logprobs element after JSON round-trip.""" + if lp is None: + return None + if isinstance(lp, (int, float)): + return float(lp) + # vLLM JSON format: {str(token_id): {"logprob": float, "rank": int, "decoded_token": str}} + if isinstance(lp, dict): + v = next(iter(lp.values()), None) + if isinstance(v, dict): + return float(v['logprob']) + if isinstance(v, (int, float)): + return float(v) + return None + + +def _ppl_from_logprobs( + prompt_logprobs: List, + n_prompt: int, +) -> Optional[float]: + response_lps = [_extract_logprob(lp) for lp in prompt_logprobs[n_prompt:]] + response_lps = [lp for lp in response_lps if lp is not None] + if len(response_lps) < _MIN_RESPONSE_TOKENS: + return None + return math.exp(-sum(response_lps) / len(response_lps)) + + +def _score_one( + backend: LLMBackend, + messages: List[Dict[str, Any]], +) -> List[Optional[float]]: + return backend.prompt_logprobs(messages) + + +# ── Preprocessor ───────────────────────────────────────────────────────────── + +class PerplexityFilter(Preprocessor): + """Filter dataset rows by model perplexity on the assistant response. + + Uses the OpenAI-compatible /v1/chat/completions endpoint with prompt_logprobs + so it is safe to use in multiprocessing contexts — no shared GPU state. + + ppl_min / ppl_max define the keep window: + - Too low → trivially memorized / degenerate output. + - Too high → out-of-distribution, garbled, or badly formatted. + + Requirement: tokenizer_name_or_path must match the model served at api_endpoint. + """ + + def __init__( + self, + backend: LLMBackend = None, + tokenizer_name_or_path: str = '', + ppl_min: float = 2.0, + ppl_max: float = 100.0, + max_workers: int = 8, + # Legacy params + api_endpoint: str = '', + model: str = 'default', + ): + from transformers import AutoTokenizer + + if backend is not None: + self._backend = backend + else: + self._backend = OpenAIBackend(endpoint=api_endpoint, model=model) + self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + self.ppl_min = ppl_min + self.ppl_max = ppl_max + self._max_workers = max_workers + + def __call__(self, rows) -> List[Dict[str, Any]]: + """Parallel-score rows via chat completions; keep rows with PPL in [ppl_min, ppl_max].""" + scoreable: List[Tuple[int, List[Dict[str, Any]], int]] = [] # (row_idx, messages, n_prompt) + for i, row in enumerate(rows): + messages = row.get('messages') or [] + result = _encode_pair(self._tokenizer, messages) + if result is not None: + scoreable.append((i, result[0], result[1])) + + if not scoreable: + return rows + + drop: set = set() + n_workers = min(self._max_workers, len(scoreable)) + with ThreadPoolExecutor(max_workers=n_workers) as pool: + future_to_meta = { + pool.submit(_score_one, self._backend, messages): (row_idx, n_prompt) + for row_idx, messages, n_prompt in scoreable + } + for future in as_completed(future_to_meta): + row_idx, n_prompt = future_to_meta[future] + try: + prompt_logprobs = future.result() + except Exception: + continue + ppl = _ppl_from_logprobs(prompt_logprobs, n_prompt) + if ppl is not None and not (self.ppl_min <= ppl <= self.ppl_max): + drop.add(row_idx) + + return [row for i, row in enumerate(rows) if i not in drop] diff --git a/src/twinkle_agentic/preprocessor/pii_presidio_filter.py b/src/twinkle_agentic/preprocessor/pii_presidio_filter.py new file mode 100644 index 00000000..66364cfa --- /dev/null +++ b/src/twinkle_agentic/preprocessor/pii_presidio_filter.py @@ -0,0 +1,382 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Multi-language, multi-country PII rewriter via Presidio + spaCy NER + Faker. + +Coverage: + Names/Locations/Orgs: PERSON, LOCATION, ORGANIZATION (NER, en + zh) + Network/contact: EMAIL_ADDRESS, IP_ADDRESS, URL + Finance: CREDIT_CARD (Luhn), IBAN_CODE, CRYPTO, US_BANK_NUMBER, CN_BANK + Government IDs: US_SSN, US_ITIN, US_PASSPORT, US_DRIVER_LICENSE, + UK_NHS, UK_NINO, IN_AADHAAR, IN_PAN, AU_ABN, SG_NRIC, + IT_FISCAL_CODE, ES_NIF, ES_NIE, CN_ID + Phones: PHONE_NUMBER (libphonenumber), CN_PHONE, CN_LANDLINE + Other: DATE_TIME, MEDICAL_LICENSE, NRP + +Strategies (per entity, configurable via ``entity_strategy``): + ``mask`` -> keep edges, mask middle (numeric IDs/cards) + ``replace`` -> Faker fake value (names/emails — preserves text fluency) + ``redact`` -> drop the span entirely + ``hash`` -> sha256 prefix (deterministic, deidentified, joinable) + +Consistency: same source value → same fake value within a batch (and optionally +across batches via ``persistent_consistency``), so dialogues stay coherent. +""" +import hashlib +import threading +from enum import Enum +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from twinkle.preprocessor import Preprocessor + +# ─── Validators ───────────────────────────────────────────────────────────────── + +_ID_WEIGHTS = (7, 9, 10, 5, 8, 4, 2, 1, 6, 3, 7, 9, 10, 5, 8, 4, 2) +_ID_CHECKS = '10X98765432' + + +def _is_valid_cn_id(s: str) -> bool: + if len(s) != 18 or not s[:17].isdigit(): + return False + total = sum(int(s[i]) * _ID_WEIGHTS[i] for i in range(17)) + return _ID_CHECKS[total % 11] == s[17].upper() + + +def _is_valid_luhn(s: str) -> bool: + digits = [int(c) for c in s if c.isdigit()] + if len(digits) < 13: + return False + checksum = 0 + for i, d in enumerate(reversed(digits)): + if i % 2 == 1: + d = d * 2 - 9 if d * 2 > 9 else d * 2 + checksum += d + return checksum % 10 == 0 + + +# ─── Replacement primitives ───────────────────────────────────────────────────── + +class Strategy(str, Enum): + MASK = 'mask' + REPLACE = 'replace' + REDACT = 'redact' + HASH = 'hash' + + @classmethod + def coerce(cls, value: 'str | Strategy') -> 'Strategy': + try: + return cls(value) if not isinstance(value, cls) else value + except ValueError as e: + allowed = ', '.join(s.value for s in cls) + raise ValueError(f'Unknown strategy {value!r}. Allowed: {allowed}') from e + + +def _mask_keep_edges(s: str, head: int = 3, tail: int = 4, ch: str = '*') -> str: + if len(s) <= head + tail: + return ch * len(s) + return s[:head] + ch * (len(s) - head - tail) + s[-tail:] + + +def _hash_short(s: str, salt: str = '') -> str: + return hashlib.sha256((salt + s).encode('utf-8')).hexdigest()[:12] + + +# ─── Faker dispatcher (per-instance, thread-safe) ─────────────────────────────── + +class FakerProvider: + """Maps Presidio entity_type → Faker provider call, with lang-locale cache.""" + + _PROVIDER: Dict[str, Any] = { + 'PERSON': lambda f: f.name(), + 'LOCATION': lambda f: f.city(), + 'ORGANIZATION': lambda f: f.company(), + 'EMAIL_ADDRESS': lambda f: f.email(), + 'PHONE_NUMBER': lambda f: f.phone_number(), + 'CN_PHONE': lambda f: f.phone_number(), + 'CN_LANDLINE': lambda f: f.phone_number(), + 'IP_ADDRESS': lambda f: f.ipv4(), + 'URL': lambda f: f.url(), + 'IBAN_CODE': lambda f: f.iban(), + 'CREDIT_CARD': lambda f: f.credit_card_number(), + 'US_BANK_NUMBER': lambda f: f.credit_card_number(), + 'CN_BANK': lambda f: f.credit_card_number(), + 'CRYPTO': lambda f: f.sha256()[:34], + 'DATE_TIME': lambda f: str(f.date()), + } + _LOCALE: Dict[str, str] = {'zh': 'zh_CN', 'en': 'en_US'} + + def __init__(self) -> None: + self._cache: Dict[str, Any] = {} + self._lock = threading.Lock() + + def faker(self, lang: str): + if lang not in self._cache: + with self._lock: + if lang not in self._cache: + from faker import Faker + self._cache[lang] = Faker(self._LOCALE.get(lang, 'en_US')) + return self._cache[lang] + + def fake_for(self, entity: str, original: str, lang: str) -> str: + f = self.faker(lang) + provider = self._PROVIDER.get(entity.upper()) + if provider is not None: + return provider(f) + # Same-length opaque alnum for unknown entities; downstream length checks survive. + return f.bothify('?' * 2 + '#' * max(2, len(original) - 2)).upper() + + +# ─── CN recognizers (module-level so they introspect/pickle cleanly) ──────────── + +def _cn_recognizer_classes(): + """Lazy-imported once; PatternRecognizer requires presidio_analyzer at import time.""" + from presidio_analyzer import Pattern, PatternRecognizer + + class CNIDRecognizer(PatternRecognizer): + def validate_result(self, pattern_text: str) -> bool: + return _is_valid_cn_id(pattern_text) + + class CNBankRecognizer(PatternRecognizer): + def validate_result(self, pattern_text: str) -> bool: + return _is_valid_luhn(pattern_text) + + return Pattern, PatternRecognizer, CNIDRecognizer, CNBankRecognizer + + +def _build_cn_recognizers(languages: Sequence[str]) -> List[Any]: + Pattern, PatternRecognizer, CNIDRecognizer, CNBankRecognizer = _cn_recognizer_classes() + specs = [ + ('CN_ID', r'(? None: + super().__init__() + self._require_deps() + + self._languages: List[str] = list(languages) + self._spacy_models = dict(self.DEFAULT_SPACY_MODELS) + if spacy_models: + self._spacy_models.update(spacy_models) + for lang in self._languages: + if lang not in self._spacy_models: + raise ValueError(f'No spaCy model configured for language {lang!r}') + + self._strategy = {k: Strategy.coerce(v) for k, v in self.DEFAULT_ENTITY_STRATEGY.items()} + if entity_strategy: + self._strategy.update({k.upper(): Strategy.coerce(v) + for k, v in entity_strategy.items()}) + self._default_strategy = Strategy.coerce(default_strategy) + + self._score_threshold = score_threshold + self._roles = set(roles) + self._consistency = consistency + self._persistent_consistency = persistent_consistency + self._hash_salt = hash_salt + self._record_counts = record_counts + + self._faker = FakerProvider() + self._persistent_map: Dict[Tuple[str, str], str] = {} + self._analyzer = self._build_analyzer() + # Restrict analyze() to entities we act on AND that the registry actually supports per language; + # avoids 'Entity X doesn't have the corresponding recognizer in language : Y' warnings. + wanted = {e for e in self._strategy if e not in self.IGNORED_ENTITIES} + registry = self._analyzer.registry + self._allowed_entities: Dict[str, List[str]] = { + lang: sorted(wanted & set(registry.get_supported_entities(languages=[lang]))) + for lang in self._languages + } + + # ── construction ──────────────────────────────────────────────────────── + + @classmethod + def _require_deps(cls) -> None: + try: + import presidio_analyzer # noqa: F401 + import presidio_anonymizer # noqa: F401 + import faker # noqa: F401 + import spacy # noqa: F401 + except ImportError as e: + raise ImportError(f'{e}. {cls.INSTALL_HINT}') from e + + def _build_analyzer(self): + from presidio_analyzer import AnalyzerEngine, RecognizerRegistry + from presidio_analyzer.nlp_engine import NlpEngineProvider + + nlp_conf = { + 'nlp_engine_name': 'spacy', + 'models': [{'lang_code': l, 'model_name': self._spacy_models[l]} + for l in self._languages], + } + nlp_engine = NlpEngineProvider(nlp_configuration=nlp_conf).create_engine() + # NER pipe is the heaviest spaCy component and we discard all NER entities; disable to save 2-4x latency. + for nlp in getattr(nlp_engine, 'nlp', {}).values(): + for pipe in ('ner', 'parser', 'attribute_ruler', 'lemmatizer'): + if pipe in nlp.pipe_names: + nlp.disable_pipe(pipe) + registry = RecognizerRegistry(supported_languages=self._languages) + registry.load_predefined_recognizers(languages=self._languages, nlp_engine=nlp_engine) + for r in _build_cn_recognizers(self._languages): + registry.add_recognizer(r) + return AnalyzerEngine(registry=registry, nlp_engine=nlp_engine, + supported_languages=self._languages) + + # ── language routing ──────────────────────────────────────────────────── + + def _resolve_language(self, text: str) -> str: + cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + guess = 'zh' if cjk / max(1, len(text)) > self.CJK_LANG_THRESHOLD else 'en' + return guess if guess in self._languages else self._languages[0] + + # ── replacement ───────────────────────────────────────────────────────── + + def _replacement_for( + self, entity: str, original: str, lang: str, + local_map: Dict[Tuple[str, str], str], + ) -> str: + strategy = self._strategy.get(entity.upper(), self._default_strategy) + if strategy is Strategy.REDACT: + return '' + if strategy is Strategy.HASH: + return f'<{entity}:{_hash_short(original, self._hash_salt)}>' + if strategy is Strategy.MASK: + return _mask_keep_edges(original) + # Strategy.REPLACE — Faker with optional consistency cache. + if not self._consistency: + return self._faker.fake_for(entity, original, lang) + cache = self._persistent_map if self._persistent_consistency else local_map + key = (entity.upper(), original) + if key not in cache: + cache[key] = self._faker.fake_for(entity, original, lang) + return cache[key] + + @classmethod + def _min_length(cls, entity: str) -> int: + return cls.DEFAULT_MIN_LENGTH.get(entity.upper(), cls.MIN_LENGTH_FALLBACK) + + # ── span dedup ────────────────────────────────────────────────────────── + + @staticmethod + def _dedupe_overlaps(results: List[Any]) -> List[Any]: + """Greedy interval scheduling: keep highest-score span per overlapping region.""" + ordered = sorted(results, key=lambda r: (-r.score, -(r.end - r.start), r.start)) + kept: List[Any] = [] + for r in ordered: + if any(r.start < k.end and r.end > k.start for k in kept): + continue + kept.append(r) + return kept + + # ── core scrubbing ────────────────────────────────────────────────────── + + def _scrub_text( + self, text: str, local_map: Dict[Tuple[str, str], str], + ) -> Tuple[str, Dict[str, int]]: + if not text: + return text, {} + lang = self._resolve_language(text) + results = self._analyzer.analyze(text=text, language=lang, + entities=self._allowed_entities.get(lang), + score_threshold=self._score_threshold) + if not results: + return text, {} + + spans = self._dedupe_overlaps(results) + spans = [r for r in spans if r.entity_type.upper() not in self.IGNORED_ENTITIES] + spans = [r for r in spans if (r.end - r.start) >= self._min_length(r.entity_type)] + if not spans: + return text, {} + # Reverse-sort so in-place index slicing stays valid. + spans.sort(key=lambda r: r.start, reverse=True) + out = text + hits: Dict[str, int] = {} + for r in spans: + original = out[r.start:r.end] + replacement = self._replacement_for(r.entity_type, original, lang, local_map) + out = out[:r.start] + replacement + out[r.end:] + hits[r.entity_type] = hits.get(r.entity_type, 0) + 1 + return out, hits + + def _scrub_row( + self, row: Dict[str, Any], local_map: Dict[Tuple[str, str], str], + ) -> Dict[str, int]: + row_hits: Dict[str, int] = {} + for m in row.get('messages') or []: + if not isinstance(m, dict) or m.get('role') not in self._roles: + continue + content = m.get('content') + if not isinstance(content, str) or not content: + continue + new_content, hits = self._scrub_text(content, local_map) + if hits: + m['content'] = new_content + for k, v in hits.items(): + row_hits[k] = row_hits.get(k, 0) + v + return row_hits + + def __call__(self, rows) -> List[Dict[str, Any]]: + local_map: Dict[Tuple[str, str], str] = {} + for row in rows: + row_hits = self._scrub_row(row, local_map) + if self._record_counts: + if row_hits: + row['_pii_hits'] = row_hits + else: + row.pop('_pii_hits', None) + return rows diff --git a/src/twinkle_agentic/preprocessor/refuse_filter.py b/src/twinkle_agentic/preprocessor/refuse_filter.py new file mode 100644 index 00000000..015794e6 --- /dev/null +++ b/src/twinkle_agentic/preprocessor/refuse_filter.py @@ -0,0 +1,143 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import re +from typing import Any, Dict, List + +from twinkle.preprocessor import Preprocessor + +# ── English refusal patterns ────────────────────────────────────────────────── +# +# Design principle: require a SELF-REFERENTIAL subject (I/we) + a task-directed +# inability/refusal verb. This avoids false positives on: +# "I cannot stress enough…" "I cannot find the bug…" +# "The API cannot handle null" "You cannot use this without auth" + +# Core: I/we + modal inability + task verb +_EN_CORE = re.compile( + r'\b(i|we)\b.{0,25}\b(' + r"can'?t|cannot|am\s+not\s+able|are\s+not\s+able|" + r"won'?t|will\s+not|am\s+unable|are\s+unable|" + r"must\s+decline|have\s+to\s+decline|" + r"decline\s+to|refuse\s+to|" + r"am\s+not\s+(allowed|permitted|authorized|comfortable)\s+to|" + r"are\s+not\s+(allowed|permitted|authorized)" + r")\b.{0,60}\b(" + r'help|assist|answer|respond|provide|generate|create|produce|' + r'fulfill|comply|address|process|complete|handle|discuss|support' + r')\b', + re.IGNORECASE | re.DOTALL, +) + +# Apology opener + refusal: "I'm sorry, but I can't…" / "Unfortunately I cannot…" +_EN_APOLOGY = re.compile( + r'\b(i\'?m\s+sorry|i\s+apologize|unfortunately|i\s+regret)\b.{0,80}' + r'\b(can\'?t|cannot|unable|won\'?t|will\s+not|must\s+decline|have\s+to\s+decline|' + r'not\s+(allowed|able|comfortable|appropriate))\b', + re.IGNORECASE | re.DOTALL, +) + +# Policy / content violation signal +_EN_POLICY = re.compile( + r'\b(this|that|your|the)\s+(request|question|prompt|content|topic|task)\b.{0,60}' + r'\b(violates?|goes?\s+against|is\s+(inappropriate|not\s+(appropriate|allowed|permitted|' + r'something\s+i\s+can)))\b', + re.IGNORECASE | re.DOTALL, +) + +# Standalone declarative refusal phrases +_EN_STANDALONE = re.compile( + r'\b(i|we)\s+(must|have\s+to|am\s+going\s+to|need\s+to)\s+(decline|refuse)\b|' + r'\b(i|we)\s+(decline|refuse)\s+(this|your|to)\b|' + r'\bthis\s+(falls\s+outside|is\s+outside|is\s+beyond)\s+(what\s+i|my)\b|' + r'\bas\s+an\s+ai[,.]?\s+i\s+(can\'?t|cannot|am\s+not\s+able|won\'?t)\b' + r'.{0,40}\b(help|assist|answer|respond|provide|generate|create|fulfill|comply|' + r'address|process|complete|handle|discuss|support)\b', + re.IGNORECASE, +) + +_EN_PATTERNS = (_EN_CORE, _EN_APOLOGY, _EN_POLICY, _EN_STANDALONE) + +# ── Chinese refusal patterns ────────────────────────────────────────────────── + +# Apology + inability (高精确:抱歉/对不起 + 无法/不能 near start) +_ZH_APOLOGY = re.compile( + r'(非常|十分|很|极为)?抱歉[,,。\s]{0,5}.{0,40}(无法|不能|不可以|不便|没有办法)|' + r'对不起[,,。\s]{0,5}.{0,40}(无法|不能|不可以|不便)', + re.UNICODE, +) + +# Self-referential: 我 + refusal + task object +_ZH_SELF = re.compile( + r'我(无法|不能|不可以|没有办法|不便|不适合|不被允许|不被授权)' + r'.{0,30}(帮|回答|提供|生成|处理|协助|完成|执行|回复|解答|协|帮助)', + re.UNICODE, +) + +# Request-level violation +_ZH_VIOLATION = re.compile( + r'(您的|这个|该)(请求|问题|内容|话题).{0,20}(违反|不当|不合适|超出了?我)', + re.UNICODE, +) + +# AI identity + refusal + task verb (avoid false positives on self-deprecating preambles +# like "作为AI,我虽无法体验情感,但……") +_ZH_AI_ID = re.compile( + r'作为(AI|人工智能|语言模型|大模型)[,,].{0,30}(无法|不能|不便|不应该|不适合)' + r'.{0,20}(帮|回答|提供|生成|处理|协助|完成|执行|回复|解答|讨论|参与|评论|创作|输出)', + re.UNICODE, +) + +_ZH_PATTERNS = (_ZH_APOLOGY, _ZH_SELF, _ZH_VIOLATION, _ZH_AI_ID) + +# ── Japanese refusal patterns ───────────────────────────────────────────────── + +_JA_PATTERNS = ( + re.compile(r'(申し訳|恐れ入り)ます(が|けれど).{0,40}(できません|お答えできません|対応できません)', re.UNICODE), + re.compile(r'(回答|対応|お答え)(する|いたす)ことは?できません', re.UNICODE), + re.compile(r'ご要望には?お(応え|答え)できません', re.UNICODE), + re.compile(r'(その|この)(リクエスト|質問|依頼).{0,20}(お断り|辞退|対応できません)', re.UNICODE), +) + +# ── Korean refusal patterns ─────────────────────────────────────────────────── + +_KO_PATTERNS = ( + re.compile(r'(죄송하지만|유감스럽게도).{0,40}(드릴 수 없|없습니다|못합니다)', re.UNICODE), + re.compile(r'(답변|도움|처리|제공)(드리기|하기)\s*(어렵|불가|할 수 없)', re.UNICODE), + re.compile(r'(요청|질문|내용).{0,20}(거절|거부|응할 수 없)', re.UNICODE), +) + +_ALL_PATTERNS = _EN_PATTERNS + _ZH_PATTERNS + _JA_PATTERNS + _KO_PATTERNS + + +# ── Core helper ─────────────────────────────────────────────────────────────── + +def _is_refusal(text: str, check_window: int = 600) -> bool: + """Return True if the text contains a self-referential refusal signal.""" + window = text[:check_window] + return any(p.search(window) for p in _ALL_PATTERNS) + + +# ── Preprocessor ───────────────────────────────────────────────────────────── + +class RefuseFilter(Preprocessor): + + def __init__(self, check_window: int = 600) -> None: + super().__init__() + self._check_window = check_window + + def __call__(self, rows) -> List[Dict[str, Any]]: + """Drop rows where the first assistant reply expresses a refusal or inability.""" + out = [] + for row in rows: + messages = row.get('messages') or [] + asst_msgs = [ + m for m in messages + if isinstance(m, dict) and m.get('role') == 'assistant' + ] + if not asst_msgs: + out.append(row) + continue + first_reply = (asst_msgs[0].get('content') or '').strip() + response = re.sub(r'.*?\s*', '', first_reply, flags=re.DOTALL).strip() + if not response or not _is_refusal(response, self._check_window): + out.append(row) + return out diff --git a/src/twinkle_agentic/preprocessor/response_refiner.py b/src/twinkle_agentic/preprocessor/response_refiner.py new file mode 100644 index 00000000..53a353e7 --- /dev/null +++ b/src/twinkle_agentic/preprocessor/response_refiner.py @@ -0,0 +1,228 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import json +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional, Tuple + +from twinkle.preprocessor import Preprocessor +from twinkle.utils import get_logger + +from .llm_backend import LLMBackend, OpenAIBackend + +logger = get_logger(only_local_master=False) + +_REFINE_SYSTEM_PROMPT = """\ +You are an expert response quality optimizer. You will be given a conversation context \ +and must produce the ideal assistant response. + +Requirements: +1. Correctness: The answer must be logically sound with no factual errors. +2. Conciseness: Remove redundant reasoning, filler phrases, and unnecessary repetition. \ +Every sentence should carry new information. +3. Completeness: Cover all aspects of the user's question without omitting key points. +4. Structure: Use clear organization (numbered steps, code blocks, formulas) when appropriate. +5. Length: Response length should be proportional to question complexity — \ +short questions get short answers, complex ones get detailed answers. + +Output format: +- Return ONLY the assistant's response content. Do not include any meta-commentary.\ +""" + +_INTENT_PROMPT_SUFFIX = { + 'code': ( + '\nFocus: This round is about CODE. ' + 'Ensure the code is correct, complete, runnable, and well-commented. ' + 'Fix any bugs in the original. Use proper formatting with language-tagged fenced blocks.' + ), + 'math': ( + '\nFocus: This round is about MATH. ' + 'Show derivation steps clearly with proper LaTeX notation. ' + 'Verify the final answer by substitution or sanity check.' + ), + 'complex_logic': ( + '\nFocus: This round requires COMPLEX REASONING. ' + 'Present a clean logical chain without backtracking. ' + 'Number each reasoning step. State assumptions explicitly.' + ), + 'user_dissatisfaction': ( + '\nFocus: The user was DISSATISFIED with the previous response. ' + 'Address the root cause of dissatisfaction directly. ' + 'Acknowledge the issue and provide a substantially improved answer.' + ), + 'tool_call': ( + '\nFocus: This round involves TOOL CALLS. ' + 'Ensure tool call arguments are correct and the synthesis of tool results is accurate. ' + 'Present the final answer clearly based on tool outputs.' + ), +} + + +def _call_model( + backend: LLMBackend, + context_messages: List[Dict[str, Any]], + temperature: float, + max_tokens: int, + intent: str = '', +) -> Optional[Dict[str, str]]: + """Call the model and return {'content': ..., 'reasoning_content': ...}.""" + system_prompt = _REFINE_SYSTEM_PROMPT + _INTENT_PROMPT_SUFFIX.get(intent, '') + messages = [{'role': 'system', 'content': system_prompt}] + context_messages + + choices = backend.chat(messages, temperature=temperature, max_tokens=max_tokens) + if not choices: + return None + + content = choices[0].get('content') or '' + reasoning = choices[0].get('reasoning_content') or '' + + if not content.strip(): + return None + + return {'content': content, 'reasoning_content': reasoning} + + +def _refine_round( + backend: LLMBackend, + messages: List[Dict[str, Any]], + assistant_idx: int, + temperature: float, + max_tokens: int, + intent: str = '', +) -> Optional[Dict[str, str]]: + """Refine a single key round's assistant response.""" + if assistant_idx >= len(messages) or assistant_idx < 1: + return None + + asst_msg = messages[assistant_idx] + if not isinstance(asst_msg, dict) or asst_msg.get('role') != 'assistant': + return None + + context = messages[:assistant_idx] + if not context: + return None + + return _call_model(backend, context, temperature, max_tokens, intent) + + +class ResponseRefiner(Preprocessor): + """Re-annotate key rounds with a strong model for highest quality responses. + + For each key round (from IntentClassifier/IFDFilter), sends the context + to an OpenAI-compatible API and replaces the assistant response with a + refined version containing both reasoning_content and content. + + Rows without key_rounds are discarded. + If refinement fails for a round, the original response is kept. + """ + + def __init__( + self, + backend: LLMBackend = None, + temperature: float = 0.6, + max_tokens: int = 4096, + max_workers: int = 8, + # Legacy params (used to create OpenAIBackend if backend is None) + api_endpoint: str = '', + model: str = 'default', + api_key: str = '', + ): + super().__init__() + if backend is not None: + self._backend = backend + else: + self._backend = OpenAIBackend( + endpoint=api_endpoint, model=model, api_key=api_key, timeout=180.0) + self._temperature = temperature + self._max_tokens = max_tokens + self._max_workers = max_workers + + def __call__(self, rows) -> List[Dict[str, Any]]: + """Refine key round responses in parallel.""" + if not rows: + return rows + + # Collect tasks: (row_idx, round_idx, assistant_idx, messages, intent) + tasks: List[Tuple[int, int, int, List[Dict[str, Any]], str]] = [] + for ri, row in enumerate(rows): + user_data = row.get('user_data') + if not isinstance(user_data, dict): + continue + key_rounds = user_data.get('key_rounds') + if not isinstance(key_rounds, list) or not key_rounds: + continue + messages = row.get('messages') or [] + intents = user_data.get('intents') or {} + for rnd_idx, asst_idx in enumerate(key_rounds): + tasks.append((ri, rnd_idx, asst_idx, messages, intents.get(asst_idx, ''))) + + if not tasks: + # No key rounds anywhere → drop all + logger.info('[ResponseRefiner] no key rounds found, dropping all rows') + return [] + + # Parallel refinement + results: Dict[Tuple[int, int], Optional[Dict[str, str]]] = {} + n_workers = min(self._max_workers, len(tasks)) + with ThreadPoolExecutor(max_workers=n_workers) as pool: + future_to_key = { + pool.submit( + _refine_round, self._backend, + msgs, asst_idx, self._temperature, self._max_tokens, intent, + ): (ri, rnd_idx) + for ri, rnd_idx, asst_idx, msgs, intent in tasks + } + for future in as_completed(future_to_key): + key = future_to_key[future] + try: + results[key] = future.result() + except Exception as e: + logger.warning(f'[ResponseRefiner] round {key} failed: {e}') + results[key] = None + + # Apply refinements + out = [] + n_refined = 0 + n_dropped = 0 + + for ri, row in enumerate(rows): + user_data = row.get('user_data') + if not isinstance(user_data, dict): + n_dropped += 1 + continue + key_rounds = user_data.get('key_rounds') + if not isinstance(key_rounds, list) or not key_rounds: + n_dropped += 1 + continue + + messages = list(row.get('messages') or []) + modified = False + + for rnd_idx, asst_idx in enumerate(key_rounds): + result = results.get((ri, rnd_idx)) + if result is None: + continue + + if asst_idx >= len(messages): + continue + + # Replace assistant content + old_msg = messages[asst_idx] + new_msg = dict(old_msg) + new_msg['content'] = result['content'] + if result['reasoning_content']: + new_msg['reasoning_content'] = result['reasoning_content'] + elif 'reasoning_content' in new_msg: + del new_msg['reasoning_content'] + messages[asst_idx] = new_msg + modified = True + n_refined += 1 + + row = dict(row, messages=messages) + if modified: + row['user_data'] = dict(user_data, refined=True) + out.append(row) + + logger.info( + f'[ResponseRefiner] refined {n_refined} rounds, ' + f'dropped {n_dropped} rows without key_rounds, ' + f'output {len(out)} rows') + return out diff --git a/src/twinkle_agentic/preprocessor/score_filter.py b/src/twinkle_agentic/preprocessor/score_filter.py new file mode 100644 index 00000000..bf1f08df --- /dev/null +++ b/src/twinkle_agentic/preprocessor/score_filter.py @@ -0,0 +1,802 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Pluggable per-round scorer/filter for SFT key rounds. + +Architecture: + + ScoreFilter(backend, scorers=[...]) + ├── pre-fetches logprobs once if any scorer requires them + ├── runs each Scorer in order, collecting ScoreResult per round + ├── trace dump (per-round JSON, multi_turn-style) + └── AND aggregation: a round is kept iff every scorer returns passed=True. + +Built-in scorers (each is its own class): + ChrMinScorer chr_dist_min_pos. LOW = hard = keep. + SIFDScorer IFD / S-IFD-50 / S-IFD-75. Default observe-only. + PassNScorer Self-rollouts judged by an LLM. extras carry rollouts/verdicts. + ParaphraseScorer chr_min over a model paraphrase produced under GT injection. + +Decoupling: + * key_rounds missing/empty → every assistant turn becomes a candidate round. + * intents=None → no intent-based gating (all rounds processed). +""" +import json +import os +import re +import time +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple + +from twinkle.preprocessor import Preprocessor +from twinkle.template import Template +from twinkle.utils import get_logger + +from ..data_format import RoundContext, ScoreResult, Scorer +from .llm_backend import LLMBackend +from .utils import ( + _chr_min_distinct, + _ifd_family_metrics, + _lp_to_jsonable, + _pad_batch, + _to_int_list, +) + +logger = get_logger(only_local_master=False) + +_MIN_RESPONSE_TOKENS = 5 + + +# ============================================================================ +# Built-in scorers +# ============================================================================ + +class ChrMinScorer: + """chr_dist_min_pos. Dual-threshold: keep samples in [low, high).""" + name = 'chr_min' + requires_logprobs = True + + def __init__(self, threshold: float = 0.47): + self._threshold = float(threshold) + + def score(self, contexts: List[RoundContext]) -> List[ScoreResult]: + out: List[ScoreResult] = [] + for ctx in contexts: + cond_lp = ctx.features.get('cond_lp') + asst_lp = ctx.features.get('asst_lp') + score = _chr_min_distinct( + cond_lp, asst_lp, ctx.cond_ids, ctx.asst_ids, ctx.n_prompt, + ) + passed = (score is None) or (score < self._threshold) + out.append(ScoreResult( + score=score, passed=passed, + extras={'threshold': self._threshold}, + )) + return out + + +class SIFDScorer: + """IFD / S-IFD-50 / S-IFD-75. Observation-only by default.""" + name = 'sifd' + requires_logprobs = True + + def __init__(self, ifd_threshold: Optional[float] = None): + # If set, passed = (ifd >= threshold). HIGH IFD = hard = keep. + self._ifd_threshold = ifd_threshold + + def score(self, contexts: List[RoundContext]) -> List[ScoreResult]: + out: List[ScoreResult] = [] + for ctx in contexts: + cond_lp = ctx.features.get('cond_lp') + asst_lp = ctx.features.get('asst_lp') + fam = _ifd_family_metrics( + cond_lp, asst_lp, ctx.cond_ids, ctx.asst_ids, ctx.n_prompt) + score = fam.get('ifd') + if self._ifd_threshold is None or score is None: + passed = True + else: + passed = score >= self._ifd_threshold + out.append(ScoreResult(score=score, passed=passed, extras=dict(fam))) + return out + + + +_JUDGE_SYSTEM_PROMPT = ( + 'You are a strict but fair answer grader. Judge whether the [Model Answer] is acceptable based on the reference answer (Ground Truth).\n' + 'Evaluate the following three aspects; if any has a major issue, return FAIL:\n\n' + '1. Computational/factual correctness: whether the final conclusion, numbers, and key factual statements match the reference answer;\n' + '2. Reasoning/approach similarity: whether the solution path, key steps, and considered dimensions are close to the reference answer;\n' + ' For open-ended questions (no single correct answer), assess whether the style, stance, and considered dimensions align with the reference answer;\n' + '3. Completeness: the answer is not truncated, ends naturally, and covers all points of the question.\n\n' + 'First give a brief 1-3 sentence justification, then on the last line strictly output:\n' + 'PASS or FAIL' +) + + +class PassNScorer: + """Self-rollouts (n × per round) judged by an LLM.""" + name = 'pass_n' + requires_logprobs = False + + def __init__( + self, + backend: LLMBackend, + judge_api=None, + judge_model: Optional[str] = None, + judge_base_url: Optional[str] = None, + judge_api_key: Optional[str] = None, + judge_client_kwargs: Optional[Dict[str, Any]] = None, + n: int = 4, + min_pass: int = 0, + sample_temperature: float = 0.7, + sample_max_tokens: int = 4096, + judge_temperature: float = 0.0, + judge_max_tokens: int = 512, + judge_max_rollout_chars: int = 8000, + judge_max_workers: int = 8, + ): + self._backend = backend + self._judge_api = self._build_judge_api( + judge_api, judge_model, judge_base_url, + judge_api_key, judge_client_kwargs) + self._n = max(1, int(n)) + self._min_pass = int(min_pass) + self._sample_temperature = float(sample_temperature) + self._sample_max_tokens = int(sample_max_tokens) + self._judge_temperature = float(judge_temperature) + self._judge_max_tokens = int(judge_max_tokens) + self._judge_max_rollout_chars = int(judge_max_rollout_chars) + self._judge_max_workers = max(1, int(judge_max_workers)) + if self._judge_api is None: + logger.warning( + '[PassNScorer] no judge_api configured; rollouts will be sampled ' + 'without verdicts (every round trivially passes).') + + @staticmethod + def _build_judge_api(api, model, base_url, api_key, client_kwargs): + if api is not None: + return api + if not model: + return None + from twinkle_agentic.protocol.openai import OpenAI as OpenAIAPI + return OpenAIAPI( + model=model, api_key=api_key, base_url=base_url, + client_kwargs=client_kwargs) + + @staticmethod + def _extract_text_from_choice(choice: Any) -> str: + if not isinstance(choice, dict): + return '' + parts: List[str] = [] + rc = choice.get('reasoning_content') + if isinstance(rc, str) and rc.strip(): + parts.append(f'\n{rc.strip()}\n') + content = choice.get('content') + if isinstance(content, str) and content.strip(): + parts.append(content.strip()) + if parts: + return '\n\n'.join(parts) + return content if isinstance(content, str) else '' + + @staticmethod + def _truncate(text: str, max_chars: int) -> str: + if not isinstance(text, str) or max_chars <= 0 or len(text) <= max_chars: + return text + head = max_chars * 2 // 3 + tail = max_chars - head - 32 + if tail <= 0: + return text[:max_chars] + return text[:head] + '\n\n...[truncated]...\n\n' + text[-tail:] + + @staticmethod + def _parse_verdict(judge_text: str) -> Optional[bool]: + if not isinstance(judge_text, str): + return None + compact = ''.join(judge_text.upper().split()) + has_pass = 'PASS' in compact + has_fail = 'FAIL' in compact + if has_pass and not has_fail: + return True + if has_fail and not has_pass: + return False + # Fallback: keyword scan in the tail (last 200 chars, post-compact). + tail = compact[-200:] + if 'PASS' in tail and 'FAIL' not in tail: + return True + if 'FAIL' in tail and 'PASS' not in tail: + return False + return None + + def _judge_one(self, user_prompt: str, gt_text: str, rollout_text: str) -> Tuple[bool, str]: + if self._judge_api is None: + return True, '(no judge configured)' + if not rollout_text or not rollout_text.strip(): + return False, '(empty rollout)' + from twinkle.data_format.sampling import SamplingParams + body = ( + f'[问题]\n{self._truncate(user_prompt, self._judge_max_rollout_chars)}\n\n' + f'[参考答案]\n{self._truncate(gt_text, self._judge_max_rollout_chars)}\n\n' + f'[模型回答]\n{self._truncate(rollout_text, self._judge_max_rollout_chars)}\n\n' + '请评分。' + ) + trajectory = {'messages': [ + {'role': 'system', 'content': _JUDGE_SYSTEM_PROMPT}, + {'role': 'user', 'content': body}, + ]} + sp = SamplingParams( + temperature=self._judge_temperature, + max_tokens=self._judge_max_tokens, + num_samples=1, + ) + # extra_body forwards `enable_thinking=False` so the judge skips CoT. + msg = self._judge_api(trajectory, sp, extra_body={'enable_thinking': False}) + if isinstance(msg, list): + msg = msg[0] if msg else {} + text = msg.get('content', '') if isinstance(msg, dict) else str(msg) + text = text or '' + verdict = self._parse_verdict(text) + # Conservative default: ambiguous verdict → FAIL. + return bool(verdict) if verdict is not None else False, text + + def score(self, contexts: List[RoundContext]) -> List[ScoreResult]: + if not contexts: + return [] + ctx_msgs = [ctx.context_messages for ctx in contexts] + batched = self._backend.chat_batch( + ctx_msgs, + temperature=self._sample_temperature, + max_tokens=self._sample_max_tokens, + n=self._n, + ) or [] + + while len(batched) < len(contexts): + batched.append([]) + + from concurrent.futures import ThreadPoolExecutor + work: List[Tuple[int, int, str, str, str]] = [] + for i, (ctx, choices) in enumerate(zip(contexts, batched)): + if not isinstance(choices, list): + continue + for r_i, choice in enumerate(choices): + rt = self._extract_text_from_choice(choice) + work.append((i, r_i, ctx.user_prompt, ctx.asst_text, rt)) + + verdict_by_round: Dict[int, List[Tuple[int, bool, str]]] = {} + if work and self._judge_api is not None: + def _do(item): + i, r_i, up, gt, rt = item + ok, raw = self._judge_one(up, gt, rt) + return i, r_i, ok, raw + with ThreadPoolExecutor(max_workers=self._judge_max_workers) as ex: + for i, r_i, ok, raw in ex.map(_do, work): + verdict_by_round.setdefault(i, []).append((r_i, ok, raw)) + + out: List[ScoreResult] = [] + for i, (ctx, choices) in enumerate(zip(contexts, batched)): + rollouts = [ + {'rollout_idx': r_i, + 'content': self._extract_text_from_choice(c)} + for r_i, c in enumerate(choices or []) + ] + verdicts = sorted(verdict_by_round.get(i, []), key=lambda x: x[0]) + judgments = [ + {'rollout_idx': r_i, 'passed': bool(p), 'judge_raw': raw} + for r_i, p, raw in verdicts + ] + pass_count = sum(1 for _, p, _ in verdicts if p) + score = (pass_count / self._n) if rollouts else None + passed = pass_count >= self._min_pass + out.append(ScoreResult( + score=score, passed=passed, + extras={ + 'pass_count': pass_count, + 'n_rollouts': len(rollouts), + 'rollouts': rollouts, + 'judgments': judgments, + 'min_pass': self._min_pass, + }, + )) + + scored = [r for r in out if r.score is not None] + if scored: + avg = sum(r.score for r in scored) / len(scored) + logger.info( + f'[PassNScorer] graded {len(scored)}/{len(out)} rounds × {self._n} ' + f'rollouts; avg pass-rate = {avg:.3f}') + return out + + +class ParaphraseScorer: + """Generate a model paraphrase under GT injection, then re-score chr_min.""" + name = 'paraphrase' + # Owns its own logprob fetch on the rewritten asst tokens. + requires_logprobs = False + + def __init__( + self, + backend: LLMBackend, + template: Template, + chr_min_threshold: Optional[float] = None, + prompt_budget: int = 4096, + sample_temperature: float = 0.7, + sample_max_tokens: int = 4096, + max_prompt_tokens: int = 1024, + ): + self._backend = backend + self._template = template + self._threshold = chr_min_threshold + self._prompt_budget = int(prompt_budget) + self._sample_temperature = float(sample_temperature) + self._sample_max_tokens = int(sample_max_tokens) + self._max_prompt_tokens = int(max_prompt_tokens) + + @staticmethod + def _inject_gt(context_messages, gt_text): + msgs = [dict(m) if isinstance(m, dict) else m for m in context_messages] + instr = ( + 'Below is the reference answer to this question, for your reference only:\n\n' + f'\n{gt_text}\n\n\n' + 'Based on the reference answer above, please provide a complete answer to the preceding question in your own words and reasoning. ' + 'Output your answer directly; do not repeat the reference answer verbatim.' + ) + if msgs and isinstance(msgs[-1], dict) and msgs[-1].get('role') == 'user': + last = dict(msgs[-1]) + last['content'] = (last.get('content') or '') + '\n\n' + instr + msgs[-1] = last + else: + msgs.append({'role': 'user', 'content': instr}) + return msgs + + def _truncate_gt(self, gt_text: str, n_prompt: int) -> Optional[str]: + # 80 = conservative instruction-template overhead. + budget = self._prompt_budget - n_prompt - 80 + if budget < 50: + return None + gt_ids = _to_int_list(self._template.tokenizer( + gt_text, add_special_tokens=False)['input_ids']) + if len(gt_ids) <= budget: + return gt_text + return self._template.tokenizer.decode( + gt_ids[:budget], skip_special_tokens=False) + + def _encode_prompt(self, ctx_msgs): + ids = _to_int_list(self._template.encode( + {'messages': list(ctx_msgs)}, add_generation_prompt=True)['input_ids']) + if self._max_prompt_tokens <= 0 or len(ids) <= self._max_prompt_tokens: + return ids + return ids[-self._max_prompt_tokens:] + + def score(self, contexts: List[RoundContext]) -> List[ScoreResult]: + if not contexts: + return [] + + keys: List[int] = [] + augmented: List[List[Dict[str, Any]]] = [] + for i, ctx in enumerate(contexts): + gt = self._truncate_gt(ctx.asst_text, ctx.n_prompt) + if gt is None or not ctx.context_messages: + continue + keys.append(i) + augmented.append(self._inject_gt(ctx.context_messages, gt)) + + out: List[ScoreResult] = [ + ScoreResult(score=None, passed=True, + extras={'reason': 'paraphrase skipped'}) + for _ in contexts + ] + if not keys: + return out + + batched = self._backend.chat_batch( + augmented, + temperature=self._sample_temperature, + max_tokens=self._sample_max_tokens, + n=1, + ) or [] + + # Re-tokenize against the ORIGINAL (no-GT) context so logprobs reflect + # pure self-conditional probability of the paraphrase. + para_data: Dict[int, Tuple[List[int], int, List[int], str]] = {} + for i, choices in zip(keys, batched): + text = None + if choices: + c0 = choices[0] + if isinstance(c0, dict): + text = c0.get('content') + if not isinstance(text, str) or not text.strip(): + continue + ctx = contexts[i] + prompt_ids = self._encode_prompt(ctx.context_messages) + asst_ids = _to_int_list(self._template.tokenizer( + text, add_special_tokens=False)['input_ids']) + if len(asst_ids) < _MIN_RESPONSE_TOKENS + 1: + continue + cond_ids = prompt_ids + asst_ids + para_data[i] = (cond_ids, len(prompt_ids), asst_ids, text) + + if not para_data: + return out + + ordered = list(para_data.keys()) + cond_batch = [para_data[i][0] for i in ordered] + asst_batch = [para_data[i][2] for i in ordered] + cond_lps = self._backend.prompt_logprobs_ids(cond_batch) + asst_lps = self._backend.prompt_logprobs_ids(asst_batch) + + for i, cond_lp, asst_lp in zip(ordered, cond_lps, asst_lps): + cond_ids, n_prompt, asst_ids, text = para_data[i] + score = _chr_min_distinct(cond_lp, asst_lp, cond_ids, asst_ids, n_prompt) + if self._threshold is None or score is None: + passed = True + else: + passed = score < self._threshold + out[i] = ScoreResult( + score=score, passed=passed, + extras={ + 'paraphrase_text': text, + 'n_prompt': n_prompt, + 'cond_lp': _lp_to_jsonable(cond_lp), + 'asst_lp': _lp_to_jsonable(asst_lp), + 'threshold': self._threshold, + }, + ) + + logger.info( + f'[ParaphraseScorer] paraphrased + scored {len(para_data)}/' + f'{len(contexts)} rounds') + return out + + +# ============================================================================ +# ScoreFilter (Preprocessor entry point) +# ============================================================================ + +class ScoreFilter(Preprocessor): + """Score and filter assistant turns by a pluggable scorer set. + + A round is kept iff every scorer returns ``passed=True``. Rows that lose + all key rounds are dropped (configurable via ``keep_if_no_key_rounds``). + + Decoupling rules: + * `key_rounds` missing/empty in `user_data` → every assistant turn + becomes a candidate round. + * `intents=None` → no intent-based gating. + """ + + def __init__( + self, + template: Template, + backend: LLMBackend, + scorers: List[Scorer], + intents: Optional[Iterable[str]] = None, + keep_if_no_key_rounds: bool = False, + drop_row_on_any_fail: bool = True, + max_prompt_tokens: int = 1024, + trace_dir: Optional[str] = None, + trace_callback: Optional[Callable[[Dict[str, Any]], bool]] = None, + success_callback: Optional[Callable[[Dict[str, Any]], bool]] = None, + ): + super().__init__() + if not isinstance(template, Template): + raise TypeError( + f'ScoreFilter requires a `Template` instance, got ' + f'{type(template).__name__}.') + self._template = template + self._backend = backend + self._scorers = list(scorers) + self._intents: Optional[Set[str]] = ( + None if intents is None else set(intents)) + self._keep_if_no_key_rounds = bool(keep_if_no_key_rounds) + self._drop_row_on_any_fail = bool(drop_row_on_any_fail) + self._max_prompt_tokens = int(max_prompt_tokens) + self._trace_dir = trace_dir + self._trace_callback = trace_callback + self._success_callback = success_callback + if self._trace_dir: + import shutil + if os.path.exists(self._trace_dir): + shutil.rmtree(self._trace_dir) + os.makedirs(self._trace_dir, exist_ok=True) + + def __call__(self, rows): + rows_list = self.map_col_to_row(rows) + contexts = self._build_contexts(rows_list) + if contexts: + score_table = self._score_contexts(contexts) + self._log_score_summary(contexts, score_table) + if self._trace_dir: + self._write_traces(contexts, score_table) + rows_list = self._apply_filter(rows_list, contexts, score_table) + return self.map_row_to_col(rows_list) + + def _log_score_summary(self, contexts, score_table): + for scorer in self._scorers: + scores = [t[scorer.name].score for t in score_table + if scorer.name in t and t[scorer.name].score is not None] + if not scores: + continue + n_pass = sum(1 for t in score_table + if scorer.name in t and t[scorer.name].passed) + extras_sample = {} + for t in score_table: + if scorer.name in t and t[scorer.name].extras: + extras_sample = t[scorer.name].extras + break + extra_keys = [k for k in extras_sample if k != 'threshold'] + extra_stats = '' + for k in extra_keys: + vals = [t[scorer.name].extras.get(k) for t in score_table + if scorer.name in t and t[scorer.name].extras + and t[scorer.name].extras.get(k) is not None] + if vals and isinstance(vals[0], (int, float)): + avg = sum(vals) / len(vals) + extra_stats += f', {k}_avg={avg:.4f}' + logger.info( + f'[ScoreFilter/{scorer.name}] n={len(scores)}, ' + f'mean={sum(scores)/len(scores):.4f}, ' + f'min={min(scores):.4f}, max={max(scores):.4f}, ' + f'pass={n_pass}/{len(score_table)}' + f'{extra_stats}') + + # ---- scoring (inlined DefaultScoreCalculator) -------------------------- + + def _score_contexts(self, contexts: List[RoundContext]) -> List[Dict[str, ScoreResult]]: + if any(getattr(s, 'requires_logprobs', False) for s in self._scorers): + self._attach_logprobs(contexts) + out: List[Dict[str, ScoreResult]] = [dict() for _ in contexts] + for scorer in self._scorers: + results = scorer.score(contexts) + if len(results) != len(contexts): + raise RuntimeError( + f'scorer {scorer.name!r} returned {len(results)} results ' + f'for {len(contexts)} contexts') + for i, r in enumerate(results): + out[i][scorer.name] = r + return out + + def _attach_logprobs(self, contexts: List[RoundContext]) -> None: + cond_batch = [ctx.cond_ids for ctx in contexts] + asst_batch = [ctx.asst_ids for ctx in contexts] + floor = self._batch_floor() + cond_padded, n_cond = _pad_batch(cond_batch, floor) + asst_padded, n_asst = _pad_batch(asst_batch, floor) + cond_lps = self._backend.prompt_logprobs_ids(cond_padded)[:n_cond] + asst_lps = self._backend.prompt_logprobs_ids(asst_padded)[:n_asst] + for ctx, c, a in zip(contexts, cond_lps, asst_lps): + ctx.features['cond_lp'] = c + ctx.features['asst_lp'] = a + + def _batch_floor(self) -> int: + sampler = getattr(self._backend, '_sampler', None) + device_mesh = getattr(sampler, 'device_mesh', None) + return getattr(device_mesh, 'dp_world_size', 1) or 1 + + # ---- context construction -------------------------------------------- + + def _build_contexts(self, rows: List[Dict[str, Any]]) -> List[RoundContext]: + out: List[RoundContext] = [] + for ri, row in enumerate(rows): + messages = row.get('messages') if isinstance(row, dict) else None + if not isinstance(messages, list): + continue + user_data = row.get('user_data') if isinstance(row, dict) else None + key_rounds = (user_data.get('key_rounds') + if isinstance(user_data, dict) else None) + if not isinstance(key_rounds, list) or not key_rounds: + key_rounds = [ + i for i, m in enumerate(messages) + if isinstance(m, dict) and m.get('role') == 'assistant' + ] + for rnd_idx, asst_idx in enumerate(key_rounds): + if not isinstance(asst_idx, int): + continue + intent = self._lookup_intent(row, asst_idx) + if self._intents is not None and intent not in self._intents: + continue + ctx = self._prepare_round(row, messages, ri, rnd_idx, asst_idx, intent) + if ctx is not None: + out.append(ctx) + return out + + def _prepare_round( + self, + row: Dict[str, Any], + messages: List[Dict[str, Any]], + ri: int, rnd_idx: int, asst_idx: int, + intent: Optional[str], + ) -> Optional[RoundContext]: + if not (0 <= asst_idx < len(messages)): + return None + asst_msg = messages[asst_idx] + if not isinstance(asst_msg, dict) or asst_msg.get('role') != 'assistant': + return None + asst_text = asst_msg.get('content') or '' + if isinstance(asst_text, list): + asst_text = ' '.join(p.get('text', '') for p in asst_text + if isinstance(p, dict) and p.get('type') == 'text') + if not asst_text.strip(): + return None + context_messages = messages[:asst_idx] + if not context_messages: + return None + prompt_ids = self._encode_prompt_within_budget(context_messages) + # Raw asst_ids (no chat-template wrapping) so cond/asst share byte-equal + # A-token sequences; otherwise chr_min positions desync. + asst_ids = _to_int_list(self._template.tokenizer( + asst_text, add_special_tokens=False)['input_ids']) + if len(asst_ids) < _MIN_RESPONSE_TOKENS + 1: + return None + return RoundContext( + row_idx=ri, rnd_idx=rnd_idx, asst_idx=asst_idx, + row=row, intent=intent, + messages=messages, + context_messages=context_messages, + cond_ids=prompt_ids + asst_ids, + n_prompt=len(prompt_ids), + asst_ids=asst_ids, + asst_text=asst_text, + user_prompt=self._render_user_prompt(context_messages), + ) + + def _encode_prompt_within_budget(self, ctx_msgs: List[Dict[str, Any]]) -> List[int]: + ctx = list(ctx_msgs) + ids = _to_int_list(self._template.encode( + {'messages': ctx}, add_generation_prompt=True)['input_ids']) + budget = self._max_prompt_tokens + if budget <= 0 or len(ids) <= budget: + return ids + has_sys = bool(ctx) and isinstance(ctx[0], dict) and ctx[0].get('role') == 'system' + body_start = 1 if has_sys else 0 + while len(ctx) - body_start > 1: + ctx.pop(body_start) + ids = _to_int_list(self._template.encode( + {'messages': ctx}, add_generation_prompt=True)['input_ids']) + if len(ids) <= budget: + return ids + # Single message still over budget → keep tail tokens. + return ids[-budget:] + + @staticmethod + def _render_user_prompt(ctx_msgs: List[Dict[str, Any]]) -> str: + parts: List[str] = [] + for m in ctx_msgs: + if not isinstance(m, dict): + continue + role = m.get('role') or 'user' + content = m.get('content', '') + if isinstance(content, list): + content = ' '.join(p.get('text', '') for p in content + if isinstance(p, dict) and p.get('type') == 'text') + if isinstance(content, str) and content.strip(): + parts.append(f'[{role}] {content.strip()}') + return '\n\n'.join(parts) + + @staticmethod + def _lookup_intent(row: Dict[str, Any], asst_idx: int) -> Optional[str]: + user_data = row.get('user_data') if isinstance(row, dict) else None + if not isinstance(user_data, dict): + return None + intents = user_data.get('intents') + if not isinstance(intents, dict): + return None + v = intents.get(asst_idx) + if v is None: + v = intents.get(str(asst_idx)) + return v if isinstance(v, str) else None + + # ---- trace dump (multi_turn-style) ----------------------------------- + + def _write_traces( + self, + contexts: List[RoundContext], + score_table: List[Dict[str, ScoreResult]], + ) -> None: + for i, ctx in enumerate(contexts): + try: + scores = score_table[i] if i < len(score_table) else {} + kept = all(r.passed for r in scores.values()) if scores else True + record = self._build_trace_record(ctx, scores, kept) + if self._trace_callback is not None and not bool(self._trace_callback(record)): + continue + success = ( + bool(self._success_callback(record)) + if self._success_callback is not None else kept + ) + prefix = 'ok' if success else 'fail' + rid = f'{ctx.row_idx}-{ctx.asst_idx}-{i}-{int(time.time() * 1000)}' + rid = re.sub(r'[^A-Za-z0-9_\-.]+', '_', rid)[:64] + path = os.path.join(self._trace_dir, f'{prefix}-{rid}.json') + with open(path, 'w', encoding='utf-8') as f: + json.dump(record, f, ensure_ascii=False, + indent=2, default=str) + except Exception as e: + # Observability must never break filtering; surface the cause. + logger.warning( + f'[ScoreFilter] trace dump failed for row={ctx.row_idx} ' + f'asst={ctx.asst_idx}: {e}') + + @staticmethod + def _build_trace_record( + ctx: RoundContext, + scores: Dict[str, ScoreResult], + kept: bool, + ) -> Dict[str, Any]: + return { + 'row_idx': ctx.row_idx, + 'rnd_idx': ctx.rnd_idx, + 'asst_idx': ctx.asst_idx, + 'intent': ctx.intent, + 'messages': ctx.messages, + 'n_prompt': ctx.n_prompt, + 'cond_ids': ctx.cond_ids, + 'asst_ids': ctx.asst_ids, + 'features': { + k: (_lp_to_jsonable(v) if k.endswith('_lp') else v) + for k, v in ctx.features.items() + }, + 'scores': { + name: {'score': r.score, 'passed': r.passed, 'extras': r.extras} + for name, r in scores.items() + }, + 'kept': bool(kept), + } + + # ---- aggregation & row reassembly ------------------------------------ + + def _apply_filter( + self, + rows: List[Dict[str, Any]], + contexts: List[RoundContext], + score_table: List[Dict[str, ScoreResult]], + ) -> List[Dict[str, Any]]: + per_row: Dict[int, Dict[str, Any]] = {} + for i, ctx in enumerate(contexts): + scores = score_table[i] if i < len(score_table) else {} + passed = all(r.passed for r in scores.values()) if scores else True + slot = per_row.setdefault(ctx.row_idx, { + 'kept': [], 'failed': 0, + }) + if passed: + slot['kept'].append(ctx.asst_idx) + else: + slot['failed'] += 1 + + out: List[Dict[str, Any]] = [] + n_removed_rounds = 0 + n_removed_rows = 0 + for ri, row in enumerate(rows): + user_data = row.get('user_data') if isinstance(row, dict) else None + had_key_rounds = ( + isinstance(user_data, dict) + and isinstance(user_data.get('key_rounds'), list) + and bool(user_data['key_rounds']) + ) + decision = per_row.get(ri) + + if decision is None: + # Row produced no contexts (no asst turns or filtered by intent). + if had_key_rounds and not self._keep_if_no_key_rounds: + n_removed_rows += 1 + continue + if self._intents is not None and not self._keep_if_no_key_rounds: + n_removed_rows += 1 + continue + out.append(row) + continue + + n_removed_rounds += decision['failed'] + kept = decision['kept'] + if had_key_rounds: + if not kept: + n_removed_rows += 1 + continue + new_row = dict(row) + new_row['user_data'] = dict(user_data, key_rounds=list(kept)) + out.append(new_row) + else: + if decision['failed'] > 0 and self._drop_row_on_any_fail: + n_removed_rows += 1 + continue + out.append(row) + + logger.info( + f'[ScoreFilter] removed {n_removed_rounds} rounds, ' + f'dropped {n_removed_rows} rows, kept {len(out)}/{len(rows)}') + return out diff --git a/src/twinkle_agentic/preprocessor/token_soup.py b/src/twinkle_agentic/preprocessor/token_soup.py new file mode 100644 index 00000000..6f88dd9a --- /dev/null +++ b/src/twinkle_agentic/preprocessor/token_soup.py @@ -0,0 +1,156 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import re +import unicodedata +from collections import Counter +from typing import Any, Dict, List + +from twinkle.preprocessor import Preprocessor + +# ── Pre-compiled patterns ───────────────────────────────────────────────────── + +# Unicode replacement character +_REPLACEMENT_CHAR_RE = re.compile(r'\ufffd') + +# Non-printable control chars (keep \t \n \r as legitimate whitespace) +_CONTROL_CHAR_RE = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]') + +# Unicode private use area (E000–F8FF, F0000–FFFFF, 100000–10FFFF) +_PRIVATE_USE_RE = re.compile(r'[\ue000-\uf8ff\U000f0000-\U000fffff\U00100000-\U0010ffff]') + +# Chat-template special tokens repeated ≥ _SPECIAL_TOKEN_COUNT times. +# Bracket-style BERT tokens (PAD/UNK/SEP/CLS/MASK) are case-sensitive via (?-i:...) — +# lowercase "[mask]"/"[pad]" collide with ordinary bitmask-DP variable names like dp[mask]. +_SPECIAL_TOKEN_RE = re.compile( + r'(<\|[^|>\n]{1,40}\|>||(?-i:\[/?(?:PAD|UNK|SEP|CLS|MASK)\])|||<0x[0-9A-Fa-f]{2}>)', + re.IGNORECASE, +) + +# Same printable character repeated 20+ times consecutively. +# Excludes whitespace and chars commonly used as legitimate decorations / numerical output: +# - ASCII rule/separator chars: - = _ . * + ~ # | > < +# - Digits 0-9 (float precision padding, test fixtures like 999999..., 111111...) +# - Box drawing (U+2500-257F), Block elements (U+2580-259F), +# Geometric shapes (U+25A0-25FF), Braille patterns (U+2800-28FF) +# - Em/en dash (U+2013-2015), fullwidth dash/hyphen (U+30FC, U+FF0D) +_SINGLE_CHAR_REPEAT_RE = re.compile( + r'([^\s\n\-=_.\*\+~#|><0-9\u2013-\u2015\u2500-\u25ff\u2800-\u28ff\u30fc\uff0d])\1{19,}' +) + + +# ── Unicode script classifier ───────────────────────────────────────────────── + +def _script_of(cp: int) -> str: + """Map a codepoint to a coarse script bucket.""" + if cp <= 0x024F: return 'latin' + if 0x0370 <= cp <= 0x03FF: return 'greek' + if 0x0400 <= cp <= 0x04FF: return 'cyrillic' + if 0x0590 <= cp <= 0x05FF: return 'hebrew' + if 0x0600 <= cp <= 0x06FF: return 'arabic' + if 0x0900 <= cp <= 0x097F: return 'devanagari' + if 0x0E00 <= cp <= 0x0E7F: return 'thai' + if 0x3040 <= cp <= 0x309F: return 'hiragana' + if 0x30A0 <= cp <= 0x30FF: return 'katakana' + if 0x4E00 <= cp <= 0x9FFF: return 'cjk' + if 0xAC00 <= cp <= 0xD7A3: return 'hangul' + if 0xE000 <= cp <= 0xF8FF: return 'private' + return 'other' + + +def _script_chaos(text: str, min_chars: int = 40) -> float: + """Return the fraction of adjacent non-space char pairs that switch script.""" + chars = [c for c in text if unicodedata.category(c)[0] in ('L', 'N')] + if len(chars) < min_chars: + return 0.0 + scripts = [_script_of(ord(c)) for c in chars] + switches = sum(a != b for a, b in zip(scripts, scripts[1:])) + return switches / (len(scripts) - 1) + + +# ── Per-signal detectors ────────────────────────────────────────────────────── + +def _ratio(pattern: re.Pattern, text: str) -> float: + return len(pattern.findall(text)) / max(len(text), 1) + + +def _is_token_soup( + text: str, + replacement_char_ratio: float = 0.02, + control_char_ratio: float = 0.01, + private_use_ratio: float = 0.03, + special_token_count: int = 20, + script_chaos_threshold: float = 0.55, + script_chaos_min_chars: int = 40, + max_chars: int = 0, +) -> bool: + """Return True if the text exhibits any garbled-output signal.""" + if not text: + return False + # Token-soup signals are statistical/uniform; sampling the head captures them + # at near-constant cost regardless of full-text length. + if max_chars and len(text) > max_chars: + text = text[:max_chars] + if _ratio(_REPLACEMENT_CHAR_RE, text) > replacement_char_ratio: + return True + if _ratio(_CONTROL_CHAR_RE, text) > control_char_ratio: + return True + if _ratio(_PRIVATE_USE_RE, text) > private_use_ratio: + return True + if len(_SPECIAL_TOKEN_RE.findall(text)) >= special_token_count: + return True + if _SINGLE_CHAR_REPEAT_RE.search(text): + return True + if _script_chaos(text, script_chaos_min_chars) > script_chaos_threshold: + return True + return False + + +# ── Preprocessor ───────────────────────────────────────────────────────────── + +class TokenSoupFilter(Preprocessor): + + def __init__( + self, + replacement_char_ratio: float = 0.02, + control_char_ratio: float = 0.01, + private_use_ratio: float = 0.03, + special_token_count: int = 20, + script_chaos_threshold: float = 0.55, + script_chaos_min_chars: int = 40, + max_chars: int = 0, + ) -> None: + super().__init__() + self._replacement_char_ratio = replacement_char_ratio + self._control_char_ratio = control_char_ratio + self._private_use_ratio = private_use_ratio + self._special_token_count = special_token_count + self._script_chaos_threshold = script_chaos_threshold + self._script_chaos_min_chars = script_chaos_min_chars + self._max_chars = max_chars + + def __call__(self, rows) -> List[Dict[str, Any]]: + out = [] + for row in rows: + messages = row.get('messages') or [] + asst_msgs = [ + m for m in messages + if isinstance(m, dict) and m.get('role') == 'assistant' + ] + if not asst_msgs: + out.append(row) + continue + if any( + _is_token_soup( + (m.get('content') or '').strip(), + replacement_char_ratio=self._replacement_char_ratio, + control_char_ratio=self._control_char_ratio, + private_use_ratio=self._private_use_ratio, + special_token_count=self._special_token_count, + script_chaos_threshold=self._script_chaos_threshold, + script_chaos_min_chars=self._script_chaos_min_chars, + max_chars=self._max_chars, + ) + for m in asst_msgs + ): + continue + out.append(row) + return out diff --git a/src/twinkle_agentic/preprocessor/utils.py b/src/twinkle_agentic/preprocessor/utils.py new file mode 100644 index 00000000..a447aa75 --- /dev/null +++ b/src/twinkle_agentic/preprocessor/utils.py @@ -0,0 +1,215 @@ +"""Pure helpers shared across preprocessor scorers (logprob extraction & metric formulas).""" +import math +from typing import Any, Dict, List, Optional, Set, Tuple + + +def _extract_logprob(lp, token_id: Optional[int] = None) -> Optional[float]: + if lp is None: + return None + if isinstance(lp, (int, float)): + return float(lp) + if not isinstance(lp, dict): + return None + # vLLM with prompt_logprobs=1 returns top-1 PLUS actual token if they differ; + # actual is appended LAST, so iter-first picks the wrong (top-1) one. + entry = None + if token_id is not None: + entry = lp.get(token_id) + if entry is None: + entry = lp.get(str(token_id)) + if entry is None: + entry = next(iter(lp.values()), None) + if entry is None: + return None + if hasattr(entry, 'logprob'): + return float(entry.logprob) + if isinstance(entry, dict): + v = entry.get('logprob') + return float(v) if v is not None else None + if isinstance(entry, (int, float)): + return float(entry) + return None + + +def _to_int_list(x) -> List[int]: + if hasattr(x, 'tolist'): + return x.tolist() + return list(x) + + +def _chr_min_distinct( + cond_lp: List, asst_lp: List, + cond_ids: List[int], asst_ids: List[int], + n_prompt: int, + exclude_ids: Optional[Set[int]] = None, +) -> Optional[float]: + """chr_dist_min_pos: fraction of distinct asst-token ids whose + per-occurrence min(cond_lp - asst_lp) is strictly positive.""" + if not asst_lp or not cond_lp or not asst_ids: + return None + n_a = min(len(asst_lp), len(asst_ids)) + n_c = len(cond_lp) + by_tok: Dict[int, List[float]] = {} + for i in range(n_a): + ci = n_prompt + i + if ci >= n_c: + break + tid = asst_ids[i] + if tid is None: + continue + if exclude_ids is not None and int(tid) in exclude_ids: + continue + a = _extract_logprob(asst_lp[i], tid) + c_tok = cond_ids[ci] if ci < len(cond_ids) else None + c = _extract_logprob(cond_lp[ci], c_tok) + if a is None or c is None: + continue + by_tok.setdefault(int(tid), []).append(c - a) + if not by_tok: + return None + pos = sum(1 for diffs in by_tok.values() if min(diffs) > 0) + return pos / len(by_tok) + + +def _chr_min_weighted( + cond_lp: List, asst_lp: List, + cond_ids: List[int], asst_ids: List[int], + n_prompt: int, +) -> Optional[float]: + """Magnitude-weighted chr_min: each distinct token contributes |min_delta| + as weight; returns sum(pos_weights) / sum(all_weights).""" + if not asst_lp or not cond_lp or not asst_ids: + return None + n_a = min(len(asst_lp), len(asst_ids)) + n_c = len(cond_lp) + by_tok: Dict[int, List[float]] = {} + for i in range(n_a): + ci = n_prompt + i + if ci >= n_c: + break + tid = asst_ids[i] + if tid is None: + continue + a = _extract_logprob(asst_lp[i], tid) + c_tok = cond_ids[ci] if ci < len(cond_ids) else None + c = _extract_logprob(cond_lp[ci], c_tok) + if a is None or c is None: + continue + by_tok.setdefault(int(tid), []).append(c - a) + if not by_tok: + return None + total_w = 0.0 + pos_w = 0.0 + for diffs in by_tok.values(): + md = min(diffs) + w = abs(md) + total_w += w + if md > 0: + pos_w += w + if total_w == 0: + return None + return pos_w / total_w + + +def _ifd_family_metrics( + cond_lp: List, asst_lp: List, + cond_ids: List[int], asst_ids: List[int], + n_prompt: int, +) -> Dict[str, Any]: + """IFD (Cherry-LLM) and S-IFD-{50,75} (T-SHIRT) for one round.""" + if not asst_lp or not cond_lp or not asst_ids: + return {} + n_a = min(len(asst_lp), len(asst_ids)) + n_c = len(cond_lp) + deltas: List[float] = [] + for i in range(n_a): + ci = n_prompt + i + if ci >= n_c: + break + tid = asst_ids[i] + if tid is None: + continue + a = _extract_logprob(asst_lp[i], tid) + c_tok = cond_ids[ci] if ci < len(cond_ids) else None + c = _extract_logprob(cond_lp[ci], c_tok) + if a is None or c is None: + continue + deltas.append(c - a) + if not deltas: + return {} + n = len(deltas) + mean_delta = sum(deltas) / n + out: Dict[str, Any] = { + 'n_tokens': n, + 'mean_delta': mean_delta, + 'ifd': math.exp(-mean_delta), + } + abs_sorted = sorted(range(n), key=lambda i: abs(deltas[i]), reverse=True) + for k_pct in (50, 75): + keep = max(1, int(round(n * k_pct / 100))) + sub = [deltas[i] for i in abs_sorted[:keep]] + out[f's_ifd_{k_pct}'] = math.exp(-sum(sub) / len(sub)) + return out + + +def _mean_logprob_delta( + cond_lp: List, asst_lp: List, + cond_ids: List[int], asst_ids: List[int], + n_prompt: int, +) -> Optional[float]: + """Mean per-token (cond_lp - asst_lp) over the response span.""" + if not asst_lp or not cond_lp or not asst_ids: + return None + n_a = min(len(asst_lp), len(asst_ids)) + n_c = len(cond_lp) + deltas: List[float] = [] + for i in range(n_a): + ci = n_prompt + i + if ci >= n_c: + break + tid = asst_ids[i] + if tid is None: + continue + a = _extract_logprob(asst_lp[i], tid) + c_tok = cond_ids[ci] if ci < len(cond_ids) else None + c = _extract_logprob(cond_lp[ci], c_tok) + if a is None or c is None: + continue + deltas.append(c - a) + if not deltas: + return None + return sum(deltas) / len(deltas) + + +def _lp_to_jsonable(lp_list): + """Convert per-position prompt_logprobs into JSON-safe form.""" + out = [] + for lp in (lp_list or []): + if lp is None: + out.append(None) + continue + if isinstance(lp, (int, float)): + out.append(float(lp)) + continue + if not isinstance(lp, dict): + out.append(repr(lp)) + continue + d = {} + for k, v in lp.items(): + if hasattr(v, 'logprob'): + d[str(k)] = {'logprob': float(v.logprob), + 'rank': getattr(v, 'rank', None), + 'decoded': getattr(v, 'decoded_token', None)} + elif isinstance(v, dict): + d[str(k)] = v + else: + d[str(k)] = repr(v) + out.append(d) + return out + + +def _pad_batch(batch: List[List[int]], floor: int) -> Tuple[List[List[int]], int]: + n = len(batch) + if n >= floor or not batch: + return batch, n + return list(batch) + [batch[-1]] * (floor - n), n diff --git a/src/twinkle_agentic/sampler/__init__.py b/src/twinkle_agentic/sampler/__init__.py new file mode 100644 index 00000000..93d4eec2 --- /dev/null +++ b/src/twinkle_agentic/sampler/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .router_sampler import RouterSampler diff --git a/src/twinkle_agentic/sampler/router_sampler.py b/src/twinkle_agentic/sampler/router_sampler.py new file mode 100644 index 00000000..847540f2 --- /dev/null +++ b/src/twinkle_agentic/sampler/router_sampler.py @@ -0,0 +1,196 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import math +from copy import copy +from typing import Any, Dict, List, Literal, Optional, Union + +import httpx + +from twinkle import get_logger +from twinkle.data_format import SampledSequence, SampleResponse, SamplingParams, Trajectory + +logger = get_logger() + + +def _entropy_from_topk(logprobs_per_token: List[List[tuple]]) -> float: + """Mean per-token entropy approximated from top-K logprobs (renormalized).""" + if not logprobs_per_token: + return float('inf') + total = 0.0 + for candidates in logprobs_per_token: + if not candidates: + total += float('inf') + continue + lps = [lp for _, lp in candidates] + max_lp = max(lps) + # numerically stable softmax over top-K + exps = [math.exp(lp - max_lp) for lp in lps] + z = sum(exps) + total += sum(-(e / z) * (lp - max_lp - math.log(z)) for e, lp in zip(exps, lps)) + return total / len(logprobs_per_token) + + +def _mean_logp(logprobs_per_token: List[List[tuple]], tokens: List[int]) -> float: + """Mean log-probability of generated tokens (sequence-level confidence).""" + if not logprobs_per_token or not tokens: + return float('-inf') + total = 0.0 + count = 0 + for t, candidates in enumerate(logprobs_per_token): + if t >= len(tokens) or not candidates: + continue + tok = tokens[t] + lp = next((v for tid, v in candidates if tid == tok), None) + if lp is None: + lp = candidates[0][1] + total += lp + count += 1 + return total / max(count, 1) + + +class RouterSampler: + """Confidence-based routing sampler. + + Generates with a local sampler first; if confidence is low, falls back + to an OpenAI-compatible endpoint (stronger model). + """ + + def __init__( + self, + sampler, + fallback_endpoint: str, + fallback_model: str = 'default', + fallback_api_key: str = '', + method: Literal['entropy', 'logp'] = 'entropy', + threshold: float = 2.0, + top_k_logprobs: int = 10, + fallback_temperature: float = 0.7, + fallback_max_tokens: int = 4096, + timeout: float = 120.0, + ): + """ + Args: + sampler: Inner sampler instance (e.g. vLLMSampler). + fallback_endpoint: OpenAI-compatible API base URL. + fallback_model: Model name for fallback requests. + fallback_api_key: Bearer token for fallback API. + method: Confidence metric — 'entropy' (route when H > threshold) + or 'logp' (route when mean logp < threshold). + threshold: Routing threshold. For entropy: higher = more routing. + For logp: lower (more negative) = more routing. + top_k_logprobs: Number of top logprobs to request from inner sampler. + fallback_temperature: Temperature for fallback generation. + fallback_max_tokens: Max tokens for fallback generation. + timeout: HTTP timeout for fallback requests. + """ + self.sampler = sampler + self._method = method + self._threshold = threshold + self._top_k = top_k_logprobs + self._fb_temperature = fallback_temperature + self._fb_max_tokens = fallback_max_tokens + self._fb_endpoint = f'{fallback_endpoint.rstrip("/")}/v1/chat/completions' + self._fb_model = fallback_model + headers = {'Content-Type': 'application/json'} + if fallback_api_key: + headers['Authorization'] = f'Bearer {fallback_api_key}' + self._client = httpx.Client(timeout=timeout, headers=headers) + + @property + def template(self): + return self.sampler.template + + def set_template(self, *args, **kwargs): + return self.sampler.set_template(*args, **kwargs) + + def _should_route(self, seq: SampledSequence) -> bool: + if not seq.logprobs: + return True + if self._method == 'entropy': + score = _entropy_from_topk(seq.logprobs) + return score > self._threshold + score = _mean_logp(seq.logprobs, seq.tokens) + return score < self._threshold + + def _fallback_generate(self, trajectory: Trajectory) -> Optional[str]: + messages = trajectory.get('messages', []) + if not messages: + return None + api_messages = [] + for m in messages: + if not isinstance(m, dict): + continue + entry = {'role': m.get('role', 'user')} + content = m.get('content', '') + if isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, dict) and block.get('type') == 'text': + parts.append(block.get('text', '')) + content = '\n'.join(parts) if parts else '' + entry['content'] = content or '' + api_messages.append(entry) + try: + resp = self._client.post(self._fb_endpoint, json={ + 'model': self._fb_model, + 'messages': api_messages, + 'temperature': self._fb_temperature, + 'max_tokens': self._fb_max_tokens, + }) + resp.raise_for_status() + choices = resp.json().get('choices', []) + if choices: + return (choices[0].get('message') or {}).get('content', '') + except Exception as e: + logger.warning(f'RouterSampler fallback failed: {e}') + return None + + def sample( + self, + inputs: Union[Dict, List[Dict]], + sampling_params: Optional[Union[SamplingParams, Dict[str, Any]]] = None, + adapter_name: str = '', + adapter_path: Optional[str] = None, + **kwargs, + ) -> List[SampleResponse]: + """Sample with confidence-based routing to fallback model.""" + if sampling_params is None: + sampling_params = SamplingParams() + elif isinstance(sampling_params, dict): + sampling_params = SamplingParams.from_dict(sampling_params) + + # Ensure logprobs are requested for confidence evaluation + routed_params = copy(sampling_params) + if routed_params.logprobs is None or routed_params.logprobs < self._top_k: + routed_params.logprobs = self._top_k + + inputs_list = inputs if isinstance(inputs, list) else [inputs] + is_trajectory = isinstance(inputs_list[0], dict) and 'input_ids' not in inputs_list[0] + + results = self.sampler.sample( + inputs_list, routed_params, adapter_name, adapter_path=adapter_path, **kwargs) + + if not is_trajectory: + return results + + for i, (resp, traj) in enumerate(zip(results, inputs_list)): + new_sequences = [] + for seq in resp.sequences: + if self._should_route(seq): + fallback_text = self._fallback_generate(traj) + if fallback_text is not None: + new_sequences.append(SampledSequence( + stop_reason='stop', + tokens=[], + logprobs=None, + decoded=fallback_text, + )) + continue + new_sequences.append(seq) + results[i] = SampleResponse( + sequences=new_sequences, + prompt_token_ids=resp.prompt_token_ids, + prompt_logprobs=resp.prompt_logprobs, + topk_prompt_logprobs=resp.topk_prompt_logprobs, + ) + + return results diff --git a/tests/preprocessor/test_agent_trace_filter.py b/tests/preprocessor/test_agent_trace_filter.py new file mode 100644 index 00000000..bf58359c --- /dev/null +++ b/tests/preprocessor/test_agent_trace_filter.py @@ -0,0 +1,238 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for AgentTraceFilter. + +AgentTraceFilter is detection-only — it tags rows with ``is_agent=True/False`` +and never drops or mutates messages. Detection delegates to +``ToolCallRegistry.detect_first`` so the test surface is: + + 1. Tag is set on EVERY row (uniform schema). + 2. role='tool' or non-empty ``tool_calls`` field → True. + 3. Text-embedded tool calls (Cline / Hermes / ReAct) on assistant role → True. + 4. Plain assistant content with no tool markers → False. + 5. Look-alike XML that the registry rejects (e.g. plain ``...`` + without inner params) → False. + 6. Malformed message lists never raise. +""" +import pytest + +from twinkle_agentic.preprocessor.agent_trace_filter import ( + AgentTraceFilter, + _is_agent_row, + _msg_text, +) + + +def _row(messages): + return {'messages': messages} + + +# ── _msg_text helper ───────────────────────────────────────────────────────── + +class TestMsgText: + def test_string_content(self): + assert _msg_text({'role': 'user', 'content': 'hello'}) == 'hello' + + def test_list_content_concat(self): + msg = {'content': [ + {'type': 'text', 'text': 'a'}, + {'type': 'image', 'url': '...'}, # non-text part ignored + {'type': 'text', 'text': 'b'}, + ]} + assert _msg_text(msg) == 'a b' + + def test_missing_content(self): + assert _msg_text({'role': 'user'}) == '' + + def test_none_content(self): + assert _msg_text({'role': 'user', 'content': None}) == '' + + def test_non_str_non_list_content(self): + assert _msg_text({'role': 'user', 'content': 123}) == '' + + +# ── _is_agent_row detection ────────────────────────────────────────────────── + +class TestIsAgentRowStructural: + def test_role_tool_triggers(self): + msgs = [ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'a', 'type': 'function', 'function': {'name': 'x', 'arguments': '{}'}} + ]}, + {'role': 'tool', 'content': 'result', 'tool_call_id': 'a'}, + ] + assert _is_agent_row(msgs) is True + + def test_tool_calls_field_triggers(self): + msgs = [ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'f', 'arguments': '{}'}} + ]}, + ] + assert _is_agent_row(msgs) is True + + def test_empty_tool_calls_field_does_not_trigger(self): + msgs = [ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': 'plain reply', 'tool_calls': []}, + ] + assert _is_agent_row(msgs) is False + + def test_non_list_tool_calls_field_does_not_trigger(self): + msgs = [ + {'role': 'assistant', 'content': 'x', 'tool_calls': None}, + ] + assert _is_agent_row(msgs) is False + + +class TestIsAgentRowTextEmbedded: + def test_cline_style_triggers(self): + msgs = [ + {'role': 'user', 'content': 'read the file'}, + {'role': 'assistant', 'content': + '/etc/hosts'}, + ] + assert _is_agent_row(msgs) is True + + def test_hermes_qwen_style_triggers(self): + msgs = [ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': + '\n{"name": "search", "arguments": {"q": "x"}}\n'}, + ] + assert _is_agent_row(msgs) is True + + def test_react_action_style_triggers(self): + # ReAct parser uses bracket syntax: ``Action: name[args]``. + msgs = [ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': + 'Thought: I need to search.\nAction: search[query=x]'}, + ] + assert _is_agent_row(msgs) is True + + def test_plain_assistant_text_does_not_trigger(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'Hello! How can I help?'}, + ] + assert _is_agent_row(msgs) is False + + def test_lookalike_xml_without_inner_params_does_not_trigger(self): + # ``echo hi`` has no ``val`` child — Cline parser + # rejects it via inner-param requirement. Hermes/ReAct also reject. + msgs = [ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': 'echo hi'}, + ] + assert _is_agent_row(msgs) is False + + def test_denied_outer_tag_does_not_trigger(self): + # ````/```` are in the Cline DENY frozenset. + msgs = [ + {'role': 'assistant', 'content': + 'because'}, + ] + assert _is_agent_row(msgs) is False + + def test_user_text_with_tool_markers_does_not_trigger(self): + # Markers must come from the assistant — user-side embedded XML is just data. + msgs = [ + {'role': 'user', 'content': + 'x'}, + {'role': 'assistant', 'content': 'I will do that.'}, + ] + assert _is_agent_row(msgs) is False + + def test_list_content_assistant_with_tool_call(self): + msgs = [ + {'role': 'assistant', 'content': [ + {'type': 'text', 'text': ''}, + {'type': 'text', 'text': '{"name":"f","arguments":{}}'}, + ]}, + ] + assert _is_agent_row(msgs) is True + + +class TestIsAgentRowEdgeCases: + def test_non_list_messages(self): + assert _is_agent_row(None) is False + assert _is_agent_row('') is False + assert _is_agent_row({}) is False + + def test_empty_messages(self): + assert _is_agent_row([]) is False + + def test_non_dict_message_skipped(self): + msgs = [ + 'not a dict', + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + ] + assert _is_agent_row(msgs) is False + + def test_short_circuits_on_first_match(self): + # Even if later messages are clean, an earlier tool-call hit wins. + msgs = [ + {'role': 'tool', 'content': 'r', 'tool_call_id': 'x'}, + {'role': 'assistant', 'content': 'plain'}, + ] + assert _is_agent_row(msgs) is True + + +# ── AgentTraceFilter pipeline behavior ─────────────────────────────────────── + +class TestAgentTraceFilterPipeline: + def test_tags_every_row(self): + rows = [ + _row([{'role': 'assistant', 'content': 'plain'}]), + _row([{'role': 'tool', 'content': 'r', 'tool_call_id': 'x'}]), + _row([{'role': 'assistant', 'content': + 'x'}]), + ] + out = AgentTraceFilter()(rows) + assert len(out) == 3 + # Every row must have ``is_agent`` so map_row_to_col sees a uniform schema. + assert all('is_agent' in r for r in out) + assert [r['is_agent'] for r in out] == [False, True, True] + + def test_never_drops_rows(self): + rows = [_row([{'role': 'user', 'content': 'x'}])] * 5 + out = AgentTraceFilter()(rows) + assert len(out) == 5 + + def test_preserves_other_fields(self): + rows = [ + {'messages': [{'role': 'tool', 'content': 'r', 'tool_call_id': 'x'}], + 'id': 'row-1', 'extra': {'k': 'v'}}, + ] + out = AgentTraceFilter()(rows) + assert out[0]['id'] == 'row-1' + assert out[0]['extra'] == {'k': 'v'} + assert out[0]['is_agent'] is True + + def test_does_not_mutate_input(self): + original = _row([{'role': 'assistant', 'content': 'plain'}]) + rows = [original] + AgentTraceFilter()(rows) + # Filter must return new dicts, not mutate originals. + assert 'is_agent' not in original + + def test_missing_messages_key(self): + rows = [{'id': 'lonely'}] # no messages + out = AgentTraceFilter()(rows) + assert len(out) == 1 + assert out[0]['is_agent'] is False + + def test_messages_is_none(self): + rows = [_row(None)] + out = AgentTraceFilter()(rows) + assert out[0]['is_agent'] is False + + def test_empty_input(self): + assert AgentTraceFilter()([]) == [] + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/preprocessor/test_dead_loop_filter.py b/tests/preprocessor/test_dead_loop_filter.py new file mode 100644 index 00000000..06b621dd --- /dev/null +++ b/tests/preprocessor/test_dead_loop_filter.py @@ -0,0 +1,266 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for DeadLoopFilter. + +Three orthogonal "stuck" signals: + 1. Hesitation density — markers per 1000 chars > threshold + 2. Correction cascade — ≥N markers within a sliding window + 3. High n-gram repetition — (1 - unique/total) > threshold + +A row is dropped if ANY signal trips on any assistant turn. +Rows with ``is_agent=True`` are always kept (agent rollouts have legitimate +self-correction phrasing). + +When the message contains ``...``, the think part and the +response part are scored independently with separate (looser) think-thresholds. +""" +import pytest + +from twinkle_agentic.preprocessor.dead_loop_filter import ( + DeadLoopFilter, + _has_correction_cascade_with_threshold, + _hesitation_density, + _high_repetition_with_threshold, + _is_stuck, +) + + +def _row(messages, **extra): + return {'messages': messages, **extra} + + +def _fil(rows, **kw): + return DeadLoopFilter(**kw)(rows) + + +# ── _hesitation_density ───────────────────────────────────────────────────── + +class TestHesitationDensity: + def test_no_markers(self): + text = 'This is a perfectly normal explanation of gradient descent.' + assert _hesitation_density(text) == 0.0 + + def test_english_marker_counted(self): + # "wait, wait" matches `wait[,\s]+(wait|...)` — one marker. + text = 'wait, wait this is wrong' + d = _hesitation_density(text) + assert d > 0 + + def test_density_per_1000(self): + # ~5 markers in 100 chars → density ~50/1000 + text = ('hmm hmm hmm hmm hmm ' * 1).strip() # 5 hmm tokens + # Each "hmm" matches `hmm+[,\s]*\.{0,3}` → 5 matches + density = _hesitation_density(text) + assert density > 100 # very dense + + def test_chinese_marker(self): + text = '等等,让我重新想想这个问题。' + assert _hesitation_density(text) > 0 + + def test_empty_text(self): + assert _hesitation_density('') == 0.0 + + def test_japanese_marker(self): + text = 'ちょっと待って、もう一度考え直してみます。' + assert _hesitation_density(text) > 0 + + def test_korean_marker(self): + text = '잠깐, 다시 생각해봐야겠어요.' + assert _hesitation_density(text) > 0 + + +# ── _has_correction_cascade_with_threshold ────────────────────────────────── + +class TestCorrectionCascade: + def test_below_threshold(self): + # Only 2 cascade markers; threshold=5 → no cascade. + text = 'wait, actually let me think.' + assert _has_correction_cascade_with_threshold(text, threshold=5) is False + + def test_at_threshold_in_window(self): + # 5 cascade tokens packed into <800 chars → cascade detected. + text = 'wait wait wait wait wait' + assert _has_correction_cascade_with_threshold(text, threshold=5, + window=800) is True + + def test_threshold_outside_window(self): + # 5 markers but spread across >800 chars → no cascade. + spacer = ' ' * 200 # each spacer is 200 chars + text = f'wait{spacer}wait{spacer}wait{spacer}wait{spacer}wait' # 5*200 = 1000 chars + assert _has_correction_cascade_with_threshold(text, threshold=5, + window=800) is False + + def test_chinese_cascade(self): + text = '等等,不对,重新想想,错了,让我再算一遍。' + assert _has_correction_cascade_with_threshold(text, threshold=4) is True + + def test_zero_threshold_unreachable(self): + # threshold=0 means need 0 matches in any window — len(matches) < 0 is + # never true so this returns True even on empty. Test the sane case. + assert _has_correction_cascade_with_threshold('clean text', + threshold=1) is False + + +# ── _high_repetition_with_threshold ───────────────────────────────────────── + +class TestRepetition: + def test_below_min_words(self): + # Fewer than ngram_min_words words → False (insufficient sample). + text = 'this is a short text' + assert _high_repetition_with_threshold( + text, threshold=0.0, ngram_min_words=30) is False + + def test_no_repetition(self): + # 30 distinct words → unique_ratio ~ 1.0 → repetition ~ 0. + text = ' '.join(f'word{i}' for i in range(40)) + assert _high_repetition_with_threshold( + text, threshold=0.45, ngram_min_words=30) is False + + def test_high_repetition_triggers(self): + # Same 8-gram repeated → unique_ratio low → repetition high. + phrase = 'the quick brown fox jumps over the lazy' + text = ' '.join([phrase] * 10) + assert _high_repetition_with_threshold( + text, threshold=0.45, ngram_size=8, ngram_min_words=30) is True + + def test_threshold_boundary(self): + # Same text under different thresholds. + phrase = 'a b c d e f g h ' + text = phrase * 6 # 48 words, only 8 unique + # very low threshold → trips + assert _high_repetition_with_threshold(text, threshold=0.1) is True + # very high threshold → does not trip even with high duplication + assert _high_repetition_with_threshold(text, threshold=0.99) is False + + +# ── _is_stuck ─────────────────────────────────────────────────────────────── + +class TestIsStuck: + def test_clean_text_not_stuck(self): + # Use diverse prose so n-gram repetition stays below threshold. + text = ( + 'Gradient descent is an iterative optimization algorithm used ' + 'for finding the local minimum of a differentiable function. ' + 'It updates parameters in the direction opposite to the ' + 'gradient of the objective at the current point. Variants ' + 'such as momentum and Adam improve convergence speed.' + ) + assert _is_stuck(text) is False + + def test_high_density_stuck(self): + # Pack many hesitation tokens to exceed 7/1000 density. + text = 'wait, wait this is wrong. hmm... actually no. uh, wait wait wait.' + assert _is_stuck(text) is True + + def test_cascade_stuck(self): + # 5 cascade tokens in tight window + text = 'wait actually wait actually wait!' + assert _is_stuck(text, hesitation_density_threshold=999.0, + cascade_threshold=5, + repetition_threshold=0.99) is True + + def test_repetition_stuck(self): + phrase = 'the quick brown fox jumps over the lazy' + text = ' '.join([phrase] * 10) + assert _is_stuck(text, hesitation_density_threshold=999.0, + cascade_threshold=999, + repetition_threshold=0.45) is True + + def test_think_block_separate_thresholds(self): + # Hesitation that would trip in response section is allowed inside + # ... because think-thresholds are looser (15.0 vs 7.0). + # Build a think with moderate density (~10/1000) — below 15 think + # threshold, but would exceed 7 in normal text. + think_part = 'wait, actually let me reconsider this. ' * 3 + 'a' * 1500 + text = f'{think_part}The answer is 42.' + assert _is_stuck(text) is False # think-density well below 15 + + def test_response_part_after_think_stuck(self): + # Clean think but stuck response → still stuck. + text = ('Calculating step by step.' + 'wait, wait this is wrong. hmm... actually no. uh, wait wait wait.') + assert _is_stuck(text) is True + + +# ── DeadLoopFilter pipeline ───────────────────────────────────────────────── + +class TestDeadLoopFilterPipeline: + def test_drops_stuck_row(self): + rows = [_row([ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': + 'wait, wait this is wrong. hmm... actually no. ' + 'uh, wait wait wait.'}, + ])] + assert _fil(rows) == [] + + def test_keeps_clean_row(self): + rows = [_row([ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': + 'A clear, well-formed answer goes here.'}, + ])] + assert len(_fil(rows)) == 1 + + def test_agent_row_always_kept(self): + # is_agent=True bypasses all stuck checks. + rows = [_row([ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': + 'wait wait wait wait wait wait wait!!!'}, + ], is_agent=True)] + assert len(_fil(rows)) == 1 + + def test_no_assistant_kept(self): + rows = [_row([{'role': 'user', 'content': 'hi'}])] + assert len(_fil(rows)) == 1 + + def test_any_assistant_stuck_drops_row(self): + rows = [_row([ + {'role': 'user', 'content': 'q1'}, + {'role': 'assistant', 'content': 'clean reply'}, + {'role': 'user', 'content': 'q2'}, + {'role': 'assistant', 'content': + 'wait, wait this is wrong. hmm... actually no. ' + 'uh, wait wait wait.'}, + ])] + assert _fil(rows) == [] + + def test_empty_input(self): + assert _fil([]) == [] + + def test_custom_thresholds(self): + # 1 hesitation marker in a long message — density well below the + # default 7/1000. Tightening the threshold should drop it. + long_msg = ( + 'Hmm, let me think about this carefully. Gradient descent ' + 'requires a learning rate, the loss function, and an ' + 'initial parameter point. The algorithm iteratively ' + 'updates the parameters towards the negative gradient. ' + 'Momentum-based variants accumulate past gradients to ' + 'smooth the trajectory and accelerate convergence on ' + 'ill-conditioned problems. Adam additionally adapts the ' + 'per-parameter learning rate using running second-moment ' + 'estimates, which often makes it the default choice for ' + 'practitioners across many deep-learning tasks.' + ) + rows = [_row([ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': long_msg}, + ])] + # Default 7/1000 — single marker in long text → kept + assert len(_fil(rows)) == 1 + # Aggressive threshold drops it + assert _fil(rows, hesitation_density_threshold=0.5) == [] + + def test_chinese_stuck(self): + rows = [_row([ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': + '等等,不对,让我重新想想。错了,让我再来一次。' + '我又搞错了。等等,等等。'}, + ])] + assert _fil(rows) == [] + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/preprocessor/test_hard_filter.py b/tests/preprocessor/test_hard_filter.py new file mode 100644 index 00000000..3385b49c --- /dev/null +++ b/tests/preprocessor/test_hard_filter.py @@ -0,0 +1,285 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for HardFilter. + +HardFilter drops: + Rule 1 — Single-turn trivial query (greeting / bare wh-question). + Rule 2 — Two-turn shallow assistant reply (< min chars, no thinking chain). + +CJK and ASCII branches use different length thresholds because of the +information density gap. +""" +import pytest + +from twinkle_agentic.preprocessor.hard_filter import ( + HardFilter, + _cjk_ratio, + _has_thinking, + _is_simple_query, +) + + +def _row(messages): + return {'messages': messages} + + +# ── _cjk_ratio ─────────────────────────────────────────────────────────────── + +class TestCjkRatio: + def test_pure_ascii(self): + assert _cjk_ratio('hello world') == 0.0 + + def test_pure_chinese(self): + assert _cjk_ratio('你好世界') == 1.0 + + def test_mixed(self): + # 2 CJK chars / 6 total + assert abs(_cjk_ratio('hi你好zz') - 2 / 6) < 1e-9 + + def test_japanese_hiragana(self): + # Hiragana is in the CJK range covered by the regex. + assert _cjk_ratio('こんにちは') == 1.0 + + def test_korean_hangul(self): + assert _cjk_ratio('안녕하세요') == 1.0 + + def test_empty(self): + # max(len, 1) → 0/1 = 0 + assert _cjk_ratio('') == 0.0 + + +# ── _is_simple_query: ASCII / English ──────────────────────────────────────── + +class TestSimpleQueryEnglish: + def test_short_text_is_simple(self): + assert _is_simple_query('hi') is True + assert _is_simple_query('a' * 9) is True # default min=10 + + def test_at_threshold_not_simple_unless_pattern(self): + # 10 non-pattern chars escapes both length and pattern checks + assert _is_simple_query('quantum xx') is False + + def test_greeting_hello(self): + assert _is_simple_query('Hello!') is True + assert _is_simple_query('Heeellloooo') is True + + def test_greeting_good_morning(self): + assert _is_simple_query('Good morning') is True + + def test_greeting_how_are_you(self): + assert _is_simple_query('How are you') is True + + def test_bare_wh_question(self): + assert _is_simple_query('what is python') is True + + def test_imperative_short(self): + assert _is_simple_query('tell me about it') is True + assert _is_simple_query('explain') is True + + def test_substantive_question_not_simple(self): + # Long, technical question should pass (not simple). + text = ('Please explain the difference between gradient descent and ' + 'momentum-based optimization in deep learning training.') + assert _is_simple_query(text) is False + + +class TestSimpleQueryChinese: + def test_short_cjk_is_simple(self): + assert _is_simple_query('你好') is True + assert _is_simple_query('你好啊') is True # < 6 + + def test_at_cjk_threshold(self): + # 6 CJK chars; greeting (`你好+` matches `你好好好好好`) → simple + assert _is_simple_query('你好好好好好') is True + # 6 substantive CJK chars; no greeting/simple pattern → NOT simple + assert _is_simple_query('量子计算原理') is False + + def test_greeting_zh(self): + assert _is_simple_query('你好!') is True + assert _is_simple_query('早上好') is True + assert _is_simple_query('哈喽哈喽') is True + + def test_what_is_x(self): + assert _is_simple_query('什么是机器学习?') is True + assert _is_simple_query('梯度下降是什么?') is True + + def test_substantive_zh_not_simple(self): + text = '请详细解释一下变换器架构中的多头自注意力机制是如何并行计算的,以及为什么需要位置编码。' + assert _is_simple_query(text) is False + + +class TestSimpleQueryJapanese: + def test_japanese_greeting(self): + assert _is_simple_query('こんにちは') is True + + def test_japanese_what_is(self): + assert _is_simple_query('機械学習とは何ですか') is True + + +class TestSimpleQueryKorean: + def test_korean_greeting(self): + assert _is_simple_query('안녕하세요') is True + + def test_korean_what_is(self): + # KO_SIMPLE_RE expects "X이/가 뭐" pattern; trailing 인가요/까요 are + # only single optional chars, so use the bare 뭐 form here. + assert _is_simple_query('머신러닝이 뭐') is True + + +class TestSimpleQueryEdge: + def test_empty(self): + assert _is_simple_query('') is True + + def test_whitespace_only(self): + assert _is_simple_query(' \n ') is True + + def test_custom_thresholds(self): + # Raise the bar so a 12-char query becomes simple. + text = 'short query!' + assert _is_simple_query(text, min_user_chars=20) is True + assert _is_simple_query(text, min_user_chars=5) is False + + +# ── _has_thinking ──────────────────────────────────────────────────────────── + +class TestHasThinking: + def test_thinking_field_long_enough(self): + msg = {'thinking': 'a' * 250} + assert _has_thinking(msg) is True + + def test_thinking_field_too_short(self): + msg = {'thinking': 'short'} + assert _has_thinking(msg) is False + + def test_reasoning_content_alias(self): + msg = {'reasoning_content': 'a' * 250} + assert _has_thinking(msg) is True + + def test_no_thinking(self): + assert _has_thinking({'content': 'reply'}) is False + + def test_custom_min_chars(self): + msg = {'thinking': 'short'} + assert _has_thinking(msg, min_chars=3) is True + + def test_non_string_thinking_truthy(self): + # Falls through to bool(thinking) + assert _has_thinking({'thinking': {'a': 1}}) is True + assert _has_thinking({'thinking': []}) is False + + +# ── HardFilter pipeline ────────────────────────────────────────────────────── + +def _fil(rows, **kw): + return HardFilter(**kw)(rows) + + +class TestRule1SimpleQuery: + def test_drops_greeting_only(self): + rows = [_row([ + {'role': 'user', 'content': 'hello'}, + {'role': 'assistant', 'content': 'hi there!'}, + ])] + assert _fil(rows) == [] + + def test_drops_bare_wh_question(self): + rows = [_row([ + {'role': 'user', 'content': 'what is AI'}, + {'role': 'assistant', 'content': 'a short answer'}, + ])] + assert _fil(rows) == [] + + def test_keeps_when_substantive(self): + rows = [_row([ + {'role': 'user', 'content': + 'Could you explain gradient descent step by step in detail?'}, + {'role': 'assistant', 'content': + 'Gradient descent is an iterative optimization algorithm... ' * 5}, + ])] + assert len(_fil(rows)) == 1 + + def test_keeps_simple_query_with_thinking(self): + # Rule 1 rescue: thinking chain ≥200 chars saves the row. + rows = [_row([ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello', + 'reasoning_content': 'Now I need to greet politely... ' * 20}, + ])] + assert len(_fil(rows)) == 1 + + def test_simple_query_no_assistant_dropped(self): + # No assistant turn → no thinking → dropped. + rows = [_row([{'role': 'user', 'content': 'hi'}])] + assert _fil(rows) == [] + + +class TestRule2ShallowReply: + def test_drops_short_reply(self): + rows = [_row([ + {'role': 'user', 'content': + 'Explain the difference between A and B in detail please.'}, + {'role': 'assistant', 'content': 'A is good.'}, # < 80 chars + ])] + assert _fil(rows) == [] + + def test_keeps_long_reply(self): + rows = [_row([ + {'role': 'user', 'content': + 'Explain the difference between A and B in detail please.'}, + {'role': 'assistant', 'content': + 'A and B differ in several ways. ' * 5}, + ])] + assert len(_fil(rows)) == 1 + + def test_short_reply_with_thinking_kept(self): + # Rule 2 rescue: thinking saves a short final reply. + rows = [_row([ + {'role': 'user', 'content': + 'Explain the difference between A and B in detail please.'}, + {'role': 'assistant', 'content': 'A is good.', + 'thinking': 'Step 1: compare features... ' * 20}, + ])] + assert len(_fil(rows)) == 1 + + +class TestPipelineEdges: + def test_no_user_dropped_by_default(self): + rows = [_row([{'role': 'assistant', 'content': 'orphan reply'}])] + assert _fil(rows) == [] + + def test_no_user_kept_when_allowed(self): + rows = [_row([{'role': 'assistant', 'content': 'orphan'}])] + assert len(_fil(rows, allow_incomplete_role=True)) == 1 + + def test_multi_user_skips_rules(self): + # With ≥2 user turns, neither Rule 1 nor Rule 2 applies. + rows = [_row([ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'short'}, + {'role': 'user', 'content': 'follow-up?'}, + {'role': 'assistant', 'content': 'tiny'}, + ])] + assert len(_fil(rows)) == 1 + + def test_non_list_messages(self): + rows = [{'messages': 'not a list'}] + assert _fil(rows) == [] # invalid → continue (skip) + + def test_missing_messages(self): + rows = [{'id': 'x'}] + # No user_msgs and allow_incomplete_role=False → skipped. + assert _fil(rows) == [] + + def test_empty_input(self): + assert _fil([]) == [] + + def test_custom_thresholds_applied(self): + # Lower min_assistant_chars_2turn → keep what would normally be dropped. + rows = [_row([ + {'role': 'user', 'content': 'tell me a real story please now'}, + {'role': 'assistant', 'content': 'A is good.'}, + ])] + assert _fil(rows, min_assistant_chars_2turn=5) and len(rows) == 1 + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/preprocessor/test_intent_classifier.py b/tests/preprocessor/test_intent_classifier.py new file mode 100644 index 00000000..3159d29c --- /dev/null +++ b/tests/preprocessor/test_intent_classifier.py @@ -0,0 +1,457 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for the heuristic IntentClassifier pipeline. + +Focus areas: +- Per-detector recall on representative samples (ZH + EN, R1-distill-flavoured). +- Per-detector FP guards (chitchat, role mismatch, first-turn dissatisfaction). +- Multi-detector ordering: ToolCallDetector short-circuit, ``setdefault`` semantics. +- Edge cases: empty / None / non-dict / list-content messages, empty trajectories. +- Public API contract: ``row['intent']``, ``user_data['key_rounds']``, ``user_data['intents']``. +- Detector pluggability: custom subclass, overriding ``DEFAULT_DETECTORS``. +""" +import pytest + +from twinkle_agentic.preprocessor.intent_classifier import ( + INTENT_CODE, + INTENT_MATH, + INTENT_OTHER, + INTENT_TOOL_CALL, + INTENT_USER_DISSATISFACTION, + CodeDetector, + IntentClassifier, + IntentDetector, + MathDetector, + ToolCallDetector, + UserDissatisfactionDetector, + _msg_text, + _pair_assistant, +) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _u(text): + return {'role': 'user', 'content': text} + + +def _a(text, **extra): + msg = {'role': 'assistant', 'content': text} + msg.update(extra) + return msg + + +def _row(*messages): + return {'messages': list(messages)} + + +def _classify_one(*messages, detectors=None): + ic = IntentClassifier(detectors=detectors) + out = ic.classify_intent([_row(*messages)]) + return out[0] + + +# ── Helper functions ────────────────────────────────────────────────────────── + +class TestHelpers: + def test_msg_text_string(self): + assert _msg_text({'content': 'hi'}) == 'hi' + + def test_msg_text_list_with_text_parts(self): + msg = {'content': [ + {'type': 'text', 'text': 'foo'}, + {'type': 'image', 'url': 'x'}, + {'type': 'text', 'text': 'bar'}, + ]} + assert _msg_text(msg) == 'foo bar' + + def test_msg_text_missing_content(self): + assert _msg_text({}) == '' + + def test_msg_text_none_content(self): + assert _msg_text({'content': None}) == '' + + def test_msg_text_list_no_text_parts(self): + assert _msg_text({'content': [{'type': 'image'}]}) == '' + + def test_pair_assistant_user_finds_next_assistant(self): + msgs = [_u('q'), _a('a1'), _u('follow'), _a('a2')] + assert _pair_assistant(msgs, 0, 'user') == 1 + assert _pair_assistant(msgs, 2, 'user') == 3 + + def test_pair_assistant_assistant_returns_self(self): + msgs = [_u('q'), _a('a1')] + assert _pair_assistant(msgs, 1, 'assistant') == 1 + + def test_pair_assistant_user_no_following_assistant(self): + # User turn at the tail with no assistant after — un-pairable. + msgs = [_a('a1'), _u('dangling')] + assert _pair_assistant(msgs, 1, 'user') is None + + def test_pair_assistant_other_role(self): + assert _pair_assistant([{'role': 'system', 'content': 's'}], 0, 'system') is None + + +# ── ToolCallDetector ───────────────────────────────────────────────────────── + +class TestToolCallDetector: + def test_definitive_flag(self): + assert ToolCallDetector.definitive is True + + def test_detects_assistant_with_tool_calls(self): + msgs = [_u('q'), _a('', tool_calls=[{'name': 'f'}])] + assert ToolCallDetector()(msgs) == [1] + + def test_ignores_assistant_without_tool_calls(self): + assert ToolCallDetector()([_u('q'), _a('plain')]) == [] + + def test_ignores_user_with_tool_calls_field(self): + # A user dict carrying a tool_calls key must not be picked up. + msgs = [{'role': 'user', 'content': 'q', 'tool_calls': [{'name': 'x'}]}] + assert ToolCallDetector()(msgs) == [] + + def test_short_circuits_pipeline(self): + # When ToolCall fires it must suppress later detectors on the same round. + msgs = [ + _u('解一元二次方程 x^2 - 5x + 6 = 0 的因式分解'), + _a('answer', tool_calls=[{'name': 'calc'}]), + ] + out = _classify_one(*msgs) + assert out['intent'] == INTENT_TOOL_CALL + # math detector must not have written into intents. + assert out['user_data']['intents'] == {1: INTENT_TOOL_CALL} + + +# ── CodeDetector ────────────────────────────────────────────────────────────── + +class TestCodeDetector: + def test_fenced_code_block(self): + text = '```python\ndef f():\n return 1\n```' + assert CodeDetector()._match(text) + + def test_short_fenced_block_below_min_length(self): + # Block content must be ≥10 chars to qualify. + assert not CodeDetector()._match('```\nhi\n```') + + def test_keyword_threshold_three(self): + # Three keyword hits must trigger. + assert CodeDetector()._match('use async function and await the response') + + def test_two_keywords_below_threshold(self): + assert not CodeDetector()._match('a class and a function') + + def test_arrow_signature_alone_insufficient(self): + # Single arrow without other signals doesn't reach threshold. + assert not CodeDetector()._match('x => x + 1') + + def test_call_signature_with_brace(self): + # `name(args) {` is a strong code indicator. + assert CodeDetector()._match( + 'function fetchData(url) { return fetch(url); } and async await yield' + ) + + def test_chitchat_with_word_class_no_fp(self): + assert not CodeDetector()._match('I took a yoga class today') + + +# ── MathDetector ────────────────────────────────────────────────────────────── + +class TestMathDetector: + @pytest.mark.parametrize('text', [ + '设 $f(x)=x^2$ 求导得 2x', + '矩阵 A 的行列式 det(A) 不等于 0', + '三角形 ABC 周长是 12,面积约为 6', + '数列 {a_n} 是等差数列,公差为 2,首项为 1', + '4, 3, 4, 3, (),奇数位是 4', + 'Σ_{i=1}^n A_{ik} B_{kj}', + 'gradient and integral are both fundamental', + '求一元二次方程 x^2 - 5x + 6 = 0 的解', + '一个圆形的直径是 10cm,所以周长是 10π', + ]) + def test_math_recall(self, text): + assert MathDetector()._match(text), f'should detect: {text!r}' + + @pytest.mark.parametrize('text', [ + '今天天气真好', + '我最近在追一部电视剧', + '帮我写一首诗', + '请帮我翻译这句英文', + # Single math keyword in non-math context — must not trip ≥2 threshold. + '积分兑换可以兑换礼品', + '矩阵这个电影很好看', + ]) + def test_math_fp_guard(self, text): + assert not MathDetector()._match(text), f'must NOT detect: {text!r}' + + def test_arithmetic_equation_single_hit(self): + # Only the arithmetic equation matches, threshold ≥2 not met. + assert not MathDetector()._match('计算 30 ÷ 6 = 5') + + def test_threshold_is_configurable(self): + # Subclass with looser threshold catches single-hit case. + class LooseMath(MathDetector): + threshold = 1 + assert LooseMath()._match('计算 30 ÷ 6 = 5') + + def test_subscript_pattern(self): + assert MathDetector()._match('矩阵元素 a_{ij} 与 b_{kl} 满足条件') + + +# ── UserDissatisfactionDetector ─────────────────────────────────────────────── + +class TestUserDissatisfactionDetector: + @pytest.mark.parametrize('text', [ + '不对,再来一次', + '完全错了', + '答非所问', + '你这是在胡扯', + '太离谱了', + '一塌糊涂', + '没逻辑啊', + '你根本没听懂我的意思', + '我说的不是这个', + '别瞎编', + '什么玩意', + '不靠谱', + '让我失望', + '不严谨', + '没get到', + ]) + def test_zh_recall(self, text): + assert UserDissatisfactionDetector()._match(text) + + @pytest.mark.parametrize('text', [ + 'this is wrong', + 'totally incorrect', + 'try again please', + "doesn't make sense", + 'that is garbage', + 'you misunderstood me', + 'low quality response', + 'completely off topic', + 'are you serious', + 'waste of time', + 'this is bullshit', + 'redo it', + 'sub-par answer', + 'do better', + 'WTF is this', + 'nowhere near correct', + ]) + def test_en_recall(self, text): + assert UserDissatisfactionDetector()._match(text) + + @pytest.mark.parametrize('text', [ + '今天心情很好', + '我喜欢这个回答', + '请帮我修改一下', + 'this is exactly what I wanted', + 'great answer thanks', + '能再详细一点吗', + ]) + def test_fp_guard(self, text): + det = UserDissatisfactionDetector() + assert not det._match(text), f'FP on: {text!r}' + + def test_first_turn_user_complaint_ignored(self): + # No prior assistant — the negative phrasing is part of the initial query, not a reaction. + msgs = [_u('你这答案完全错了,太垃圾'), _a('sorry')] + assert UserDissatisfactionDetector()(msgs) == [] + + def test_system_first_then_user_complaint_ignored(self): + msgs = [ + {'role': 'system', 'content': 'You are helpful.'}, + _u('上次回答简直一塌糊涂'), + _a('sorry'), + ] + # System turn must not satisfy "prior assistant". + assert UserDissatisfactionDetector()(msgs) == [] + + def test_multiturn_reaction_detected(self): + msgs = [_u('解释勾股定理'), _a('a²+b²=c²'), _u('不对,再来一次'), _a('好的')] + # The dissat user is at idx 2 → key round is the next assistant idx 3. + assert UserDissatisfactionDetector()(msgs) == [3] + + def test_dissat_with_no_following_assistant_dropped(self): + # User dissatisfaction at the tail with no assistant pair → unpaired, no key round. + msgs = [_u('q'), _a('answer'), _u('完全错了')] + assert UserDissatisfactionDetector()(msgs) == [] + + def test_role_filter_blocks_assistant_self_correction(self): + # "等等我算错了,重新推导" appearing on assistant must not be tagged dissatisfaction. + msgs = [_u('推导一下'), _a('等等,我之前算错了,让我重新推导')] + assert UserDissatisfactionDetector()(msgs) == [] + + +# ── End-to-end IntentClassifier ─────────────────────────────────────────────── + +class TestIntentClassifierE2E: + def test_chitchat_other(self): + out = _classify_one(_u('今天天气真好'), _a('是的,挺适合出门的')) + assert out['intent'] == INTENT_OTHER + assert 'user_data' not in out or 'key_rounds' not in (out.get('user_data') or {}) + + def test_math_round(self): + out = _classify_one( + _u('求一元二次方程 x^2 - 5x + 6 = 0 的解'), + _a('由因式分解得 (x-2)(x-3)=0'), + ) + assert out['intent'] == INTENT_MATH + assert out['user_data']['key_rounds'] == [1] + assert out['user_data']['intents'] == {1: INTENT_MATH} + + def test_code_round(self): + out = _classify_one( + _u('use async function and await the response in JavaScript'), + _a('try const fetchData = async () => { return await fetch(url); }'), + ) + assert out['intent'] == INTENT_CODE + + def test_dissat_round(self): + out = _classify_one(_u('q'), _a('answer'), _u('totally garbage answer, redo'), _a('sorry')) + assert out['intent'] == INTENT_USER_DISSATISFACTION + assert out['user_data']['key_rounds'] == [3] + + def test_assistant_self_correction_not_dissat(self): + # Root cause for original FP: role-agnostic regex on assistant text. Must stay fixed. + out = _classify_one(_u('推导一下'), _a('等等,我之前算错了,让我重新推导...')) + assert out['intent'] == INTENT_OTHER + + def test_first_turn_user_negative_words_not_dissat(self): + out = _classify_one(_u('你这答案完全错了,太垃圾'), _a('抱歉')) + assert out['intent'] == INTENT_OTHER + + def test_setdefault_earlier_detector_wins(self): + # When a round is first claimed by MathDetector, a later UserDissatisfactionDetector + # touching the same round must not overwrite it. + out = _classify_one( + _u('解一元二次方程 x^2 - 5x + 6 = 0'), + _a('factoring: (x-2)(x-3)'), + _u('不对,再来一次'), + _a('好的'), + ) + intents = out['user_data']['intents'] + assert intents[1] == INTENT_MATH + assert intents[3] == INTENT_USER_DISSATISFACTION + + def test_tool_call_definitive_short_circuits(self): + out = _classify_one( + _u('解一元二次方程 x^2 - 5x + 6 = 0'), + _a('', tool_calls=[{'name': 'calc'}]), + ) + assert out['intent'] == INTENT_TOOL_CALL + # MathDetector must not have run after the definitive ToolCallDetector. + assert set(out['user_data']['intents'].values()) == {INTENT_TOOL_CALL} + + def test_multimodal_list_content(self): + # List-content messages must work transparently. + msgs = [ + _u([{'type': 'text', 'text': '求一元二次方程'}, {'type': 'image', 'url': 'x'}]), + _a([{'type': 'text', 'text': '因式分解后得到结果'}]), + ] + out = _classify_one(*msgs) + assert out['intent'] == INTENT_MATH + + +# ── Edge / robustness ───────────────────────────────────────────────────────── + +class TestEdgeCases: + def test_empty_rows(self): + assert IntentClassifier().classify_intent([]) == [] + + def test_missing_messages_field(self): + out = IntentClassifier().classify_intent([{'foo': 'bar'}]) + assert out[0]['intent'] == INTENT_OTHER + + def test_messages_is_none(self): + out = IntentClassifier().classify_intent([{'messages': None}]) + assert out[0]['intent'] == INTENT_OTHER + + def test_messages_empty_list(self): + out = IntentClassifier().classify_intent([{'messages': []}]) + assert out[0]['intent'] == INTENT_OTHER + + def test_messages_with_non_dict_entries(self): + # Non-dict entries must be silently skipped. + out = IntentClassifier().classify_intent([{'messages': [ + 'not a dict', + None, + _u('求一元二次方程'), + _a('因式分解'), + ]}]) + assert out[0]['intent'] == INTENT_MATH + + def test_user_data_preexists_preserved(self): + # IntentClassifier merges into existing user_data, must not clobber. + rows = [{ + 'messages': [_u('解一元二次方程 x^2'), _a('因式分解 (x-2)(x-3)')], + 'user_data': {'source': 'gsm8k', 'difficulty': 'easy'}, + }] + out = IntentClassifier().classify_intent(rows) + ud = out[0]['user_data'] + assert ud['source'] == 'gsm8k' + assert ud['difficulty'] == 'easy' + assert ud['key_rounds'] == [1] + assert ud['intents'] == {1: INTENT_MATH} + + def test_input_row_not_mutated(self): + # classify_intent must shallow-copy rows; original dict must remain untouched. + original = {'messages': [_u('你好'), _a('hi')]} + IntentClassifier().classify_intent([original]) + assert 'intent' not in original + assert 'user_data' not in original + + def test_other_intent_does_not_emit_user_data(self): + out = _classify_one(_u('你好'), _a('hi')) + # No detectors fired → no key_rounds / intents written. + assert 'user_data' not in out or 'key_rounds' not in (out.get('user_data') or {}) + + +# ── Pluggability ────────────────────────────────────────────────────────────── + +class TestPluggability: + def test_custom_detector_via_constructor(self): + class GreetingDetector(IntentDetector): + intent = 'greeting' + + def __call__(self, messages): + return [ + i for i, m in enumerate(messages) + if isinstance(m, dict) and m.get('role') == 'assistant' + and isinstance(m.get('content'), str) and 'hello' in m['content'].lower() + ] + + ic = IntentClassifier(detectors=[GreetingDetector()]) + out = ic.classify_intent([_row(_u('hi'), _a('Hello there'))]) + assert out[0]['intent'] == 'greeting' + + def test_empty_detector_list_yields_other(self): + ic = IntentClassifier(detectors=[]) + out = ic.classify_intent([_row(_u('q'), _a('因式分解 一元二次方程'))]) + assert out[0]['intent'] == INTENT_OTHER + + def test_intent_field_override(self): + ic = IntentClassifier(intent_field='label') + out = ic.classify_intent([_row(_u('q'), _a('a'))]) + assert 'label' in out[0] + assert 'intent' not in out[0] + + def test_definitive_short_circuits_custom_pipeline(self): + # User-defined definitive detector must halt the pipeline after firing. + seen = [] + + class StopAll(IntentDetector): + intent = 'stop' + definitive = True + def __call__(self, messages): + seen.append('stop') + return [len(messages) - 1] + + class NeverRuns(IntentDetector): + intent = 'never' + def __call__(self, messages): + seen.append('never') + return [0] + + ic = IntentClassifier(detectors=[StopAll(), NeverRuns()]) + ic.classify_intent([_row(_u('q'), _a('a'))]) + assert seen == ['stop'] diff --git a/tests/preprocessor/test_message_sanity.py b/tests/preprocessor/test_message_sanity.py new file mode 100644 index 00000000..3996219d --- /dev/null +++ b/tests/preprocessor/test_message_sanity.py @@ -0,0 +1,386 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for MessageSanityFilter preprocessor.""" +import pytest + +from twinkle_agentic.preprocessor.message_sanity import ( + MessageSanityFilter, + _validate_role_order, + _validate_tool_call_matching, + _validate_content_integrity, + _trim_to_last_assistant, +) + + +# ── Fixtures ────────────────────────────────────────────────────────────────── + +def _make_rows(messages_list): + """Wrap messages lists into row-format for the filter.""" + return [{'messages': m} for m in messages_list] + + +def _run_filter(messages_list, **kwargs): + """Run MessageSanityFilter on a list of message sequences, return surviving messages.""" + f = MessageSanityFilter(**kwargs) + rows = _make_rows(messages_list) + result = f.message_sanity_filter(rows) + return [r['messages'] for r in result] + + +# ── Role order tests ────────────────────────────────────────────────────────── + +class TestRoleOrder: + def test_valid_simple(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + ] + assert _validate_role_order(msgs) is True + + def test_valid_with_system(self): + msgs = [ + {'role': 'system', 'content': 'You are helpful.'}, + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + ] + assert _validate_role_order(msgs) is True + + def test_system_not_first(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'system', 'content': 'late system'}, + {'role': 'assistant', 'content': 'hello'}, + ] + assert _validate_role_order(msgs) is False + + def test_tool_without_tool_calls(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'let me check'}, + {'role': 'tool', 'content': 'result', 'tool_call_id': 'x'}, + ] + assert _validate_role_order(msgs) is False + + def test_tool_after_assistant_with_tool_calls(self): + msgs = [ + {'role': 'user', 'content': 'search'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'search', 'arguments': '{}'}} + ]}, + {'role': 'tool', 'content': 'found it', 'tool_call_id': 'c1'}, + ] + assert _validate_role_order(msgs) is True + + def test_tool_after_user(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'tool', 'content': 'bad', 'tool_call_id': 'x'}, + ] + assert _validate_role_order(msgs) is False + + def test_invalid_role(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'bot', 'content': 'hello'}, + ] + assert _validate_role_order(msgs) is False + + def test_empty(self): + assert _validate_role_order([]) is False + + def test_consecutive_tools(self): + msgs = [ + {'role': 'user', 'content': 'do things'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'a', 'arguments': '{}'}}, + {'id': 'c2', 'type': 'function', 'function': {'name': 'b', 'arguments': '{}'}}, + ]}, + {'role': 'tool', 'content': 'res1', 'tool_call_id': 'c1'}, + {'role': 'tool', 'content': 'res2', 'tool_call_id': 'c2'}, + ] + assert _validate_role_order(msgs) is True + + +# ── Tool call matching tests ────────────────────────────────────────────────── + +class TestToolCallMatching: + def test_valid_matching(self): + msgs = [ + {'role': 'user', 'content': 'go'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'fn', 'arguments': '{}'}}, + ]}, + {'role': 'tool', 'content': 'ok', 'tool_call_id': 'c1'}, + {'role': 'assistant', 'content': 'done'}, + ] + assert _validate_tool_call_matching(msgs) is True + + def test_orphan_tool_calls(self): + """Assistant has tool_calls but no tool response follows.""" + msgs = [ + {'role': 'user', 'content': 'go'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'fn', 'arguments': '{}'}}, + ]}, + {'role': 'user', 'content': 'what happened?'}, + ] + assert _validate_tool_call_matching(msgs) is False + + def test_phantom_tool_response(self): + """Tool response references an ID not in the assistant's tool_calls.""" + msgs = [ + {'role': 'user', 'content': 'go'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'fn', 'arguments': '{}'}}, + ]}, + {'role': 'tool', 'content': 'ok', 'tool_call_id': 'WRONG_ID'}, + ] + assert _validate_tool_call_matching(msgs) is False + + def test_partial_response_ok(self): + """Only some tool_calls get responses — currently allowed.""" + msgs = [ + {'role': 'user', 'content': 'go'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'a', 'arguments': '{}'}}, + {'id': 'c2', 'type': 'function', 'function': {'name': 'b', 'arguments': '{}'}}, + ]}, + {'role': 'tool', 'content': 'res1', 'tool_call_id': 'c1'}, + ] + assert _validate_tool_call_matching(msgs) is True + + def test_no_tool_calls_passes(self): + """Conversations without tool_calls pass trivially.""" + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + ] + assert _validate_tool_call_matching(msgs) is True + + +# ── Content integrity tests ─────────────────────────────────────────────────── + +class TestContentIntegrity: + def test_valid_basic(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello there'}, + ] + assert _validate_content_integrity(msgs) is True + + def test_empty_assistant(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': ''}, + ] + assert _validate_content_integrity(msgs) is False + + def test_assistant_with_tool_calls_no_content_ok(self): + msgs = [ + {'role': 'user', 'content': 'search'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'search_web', 'arguments': '{"q":"test"}'}} + ]}, + ] + assert _validate_content_integrity(msgs) is True + + def test_empty_system(self): + msgs = [ + {'role': 'system', 'content': ''}, + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + ] + assert _validate_content_integrity(msgs) is False + + def test_too_long_message(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'x' * 60000}, + ] + assert _validate_content_integrity(msgs, max_msg_chars=50000) is False + + def test_invalid_tool_call_structure(self): + msgs = [ + {'role': 'user', 'content': 'go'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'function': 'not_a_dict'}, # function must be dict + ]}, + ] + assert _validate_content_integrity(msgs) is False + + def test_invalid_function_name(self): + msgs = [ + {'role': 'user', 'content': 'go'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': '123bad', 'arguments': '{}'}}, + ]}, + ] + assert _validate_content_integrity(msgs) is False + + def test_invalid_arguments_json(self): + msgs = [ + {'role': 'user', 'content': 'go'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'fn', 'arguments': '{invalid json'}}, + ]}, + ] + assert _validate_content_integrity(msgs) is False + + def test_dict_arguments_ok(self): + msgs = [ + {'role': 'user', 'content': 'go'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'fn', 'arguments': {'key': 'val'}}}, + ]}, + ] + assert _validate_content_integrity(msgs) is True + + def test_duplicate_user_messages(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + ] + assert _validate_content_integrity(msgs) is False + + def test_duplicate_tool_messages_allowed(self): + """Two consecutive tool messages with same content should NOT be rejected.""" + msgs = [ + {'role': 'user', 'content': 'search both'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'search', 'arguments': '{"q":"x"}'}}, + {'id': 'c2', 'type': 'function', 'function': {'name': 'search', 'arguments': '{"q":"x"}'}}, + ]}, + {'role': 'tool', 'content': 'same result', 'tool_call_id': 'c1'}, + {'role': 'tool', 'content': 'same result', 'tool_call_id': 'c2'}, + {'role': 'assistant', 'content': 'both returned same'}, + ] + assert _validate_content_integrity(msgs) is True + + def test_min_turns(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + ] + # min_turns=2 → user(1)+assistant(1)=2 >= 2 → pass + assert _validate_content_integrity(msgs, min_turns=2) is True + # min_turns=3 → total=2 < 3 → fail + assert _validate_content_integrity(msgs, min_turns=3) is False + + +# ── Trim tests ──────────────────────────────────────────────────────────────── + +class TestTrimToLastAssistant: + def test_already_ends_with_assistant(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + ] + assert _trim_to_last_assistant(msgs) == msgs + + def test_trim_trailing_user(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + {'role': 'user', 'content': 'bye'}, + ] + assert _trim_to_last_assistant(msgs) == msgs[:2] + + def test_no_assistant(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'user', 'content': 'hello?'}, + ] + assert _trim_to_last_assistant(msgs) == [] + + +# ── Sensitive word tests ────────────────────────────────────────────────────── + +class TestSensitiveWords: + def test_english_word_boundary(self): + msgs_clean = [ + {'role': 'user', 'content': 'hello world'}, + {'role': 'assistant', 'content': 'hi there'}, + ] + msgs_bad = [ + {'role': 'user', 'content': 'hello world'}, + {'role': 'assistant', 'content': 'what the fuck'}, + ] + result = _run_filter( + [msgs_clean, msgs_bad], + extra_sensitive_words=['fuck'], + ) + assert len(result) == 1 + assert result[0] == msgs_clean + + def test_chinese_sensitive(self): + msgs_bad = [ + {'role': 'user', 'content': '你好'}, + {'role': 'assistant', 'content': '操你妈'}, + ] + result = _run_filter( + [msgs_bad], + extra_sensitive_words=['操你妈'], + ) + assert len(result) == 0 + + def test_no_sensitive_config_passes_all(self): + msgs = [ + {'role': 'user', 'content': 'fuck'}, + {'role': 'assistant', 'content': 'hello'}, + ] + # No sensitive words configured → everything passes + result = _run_filter([msgs]) + assert len(result) == 1 + + +# ── End-to-end filter tests ─────────────────────────────────────────────────── + +class TestEndToEnd: + def test_full_valid_agentic_trajectory(self): + msgs = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'What is the weather?'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'call_1', 'type': 'function', + 'function': {'name': 'get_weather', 'arguments': '{"city":"Beijing"}'}}, + ]}, + {'role': 'tool', 'content': '{"temp": 22, "condition": "sunny"}', 'tool_call_id': 'call_1'}, + {'role': 'assistant', 'content': 'It is 22°C and sunny in Beijing.'}, + ] + result = _run_filter([msgs]) + assert len(result) == 1 + + def test_trim_and_validate(self): + """Trailing user message gets trimmed, result still valid.""" + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'assistant', 'content': 'hello'}, + {'role': 'user', 'content': 'thanks'}, + ] + result = _run_filter([msgs]) + assert len(result) == 1 + assert result[0][-1]['role'] == 'assistant' + + def test_no_assistant_discarded(self): + msgs = [ + {'role': 'user', 'content': 'hi'}, + {'role': 'user', 'content': 'hello?'}, + ] + result = _run_filter([msgs]) + assert len(result) == 0 + + def test_multiple_tool_rounds(self): + msgs = [ + {'role': 'user', 'content': 'plan a trip'}, + {'role': 'assistant', 'content': '', 'tool_calls': [ + {'id': 'c1', 'type': 'function', 'function': {'name': 'search_flights', 'arguments': '{}'}}, + ]}, + {'role': 'tool', 'content': 'flight options...', 'tool_call_id': 'c1'}, + {'role': 'assistant', 'content': 'Found flights. Let me check hotels.', 'tool_calls': [ + {'id': 'c2', 'type': 'function', 'function': {'name': 'search_hotels', 'arguments': '{}'}}, + ]}, + {'role': 'tool', 'content': 'hotel options...', 'tool_call_id': 'c2'}, + {'role': 'assistant', 'content': 'Here is your complete trip plan.'}, + ] + result = _run_filter([msgs]) + assert len(result) == 1 diff --git a/tests/preprocessor/test_pii_presidio_filter.py b/tests/preprocessor/test_pii_presidio_filter.py new file mode 100644 index 00000000..9a642ace --- /dev/null +++ b/tests/preprocessor/test_pii_presidio_filter.py @@ -0,0 +1,223 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for pure helpers in pii_presidio_filter. + +Only validators and replacement primitives are tested here — the full +``PIIPresidioFilter`` requires presidio_analyzer + spacy + faker which are +heavy/optional deps. Pure helpers are usable standalone and have clear +mathematical contracts. + +Coverage: + * ``_is_valid_cn_id`` — 18-digit checksum (last digit may be 'X') + * ``_is_valid_luhn`` — Luhn algorithm with min length 13 + * ``_mask_keep_edges`` — keep head/tail, mask middle + * ``_hash_short`` — SHA-256 prefix, deterministic w/ salt + * ``Strategy.coerce`` — enum coercion + strict failure mode +""" +import hashlib + +import pytest + +from twinkle_agentic.preprocessor.pii_presidio_filter import ( + Strategy, + _hash_short, + _is_valid_cn_id, + _is_valid_luhn, + _mask_keep_edges, +) + + +# ── _is_valid_cn_id ───────────────────────────────────────────────────────── + +class TestIsValidCnId: + """ + Verified against the official GB 11643-1999 weights: + weights = (7,9,10,5,8,4,2,1,6,3,7,9,10,5,8,4,2) + checks = '10X98765432' + Test ID `11010519491231002X` is a textbook valid example. + """ + + def test_valid_id_with_x_check(self): + assert _is_valid_cn_id('11010519491231002X') is True + + def test_valid_id_with_x_lowercase(self): + # Implementation upper-cases the check digit before compare. + assert _is_valid_cn_id('11010519491231002x') is True + + def test_invalid_check_digit(self): + # Flip the last char to a wrong number. + assert _is_valid_cn_id('110105194912310020') is False + + def test_too_short(self): + assert _is_valid_cn_id('110105194912310') is False + + def test_too_long(self): + assert _is_valid_cn_id('11010519491231002X9') is False + + def test_non_digit_in_first_17(self): + assert _is_valid_cn_id('1101051949123100AX') is False + + def test_empty(self): + assert _is_valid_cn_id('') is False + + def test_18_digits_invalid_checksum(self): + # 18 digits but last is wrong number + assert _is_valid_cn_id('110105194912310029') is False + + +# ── _is_valid_luhn ────────────────────────────────────────────────────────── + +class TestIsValidLuhn: + """ + `4532015112830366` is a well-known Visa test number that satisfies Luhn. + """ + + def test_valid_visa_test_number(self): + assert _is_valid_luhn('4532015112830366') is True + + def test_valid_with_separators(self): + # Implementation strips non-digits via `c.isdigit()`. + assert _is_valid_luhn('4532-0151-1283-0366') is True + assert _is_valid_luhn('4532 0151 1283 0366') is True + + def test_invalid_checksum(self): + # Flip the last digit. + assert _is_valid_luhn('4532015112830367') is False + + def test_too_short(self): + # Only 12 digits — below 13-digit minimum. + assert _is_valid_luhn('453201511283') is False + + def test_empty(self): + assert _is_valid_luhn('') is False + + def test_no_digits(self): + assert _is_valid_luhn('abcd-efgh-ijkl-mnop') is False + + def test_amex_test_number(self): + # 15-digit Amex test card. + assert _is_valid_luhn('378282246310005') is True + + def test_mastercard_test_number(self): + assert _is_valid_luhn('5555555555554444') is True + + +# ── _mask_keep_edges ──────────────────────────────────────────────────────── + +class TestMaskKeepEdges: + def test_default_head_tail(self): + # head=3, tail=4 → keep 3 + mask middle + keep 4 + s = '13800138000' # 11 chars + # 11 > 3+4 = 7 → masked = 11 - 7 = 4 stars + out = _mask_keep_edges(s) + assert out == '138' + '*' * 4 + '8000' + + def test_short_string_all_masked(self): + # len ≤ head+tail → entire string masked. + s = 'short' # 5 chars; head+tail = 7 + assert _mask_keep_edges(s) == '*****' + + def test_at_threshold_all_masked(self): + # len == head+tail → all masked (boundary is `<=`) + s = '1234567' # 7 chars + assert _mask_keep_edges(s) == '*' * 7 + + def test_custom_head_tail(self): + s = 'abcdefghij' # 10 chars + # head=2, tail=2 → keep ab + 6 stars + ij + assert _mask_keep_edges(s, head=2, tail=2) == 'ab' + '*' * 6 + 'ij' + + def test_custom_mask_char(self): + s = '1234567890' + out = _mask_keep_edges(s, head=1, tail=1, ch='X') + assert out == '1' + 'X' * 8 + '0' + + def test_empty_string(self): + # len=0 ≤ head+tail → '' * 0 = '' + assert _mask_keep_edges('') == '' + + def test_credit_card_default(self): + s = '4532015112830366' # 16 chars + out = _mask_keep_edges(s) + # head=3, tail=4 → keep 453 + 9 stars + 0366 + assert out == '453' + '*' * 9 + '0366' + + +# ── _hash_short ───────────────────────────────────────────────────────────── + +class TestHashShort: + def test_length_is_12(self): + assert len(_hash_short('alice@example.com')) == 12 + + def test_deterministic_same_input(self): + a = _hash_short('hello') + b = _hash_short('hello') + assert a == b + + def test_different_inputs_different_outputs(self): + a = _hash_short('alice@example.com') + b = _hash_short('bob@example.com') + assert a != b + + def test_salt_changes_output(self): + a = _hash_short('hello', salt='') + b = _hash_short('hello', salt='secret') + assert a != b + + def test_matches_sha256_prefix(self): + expected = hashlib.sha256(b'hello').hexdigest()[:12] + assert _hash_short('hello') == expected + + def test_matches_sha256_with_salt(self): + expected = hashlib.sha256(b'saltyhello').hexdigest()[:12] + assert _hash_short('hello', salt='salty') == expected + + def test_empty_string(self): + # Hash is well-defined for empty input too. + expected = hashlib.sha256(b'').hexdigest()[:12] + assert _hash_short('') == expected + + def test_unicode_input(self): + # UTF-8 encoding before hashing. + expected = hashlib.sha256('张三'.encode('utf-8')).hexdigest()[:12] + assert _hash_short('张三') == expected + + +# ── Strategy.coerce ───────────────────────────────────────────────────────── + +class TestStrategyCoerce: + def test_coerce_string_to_enum(self): + assert Strategy.coerce('mask') is Strategy.MASK + assert Strategy.coerce('replace') is Strategy.REPLACE + assert Strategy.coerce('redact') is Strategy.REDACT + assert Strategy.coerce('hash') is Strategy.HASH + + def test_coerce_enum_returns_self(self): + assert Strategy.coerce(Strategy.MASK) is Strategy.MASK + + def test_coerce_unknown_raises(self): + with pytest.raises(ValueError) as exc: + Strategy.coerce('encrypt') + # Error message lists allowed strategies for diagnosability. + msg = str(exc.value) + assert 'mask' in msg + assert 'replace' in msg + assert 'redact' in msg + assert 'hash' in msg + + def test_coerce_empty_string_raises(self): + with pytest.raises(ValueError): + Strategy.coerce('') + + def test_string_enum_membership(self): + # Strategy is a str-Enum: values should compare equal to their str form. + assert Strategy.MASK == 'mask' + assert Strategy.REPLACE.value == 'replace' + + def test_coerce_case_sensitive(self): + # Implementation does not lowercase before lookup. + with pytest.raises(ValueError): + Strategy.coerce('MASK') + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/preprocessor/test_preprocessor_utils.py b/tests/preprocessor/test_preprocessor_utils.py new file mode 100644 index 00000000..a41796b1 --- /dev/null +++ b/tests/preprocessor/test_preprocessor_utils.py @@ -0,0 +1,333 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for preprocessor.utils — pure logprob math helpers. + +These helpers compute conditional-vs-unconditional logprob deltas for +IFD-family scoring (CherryLLM, T-SHIRT, ChR). All functions are stateless +and accept simple list inputs. + +Conventions used in this test file: + * "lp" lists are aligned to the FULL sequence (prompt + answer). + * ``n_prompt`` is the number of prompt tokens; assistant tokens start at + index ``n_prompt`` in the cond list. + * Each lp entry is a dict {token_id: logprob_float}. +""" +import math + +import pytest + +from twinkle_agentic.preprocessor.utils import ( + _chr_min_distinct, + _chr_min_weighted, + _extract_logprob, + _ifd_family_metrics, + _lp_to_jsonable, + _mean_logprob_delta, + _pad_batch, + _to_int_list, +) + + +# ── _extract_logprob ──────────────────────────────────────────────────────── + +class TestExtractLogprob: + def test_none(self): + assert _extract_logprob(None) is None + + def test_scalar_int(self): + assert _extract_logprob(5) == 5.0 + + def test_scalar_float(self): + assert _extract_logprob(-1.2) == -1.2 + + def test_dict_with_int_token_id(self): + lp = {7: -0.5, 8: -2.0} + assert _extract_logprob(lp, token_id=7) == -0.5 + assert _extract_logprob(lp, token_id=8) == -2.0 + + def test_dict_with_str_token_id_fallback(self): + # vLLM may emit string keys; lookup must fall back to str(token_id). + lp = {'7': -0.5} + assert _extract_logprob(lp, token_id=7) == -0.5 + + def test_dict_no_token_id_picks_first(self): + # No token_id → iter-first behaviour. + lp = {7: -0.5} + assert _extract_logprob(lp) == -0.5 + + def test_dict_token_id_missing_uses_first(self): + # token_id not in dict → fall back to first entry. + lp = {99: -3.0} + assert _extract_logprob(lp, token_id=7) == -3.0 + + def test_dict_with_logprob_attr_object(self): + class Entry: + def __init__(self, v): + self.logprob = v + lp = {7: Entry(-0.7)} + assert _extract_logprob(lp, token_id=7) == -0.7 + + def test_dict_with_nested_dict(self): + lp = {7: {'logprob': -0.9, 'rank': 1}} + assert _extract_logprob(lp, token_id=7) == -0.9 + + def test_dict_with_nested_dict_none_logprob(self): + lp = {7: {'logprob': None}} + assert _extract_logprob(lp, token_id=7) is None + + def test_unrecognized_type(self): + # str entries → returns None + lp = {7: 'oops'} + assert _extract_logprob(lp, token_id=7) is None + + def test_non_dict_non_scalar(self): + # A list is neither scalar nor dict → None. + assert _extract_logprob([1, 2, 3]) is None + + +# ── _to_int_list ──────────────────────────────────────────────────────────── + +class TestToIntList: + def test_plain_list(self): + assert _to_int_list([1, 2, 3]) == [1, 2, 3] + + def test_tuple(self): + assert _to_int_list((1, 2, 3)) == [1, 2, 3] + + def test_with_tolist(self): + class Tensor: + def tolist(self): + return [4, 5, 6] + assert _to_int_list(Tensor()) == [4, 5, 6] + + def test_empty(self): + assert _to_int_list([]) == [] + + +# ── _chr_min_distinct ─────────────────────────────────────────────────────── + +class TestChrMinDistinct: + def test_empty_inputs_returns_none(self): + assert _chr_min_distinct([], [{1: -1.0}], [], [1], 0) is None + assert _chr_min_distinct([{1: -1.0}], [], [1], [], 0) is None + assert _chr_min_distinct([{1: -1.0}], [{1: -1.0}], [1], [], 0) is None + + def test_simple_all_positive(self): + # cond_lp[i] - asst_lp[i] > 0 for all i → ratio = 1.0 + n_prompt = 1 + # cond covers prompt(1) + asst(2) = 3 positions + cond_lp = [{0: -10.0}, # prompt position + {1: -0.1}, # asst pos 0 — high cond logprob + {2: -0.2}] # asst pos 1 + asst_lp = [{1: -1.0}, {2: -1.5}] + cond_ids = [0, 1, 2] + asst_ids = [1, 2] + ratio = _chr_min_distinct(cond_lp, asst_lp, cond_ids, asst_ids, n_prompt) + assert ratio == 1.0 + + def test_all_negative(self): + # delta < 0 → ratio = 0 + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: -3.0}, {2: -3.0}] + asst_lp = [{1: -0.5}, {2: -0.5}] + ratio = _chr_min_distinct(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt) + assert ratio == 0.0 + + def test_distinct_token_min_aggregation(self): + # Two occurrences of same token: one has +delta, one has -delta. + # min(deltas) is negative → token contributes 0 to ratio. + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: -0.1}, {1: -3.0}] + asst_lp = [{1: -1.0}, {1: -0.5}] # delta1=+0.9, delta2=-2.5 + ratio = _chr_min_distinct(cond_lp, asst_lp, [0, 1, 1], [1, 1], n_prompt) + assert ratio == 0.0 # min < 0 + + def test_exclude_ids(self): + # Excluded token is dropped before counting. + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: -0.1}, {2: -0.1}] + asst_lp = [{1: -1.0}, {2: -1.0}] + # Without exclude: 2 distinct tokens, both positive → 1.0 + ratio = _chr_min_distinct(cond_lp, asst_lp, [0, 1, 2], [1, 2], + n_prompt, exclude_ids={1}) + assert ratio == 1.0 # only token 2 counted, still positive + + def test_truncation_when_cond_short(self): + # cond_lp shorter than n_prompt + n_asst → loop breaks early. + n_prompt = 2 + cond_lp = [{0: 0.0}, {0: 0.0}, {1: -0.1}] # only 1 asst position + asst_lp = [{1: -1.0}, {2: -1.0}] # 2 asst positions requested + ratio = _chr_min_distinct(cond_lp, asst_lp, [0, 0, 1], [1, 2], n_prompt) + assert ratio == 1.0 # only the first delta processed + + +# ── _chr_min_weighted ─────────────────────────────────────────────────────── + +class TestChrMinWeighted: + def test_empty_returns_none(self): + assert _chr_min_weighted([], [{1: -1.0}], [], [1], 0) is None + + def test_all_positive_returns_one(self): + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: -0.1}, {2: -0.2}] + asst_lp = [{1: -1.0}, {2: -1.5}] + ratio = _chr_min_weighted(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt) + assert ratio == 1.0 # all positive → pos_w == total_w + + def test_zero_total_weight_returns_none(self): + # All deltas == 0 → total_w == 0 → None + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: -1.0}] + asst_lp = [{1: -1.0}] + assert _chr_min_weighted(cond_lp, asst_lp, [0, 1], [1], n_prompt) is None + + def test_weighted_mixture(self): + # Token A: min_delta = +2.0 (weight 2) + # Token B: min_delta = -1.0 (weight 1) + # pos / total = 2 / 3 + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: 1.0}, {2: -2.0}] # cond: A=1.0, B=-2.0 + asst_lp = [{1: -1.0}, {2: -1.0}] # asst: A=-1.0, B=-1.0 + # delta A = 1.0 - (-1.0) = 2.0 + # delta B = -2.0 - (-1.0) = -1.0 + ratio = _chr_min_weighted(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt) + assert abs(ratio - 2 / 3) < 1e-9 + + +# ── _ifd_family_metrics ───────────────────────────────────────────────────── + +class TestIfdFamilyMetrics: + def test_empty_returns_empty_dict(self): + assert _ifd_family_metrics([], [{1: -1.0}], [], [1], 0) == {} + + def test_simple_uniform(self): + # All deltas = 0.5 → mean=0.5, ifd=exp(-0.5) + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: -0.5}, {2: -0.5}] + asst_lp = [{1: -1.0}, {2: -1.0}] + out = _ifd_family_metrics(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt) + assert out['n_tokens'] == 2 + assert abs(out['mean_delta'] - 0.5) < 1e-9 + assert abs(out['ifd'] - math.exp(-0.5)) < 1e-9 + # s_ifd_50 keeps top-1 by |delta| = 0.5; s_ifd_75 keeps top-2 (rounded up). + assert abs(out['s_ifd_50'] - math.exp(-0.5)) < 1e-9 + assert abs(out['s_ifd_75'] - math.exp(-0.5)) < 1e-9 + + def test_mixed_deltas(self): + # deltas = [+2.0, -1.0]; mean = 0.5 + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: 1.0}, {2: -2.0}] + asst_lp = [{1: -1.0}, {2: -1.0}] + out = _ifd_family_metrics(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt) + assert out['n_tokens'] == 2 + assert abs(out['mean_delta'] - 0.5) < 1e-9 + # s_ifd_50 keeps top-1 by |delta| = 2.0 → exp(-2.0) + assert abs(out['s_ifd_50'] - math.exp(-2.0)) < 1e-9 + + +# ── _mean_logprob_delta ───────────────────────────────────────────────────── + +class TestMeanLogprobDelta: + def test_empty(self): + assert _mean_logprob_delta([], [{1: -1.0}], [], [1], 0) is None + + def test_uniform_delta(self): + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: -0.5}, {2: -0.5}] + asst_lp = [{1: -1.0}, {2: -1.0}] + out = _mean_logprob_delta(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt) + assert abs(out - 0.5) < 1e-9 + + def test_mixed_average(self): + # deltas = [+2.0, -1.0] → mean 0.5 + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: 1.0}, {2: -2.0}] + asst_lp = [{1: -1.0}, {2: -1.0}] + out = _mean_logprob_delta(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt) + assert abs(out - 0.5) < 1e-9 + + def test_skips_none_logprobs(self): + # When asst lp returns None, that position is skipped silently. + n_prompt = 1 + cond_lp = [{0: 0.0}, {1: -0.5}, {2: -0.5}] + asst_lp = [None, {2: -1.0}] + out = _mean_logprob_delta(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt) + assert abs(out - 0.5) < 1e-9 # only position 1 used + + +# ── _lp_to_jsonable ───────────────────────────────────────────────────────── + +class TestLpToJsonable: + def test_none_input(self): + assert _lp_to_jsonable(None) == [] + + def test_empty(self): + assert _lp_to_jsonable([]) == [] + + def test_none_passthrough(self): + assert _lp_to_jsonable([None, None]) == [None, None] + + def test_scalar_to_float(self): + assert _lp_to_jsonable([1, -2.0]) == [1.0, -2.0] + + def test_dict_with_logprob_object(self): + class Entry: + def __init__(self, lp, rank, decoded): + self.logprob = lp + self.rank = rank + self.decoded_token = decoded + out = _lp_to_jsonable([{7: Entry(-0.5, 1, 'hello')}]) + assert out == [{ + '7': {'logprob': -0.5, 'rank': 1, 'decoded': 'hello'} + }] + + def test_dict_with_nested_dict(self): + out = _lp_to_jsonable([{7: {'logprob': -0.5}}]) + assert out == [{'7': {'logprob': -0.5}}] + + def test_dict_with_repr_fallback(self): + # Non-dict, non-Entry value falls back to repr string. + out = _lp_to_jsonable([{7: 'plain'}]) + assert out == [{'7': repr('plain')}] + + def test_non_dict_non_scalar_repr(self): + # An object that isn't dict/scalar gets repr-ed. + out = _lp_to_jsonable([(1, 2)]) + assert out == [repr((1, 2))] + + +# ── _pad_batch ────────────────────────────────────────────────────────────── + +class TestPadBatch: + def test_empty_batch(self): + padded, n = _pad_batch([], floor=4) + assert padded == [] + assert n == 0 + + def test_already_at_floor(self): + batch = [[1], [2], [3], [4]] + padded, n = _pad_batch(batch, floor=4) + assert padded == batch + assert n == 4 + + def test_above_floor(self): + batch = [[1], [2], [3], [4], [5]] + padded, n = _pad_batch(batch, floor=3) + assert padded == batch # unchanged + assert n == 5 + + def test_below_floor_pads_with_last(self): + batch = [[1], [2]] + padded, n = _pad_batch(batch, floor=4) + assert padded == [[1], [2], [2], [2]] + assert n == 2 # original size + + def test_returns_new_list(self): + batch = [[1], [2]] + padded, _ = _pad_batch(batch, floor=4) + # Mutating padded should not affect original. + padded.append([99]) + assert batch == [[1], [2]] + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/preprocessor/test_refuse_filter.py b/tests/preprocessor/test_refuse_filter.py new file mode 100644 index 00000000..bab7d25a --- /dev/null +++ b/tests/preprocessor/test_refuse_filter.py @@ -0,0 +1,246 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for RefuseFilter. + +RefuseFilter drops rows whose first assistant reply contains a self-referential +refusal signal across en/zh/ja/ko. Key invariants: + + * Detection is anchored to ``I/we`` + inability + task verb (en) and 我 + + 无法/不能 + 帮/回答/... (zh). Generic uses of ``cannot`` / ``无法`` that are + NOT self-referential must NOT trip the filter. + * Only the FIRST assistant turn is scanned, within ``check_window`` chars. + * ``...`` blocks are stripped before scanning. + * Empty / missing assistant content is kept (filter is conservative). +""" +import pytest + +from twinkle_agentic.preprocessor.refuse_filter import ( + RefuseFilter, + _is_refusal, +) + + +def _row(messages): + return {'messages': messages} + + +def _fil(rows, **kw): + return RefuseFilter(**kw)(rows) + + +# ── _is_refusal: English ──────────────────────────────────────────────────── + +class TestRefusalEnglish: + def test_core_self_referential(self): + assert _is_refusal("I can't help with that request.") is True + assert _is_refusal('I cannot assist with creating malware.') is True + assert _is_refusal('I am not able to provide that information.') is True + assert _is_refusal("We won't generate that content.") is True + + def test_apology_opener(self): + assert _is_refusal("I'm sorry, but I can't help with that.") is True + assert _is_refusal('Unfortunately, I am unable to comply.') is True + assert _is_refusal('I apologize, but I cannot answer that.') is True + + def test_policy_violation(self): + assert _is_refusal('This request violates my guidelines.') is True + assert _is_refusal('Your question goes against safety policies.') is True + + def test_standalone_decline(self): + assert _is_refusal('I must decline this request.') is True + assert _is_refusal('I refuse to participate in that.') is True + assert _is_refusal("As an AI, I can't help with illegal activities.") is True + + def test_false_positive_cannot_stress(self): + # "I cannot stress enough" — has I+cannot but NOT a task verb in window + text = 'I cannot stress enough how important documentation is.' + assert _is_refusal(text) is False + + def test_false_positive_cannot_find(self): + # "I cannot find the bug" — find is not in task-verb whitelist + text = 'I cannot find the bug after looking everywhere.' + assert _is_refusal(text) is False + + def test_false_positive_third_person(self): + # No I/we subject → not refusal + assert _is_refusal('The API cannot handle null inputs.') is False + assert _is_refusal('You cannot use this without auth.') is False + + def test_case_insensitive(self): + assert _is_refusal("I CANNOT HELP with that.") is True + assert _is_refusal("i can't assist you here") is True + + +# ── _is_refusal: Chinese ──────────────────────────────────────────────────── + +class TestRefusalChinese: + def test_apology_inability(self): + assert _is_refusal('非常抱歉,我无法回答这个问题。') is True + assert _is_refusal('对不起,我不能提供这类信息。') is True + assert _is_refusal('抱歉,没有办法帮助您完成这个请求。') is True + + def test_self_referential(self): + assert _is_refusal('我无法帮助你完成这项任务。') is True + assert _is_refusal('我不能提供医疗建议。') is True + assert _is_refusal('我不被允许回答此类问题。') is True + + def test_request_violation(self): + assert _is_refusal('您的请求违反了使用规则。') is True + assert _is_refusal('这个问题不合适讨论。') is True + + def test_ai_identity_refusal(self): + assert _is_refusal('作为AI,我无法回答涉及隐私的问题。') is True + assert _is_refusal('作为人工智能,我不能提供该类内容。') is True + + def test_false_positive_self_deprecating_preamble(self): + # "作为AI,我虽无法体验情感,但..." — 无法 not paired with task verb + text = '作为AI,我虽无法亲身体验情感,但可以分析其特征。' + assert _is_refusal(text) is False + + def test_false_positive_third_person(self): + # No 我 subject + task verb pattern + assert _is_refusal('该接口无法处理空输入。') is False + # "无法理解" with no 我 / task verb pattern + assert _is_refusal('系统暂时不能正常运作。') is False + + +# ── _is_refusal: Japanese ─────────────────────────────────────────────────── + +class TestRefusalJapanese: + def test_apology_inability(self): + # Pattern is `(申し訳|恐れ入り)ます(が|けれど)` — only 恐れ入りますが + # is grammatically natural. + assert _is_refusal('恐れ入りますが、お答えできません。') is True + assert _is_refusal('恐れ入りますけれど、対応できません。') is True + + def test_explicit_inability(self): + assert _is_refusal('回答することはできません。') is True + assert _is_refusal('ご要望にはお応えできません。') is True + + def test_request_decline(self): + # Matches `(その|この)(リクエスト|質問|依頼).{0,20}(お断り|辞退|対応できません)` + assert _is_refusal('そのリクエストはお断りします。') is True + assert _is_refusal('このリクエストには対応できません。') is True + + +# ── _is_refusal: Korean ───────────────────────────────────────────────────── + +class TestRefusalKorean: + def test_apology_inability(self): + assert _is_refusal('죄송하지만 답변을 드릴 수 없습니다.') is True + assert _is_refusal('유감스럽게도 도와드릴 수 없습니다.') is True + + def test_action_difficulty(self): + assert _is_refusal('답변드리기 어렵습니다.') is True + assert _is_refusal('처리하기 불가능합니다.') is True + + +# ── check_window ──────────────────────────────────────────────────────────── + +class TestCheckWindow: + def test_window_excludes_late_refusal(self): + # Refusal at position 700 — beyond default 600-char window + text = 'a' * 700 + " I can't help you complete that task." + assert _is_refusal(text, check_window=600) is False + + def test_custom_window_includes_late_refusal(self): + text = 'a' * 700 + " I can't help you complete that task." + assert _is_refusal(text, check_window=1000) is True + + def test_zero_window_finds_nothing(self): + assert _is_refusal("I can't help you complete tasks.", check_window=0) is False + + +# ── RefuseFilter pipeline ─────────────────────────────────────────────────── + +class TestRefuseFilterPipeline: + def test_drops_refusal_row(self): + rows = [_row([ + {'role': 'user', 'content': 'do bad thing'}, + {'role': 'assistant', 'content': + "I'm sorry, but I cannot help with that request."}, + ])] + assert _fil(rows) == [] + + def test_keeps_normal_reply(self): + rows = [_row([ + {'role': 'user', 'content': 'explain X'}, + {'role': 'assistant', 'content': 'X is a concept that...'}, + ])] + assert len(_fil(rows)) == 1 + + def test_only_first_assistant_scanned(self): + # Refusal in SECOND assistant turn → kept (filter only checks first). + rows = [_row([ + {'role': 'user', 'content': 'q1'}, + {'role': 'assistant', 'content': 'A clean reply.'}, + {'role': 'user', 'content': 'q2'}, + {'role': 'assistant', 'content': "I can't help with that."}, + ])] + assert len(_fil(rows)) == 1 + + def test_think_block_stripped(self): + # Refusal phrasing inside ... must NOT trigger. + rows = [_row([ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': + "I cannot help with this request" + "Sure, here is the answer: 42."}, + ])] + assert len(_fil(rows)) == 1 + + def test_no_assistant_kept(self): + rows = [_row([{'role': 'user', 'content': 'hi'}])] + assert len(_fil(rows)) == 1 + + def test_empty_assistant_kept(self): + rows = [_row([ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': ''}, + ])] + assert len(_fil(rows)) == 1 + + def test_empty_input(self): + assert _fil([]) == [] + + def test_missing_messages_kept(self): + # No messages key → no assistant → kept + rows = [{'id': 'x'}] + assert len(_fil(rows)) == 1 + + def test_mixed_batch(self): + rows = [ + _row([ + {'role': 'user', 'content': 'q1'}, + {'role': 'assistant', 'content': 'a normal answer'}, + ]), + _row([ + {'role': 'user', 'content': 'q2'}, + {'role': 'assistant', 'content': + 'I refuse to help you with that task.'}, + ]), + _row([ + {'role': 'user', 'content': 'q3'}, + {'role': 'assistant', 'content': + '抱歉,我无法回答这个问题。'}, + ]), + ] + out = _fil(rows) + assert len(out) == 1 + assert out[0]['messages'][0]['content'] == 'q1' + + def test_custom_check_window(self): + # Default 600 would miss a late refusal; tighten via pipeline kw. + long_prefix = 'a' * 700 + rows = [_row([ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': + long_prefix + " I can't help you complete that."}, + ])] + # default window → kept + assert len(_fil(rows)) == 1 + # widen → dropped + assert _fil(rows, check_window=1000) == [] + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/preprocessor/test_token_soup.py b/tests/preprocessor/test_token_soup.py new file mode 100644 index 00000000..c1b35beb --- /dev/null +++ b/tests/preprocessor/test_token_soup.py @@ -0,0 +1,253 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for TokenSoupFilter. + +Covers each garbled-output signal in ``_is_token_soup`` plus the +script-chaos analyzer and the row-filter pipeline. +""" +import pytest + +from twinkle_agentic.preprocessor.token_soup import ( + TokenSoupFilter, + _is_token_soup, + _script_chaos, + _script_of, +) + + +def _row(content): + return {'messages': [ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': content}, + ]} + + +# ── Per-signal detector tests ──────────────────────────────────────────────── + +class TestReplacementChar: + def test_above_threshold(self): + text = '\ufffd' * 5 + 'short' # 5/10 = 50% > 2% + assert _is_token_soup(text) is True + + def test_below_threshold(self): + text = '\ufffd' + 'hello world this is text. ' * 30 # 1/~780 ≈ 0.1% < 2% + # No other signal should fire + assert _is_token_soup(text) is False + + def test_no_replacement_char(self): + assert _is_token_soup('hello world') is False + + +class TestControlChar: + def test_above_threshold(self): + text = '\x01\x02\x03\x04\x05' + 'a' * 100 # 5/105 ≈ 4.8% > 1% + assert _is_token_soup(text) is True + + def test_keeps_legitimate_whitespace(self): + text = 'line1\nline2\tindented\rcr' + assert _is_token_soup(text) is False + + def test_del_char_triggers(self): + text = '\x7f' * 5 + 'a' * 100 + assert _is_token_soup(text) is True + + +class TestPrivateUseArea: + def test_bmp_pua_above_threshold(self): + text = '\ue000\ue001\ue002\ue003\ue004' + 'a' * 100 # 5/105 ≈ 4.8% > 3% + assert _is_token_soup(text) is True + + def test_below_threshold(self): + text = '\ue000' + 'hello world this is text. ' * 30 # ~0.1% < 3% + assert _is_token_soup(text) is False + + +class TestSpecialTokens: + def test_repeated_pipe_token(self): + text = '<|endoftext|>' * 25 + assert _is_token_soup(text, special_token_count=20) is True + + def test_repeated_bert_uppercase(self): + text = '[PAD]' * 25 + assert _is_token_soup(text, special_token_count=20) is True + + def test_lowercase_brackets_not_matched(self): + # ``dp[mask]`` is normal code; lowercase variant must NOT match. + text = 'arr[mask] = arr[mask] | 1; ' * 30 + assert _is_token_soup(text, special_token_count=20) is False + + def test_byte_token_form(self): + text = '<0x0A>' * 25 + assert _is_token_soup(text, special_token_count=20) is True + + def test_below_count(self): + text = '<|endoftext|>' * 5 + assert _is_token_soup(text, special_token_count=20) is False + + def test_unk_pad_html_tags(self): + text = '' * 12 + '' * 13 + assert _is_token_soup(text, special_token_count=20) is True + + +class TestSingleCharRepeat: + def test_letter_repeat_triggers(self): + text = 'aaaaaaaaaaaaaaaaaaaaaaaaaa hello world' # 26 a's > 19 + assert _is_token_soup(text) is True + + def test_dash_excluded(self): + text = '-' * 50 + ' separator' + assert _is_token_soup(text) is False + + def test_equals_excluded(self): + text = '=' * 50 + assert _is_token_soup(text) is False + + def test_digit_excluded(self): + text = '9' * 50 + assert _is_token_soup(text) is False + + def test_box_drawing_excluded(self): + text = '\u2500' * 50 # ─ box-drawing horizontal + assert _is_token_soup(text) is False + + def test_below_threshold(self): + text = 'a' * 19 # 19 < 20 (regex requires \1{19,} → 1 + 19 = 20) + assert _is_token_soup(text) is False + + def test_at_threshold(self): + text = 'a' * 20 # 20 a's: 1 + 19 repeats → matches + assert _is_token_soup(text) is True + + +# ── Script-chaos analyzer ──────────────────────────────────────────────────── + +class TestScriptOf: + def test_latin(self): + assert _script_of(ord('A')) == 'latin' + assert _script_of(ord('z')) == 'latin' + + def test_cjk(self): + assert _script_of(ord('中')) == 'cjk' + + def test_hiragana_katakana(self): + assert _script_of(0x3042) == 'hiragana' # あ + assert _script_of(0x30A2) == 'katakana' # ア + + def test_cyrillic(self): + assert _script_of(0x0410) == 'cyrillic' + + def test_hangul(self): + assert _script_of(0xAC00) == 'hangul' + + def test_private(self): + assert _script_of(0xE000) == 'private' + + def test_other(self): + assert _script_of(0x2000) == 'other' # general punctuation + + +class TestScriptChaos: + def test_pure_latin_zero_chaos(self): + assert _script_chaos('hello world this is a long english sentence') == 0.0 + + def test_pure_cjk_zero_chaos(self): + assert _script_chaos('这是一段足够长的中文文本用于测试脚本切换检测' * 2) == 0.0 + + def test_short_text_returns_zero(self): + # Below ``min_chars`` → returns 0.0 regardless of mix. + assert _script_chaos('aあ', min_chars=40) == 0.0 + + def test_high_chaos_alternation(self): + # Pure letter/number alternation between scripts → chaos ≈ 1.0. + text = ('aあbいcうdえeお' * 5) # 50 alternating letters + score = _script_chaos(text, min_chars=40) + assert score > 0.9 + + def test_filter_with_chaos(self): + text = ('aあbいcうdえeお' * 5) # high chaos + assert _is_token_soup(text, script_chaos_min_chars=40, + script_chaos_threshold=0.55) is True + + def test_skips_punct_whitespace(self): + # Categories not in (L, N) are dropped before script-of pairing. + text = 'hello, world! how are you?' + assert _script_chaos(text) == 0.0 + + +# ── max_chars head-sampling ────────────────────────────────────────────────── + +class TestMaxChars: + def test_only_head_examined(self): + # Soup at the tail; head is clean. With max_chars=100 we should not see it. + head = 'hello world this is plain text. ' * 4 # ~128 chars, no repeat-20 + text = head[:100] + '\ufffd' * 100 + assert _is_token_soup(text, max_chars=100, + replacement_char_ratio=0.02) is False + + def test_full_text_when_max_chars_zero(self): + head = 'hello world this is plain text. ' * 4 + text = head[:100] + '\ufffd' * 100 + assert _is_token_soup(text, max_chars=0, + replacement_char_ratio=0.02) is True + + +# ── Empty / trivial inputs ─────────────────────────────────────────────────── + +class TestTrivial: + def test_empty_text(self): + assert _is_token_soup('') is False + + def test_short_clean_text(self): + assert _is_token_soup('Hi there!') is False + + +# ── Pipeline ───────────────────────────────────────────────────────────────── + +class TestTokenSoupFilterPipeline: + def test_drops_soupy_assistant(self): + f = TokenSoupFilter() + rows = [_row('clean response'), _row('aaaaaaaaaaaaaaaaaaaaaaaaaaaaa')] + out = f(rows) + assert len(out) == 1 + assert out[0]['messages'][1]['content'] == 'clean response' + + def test_keeps_row_without_assistant(self): + f = TokenSoupFilter() + rows = [{'messages': [{'role': 'user', 'content': 'q'}]}] + out = f(rows) + assert len(out) == 1 + + def test_any_assistant_soupy_drops_row(self): + f = TokenSoupFilter() + rows = [{'messages': [ + {'role': 'user', 'content': 'q'}, + {'role': 'assistant', 'content': 'fine'}, + {'role': 'user', 'content': 'q2'}, + {'role': 'assistant', 'content': '\ufffd' * 10 + 'a' * 5}, + ]}] + out = f(rows) + assert out == [] + + def test_strips_whitespace_before_check(self): + # Leading/trailing whitespace shouldn't bypass detection. + f = TokenSoupFilter() + rows = [_row(' ' + 'a' * 30 + ' ')] + assert f(rows) == [] + + def test_threshold_overrides_propagated(self): + # With a stricter ratio, even small amounts of \ufffd trip it. + f = TokenSoupFilter(replacement_char_ratio=0.0) + rows = [_row('hello\ufffdworld')] + assert f(rows) == [] + + def test_empty_rows(self): + assert TokenSoupFilter()([]) == [] + + def test_messages_missing(self): + f = TokenSoupFilter() + rows = [{'id': 'no-msgs'}] + out = f(rows) + assert len(out) == 1 + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/template/test_tool_parsers.py b/tests/template/test_tool_parsers.py new file mode 100644 index 00000000..48402927 --- /dev/null +++ b/tests/template/test_tool_parsers.py @@ -0,0 +1,449 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Pure-Python tests for tool-call parsers (no model download). + +Covers Hermes/Qwen, ReAct, Cline parsing, cleaning, and — most importantly +— streaming correctness via the generic state machine in +:class:`twinkle.template.base.Template`. +""" +import json + +import pytest + +from twinkle.template.base import Template +from twinkle.template.tools import ( + ClineParser, + HermesQwenParser, + ReActParser, + ToolCallRegistry, + trailing_prefix_of, +) + + +class _StubTemplate: + """Minimal Template-shaped object exposing only stream-related members. + + Avoids loading a real tokenizer/processor (which would need network). + """ + + parse_tool_call_stream = Template.parse_tool_call_stream + _stream_marker_blocks = Template._stream_marker_blocks + _format_tc_delta = staticmethod(Template._format_tc_delta) + + def __init__(self, model_id: str): + self.model_id = model_id + + +def _stream(model_id, chunks_with_finished): + t = _StubTemplate(model_id) + state = {} + events = [] + for chunk, fin in chunks_with_finished: + events.extend(t.parse_tool_call_stream(state, chunk, finished=fin)) + return events, state + + +# --------------------------------------------------------------------------- +# HermesQwenParser +# --------------------------------------------------------------------------- + + +class TestHermesQwenParser: + + def setup_method(self): + self.p = HermesQwenParser() + + def test_detect(self): + assert self.p.detect('hi {"name":"f","arguments":{}}') + assert not self.p.detect('plain text') + assert not self.p.detect('') + + def test_matches_model(self): + assert self.p.matches_model('qwen2.5-7b') + assert self.p.matches_model('qwen3-32b') + assert not self.p.matches_model('llama-3.1-8b') + + def test_parse_json_variant(self): + text = '{"name": "get_weather", "arguments": {"city": "Paris"}}' + out = self.p.parse(text) + assert out == [{ + 'type': 'function', + 'function': {'name': 'get_weather', 'arguments': {'city': 'Paris'}}, + }] + + def test_parse_function_xml_variant(self): + text = ('' + '12' + '') + out = self.p.parse(text) + assert len(out) == 1 + assert out[0]['function']['name'] == 'add' + # JSON-decoding of param values: numbers come back as int. + assert out[0]['function']['arguments'] == {'a': 1, 'b': 2} + + def test_parse_multiple_blocks(self): + text = ('{"name":"f1","arguments":{}}' + 'between ' + '{"name":"f2","arguments":{"k":"v"}}') + out = self.p.parse(text) + assert [c['function']['name'] for c in out] == ['f1', 'f2'] + assert out[1]['function']['arguments'] == {'k': 'v'} + + def test_parse_unclosed_block_at_eof(self): + # ``\Z`` fallback in _BLOCK_RE handles truncated trailing block. + text = '{"name": "f", "arguments": {}}' + out = self.p.parse(text) + assert out and out[0]['function']['name'] == 'f' + + def test_parse_empty_returns_empty_list(self): + assert self.p.parse('') == [] + assert self.p.parse('plain text without markers') == [] + + def test_clean_strips_blocks(self): + text = 'hello {"name":"f","arguments":{}} world' + assert self.p.clean(text) == 'hello world' + + def test_clean_unclosed_at_eof(self): + text = 'hello {"name":"f"' + assert self.p.clean(text) == 'hello' + + def test_clean_empty(self): + assert self.p.clean('') == '' + + def test_markers_declared(self): + assert self.p.open_marker == '' + assert self.p.close_marker == '' + + +class TestHermesQwenStreaming: + """Generic open/close marker buffer state machine.""" + + def test_plain_text_passthrough(self): + events, _ = _stream('qwen2.5-7b', [('Hello world!', True)]) + assert events == [{'content': 'Hello world!'}] + + def test_holds_back_partial_open_marker(self): + events, state = _stream('qwen2.5-7b', [ + ('Hello! ', False), + ('{"name":"f","arguments":{}}', False), + ('done.', False), + ('', True), + ]) + types = [next(iter(e)) for e in events] + assert types == ['content', 'tool_calls', 'content'] + tc = events[1]['tool_calls'][0] + assert tc['function']['name'] == 'f' + # OpenAI streaming spec: arguments serialised as JSON string. + assert tc['function']['arguments'] == '{}' + assert tc['index'] == 0 + assert tc['id'].startswith('call_') + assert tc['type'] == 'function' + + def test_stream_chunked_inside_block(self): + # Split the block at every char to torture-test the partial-marker + # hold-back logic. + full = '{"name":"f","arguments":{"x":1}}' + chunks = [(full[i:i + 1], False) for i in range(len(full))] + chunks.append(('', True)) + events, state = _stream('qwen2.5-7b', chunks) + tcs = [e['tool_calls'][0] for e in events if 'tool_calls' in e] + assert len(tcs) == 1 + assert tcs[0]['function']['name'] == 'f' + assert json.loads(tcs[0]['function']['arguments']) == {'x': 1} + assert state['pending'] == '' + # No content events should leak the markup. + for e in events: + if 'content' in e: + assert '' not in e['content'] + assert '' not in e['content'] + + def test_multiple_blocks_increasing_indices(self): + events, _ = _stream('qwen2.5-7b', [ + ('{"name":"a","arguments":{}}' + '{"name":"b","arguments":{}}', True), + ]) + tcs = [e['tool_calls'][0] for e in events if 'tool_calls' in e] + assert [t['function']['name'] for t in tcs] == ['a', 'b'] + assert [t['index'] for t in tcs] == [0, 1] + + def test_unclosed_block_flushed_on_finish(self): + events, state = _stream('qwen2.5-7b', [ + ('{"name":"f","arguments":{}}', True), + ]) + assert state['pending'] == '' + tcs = [e['tool_calls'][0] for e in events if 'tool_calls' in e] + assert tcs and tcs[0]['function']['name'] == 'f' + + def test_arguments_serialised_as_json_string(self): + events, _ = _stream('qwen2.5-7b', [ + ('{"name":"f","arguments":{"k":"v","n":3}}', True), + ]) + tc = next(e['tool_calls'][0] for e in events if 'tool_calls' in e) + assert isinstance(tc['function']['arguments'], str) + assert json.loads(tc['function']['arguments']) == {'k': 'v', 'n': 3} + + def test_content_events_lossless_for_non_block_text(self): + # All non-tool-call text must pass through verbatim, regardless of + # chunk boundaries. + original_content_outside = 'aXY' + full = ('a' + '{"name":"f","arguments":{}}' + 'XY') + chunks = [(full[i:i + 3], False) for i in range(0, len(full), 3)] + chunks.append(('', True)) + events, _ = _stream('qwen2.5-7b', chunks) + rebuilt = ''.join(e['content'] for e in events if 'content' in e) + assert rebuilt == original_content_outside + + def test_no_emission_until_chunk_arrives(self): + # Streaming with empty chunk and not-finished should be a no-op. + events, _ = _stream('qwen2.5-7b', [('', False)]) + assert events == [] + + +# --------------------------------------------------------------------------- +# ReActParser +# --------------------------------------------------------------------------- + + +class TestReActParser: + + def setup_method(self): + self.p = ReActParser() + + def test_detect_action_line(self): + assert self.p.detect('Thought: I need search.\nAction: search[python]') + assert not self.p.detect('plain text without action keyword') + assert not self.p.detect('') + + def test_no_block_marker(self): + # Prose format — streaming has no marker to lock onto. + assert self.p.open_marker is None + assert self.p.close_marker is None + + def test_does_not_match_qwen_model(self): + assert not self.p.matches_model('qwen2.5') + assert not self.p.matches_model('llama-3') + + def test_parse_single_action(self): + text = 'Thought: search the web.\nAction: search[hello world]' + out = self.p.parse(text) + assert out == [{ + 'type': 'function', + 'function': {'name': 'search', 'arguments': {'input': 'hello world'}}, + }] + + def test_parse_multiple_actions(self): + text = ('Thought: a\nAction: tool_a[x]\n' + 'Observation: ok\n' + 'Thought: b\nAction: tool_b[y z]') + out = self.p.parse(text) + assert [c['function']['name'] for c in out] == ['tool_a', 'tool_b'] + assert out[1]['function']['arguments'] == {'input': 'y z'} + + def test_clean_removes_action_lines(self): + text = 'Thought: hi\nAction: search[x]\nDone' + cleaned = self.p.clean(text) + assert 'Action: search' not in cleaned + assert 'Thought: hi' in cleaned + assert 'Done' in cleaned + + def test_parse_empty(self): + assert self.p.parse('') == [] + + +class TestReActStreaming: + """ReAct has no marker → falls back to plain content passthrough. + + Detection is a final-pass concern; streaming preserves content faithfully. + """ + + def test_passthrough_when_no_marker_parser(self): + # 'react-agent' doesn't match HermesQwen ('qwen' substring) → no parser + # cached → passthrough mode. + events, state = _stream('react-agent', [ + ('Thought: hi\n', False), + ('Action: foo[bar]\n', False), + ('done', False), + ('', True), + ]) + rebuilt = ''.join(e['content'] for e in events if 'content' in e) + assert rebuilt == 'Thought: hi\nAction: foo[bar]\ndone' + assert state.get('parser') is None + + def test_no_tool_calls_event_emitted(self): + events, _ = _stream('react-agent', [ + ('Action: foo[bar]', True), + ]) + assert all('tool_calls' not in e for e in events) + + +# --------------------------------------------------------------------------- +# ClineParser +# --------------------------------------------------------------------------- + + +class TestClineParser: + + def setup_method(self): + self.p = ClineParser() + + def test_detect_simple_tool(self): + assert self.p.detect('foo.py') + + def test_detect_ignores_html_like_tags(self): + # ``think`` / ``code`` are denied — even with inner content they aren't + # treated as tool calls. + assert not self.p.detect('x') + assert not self.p.detect('x') + + def test_detect_requires_inner_param(self): + # No inner ``VAL`` → not a Cline call. + assert not self.p.detect('just text') + + def test_detect_ignores_hermes_block(self): + # Hermes already owns ```` — Cline must skip it. + assert not self.p.detect('{"name":"f","arguments":{}}') + + def test_no_marker_for_streaming(self): + # Outer tag varies per call — streaming uses passthrough, not the + # marker state machine. + assert self.p.open_marker is None + assert self.p.close_marker is None + + def test_does_not_match_any_model_by_default(self): + # Cline is an app-level prompt protocol, not a model-family format. + assert not self.p.matches_model('qwen2.5') + assert not self.p.matches_model('claude-3') + + def test_parse_single_arg(self): + text = 'src/foo.py' + out = self.p.parse(text) + assert out == [{ + 'type': 'function', + 'function': {'name': 'read_file', 'arguments': {'path': 'src/foo.py'}}, + }] + + def test_parse_multi_arg_with_whitespace(self): + text = ('\n' + ' ls -la\n' + ' false\n' + '') + out = self.p.parse(text) + fn = out[0]['function'] + assert fn['name'] == 'execute_command' + assert fn['arguments'] == {'command': 'ls -la', 'requires_approval': 'false'} + + def test_parse_multiple_blocks(self): + text = ('a' + ' between ' + 'btrue') + out = self.p.parse(text) + assert [c['function']['name'] for c in out] == ['read_file', 'list_files'] + assert out[1]['function']['arguments'] == {'path': 'b', 'recursive': 'true'} + + def test_parse_skips_hermes_block(self): + text = '{"name":"f","arguments":{}}' + assert self.p.parse(text) == [] + + def test_clean_strips_tool_blocks(self): + text = 'before x after' + assert self.p.clean(text) == 'before after' + + def test_clean_preserves_non_tool_xml(self): + text = 'reasoning x tail' + cleaned = self.p.clean(text) + assert 'reasoning' in cleaned + assert '' not in cleaned + assert 'tail' in cleaned + + def test_clean_empty(self): + assert self.p.clean('') == '' + + +class TestClineStreaming: + """Cline streams as plain content (no fixed open marker).""" + + def test_content_passthrough_lossless_across_chunk_boundaries(self): + full = ('intro src/foo.py outro' + ' next x') + # Chunk every 4 chars — boundaries fall inside tags, args, etc. + chunks = [(full[i:i + 4], False) for i in range(0, len(full), 4)] + chunks.append(('', True)) + events, _ = _stream('cline-bot', chunks) + rebuilt = ''.join(e['content'] for e in events if 'content' in e) + assert rebuilt == full + # No tool_calls events because no parser was selected by model_id. + assert all('tool_calls' not in e for e in events) + + +# --------------------------------------------------------------------------- +# Registry round-robin & helpers +# --------------------------------------------------------------------------- + + +class TestRegistryRoundRobin: + + def test_first_match_wins_no_nested_reparse(self): + # Hermes block must take ownership; ReAct/Cline shouldn't see it. + text = '{"name":"f","arguments":{}}' + parser = ToolCallRegistry.detect_first(text) + assert parser is not None and parser.name == 'hermes_qwen' + + def test_cline_wins_for_xml_tools(self): + text = 'x' + parser = ToolCallRegistry.detect_first(text) + assert parser is not None and parser.name == 'cline' + + def test_react_wins_for_action_keyword(self): + text = 'Thought: hi\nAction: search[x]' + parser = ToolCallRegistry.detect_first(text) + assert parser is not None and parser.name == 'react' + + def test_no_parser_for_plain_text(self): + assert ToolCallRegistry.detect_first('just some plain text') is None + assert ToolCallRegistry.detect_first('') is None + + def test_select_for_qwen_picks_hermes(self): + parser = ToolCallRegistry.select_for_model('qwen2.5-7b') + assert parser is not None and parser.name == 'hermes_qwen' + + def test_select_for_unknown_returns_none(self): + assert ToolCallRegistry.select_for_model('llama-3.1-8b') is None + assert ToolCallRegistry.select_for_model(None) is None + + +class TestTrailingPrefixOf: + """Holdback length helper used by the marker state machine.""" + + def test_no_prefix(self): + assert trailing_prefix_of('hello world', '') == 0 + + def test_partial_prefix_4_chars(self): + # buf ends with '' length 4. + assert trailing_prefix_of('hello ') == 4 + + def test_partial_prefix_1_char(self): + assert trailing_prefix_of('hello <', '') == 1 + + def test_full_marker_returns_zero(self): + # Full marker at end is NOT a strict prefix (search range is 1..len-1), + # so the helper returns 0 — block code path will see the marker via + # ``find()`` rather than holdback. + assert trailing_prefix_of('text', '') == 0 + + def test_empty_buf(self): + assert trailing_prefix_of('', '') == 0 + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])