From 05f84d93a63617a882f8c4245968357887b4e5e9 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 14 Jun 2026 22:39:32 +0800 Subject: [PATCH 01/53] fix --- .../exp/embedding/build_thinking_rag_index.py | 810 ++++++++++++++++++ cookbook/exp/embedding/dataset_index.py | 718 ++++++++++++++++ .../exp/embedding/train_embedding_full_ddp.py | 49 +- 3 files changed, 1566 insertions(+), 11 deletions(-) create mode 100644 cookbook/exp/embedding/build_thinking_rag_index.py create mode 100644 cookbook/exp/embedding/dataset_index.py diff --git a/cookbook/exp/embedding/build_thinking_rag_index.py b/cookbook/exp/embedding/build_thinking_rag_index.py new file mode 100644 index 00000000..cf4d1dd7 --- /dev/null +++ b/cookbook/exp/embedding/build_thinking_rag_index.py @@ -0,0 +1,810 @@ +"""Build a thinking-trace RAG index from condensed (query, cot) pairs. + +Pipeline (per row, batched): + 1. Load (user_query, reasoning_content) pairs from ``dataset_think.get_dataset``. + 2. Compress query with ``FIXED_QUERY_NEED`` and cot with ``FIXED_QUERY_SKILL`` + using a Twinkle ``vLLMSampler`` (TP=4 across GPUs 0-3). Reuses the prompt + suite from ``cookbook/exp/condenser/make_condenser_dataset.py``. + 3. On condenser truncation (``stop_reason='length'`` or skeleton-incomplete + output), fall back to an external OpenAI-compatible API. + 4. Encode the condensed pair via the trained embedding model — Twinkle + ``TransformersModel`` on the ``emb_model`` device group (DP=4 across GPUs + 4-7) using ``forward_only(task='embedding')``, the same code path as + training. + 5. Compute cosine similarity for each (query, thinking) pair, drop pairs with + ``sim < SIM_THRESHOLD``, and insert kept rows into LanceDB. The vector + column carries the **positive (compressed-skill)** embedding so a search + keyed by an anchor-encoded query retrieves the matching thinking trace. + 6. Each row stores the **raw thinking** alongside its embedding, so a hit + in the index can directly surface the original CoT. + +Eval mode (``--mode eval`` or ``--mode both``): + * Self-recall test — encode a sample of dataset queries (whose corresponding + rows are already in the index) as anchors and report recall@1/5/10 plus + a per-source breakdown. + +Architecture (8 GPUs): + * GPU 0-3: vLLM condenser (tensor-parallel, ``DeviceGroup name='sampler'``) + * GPU 4-7: TransformersModel embedding (data-parallel, ``DeviceGroup name='emb_model'``) + * Single ``twinkle.initialize(mode='ray', ...)`` call wires both groups. + +Launch examples: + python build_thinking_rag_index.py --mode build --total 500000 + python build_thinking_rag_index.py --mode eval --eval-size 1000 + python build_thinking_rag_index.py --mode both --total 200000 --eval-size 500 +""" +import argparse +import json +import os +import re +import sys +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import tqdm + +# --------------------------------------------------------------------------- +# Reuse condenser prompts (single source of truth) +# --------------------------------------------------------------------------- +_HERE = Path(__file__).resolve().parent +sys.path.insert(0, str(_HERE.parent / 'condenser')) +sys.path.insert(0, str(_HERE)) +from make_condenser_dataset import ( # noqa: E402 (after sys.path tweak) + COMPRESS_SYSTEM, + COMPRESS_USER, + FIXED_QUERY_NEED, + FIXED_QUERY_SKILL, +) +# Default dataset loader is the index-time corpus (broader retrieval profile); +# pass --dataset-module dataset_think to fall back to the training mix. +from dataset_index import get_dataset as _default_get_dataset # noqa: E402 + +_GET_DATASET = _default_get_dataset + +import twinkle # noqa: E402 +from twinkle import DeviceGroup, DeviceMesh, get_logger # noqa: E402 +from twinkle.data_format import SamplingParams as TwinkleSamplingParams # noqa: E402 +from twinkle.loss import InfonceLoss # noqa: E402 +from twinkle.model import TransformersModel # noqa: E402 +from twinkle.processor import InputProcessor # noqa: E402 +from twinkle.sampler import vLLMSampler # noqa: E402 +from twinkle.template import Qwen3_5Template # noqa: E402 +from twinkle.utils.parallel import PosixFileLock # noqa: E402 +from twinkle_agentic.protocol.openai import OpenAI as OpenAIClient # noqa: E402 + +logger = get_logger() + + +# =========================================================================== +# Config (most fields overridable via CLI / env) +# =========================================================================== + +EMBED_MODEL_ID = os.environ.get( + 'EMBED_MODEL_ID', + 'ms://twinkle-kit/Qwen3.5-4B-QA-emb', +) +CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2') + +# Twinkle device topology: TP=4 sampler on 0-3, DP=4 embedding on 4-7. +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +EMB_GPUS = int(os.environ.get('EMB_GPUS', 4)) +NUM_GPUS = SAMPLER_GPUS + EMB_GPUS + +# vLLM engine sizing. +CONDENSE_GPU_MEM = float(os.environ.get('CONDENSE_GPU_MEM', 0.85)) +CONDENSE_MAX_MODEL_LEN = int(os.environ.get('CONDENSE_MAX_MODEL_LEN', 32768)) +CONDENSE_MAX_TOKENS = int(os.environ.get('CONDENSE_MAX_TOKENS', 8192)) +COMPRESS_TEMPERATURE = float(os.environ.get('COMPRESS_TEMPERATURE', 0.2)) +COMPRESS_TOP_P = float(os.environ.get('COMPRESS_TOP_P', 0.5)) + +# Embedding sizing. +EMBED_MAX_LENGTH = int(os.environ.get('EMBED_MAX_LENGTH', 8192)) + +SIM_THRESHOLD = float(os.environ.get('SIM_THRESHOLD', 0.65)) +MIN_TEXT_CHARS = int(os.environ.get('MIN_TEXT_CHARS', 256)) + +# OpenAI API fallback (used when vLLM truncates). +COMPRESS_API_KEY = os.environ.get('COMPRESS_API_KEY', '') +COMPRESS_BASE_URL = os.environ.get( + 'COMPRESS_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1') +COMPRESS_API_MODEL = os.environ.get('COMPRESS_API_MODEL', 'qwen3.7-max') + +# Source → coarse domain (for filtered eval). +DOMAIN_MAP = { + 'CodeX-2M-Thinking': 'code', + 'OpenThoughts3-1.2M': 'reasoning', + 'LIMO-v2': 'math', + 'Chinese-DeepSeek-R1-Distill-data-110k': 'reasoning_zh', + 'Opus-4.6-Reasoning-3000x-filtered': 'reasoning', + 'claude-opus-4.6-10000x': 'mixed', + 'angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k': 'mixed', +} + + +# =========================================================================== +# Small helpers +# =========================================================================== + +def _is_truncated_compression(text: str) -> bool: + """Detect structurally incomplete condenser output even when stop='stop'. + + The skeleton mandates a ``## Summary`` plus a ``## More`` bullet list whose + last line begins with ``-`` or ends with ``)`` (the ``(none)`` sentinel). + Anything short of that signals truncation; the API fallback is invoked. + """ + if not text or not text.strip(): + return True + if '## More' not in text or '## Summary' not in text: + return True + after_more = text.split('## More', 1)[1].strip() + if not after_more: + return True + last_line = after_more.splitlines()[-1].strip() + if not (last_line.startswith('-') or last_line.endswith(')')): + return True + return False + + +def _strip_outer_codefence(text: str) -> str: + m = re.match(r'^```[a-zA-Z]*\n(.*?)\n```\s*$', text, re.DOTALL) + if m: + return m.group(1).strip() + return text.strip() + + +def _wrap_anchor(text: str) -> List[Dict[str, str]]: + """Anchor-side message wrapping (must match training).""" + return [ + {'role': 'user', 'content': text}, + {'role': 'assistant', 'content': 'Match the correct response here.'}, + ] + + +def _wrap_positive(text: str) -> List[Dict[str, str]]: + """Positive-side message wrapping (must match training).""" + return [ + {'role': 'user', 'content': 'Match the correct query here.'}, + {'role': 'assistant', 'content': text}, + ] + + +def _short(text: str, n: int = 96) -> str: + text = (text or '').replace('\n', ' ').strip() + return text[:n] + ('…' if len(text) > n else '') + + +def _detect_lang(text: str) -> str: + if not text: + return 'unknown' + cjk = sum(1 for ch in text[:512] if '\u4e00' <= ch <= '\u9fff') + return 'zh' if cjk >= 8 else 'en' + + +def _build_compress_messages(text: str, query: str) -> List[Dict[str, str]]: + return [ + {'role': 'system', 'content': COMPRESS_SYSTEM}, + {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)}, + ] + + +# =========================================================================== +# Twinkle component wrappers +# =========================================================================== + +def initialize_twinkle() -> Tuple[DeviceMesh, DeviceMesh]: + """Wire two device groups (sampler / emb_model) and return their meshes.""" + device_groups = [ + DeviceGroup( + name='sampler', + ranks=list(range(SAMPLER_GPUS)), + device_type='GPU', + gpus_per_worker=SAMPLER_GPUS, # TP=4 → one worker spans all 4 GPUs + ), + DeviceGroup( + name='emb_model', + ranks=list(range(SAMPLER_GPUS, NUM_GPUS)), + device_type='GPU', + ), + ] + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, tp_size=SAMPLER_GPUS) + emb_mesh = DeviceMesh.from_sizes(world_size=EMB_GPUS, dp_size=EMB_GPUS) + twinkle.initialize( + mode='ray', + nproc_per_node=NUM_GPUS, + groups=device_groups, + lazy_collect=False, + ) + return sampler_mesh, emb_mesh + + +def build_sampler(sampler_mesh: DeviceMesh) -> vLLMSampler: + sampler = vLLMSampler( + model_id=CONDENSE_MODEL_ID, + engine_args={ + 'gpu_memory_utilization': CONDENSE_GPU_MEM, + 'max_model_len': CONDENSE_MAX_MODEL_LEN, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template( + 'Qwen3_5Template', + model_id=CONDENSE_MODEL_ID, + enable_thinking=False, + max_length=CONDENSE_MAX_MODEL_LEN, + ) + return sampler + + +def build_emb_model(emb_mesh: DeviceMesh) -> Tuple[TransformersModel, Qwen3_5Template]: + model = TransformersModel( + model_id=EMBED_MODEL_ID, + device_mesh=emb_mesh, + remote_group='emb_model', + ) + model.set_processor(InputProcessor) + # InfonceLoss is required by the framework even though forward_only does + # not actually invoke it; matches the training-time configuration. + model.set_loss(InfonceLoss, temperature=0.03, use_batch=True) + # Qwen3.5-specific subclass applies orphan- chat-template patches. + template = Qwen3_5Template( + model_id=EMBED_MODEL_ID, + max_length=EMBED_MAX_LENGTH, + truncation_strategy='delete', + enable_thinking=False, + ) + return model, template + + +# =========================================================================== +# Compression helpers (vLLMSampler) + API fallback +# =========================================================================== + +def _vllm_compress(sampler: vLLMSampler, texts: List[str], query_hint: str + ) -> List[Tuple[str, str]]: + """Compress ``texts`` via the sampler; return ``(decoded, stop_reason)``.""" + if not texts: + return [] + prompts = [{'messages': _build_compress_messages(t, query_hint)} for t in texts] + params = TwinkleSamplingParams( + max_tokens=CONDENSE_MAX_TOKENS, + temperature=COMPRESS_TEMPERATURE, + top_p=COMPRESS_TOP_P, + num_samples=1, + ) + responses = sampler.sample(prompts, params) + results: List[Tuple[str, str]] = [] + for resp in responses: + seq = resp.sequences[0] if resp and resp.sequences else None + if seq is None: + results.append(('', 'error')) + continue + text = seq.decoded or '' + # Strip any leaked chat-template special tokens like ``<|im_end|>``. + text = re.sub(r'<\|[^|]+\|>', '', text).rstrip() + text = _strip_outer_codefence(text) + results.append((text, seq.stop_reason or 'stop')) + return results + + +def _api_compress(api: OpenAIClient, messages: List[Dict[str, str]]) -> Optional[str]: + sp = TwinkleSamplingParams(temperature=COMPRESS_TEMPERATURE, max_tokens=CONDENSE_MAX_TOKENS) + try: + reply = api({'messages': messages}, sp, extra_body={'enable_thinking': False}) + except Exception as exc: # noqa: BLE001 — broad catch is intentional + sys.stderr.write(f'[api_fallback] error: {exc}\n') + return None + content = (reply.get('content') or '').strip() + if not content: + return None + return _strip_outer_codefence(content) + + +def _resolve_compressed(sampler: vLLMSampler, api: Optional[OpenAIClient], + texts: List[str], query_hint: str) -> List[Optional[str]]: + """Run vLLM batch; replace truncations / skeleton-incomplete with API output.""" + pairs = _vllm_compress(sampler, texts, query_hint) + results: List[Optional[str]] = [] + for (text, stop), src_text in zip(pairs, texts): + if stop != 'length' and not _is_truncated_compression(text): + results.append(text) + continue + if api is None: + results.append(None) + continue + api_text = _api_compress(api, _build_compress_messages(src_text, query_hint)) + if api_text is None or _is_truncated_compression(api_text): + results.append(None) + else: + results.append(api_text) + return results + + +# =========================================================================== +# Embedding helpers (TransformersModel.forward_only(task='embedding')) +# =========================================================================== + +def _build_features(template: Qwen3_5Template, texts: List[str], role: str + ) -> List[Dict[str, Any]]: + """Wrap each text into the role-specific anchor / positive feature dict.""" + features: List[Dict[str, Any]] = [] + for text in texts: + if not text or not text.strip(): + # Pad with a single space so positional alignment holds against + # the input list — the caller filters out empty-text rows upstream. + text = ' ' + if role == 'anchor': + feat = template.encode({'messages': _wrap_anchor(text)}) + feat['labels'] = [1] + else: + feat = template.encode({'messages': _wrap_positive(text)}) + feat['labels'] = [0] + features.append(feat) + return features + + +def get_embeddings(model: TransformersModel, template: Qwen3_5Template, + texts: List[str], role: str) -> np.ndarray: + """Return ``[N, H]`` float32 L2-normalised embeddings for ``texts``. + + Inputs are padded up to a multiple of ``EMB_GPUS`` and sliced back to the + original ``N``: the dispatch layer (``_dispatch_args``) starves any rank + whose chunk lands beyond ``len(texts)``, so a single forward of fewer than + ``EMB_GPUS`` items (e.g. the probe) would otherwise raise + ``Batch too small for {EMB_GPUS} workers``. + """ + if not texts: + return np.zeros((0,), dtype=np.float32) + n = len(texts) + pad_n = (-n) % EMB_GPUS + padded = list(texts) + [' '] * pad_n if pad_n else list(texts) + features = _build_features(template, padded, role) + out = model.forward_only(inputs=features, task='embedding', return_logits=True) + emb = out['embeddings'] + if isinstance(emb, torch.Tensor): + emb = emb.detach().to(torch.float32).cpu().numpy() + emb = np.asarray(emb, dtype=np.float32) + return emb[:n] if pad_n else emb + + +def _probe_hidden_size(model: TransformersModel, template: Qwen3_5Template) -> int: + """One-shot warmup forward to read out the embedding dimension.""" + emb = get_embeddings(model, template, ['probe'], role='anchor') + if emb.ndim != 2 or emb.shape[0] == 0: + raise RuntimeError(f'unexpected embedding shape from probe: {emb.shape}') + return int(emb.shape[1]) + + +# =========================================================================== +# LanceDB I/O +# =========================================================================== + +def _make_arrow_schema(hidden_size: int): + import pyarrow as pa + return pa.schema([ + pa.field('id', pa.string()), + pa.field('vector', pa.list_(pa.float32(), hidden_size)), + pa.field('thinking_raw', pa.string()), + pa.field('query_raw', pa.string()), + pa.field('cot_compressed', pa.string()), + pa.field('query_compressed', pa.string()), + pa.field('source', pa.string()), + pa.field('domain', pa.string()), + pa.field('language', pa.string()), + pa.field('sim', pa.float32()), + ]) + + +def _open_or_create_table(db_path: str, table_name: str, hidden_size: int, + mode: str): + """Open an existing table for append/eval, or create a fresh one.""" + import lancedb + db = lancedb.connect(db_path) + schema = _make_arrow_schema(hidden_size) + if table_name in db.table_names(): + if mode == 'overwrite': + db.drop_table(table_name) + tbl = db.create_table(table_name, schema=schema, mode='overwrite') + else: + tbl = db.open_table(table_name) + else: + tbl = db.create_table(table_name, schema=schema, mode='create') + return db, tbl + + +def _existing_ids(table) -> set: + try: + col = table.to_pandas(columns=['id']) + return set(col['id'].astype(str).tolist()) + except Exception: # noqa: BLE001 + return set() + + +# =========================================================================== +# Build pipeline +# =========================================================================== + +def _stream_corpus(total: Optional[int], load_from_cache_file: bool, + max_rows: int = 0) -> Iterator[Dict[str, Any]]: + ds = _GET_DATASET(total=total, load_from_cache_file=load_from_cache_file) + n_full = len(ds) + cap = max_rows if (max_rows and max_rows < n_full) else n_full + sys.stderr.write(f'[corpus] get_dataset: {n_full} rows' + + (f' → yielding first {cap}\n' if cap < n_full else '\n')) + for i, row in enumerate(ds): + if i >= cap: + break + yield row + + +def _extract_query_cot(row: Dict[str, Any]) -> Tuple[str, str]: + user_query, cot = '', '' + for m in row.get('messages') or []: + if not isinstance(m, dict): + continue + role = m.get('role') or '' + if role == 'user' and not user_query: + user_query = (m.get('content') or '').strip() + elif role == 'assistant': + cot = (m.get('reasoning_content') or '').strip() + break + return user_query, cot + + +def _log_miss(misses_path: str, lock: PosixFileLock, record: Dict[str, Any]) -> None: + line = json.dumps(record, ensure_ascii=False, default=str) + '\n' + with lock: + with open(misses_path, 'a', encoding='utf-8') as fh: + fh.write(line) + + +def build_index(args: argparse.Namespace, + sampler: vLLMSampler, + emb_model: TransformersModel, + emb_template: Qwen3_5Template, + api: Optional[OpenAIClient]) -> None: + # ---- Probe embedding dimension ----------------------------------------- + sys.stderr.write('[build] probing embedding hidden size...\n') + hidden_size = _probe_hidden_size(emb_model, emb_template) + sys.stderr.write(f'[build] hidden_size={hidden_size}\n') + + # ---- LanceDB ------------------------------------------------------------ + db, tbl = _open_or_create_table( + args.db_path, args.table, hidden_size, + mode='overwrite' if args.overwrite else 'append', + ) + indexed = _existing_ids(tbl) if not args.overwrite else set() + sys.stderr.write(f'[build] table "{args.table}" — {len(indexed)} existing rows.\n') + + misses_path = args.misses_log or (str(Path(args.db_path) / f'{args.table}.misses.jsonl')) + Path(misses_path).parent.mkdir(parents=True, exist_ok=True) + misses_lock = PosixFileLock(misses_path + '.lock') + + # ---- Streaming loop ----------------------------------------------------- + n_seen = n_kept = n_dropped_short = n_dropped_compress = n_dropped_sim = 0 + n_dropped_dup = 0 + pbar = tqdm(desc='index', unit='row', dynamic_ncols=True) + + batch: List[Dict[str, Any]] = [] + + def _flush(rows: List[Dict[str, Any]]) -> None: + nonlocal n_kept, n_dropped_compress, n_dropped_sim + if not rows: + return + # Phase 1 — compress query (FIXED_QUERY_NEED) and cot (FIXED_QUERY_SKILL). + q_compressed = _resolve_compressed( + sampler, api, [r['query_raw'] for r in rows], FIXED_QUERY_NEED) + c_compressed = _resolve_compressed( + sampler, api, [r['cot_raw'] for r in rows], FIXED_QUERY_SKILL) + kept_rows: List[Dict[str, Any]] = [] + for r, q_cmp, c_cmp in zip(rows, q_compressed, c_compressed): + if not q_cmp or not c_cmp: + n_dropped_compress += 1 + _log_miss(misses_path, misses_lock, { + 'id': r['id'], 'source': r['source'], 'reason': 'compress_fail', + 'query_raw_head': _short(r['query_raw'], 200), + 'cot_raw_head': _short(r['cot_raw'], 200), + }) + continue + r['query_compressed'] = q_cmp + r['cot_compressed'] = c_cmp + kept_rows.append(r) + if not kept_rows: + return + # Phase 2 — encode anchor (compressed query) + positive (compressed cot). + anchor_emb = get_embeddings( + emb_model, emb_template, [r['query_compressed'] for r in kept_rows], role='anchor') + positive_emb = get_embeddings( + emb_model, emb_template, [r['cot_compressed'] for r in kept_rows], role='positive') + sims = (anchor_emb * positive_emb).sum(axis=1).astype(np.float32) + # Phase 3 — sim filter + LanceDB insert. + to_insert: List[Dict[str, Any]] = [] + for idx, (r, sim_val) in enumerate(zip(kept_rows, sims)): + tag = 'KEEP' if sim_val >= SIM_THRESHOLD else 'DROP' + print(f'[{tag} sim={sim_val:.4f}] {r["source"][:24]} ' + f'q={_short(r["query_raw"], 60)!r} ' + f'cot={_short(r["cot_raw"], 60)!r}', flush=True) + if sim_val < SIM_THRESHOLD: + n_dropped_sim += 1 + _log_miss(misses_path, misses_lock, { + 'id': r['id'], 'source': r['source'], 'reason': 'sim_low', + 'sim': float(sim_val), + 'query_raw': r['query_raw'], + 'cot_raw': r['cot_raw'], + 'query_compressed': r['query_compressed'], + 'cot_compressed': r['cot_compressed'], + }) + continue + to_insert.append({ + 'id': r['id'], + 'vector': positive_emb[idx].tolist(), + 'thinking_raw': r['cot_raw'], + 'query_raw': r['query_raw'], + 'cot_compressed': r['cot_compressed'], + 'query_compressed': r['query_compressed'], + 'source': r['source'], + 'domain': DOMAIN_MAP.get(r['source'], 'mixed'), + 'language': _detect_lang(r['cot_raw']), + 'sim': float(sim_val), + }) + if to_insert: + tbl.add(to_insert) + n_kept += len(to_insert) + indexed.update(r['id'] for r in to_insert) + + try: + for row in _stream_corpus(total=args.total, load_from_cache_file=not args.no_cache, + max_rows=args.max_rows): + n_seen += 1 + if args.limit and n_kept >= args.limit: + break + rid = row.get('id') or '' + if not rid: + continue + if rid in indexed: + n_dropped_dup += 1 + continue + user_query, cot = _extract_query_cot(row) + if not user_query or len(cot) < MIN_TEXT_CHARS: + n_dropped_short += 1 + continue + batch.append({ + 'id': rid, + 'source': row.get('source') or 'unknown', + 'query_raw': user_query, + 'cot_raw': cot, + }) + if len(batch) >= args.batch_size: + _flush(batch) + batch.clear() + pbar.set_postfix(kept=n_kept, sim_drop=n_dropped_sim, + cmp_drop=n_dropped_compress, refresh=False) + pbar.update(1) + if batch: + _flush(batch) + batch.clear() + finally: + pbar.close() + + sys.stderr.write( + f'[build] seen={n_seen} kept={n_kept} sim_drop={n_dropped_sim} ' + f'cmp_drop={n_dropped_compress} short_drop={n_dropped_short} ' + f'dup_skip={n_dropped_dup}\n') + + # ---- Build vector index for fast retrieval ------------------------------ + if n_kept >= 64 and not args.skip_index: + sys.stderr.write('[build] creating IVF_PQ index (metric=dot)...\n') + n_partitions = max(8, min(256, n_kept // 1000 + 1)) + try: + tbl.create_index( + metric='dot', + vector_column_name='vector', + num_partitions=n_partitions, + num_sub_vectors=16, + index_type='IVF_PQ', + replace=True, + ) + except Exception as exc: # noqa: BLE001 + sys.stderr.write(f'[build] index build failed: {exc} ' + '(table is still queryable via brute-force scan)\n') + sys.stderr.write(f'[build] done. table rows={tbl.count_rows()}\n') + + +# =========================================================================== +# Eval pipeline (self-recall on indexed rows) +# =========================================================================== + +def eval_recall(args: argparse.Namespace, + sampler: vLLMSampler, + emb_model: TransformersModel, + emb_template: Qwen3_5Template, + api: Optional[OpenAIClient]) -> None: + """Probe each gold query against the index; report recall@k. + + Self-recall semantics: only rows whose ``id`` is already present in the + index are probed. The corresponding ``cot``-keyed vector must be retrieved + by encoding the **raw user query** through the condenser → embedder + pipeline (anchor side). The match is correct iff the retrieved row's + ``id`` equals the probe row's ``id``. + """ + import lancedb + db = lancedb.connect(args.db_path) + if args.table not in db.table_names(): + raise SystemExit(f'[eval] table "{args.table}" does not exist in {args.db_path}') + tbl = db.open_table(args.table) + indexed_ids = _existing_ids(tbl) + sys.stderr.write(f'[eval] table rows={tbl.count_rows()} indexed_ids={len(indexed_ids)}\n') + if not indexed_ids: + sys.stderr.write('[eval] empty index — nothing to evaluate.\n') + return + + ks = sorted({1, 5, 10, args.top_k}) + hits = {k: 0 for k in ks} + per_source_hits: Dict[str, Dict[int, int]] = {} + per_source_total: Dict[str, int] = {} + probed = 0 + + pbar = tqdm(desc='eval', unit='probe', dynamic_ncols=True) + batch_rows: List[Dict[str, Any]] = [] + + def _flush(rows: List[Dict[str, Any]]) -> None: + nonlocal probed + if not rows: + return + compressed = _resolve_compressed( + sampler, api, [r['query_raw'] for r in rows], FIXED_QUERY_NEED) + useful = [(r, c) for r, c in zip(rows, compressed) if c] + if not useful: + return + anchor_emb = get_embeddings( + emb_model, emb_template, [c for _, c in useful], role='anchor') + for (r, _), vec in zip(useful, anchor_emb): + res = ( + tbl.search(vec.astype(np.float32).tolist()) + .metric('dot') + .limit(max(ks)) + .select(['id', 'source']) + .to_list() + ) + hit_ids = [item['id'] for item in res] + try: + rank = hit_ids.index(r['id']) + except ValueError: + rank = -1 + for k in ks: + if 0 <= rank < k: + hits[k] += 1 + per_source_hits.setdefault(r['source'], {kk: 0 for kk in ks})[k] += 1 + per_source_total[r['source']] = per_source_total.get(r['source'], 0) + 1 + per_source_hits.setdefault(r['source'], {kk: 0 for kk in ks}) + probed += 1 + pbar.update(len(useful)) + + try: + for row in _stream_corpus(total=args.total, load_from_cache_file=not args.no_cache, + max_rows=args.max_rows): + if probed + len(batch_rows) >= args.eval_size: + break + rid = row.get('id') or '' + if not rid or rid not in indexed_ids: + continue + user_query, _ = _extract_query_cot(row) + if not user_query or len(user_query) < MIN_TEXT_CHARS: + continue + batch_rows.append({ + 'id': rid, + 'source': row.get('source') or 'unknown', + 'query_raw': user_query, + }) + if len(batch_rows) >= args.batch_size: + _flush(batch_rows) + batch_rows.clear() + if batch_rows: + _flush(batch_rows) + finally: + pbar.close() + + if probed == 0: + sys.stderr.write( + '[eval] no probed rows — index empty, queries too short, or ' + 'corpus exhausted before eval-size?\n') + return + + print('\n=== Recall @ k (self-recall, gold present in index) ===') + print(f'probed = {probed}') + for k in ks: + print(f' recall@{k:<3} = {hits[k]/probed:.4f} ({hits[k]}/{probed})') + + print('\n=== Per-source recall@10 ===') + for src in sorted(per_source_total): + tot = per_source_total[src] + h10 = per_source_hits.get(src, {}).get(10, 0) + print(f' {src:<48s} {h10/tot:.4f} ({h10}/{tot})') + + +# =========================================================================== +# CLI +# =========================================================================== + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument('--mode', choices=['build', 'eval', 'both'], default='build') + p.add_argument('--db-path', default='./output/thinking_rag/lance.db', + help='LanceDB on-disk directory (persisted across runs).') + p.add_argument('--table', default='thinking_traces', + help='LanceDB table name within --db-path.') + p.add_argument('--total', type=int, default=0, + help='Total dataset rows to scale corpus to (0 = base sizes from the loader module).') + p.add_argument('--dataset-module', default='dataset_index', + choices=['dataset_index', 'dataset_think'], + help='Which loader to use: dataset_index (RAG profile) or ' + 'dataset_think (training mix).') + p.add_argument('--limit', type=int, default=0, + help='Stop building once this many rows are kept (0 = no cap).') + p.add_argument('--max-rows', type=int, default=0, + help='Truncate corpus to this many rows AFTER get_dataset (0 = no cap). ' + 'Use this instead of --total to avoid invalidating the dataset cache.') + p.add_argument('--batch-size', type=int, default=64, + help='Rows per condense+encode batch.') + p.add_argument('--no-cache', action='store_true', + help='Disable load_from_cache_file in dataset_think.get_dataset.') + p.add_argument('--overwrite', action='store_true', + help='Drop the table before build and start fresh.') + p.add_argument('--skip-index', action='store_true', + help='Skip IVF_PQ index build at the end (debug).') + p.add_argument('--misses-log', default='', + help='Path for filtered-row JSONL log (defaults to /.misses.jsonl).') + + # eval-only + p.add_argument('--eval-size', type=int, default=500, + help='Number of probes for self-recall evaluation.') + p.add_argument('--top-k', type=int, default=10, + help='Largest k to report. Smaller ks (1, 5) are always reported.') + + return p.parse_args() + + +def main() -> None: + args = parse_args() + Path(args.db_path).mkdir(parents=True, exist_ok=True) + + global _GET_DATASET + if args.dataset_module == 'dataset_think': + from dataset_think import get_dataset as _swap + _GET_DATASET = _swap + sys.stderr.write(f'[main] dataset loader: {args.dataset_module}\n') + + # Build/eval both depend on the same Twinkle stack — initialize once. + sampler_mesh, emb_mesh = initialize_twinkle() + sys.stderr.write(f'[main] twinkle initialized: ' + f'sampler ranks 0-{SAMPLER_GPUS - 1} (TP={SAMPLER_GPUS}), ' + f'emb_model ranks {SAMPLER_GPUS}-{NUM_GPUS - 1} (DP={EMB_GPUS}).\n') + + sys.stderr.write('[main] starting vLLM condenser sampler...\n') + sampler = build_sampler(sampler_mesh) + sys.stderr.write('[main] starting embedding TransformersModel...\n') + emb_model, emb_template = build_emb_model(emb_mesh) + + api: Optional[OpenAIClient] = None + if COMPRESS_API_KEY: + api = OpenAIClient( + model=COMPRESS_API_MODEL, + api_key=COMPRESS_API_KEY, + base_url=COMPRESS_BASE_URL, + ) + else: + sys.stderr.write( + '[main] WARNING: COMPRESS_API_KEY unset — truncated rows will be dropped.\n') + + if args.mode in ('build', 'both'): + build_index(args, sampler, emb_model, emb_template, api) + if args.mode in ('eval', 'both'): + eval_recall(args, sampler, emb_model, emb_template, api) + + +if __name__ == '__main__': + main() diff --git a/cookbook/exp/embedding/dataset_index.py b/cookbook/exp/embedding/dataset_index.py new file mode 100644 index 00000000..7d2905a5 --- /dev/null +++ b/cookbook/exp/embedding/dataset_index.py @@ -0,0 +1,718 @@ +"""RAG-index corpus loader — abstract reasoning skills + textbook-style methods. + +Distinct from training-time ``dataset_think.py``. Optimizes for **abstraction +density**, not raw coverage: every row should encode a transferable method, +theorem, or solution pattern that downstream queries can retrieve as a +"use-when-X-do-Y" recipe. + +Single-table design (``thinking_traces``); EMBED_QUERY_COT condense step in +``build_thinking_rag_index`` homogenizes thinking-style and textbook-style +content into the same retrieval form, so dual-table is unnecessary. The +``source`` field carries the original dataset name for eval-time +domain-bucket diagnostics. + +Output schema matches ``dataset_think.get_dataset()``: ``{id, source, messages}`` +with ``messages[1].reasoning_content`` carrying the CoT. + +Mix (≈3.6M rows base, 10 datasets): + Math thinking 23% — OpenMathReasoning + OpenR1-Math-220k + s1K-1.1 + Code thinking 19% — OpenCodeReasoning-2 + codeforces-cots + Cross-domain R1 39% — Bespoke-Stratos + dolphin-r1 + reasoning-v1-20m + + natural_reasoning + Textbook synth 17% — cosmopedia v1 (auto_math_text, chunked by H2) + Olympiad solutions <1% — Omni-MATH + +Dropped: camel-ai/{physics,chemistry,biology} (zip-only, no parquet/jsonl) and +swift/stack-exchange-paired (dataset_infos.json/data layout mismatch); the +textbook-density gap is covered by a larger cosmopedia slice. + +Textbook processors synthesize a question from the chapter heading and place +the explanatory body into the ``cot`` field — embedding+condense reads +``query | cot`` so the textbook prose becomes a retrievable method. + +Field extraction is defensive: each processor tries multiple plausible column +names and silently drops rows that miss a usable signal. Inspect +``dropped_index.jsonl`` after the first run to verify field-name guesses. +""" +import re +from typing import Any, Dict, List, Optional + +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.preprocessor import Preprocessor + +from dataset_think import _THINK_RE, _hash_id, _register, ToMessagesProcessor + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# Sky-T1 / Bespoke-Stratos custom markers (used in place of ). +_BOT_RE = re.compile( + r'<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>', re.DOTALL) +_BOS_RE = re.compile( + r'<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>', re.DOTALL) + +# H2 heading split for cosmopedia-style markdown chunks. +_H2_RE = re.compile(r'^##\s+(.+?)\s*$', re.MULTILINE) + + +def _split_think(text: str) -> tuple: + """Return ``(cot, response)``; cot empty if no ```` block found.""" + if not text: + return '', '' + m = _THINK_RE.search(text) + if not m: + return '', text.strip() + return m.group(1).strip(), text[m.end():].strip() + + +def _split_sky_t1(text: str) -> tuple: + """Return ``(cot, response)`` for Sky-T1 / Bespoke-Stratos marker format.""" + if not text: + return '', '' + bot = _BOT_RE.search(text) + bos = _BOS_RE.search(text) + cot = bot.group(1).strip() if bot else '' + sol = bos.group(1).strip() if bos else '' + return cot, sol + + +def _from_messages(messages: Any) -> tuple: + """Pull (first_user, first_assistant) from OpenAI/ShareGPT-style list.""" + if not isinstance(messages, list): + return '', '' + query, assistant = '', '' + for msg in messages: + if not isinstance(msg, dict): + continue + role = msg.get('role') or msg.get('from') or '' + content = msg.get('content') or msg.get('value') or '' + if not isinstance(content, str): + continue + if role in ('user', 'human') and not query: + query = content.strip() + elif role in ('assistant', 'gpt') and not assistant: + assistant = content.strip() + break + return query, assistant + + +def _chunk_by_h2(text: str, min_chars: int = 200, max_chars: int = 6000): + """Split markdown text on ``## `` headings; yield ``(title, body)`` pairs.""" + if not text: + return + matches = list(_H2_RE.finditer(text)) + if not matches: + head = text.strip()[:80].splitlines()[0] if text.strip() else '' + body = text.strip() + if head and min_chars <= len(body) <= max_chars: + yield head, body + return + for i, m in enumerate(matches): + title = m.group(1).strip() + start = m.end() + end = matches[i + 1].start() if i + 1 < len(matches) else len(text) + body = text[start:end].strip() + if min_chars <= len(body) <= max_chars and title: + yield title, body + + +# =========================================================================== +# Math thinking +# =========================================================================== + +OPEN_MATH_REASONING_REPO = 'ms://AI-ModelScope/OpenMathReasoning' + + +class OpenMathReasoningProcessor(Preprocessor): + """OpenMathReasoning → ``{id, source, query, cot, response}``. + + Schema: ``problem``, ``generated_solution`` (R1 trace with ````), + ``expected_answer``. The ``cot`` *split* (not column) is the long-CoT + portion — TIR/genselect/additional_problems sit in sibling splits and + are filtered at load time, not row-level. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('problem') or row.get('question') or '').strip() + assistant = (row.get('generated_solution') or row.get('solution') + or row.get('output') or '').strip() + if not query or not assistant: + continue + cot, response = _split_think(assistant) + if not cot: + continue + if not response: + response = (row.get('expected_answer') or row.get('answer') or '').strip() + if not response: + continue + out.append({ + 'id': _hash_id('open_math_reasoning', f'{query}\n{response}'), + 'source': 'OpenMathReasoning', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +OPEN_R1_MATH_REPO = 'ms://open-r1/OpenR1-Math-220k' + + +class OpenR1MathProcessor(Preprocessor): + """OpenR1-Math-220k → ``{id, source, query, cot, response}``. + + Schema: ``problem``, ``solution``, ``answer``, ``generations`` (list of + R1 traces), ``correctness_math_verify`` (parallel bool list). Pick the + first generation whose math-verify passed; fall back to ``solution``. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('problem') or row.get('question') or '').strip() + if not query: + continue + assistant = '' + gens = row.get('generations') + verifies = row.get('correctness_math_verify') + if isinstance(gens, list): + if isinstance(verifies, list) and len(verifies) == len(gens): + for g, v in zip(gens, verifies): + if v and isinstance(g, str) and g.strip(): + assistant = g.strip() + break + if not assistant: + for g in gens: + if isinstance(g, str) and g.strip(): + assistant = g.strip() + break + if not assistant: + assistant = (row.get('solution') or '').strip() + if not assistant: + continue + cot, response = _split_think(assistant) + if not cot: + continue + if not response: + response = (row.get('answer') or '').strip() + if not response: + continue + out.append({ + 'id': _hash_id('open_r1_math', f'{query}\n{response}'), + 'source': 'OpenR1-Math-220k', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +S1K_REPO = 'ms://simplescaling/s1K-1.1' + + +class S1KProcessor(Preprocessor): + """s1K-1.1 → ``{id, source, query, cot, response}``. + + Schema: ``question`` + ``deepseek_thinking_trajectory`` (or + ``thinking_trajectories`` legacy) + ``deepseek_attempt`` (final answer). + Hand-curated peak-abstraction set, kept whole. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('question') or row.get('problem') or '').strip() + thinking = (row.get('deepseek_thinking_trajectory') + or row.get('thinking_trajectories') + or row.get('thinking') or '') + if isinstance(thinking, list): + thinking = '\n\n'.join(t for t in thinking if isinstance(t, str)) + cot = (thinking or '').strip() + response = (row.get('deepseek_attempt') or row.get('attempt') + or row.get('answer') or row.get('solution') or '').strip() + if not query or not cot or not response: + continue + out.append({ + 'id': _hash_id('s1k', f'{query}\n{response}'), + 'source': 's1K-1.1', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# =========================================================================== +# Code thinking +# =========================================================================== + +OPEN_CODE_REASONING_REPO = 'ms://nv-community/OpenCodeReasoning-2' + + +class OpenCodeReasoning2Processor(Preprocessor): + """OpenCodeReasoning-2 → ``{id, source, query, cot, response}``. + + Schema: ``input``/``problem``, plus per-model R1-style trace columns + (``r1_generation``, ``qwq_generation``, etc.). Prefer the ``r1`` trace; + fall back to ``solution``. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('input') or row.get('problem') + or row.get('question') or '').strip() + # OCR-2 'python' split ships dirty rows where question is literally '-'; + # the real prompt is buried in r1_generation and not recoverable here. + if not query or query == '-': + continue + assistant = (row.get('r1_generation') or row.get('reasoning_content') + or row.get('solution') or row.get('output') or '').strip() + if not assistant: + continue + cot, response = _split_think(assistant) + if not cot: + continue + if not response: + response = (row.get('expected_solution') or row.get('answer') or '').strip() + if not response: + continue + out.append({ + 'id': _hash_id('opencode_reasoning2', f'{query}\n{response}'), + 'source': 'OpenCodeReasoning-2', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +CODEFORCES_COTS_REPO = 'ms://open-r1/codeforces-cots' + + +class CodeforcesCotsProcessor(Preprocessor): + """codeforces-cots → ``{id, source, query, cot, response}``. + + Schema: ``description``/``problem``, ``generation``/``solution`` (R1 + trace with ```` + final code). Algorithmic patterns at high + abstraction density. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('description') or row.get('problem') + or row.get('input') or row.get('question') or '').strip() + assistant = (row.get('generation') or row.get('solution') + or row.get('output') or '').strip() + if not query or not assistant: + continue + cot, response = _split_think(assistant) + if not cot or not response: + continue + out.append({ + 'id': _hash_id('codeforces_cots', f'{query}\n{response}'), + 'source': 'codeforces-cots', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# =========================================================================== +# Cross-domain R1 +# =========================================================================== + +BESPOKE_STRATOS_REPO = 'ms://bespokelabs/Bespoke-Stratos-17k' + + +class BespokeStratosProcessor(Preprocessor): + """Bespoke-Stratos-17k → ``{id, source, query, cot, response}``. + + Schema: ``conversations`` (ShareGPT). Assistant content uses Sky-T1 + markers ``<|begin_of_thought|>...<|end_of_thought|>`` then + ``<|begin_of_solution|>...<|end_of_solution|>``. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query, assistant = _from_messages( + row.get('conversations') or row.get('messages')) + if not query or not assistant: + continue + cot, response = _split_sky_t1(assistant) + if not cot: + cot, response = _split_think(assistant) + if not cot or not response: + continue + out.append({ + 'id': _hash_id('bespoke_stratos', f'{query}\n{response}'), + 'source': 'Bespoke-Stratos-17k', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +DOLPHIN_R1_REPO = 'ms://AI-ModelScope/dolphin-r1' + + +class DolphinR1Processor(Preprocessor): + """dolphin-r1 → ``{id, source, query, cot, response}``. + + Schema (reasoning-deepseek subset): ``messages=[system, user]`` (no + assistant turn) + flat ``reasoning`` (CoT) + ``answer`` (final response) + + ``model``. Pull the user turn as query, ``reasoning``/``answer`` as + cot/response. Fallback to embedded ```` for legacy rows. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + msgs = row.get('messages') or row.get('conversations') + query = '' + if isinstance(msgs, list): + for msg in msgs: + if not isinstance(msg, dict): + continue + role = msg.get('role') or msg.get('from') or '' + content = msg.get('content') or msg.get('value') or '' + if role in ('user', 'human') and isinstance(content, str): + query = content.strip() + cot = (row.get('reasoning') or row.get('reasoning_content') or '').strip() + response = (row.get('answer') or '').strip() + if (not cot or not response) and isinstance(msgs, list): + _, assistant = _from_messages(msgs) + if assistant: + c2, r2 = _split_think(assistant) + if c2: + cot = cot or c2 + response = response or r2 or assistant + if not query or not cot or not response: + continue + out.append({ + 'id': _hash_id('dolphin_r1', f'{query}\n{response}'), + 'source': 'dolphin-r1', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +GLAIVE_REASONING_REPO = 'ms://glaiveai/reasoning-v1-20m' + + +class GlaiveReasoningProcessor(Preprocessor): + """reasoning-v1-20m → ``{id, source, query, cot, response}``. + + Schema: ``prompt``, ``response`` (R1 trace with ```` + answer). + Largest cross-domain corpus in the mix; downsample aggressively. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('prompt') or row.get('question') + or row.get('input') or '').strip() + assistant = (row.get('response') or row.get('output') + or row.get('answer') or '').strip() + if not query or not assistant: + continue + cot, response = _split_think(assistant) + if not cot or not response: + continue + out.append({ + 'id': _hash_id('glaive_reasoning', f'{query}\n{response}'), + 'source': 'reasoning-v1-20m', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +NATURAL_REASONING_REPO = 'ms://facebook/natural_reasoning' + + +class NaturalReasoningProcessor(Preprocessor): + """natural_reasoning → ``{id, source, query, cot, response}``. + + Schema: ``question`` + ``reference_answer`` + ``responses=[{response_model, + response}]``. The ``response`` field itself is the step-by-step CoT + (``## Step 1...## Step 2...``); there is no separate ``reasoning`` key. + Map ``responses[i].response`` → cot, ``reference_answer`` → response. + Rows with empty ``reference_answer`` (~18% per README) are dropped. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('question') or '').strip() + if not query: + continue + cot = '' + responses = row.get('responses') + if isinstance(responses, list): + for r in responses: + if not isinstance(r, dict): + continue + txt = (r.get('response') or r.get('reasoning') + or r.get('thinking') or r.get('answer') or '').strip() + if txt: + cot = txt + break + if not cot: + cot = (row.get('reasoning') or row.get('thinking') + or row.get('response') or '').strip() + response = (row.get('reference_answer') or row.get('answer') or '').strip() + if not cot or not response: + continue + out.append({ + 'id': _hash_id('natural_reasoning', f'{query}\n{response}'), + 'source': 'natural_reasoning', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# =========================================================================== +# Textbook-style — synthesize query from chapter heading; body → cot +# =========================================================================== + +COSMOPEDIA_REPO = 'ms://HuggingFaceTB/cosmopedia' + +class CosmopediaProcessor(Preprocessor): + """cosmopedia v1 → ``{id, source, query, cot, response}``. + + Schema: ``prompt`` (writing instruction), ``text`` (full chapter body), + ``format``/``audience``/``seed_data``. The subset is selected at load + time (``subset_name='auto_math_text'`` — densest math-textbook slice); + H2 chunking inside each row yields synthetic queries + (``Explain {heading}``) with the body placed into ``cot``. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + text = (row.get('text') or row.get('content') or '').strip() + if not text: + continue + for title, body in _chunk_by_h2(text): + # Heading-only "Explain: X" was 1-2 tokens and impossible to align + # with full-section cot. Promote the section's lead paragraph into + # the query so anchor carries real semantic content. + parts = body.split('\n\n', 1) + first_para = parts[0].strip() + rest = parts[1].strip() if len(parts) > 1 else '' + if len(first_para) < 256 or len(rest) < 256: + continue + query = f'{title}\n\n{first_para}' if title else first_para + out.append({ + 'id': _hash_id('cosmopedia', f'{title}\n{first_para[:200]}'), + 'source': 'cosmopedia-v1', + 'query': query, + 'cot': rest, + 'response': '', + }) + return self.map_row_to_col(out) + + +OMNI_MATH_REPO = 'ms://AI-ModelScope/Omni-MATH' + + +class OmniMathProcessor(Preprocessor): + """Omni-MATH → ``{id, source, query, cot, response}``. + + Schema: ``problem``, ``solution`` (full proof), ``answer``, ``domain``, + ``difficulty``. Olympiad-grade derivations — solution body → cot, + answer → response. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('problem') or row.get('question') or '').strip() + solution = (row.get('solution') or '').strip() + answer = (row.get('answer') or row.get('expected_answer') or '').strip() + if not query or not solution: + continue + out.append({ + 'id': _hash_id('omni_math', f'{query}\n{solution[:200]}'), + 'source': 'Omni-MATH', + 'query': query, + 'cot': solution, + 'response': answer, + }) + return self.map_row_to_col(out) + + +# =========================================================================== +# Mix configuration — base sizes target ≈3.6M total rows +# =========================================================================== + +_BASE_SIZES = { + 'open_math_reasoning': 600_000, + 'open_r1_math': 220_000, + 's1k': 1_000, + 'opencode_reasoning2': 500_000, + 'codeforces_cots': 200_000, + 'bespoke_stratos': 17_000, + 'dolphin_r1': 400_000, + 'glaive_reasoning': 800_000, + 'natural_reasoning': 200_000, + 'cosmopedia': 700_000, + 'omni_math': 4_000, +} + + +def _scaled_sizes(total: Optional[int]) -> Dict[str, int]: + if total is None or total <= 0: + return dict(_BASE_SIZES) + scale = total / sum(_BASE_SIZES.values()) + return {k: max(1, int(round(v * scale))) for k, v in _BASE_SIZES.items()} + + +def _build_dataset(total: Optional[int] = None, + load_from_cache_file: bool = True) -> Dataset: + sizes = _scaled_sizes(total) + dataset = Dataset() + + _register(dataset, OpenMathReasoningProcessor, + DatasetMeta(dataset_id=OPEN_MATH_REASONING_REPO, split='cot', + data_slice=range(sizes['open_math_reasoning'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, OpenR1MathProcessor, + DatasetMeta(dataset_id=OPEN_R1_MATH_REPO, split='train', + data_slice=range(sizes['open_r1_math'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, S1KProcessor, + DatasetMeta(dataset_id=S1K_REPO, split='train'), + load_from_cache_file=load_from_cache_file) + + _register(dataset, OpenCodeReasoning2Processor, + DatasetMeta(dataset_id=OPEN_CODE_REASONING_REPO, + subset_name='train', split='python', + data_slice=range(sizes['opencode_reasoning2'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, CodeforcesCotsProcessor, + DatasetMeta(dataset_id=CODEFORCES_COTS_REPO, + subset_name='solutions_w_editorials_decontaminated', + split='train', + data_slice=range(sizes['codeforces_cots'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, BespokeStratosProcessor, + DatasetMeta(dataset_id=BESPOKE_STRATOS_REPO, split='train'), + load_from_cache_file=load_from_cache_file) + + _register(dataset, DolphinR1Processor, + DatasetMeta(dataset_id=DOLPHIN_R1_REPO, + subset_name='reasoning-deepseek', split='train', + data_slice=range(sizes['dolphin_r1'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, GlaiveReasoningProcessor, + DatasetMeta(dataset_id=GLAIVE_REASONING_REPO, split='train', + data_slice=range(sizes['glaive_reasoning'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, NaturalReasoningProcessor, + DatasetMeta(dataset_id=NATURAL_REASONING_REPO, split='train', + data_slice=range(sizes['natural_reasoning'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, CosmopediaProcessor, + DatasetMeta(dataset_id=COSMOPEDIA_REPO, + subset_name='auto_math_text', split='train', + data_slice=range(sizes['cosmopedia'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, OmniMathProcessor, + DatasetMeta(dataset_id=OMNI_MATH_REPO, split='test'), + load_from_cache_file=load_from_cache_file) + + dataset.mix_dataset(False) + # Mix is concatenated in registration order; shuffle so the streaming + # consumer sees all sources interleaved instead of 600k OpenMathReasoning + # rows before it ever reaches code/textbook splits. + dataset.dataset = dataset.dataset.shuffle(seed=42) + return dataset + + +def get_dataset(total: Optional[int] = None, + dropped_log: Optional[str] = None, + load_from_cache_file: bool = True) -> Dataset: + """Build, convert to messages, and quality-filter the RAG-index corpus. + + Mirrors ``dataset_think.get_dataset``: identical signature + output + schema so ``build_thinking_rag_index`` consumes both modules unchanged. + """ + from twinkle_agentic.preprocessor import ( + DeadLoopFilter, + FixUnicodeFilter, + HardFilter, + MessageSanityFilter, + QualityPreprocessor, + RefuseFilter, + RemoveRepeatSentencesFilter, + TokenNumFilter, + TokenSoupFilter, + ) + + dataset = _build_dataset(total=total, load_from_cache_file=load_from_cache_file) + # Drop trivially-short queries (e.g. one-line math problems, OmniMath stubs) + # before message conversion — anchor side needs enough tokens to embed meaningfully. + dataset.dataset = dataset.dataset.filter( + lambda x: len((x.get('query') or '').strip()) >= 100, + num_proc=32, load_from_cache_file=load_from_cache_file) + dataset.map(ToMessagesProcessor(), remove_columns=['query', 'cot', 'response'], + load_from_cache_file=load_from_cache_file) + qp = QualityPreprocessor( + pipeline=[ + HardFilter(), + RefuseFilter(), + DeadLoopFilter(), + TokenSoupFilter(), + MessageSanityFilter(min_turns=1, max_msg_chars=200000), + FixUnicodeFilter(), + RemoveRepeatSentencesFilter(), + TokenNumFilter(max_num=32768), + ], + dropped_log_path=dropped_log or '', + ) + dataset.map(qp, num_proc=32, load_from_cache_file=load_from_cache_file) + return dataset + + +if __name__ == '__main__': + import os + dropped_log = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'dropped_index.jsonl') + if os.path.exists(dropped_log): + os.remove(dropped_log) + dataset = get_dataset(load_from_cache_file=False) + print(len(dataset)) diff --git a/cookbook/exp/embedding/train_embedding_full_ddp.py b/cookbook/exp/embedding/train_embedding_full_ddp.py index 492e29aa..bed1dbe9 100644 --- a/cookbook/exp/embedding/train_embedding_full_ddp.py +++ b/cookbook/exp/embedding/train_embedding_full_ddp.py @@ -40,7 +40,8 @@ from twinkle_agentic.protocol.openai import OpenAI as OpenAIClient sys.path.insert(0, str(Path(__file__).resolve().parent)) -from dataset_think import get_dataset # noqa: E402 +from dataset_think import get_dataset as get_dataset_think # noqa: E402 +from dataset_index import get_dataset as get_dataset_index # noqa: E402 logger = get_logger() @@ -49,7 +50,7 @@ # Condenser (online compression + LoRA self-improvement); embedding model trains LoRA on top of MODEL_ID. CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2') -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +MODEL_ID = os.environ.get('MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-QA-emb') TEMPLATE_NAME = 'Qwen3_5Template' # -- GPU placement (8 total) -------------------------------------------------- @@ -68,15 +69,19 @@ GRADIENT_ACCUMULATION_STEPS = 1 LOG_INTERVAL = 2 SAVE_INTERVAL = 4000 -NUM_EPOCHS = 2 +NUM_EPOCHS = 1 TOTAL_SAMPLES: Optional[int] = None +# Post-build caps on each loader (None = no cap). Applied via .select() before mix. +THINK_CAP: Optional[int] = 200_000 +INDEX_CAP: Optional[int] = 400_000 +MIX_SHUFFLE_SEED = 42 # -- Resume from checkpoint --------------------------------------------------- -RESUME_CHECKPOINT = os.environ.get( - 'RESUME_CHECKPOINT', - './output/embedding_lora_transformers/step_16000') -RESUME_STEP = int(os.environ.get('RESUME_STEP', 16000)) +# Empty by default — build_model falls back to MODEL_ID (the published emb model). +# Set both to point at a local in-progress run only when resuming the *same* schedule. +RESUME_CHECKPOINT = os.environ.get('RESUME_CHECKPOINT', '') +RESUME_STEP = int(os.environ.get('RESUME_STEP', 0)) # -- Online-compression knobs ------------------------------------------------- # Below this length, condenser fabricates content for open-ended short prompts; @@ -355,10 +360,15 @@ def _build_compress_prompts(rows: List[Dict[str, Any]]) -> tuple: raw_pairs: List[tuple] = [] prompt_queries: List[str] = [] passthrough: List[Optional[str]] = [] + # Conservative char budget: 32768 max_length - 8192 gen - ~2k prompt overhead = ~22k tokens. + # At worst-case 1.5 chars/token (CJK), that's ~33k chars. Use 60k as safe English-mix cap. + _MAX_COT_CHARS = 60_000 for i, row in enumerate(rows): query, cot = _extract_query_cot(row) if not query or len(cot) < MIN_TEXT_CHARS: continue + if len(cot) > _MAX_COT_CHARS: + continue valid_indices.append(i) raw_pairs.append((query, cot)) # Short query bypasses condenser to avoid skeleton-induced hallucination. @@ -521,7 +531,20 @@ def train(): twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) # -------- Data ----------------------------------------------------------- - dataset = get_dataset(total=TOTAL_SAMPLES, load_from_cache_file=True) + dataset = get_dataset_think(total=TOTAL_SAMPLES, load_from_cache_file=True) + if THINK_CAP and len(dataset.dataset) > THINK_CAP: + dataset.dataset = dataset.dataset.select(range(THINK_CAP)) + if INDEX_CAP != 0: + from datasets import concatenate_datasets + ds_index = get_dataset_index(total=None, load_from_cache_file=True) + if INDEX_CAP and len(ds_index.dataset) > INDEX_CAP: + ds_index.dataset = ds_index.dataset.select(range(INDEX_CAP)) + n_think = len(dataset.dataset) + n_index = len(ds_index.dataset) + # Both loaders emit identical {id, source, messages} schema post-QP. + dataset.dataset = concatenate_datasets( + [dataset.dataset, ds_index.dataset]).shuffle(seed=MIX_SHUFFLE_SEED) + logger.info(f'[mix] think={n_think} + index={n_index} → total={len(dataset.dataset)}') dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) total_forward_steps = len(dataloader) * NUM_EPOCHS optimizer_steps = total_forward_steps // GRADIENT_ACCUMULATION_STEPS @@ -608,15 +631,19 @@ def _sample_batch(raw_batch): """Compress via vLLM sampler; fall back to API on truncation.""" compress_prompts, valid_indices, raw_pairs, prompt_queries, passthrough = \ _build_compress_prompts(raw_batch) - if not compress_prompts: + if len(compress_prompts) < 4: return None # Only submit non-passthrough prompts to the sampler. sampler_input = [p for p in compress_prompts if p is not None] sampler_pos = [ri for ri, p in enumerate(compress_prompts) if p is not None] if sampler_input: - with retrainer.sampler_lock: - sampler_responses = condenser_sampler.sample(sampler_input, compress_params) + try: + with retrainer.sampler_lock: + sampler_responses = condenser_sampler.sample(sampler_input, compress_params) + except Exception as exc: + logger.warning(f'[sampler] encode overflow in batch, falling back to API: {exc}') + sampler_responses = [None] * len(sampler_input) else: sampler_responses = [] responses = [None] * len(compress_prompts) From 28ab28fa4ea5f3c321dfcc4605a9bfcac03ea0fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 11:12:33 +0800 Subject: [PATCH 02/53] wip --- src/twinkle/checkpoint_engine/manager.py | 1 - src/twinkle/checkpoint_engine/mixin.py | 1 - src/twinkle/cli/cli.py | 7 +++---- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index cde5c519..3860d284 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py -import time from typing import List, Optional from twinkle import Platform, get_logger diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py index e2e5d94d..8dc15c92 100644 --- a/src/twinkle/checkpoint_engine/mixin.py +++ b/src/twinkle/checkpoint_engine/mixin.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import os from twinkle import Platform, remote_function from twinkle.checkpoint_engine.base import CheckpointEngine diff --git a/src/twinkle/cli/cli.py b/src/twinkle/cli/cli.py index 1730887f..10ad4a39 100644 --- a/src/twinkle/cli/cli.py +++ b/src/twinkle/cli/cli.py @@ -1,12 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - import os import sys from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from pathlib import Path -from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union +from typing import Any, Iterator, Literal + # ──────────────────────────────────────────────────────────────────────────────── # Arg group dataclasses @@ -243,7 +242,7 @@ def _resolve_path(self) -> Path | None: class EnvVarSource(ConfigSource): """Reads os.environ; recognizes TWINKLE_ prefix and any key known to the registry.""" - def __init__(self, registry: ConfigRegistry): + def __init__(self, registry: 'ConfigRegistry'): self._registry = registry def load(self) -> dict[str, str]: From 30e8412de6d3ec5e7ae929c62d6a40c28829a88a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 11:43:42 +0800 Subject: [PATCH 03/53] wip --- src/twinkle/data_format/sampling.py | 1 - src/twinkle/dataloader/dataloader.py | 3 +- .../dataset/iterable_packing_dataset.py | 30 ++++++++++++++----- src/twinkle/dataset/packing_dataset.py | 2 ++ 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index 1d5fe07c..01ff0377 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index c392d56c..408c8d4b 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -146,7 +146,7 @@ def _tracking_iter(self, inner): def skip_consumed_samples(self, consumed_train_samples: int) -> None: from torch.utils.data import IterableDataset - if isinstance(self.dataset, IterableDataset): + if isinstance(self.dataset, IterableDataset) or consumed_train_samples is None or consumed_train_samples <= 0: warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.') self._skip_samples = 0 return @@ -164,6 +164,7 @@ def resume_from_checkpoint(self, consumed_train_samples, **kwargs): @remote_function() def get_state(self) -> dict: + """The dataloader state for saving.""" return {'consumed_train_samples': self._consumed_train_samples} def _rebuild_sampler_stack(self): diff --git a/src/twinkle/dataset/iterable_packing_dataset.py b/src/twinkle/dataset/iterable_packing_dataset.py index ca7c6fbd..ab1d3a98 100644 --- a/src/twinkle/dataset/iterable_packing_dataset.py +++ b/src/twinkle/dataset/iterable_packing_dataset.py @@ -88,10 +88,27 @@ def _fetch_data_out_queue(self, last_res, num_samples): last_res += res return last_res - @staticmethod - def _cyclic_iter(iterable): - while True: - yield from iterable + def _write_through_iter(self, iterable): + """Yields from iterable, meanwhile, save it to disk if needed. + Saving is needed when you are using several datasets at a time. + """ + if not self.cyclic: + for row in iterable: + self._write_through(row) + yield row + return + else: + first_pass = True + while True: + empty = True + for row in iterable: + empty = False + if first_pass: + self._write_through(row) + yield row + if empty: + return + first_pass = False @remote_function() def __iter__(self): @@ -102,10 +119,7 @@ def __iter__(self): except StopIteration: return - if self.cyclic: - iterator = self._cyclic_iter(self.dataset) - else: - iterator = iter(self.dataset) + iterator = self._write_through_iter(self.dataset) data = [] max_length = self.template.max_length or 2048 while True: diff --git a/src/twinkle/dataset/packing_dataset.py b/src/twinkle/dataset/packing_dataset.py index fa4acbd5..ada9498b 100644 --- a/src/twinkle/dataset/packing_dataset.py +++ b/src/twinkle/dataset/packing_dataset.py @@ -114,6 +114,8 @@ def __getitem__(self, index): assert self._packed_called, 'Call `pack_dataset()` first before index the sample.' sequence = self.packed_idx[index] rows = [self.dataset[i] for i in sequence] + for row in rows: + self._write_through(row) output = {} for key in rows[0]: output[key] = [r[key] for r in rows] From 6909de3bd3f2ef03ffefb45524a40ee2c02105de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 11:44:19 +0800 Subject: [PATCH 04/53] wip --- docs/source_en/Components/Gym/Gym.md | 3 ++- "docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" | 3 ++- src/twinkle/gym/__init__.py | 2 -- src/twinkle/gym/base.py | 10 ---------- 4 files changed, 4 insertions(+), 14 deletions(-) delete mode 100644 src/twinkle/gym/__init__.py delete mode 100644 src/twinkle/gym/base.py diff --git a/docs/source_en/Components/Gym/Gym.md b/docs/source_en/Components/Gym/Gym.md index 4db355b8..8d243677 100644 --- a/docs/source_en/Components/Gym/Gym.md +++ b/docs/source_en/Components/Gym/Gym.md @@ -3,7 +3,8 @@ The Gym component provides an interface for reinforcement learning environments in Twinkle. ```python -from twinkle.gym import Gym +from twinkle_agentic.env import Gym + class CustomGym(Gym): diff --git "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" "b/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" index 63dc87aa..4e34ffb9 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" @@ -3,7 +3,8 @@ Gym 组件为 Twinkle 中的强化学习环境提供接口。 ```python -from twinkle.gym import Gym +from twinkle_agentic.env import Gym + class CustomGym(Gym): diff --git a/src/twinkle/gym/__init__.py b/src/twinkle/gym/__init__.py deleted file mode 100644 index 44b0771b..00000000 --- a/src/twinkle/gym/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from .base import Gym diff --git a/src/twinkle/gym/base.py b/src/twinkle/gym/base.py deleted file mode 100644 index aca79809..00000000 --- a/src/twinkle/gym/base.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. - - -class Gym: - - def __init__(self): - pass - - def step(self): - pass From c7a8b88b125545c64b7ff4b748ded32725bfa422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 14:02:40 +0800 Subject: [PATCH 05/53] wip --- src/twinkle/infra/__init__.py | 20 +-- src/twinkle/loss/chunked_cross_entropy.py | 200 ++++++++++++++++------ src/twinkle/notifier/base.py | 2 +- 3 files changed, 152 insertions(+), 70 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 83e10d13..23fb1fda 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -1,11 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import functools import inspect -import itertools import json import numpy as np import os -import random import sys from typing import Any, Callable, List, Literal, Optional, TypeVar, Union @@ -59,7 +57,7 @@ def _tag_exc(exc: BaseException, caller: Optional[str]) -> None: prefix = f'[twinkle driver caller: {caller}] ' exc.args = (prefix + str(exc.args[0]), *exc.args[1:]) if exc.args else (prefix.rstrip(), ) exc._twinkle_caller_augmented = True - except Exception: # noqa: BLE001 + except Exception: # noqa pass @@ -404,6 +402,7 @@ def dispatch_func(arg, n): return result elif dispatch == 'slice_dp': + assert device_mesh is not None # split by dp. each worker in one ep will receive the same argument result = [] # if device_mesh is not None: @@ -420,14 +419,6 @@ def dispatch_func(arg, n): import torch if isinstance(arg, list) or isinstance(arg, torch.Tensor): _args = [] - if device_mesh is None: - total = len(arg) - chunk = max(1, (total + n - 1) // n) - for i in range(n): - start = i * chunk - end = min(total, start + chunk) - _args.append(arg[start:end]) - return _args for i in range(n): _args.append(arg[device_mesh.get_slice( len(arg), device_mesh.get_data_rank_from_global_rank(i * _rank_stride))]) @@ -803,18 +794,13 @@ def wrapper(self, *args, **kwargs) -> T1: # And this is user independent, only decided by the code. _local_lazy_collect = self._lazy_collect if _local_lazy_collect: - # Wrap the deferred collector so that exceptions - # raised when the caller later materializes the - # result also trigger the notifier. Attributes - # (``_futures`` etc.) on the original collector - # are preserved for downstream code paths. _orig_result_func = result_func @functools.wraps(_orig_result_func) def _notifying_result_func(*rargs, **rkwargs): try: return _orig_result_func(*rargs, **rkwargs) - except Exception as _e: # noqa: BLE001 + except Exception as _e: # noqa _tag_exc(_e, _caller) notify_exception(_notifier, _ctx, _e, _name) raise diff --git a/src/twinkle/loss/chunked_cross_entropy.py b/src/twinkle/loss/chunked_cross_entropy.py index 22d3d407..061ca216 100644 --- a/src/twinkle/loss/chunked_cross_entropy.py +++ b/src/twinkle/loss/chunked_cross_entropy.py @@ -1,63 +1,159 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import math -from typing import Any - -from ..data_format import LossOutput +from twinkle.data_format import LossOutput from .base import Loss +# Lazily-built singleton autograd.Function, so we neither pay the +# class-construction cost on every forward nor force a top-level torch import. +_CHUNKED_CE_FUNC = None + + +def _get_chunked_ce_func(): + global _CHUNKED_CE_FUNC + if _CHUNKED_CE_FUNC is not None: + return _CHUNKED_CE_FUNC + + import torch + import torch.nn.functional as F + + class _ChunkedCrossEntropyFunc(torch.autograd.Function): + """Chunked CE that materialises log_softmax(B, V) only one chunk at a time. + + Forward returns a scalar loss; backward writes per-token gradients into + a freshly allocated `grad_logits` tensor (the input `logits` is never + mutated). Mathematically equivalent to ``CrossEntropyLoss`` in the same + package; ``chunk_size`` only controls the memory/throughput trade-off. + """ + + @staticmethod + def forward(ctx, logits, labels, chunk_size, ignore_index, reduction, dft): + ctx.save_for_backward(logits, labels) + ctx.chunk_size = chunk_size + ctx.ignore_index = ignore_index + ctx.reduction = reduction + ctx.dft = dft + + n = logits.shape[0] + # Use fp32 accumulators so we don't lose precision when summing + # over many tokens under fp16/bf16 autocast (matches cross_entropy.py). + total_loss = logits.new_zeros((), dtype=torch.float32) + total_count = logits.new_zeros((), dtype=torch.float32) + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + logits_chunk = logits[start:end] + labels_chunk = labels[start:end] + mask = (labels_chunk != ignore_index).float() + + logps = F.log_softmax(logits_chunk, dim=-1).gather( + -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1) + per_token = -logps * logps.exp() if dft else -logps + + total_loss = total_loss + (per_token * mask).sum() + total_count = total_count + mask.sum() + + ctx.num_tokens = total_count.detach() + if reduction == 'mean': + return total_loss / total_count.clamp(min=1) + return total_loss + + @staticmethod + def backward(ctx, grad_output): + logits, labels = ctx.saved_tensors + chunk_size = ctx.chunk_size + ignore_index = ctx.ignore_index + reduction = ctx.reduction + dft = ctx.dft + + if reduction == 'mean': + scale = grad_output / ctx.num_tokens.clamp(min=1) + else: + scale = grad_output + + grad_logits = torch.empty_like(logits) + n = logits.shape[0] + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + logits_chunk = logits[start:end].detach().requires_grad_(True) + labels_chunk = labels[start:end] + mask = (labels_chunk != ignore_index).float() + + with torch.enable_grad(): + logps = F.log_softmax(logits_chunk, dim=-1).gather( + -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1) + per_token = -logps * logps.exp() if dft else -logps + loss_chunk = (per_token * mask).sum() + + grad_chunk = torch.autograd.grad(loss_chunk, logits_chunk, retain_graph=False)[0] + grad_logits[start:end] = grad_chunk * scale + + # logits, labels, chunk_size, ignore_index, reduction, dft + return grad_logits, None, None, None, None, None + + _CHUNKED_CE_FUNC = _ChunkedCrossEntropyFunc + return _CHUNKED_CE_FUNC + class ChunkedCrossEntropyLoss(Loss): - """TODO untested code""" + """CE loss that chunks the (B, V) softmax to bound peak memory. + + Drop-in replacement for :class:`CrossEntropyLoss` when ``outputs['logits']`` + is large (e.g. long sequence x big vocab). Behaviour matches that loss + bit-for-bit; ``chunk_size`` only affects memory/throughput. + + Args: + chunk_size: How many rows of ``logits`` to process per chunk. + ignore_index: Label id treated as padding (excluded from loss). + reduction: ``'mean'`` or ``'sum'``; matches ``CrossEntropyLoss``. + dft: If True, use DFT weighting ``-p*log(p)`` (arxiv 2508.05629). + """ - def __init__(self, chunk_size): + require_logits = True + # We chunk the (B, V) softmax ourselves; tell upstream not to materialise + # `logps` (which would already pay the full memory cost we're trying to + # avoid). The `_loss_from_logps` fast path is kept only for the rare case + # where someone explicitly hands us pre-computed logps. + require_logps = False + + def __init__(self, + chunk_size: int, + ignore_index: int = -100, + reduction: str = 'mean', + dft: bool = False, + **kwargs): + super().__init__() + assert chunk_size > 0, 'chunk_size must be positive' + assert reduction in ('mean', 'sum'), f"reduction must be 'mean' or 'sum', got {reduction!r}" self.chunk_size = chunk_size + self.ignore_index = ignore_index + self.reduction = reduction + self.dft = dft def __call__(self, inputs, outputs, **kwargs): - import torch - - class ChunkedCrossEntropyLossFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, logits, labels, chunk_size): - import torch - ctx.save_for_backward(logits, labels) - ctx.chunk_size = chunk_size - - losses = [] - for i in range(math.ceil(logits.shape[0] / chunk_size)): - l_start = i * chunk_size - l_end = min((i + 1) * chunk_size, logits.shape[0]) - logits_chunk = logits[l_start:l_end] - labels_chunk = labels[l_start:l_end] - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - loss_chunk = loss_fct(logits_chunk, labels_chunk) - losses.append(loss_chunk) - del logits_chunk - del labels_chunk - all_losses = torch.cat(losses) - return all_losses - - @staticmethod - def backward(ctx: Any, *grad_outputs: Any): - import torch - logits, labels = ctx.saved_tensors - chunk_size = ctx.chunk_size - - for i in range(math.ceil(logits.shape[0] / chunk_size)): - l_start = i * chunk_size - l_end = min((i + 1) * chunk_size, logits.shape[0]) - logits_chunk = logits[l_start:l_end].detach().requires_grad_(True) - labels_chunk = labels[l_start:l_end] - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - with torch.enable_grad(): - loss_chunk = loss_fct(logits_chunk, labels_chunk) - grad_output_chunk = grad_outputs[0][l_start:l_end] - _loss_chunk = (loss_chunk * grad_output_chunk).sum() - grad_chunk = torch.autograd.grad(_loss_chunk, logits_chunk, retain_graph=False)[0] - logits[l_start:l_end] = grad_chunk - - return logits, None, None + labels = inputs['labels'] + logps = outputs.get('logps') + + # Fast path: if logps is already gathered upstream, chunking the + # softmax is moot — fall back to the same scalar formula as + # CrossEntropyLoss to keep behaviour identical. + if logps is not None: + return self._loss_from_logps(labels, logps) logits = outputs['logits'] - labels = inputs['labels'] - return LossOutput(loss=ChunkedCrossEntropyLossFunc.apply(logits, labels, self.chunk_size), num_tokens=0) + labels = labels.view(-1) + logits = logits.view(-1, logits.shape[-1]) + + func = _get_chunked_ce_func() + loss = func.apply(logits, labels, self.chunk_size, self.ignore_index, self.reduction, self.dft) + + if self.reduction == 'mean': + return LossOutput(loss=loss, num_tokens=0) + num_tokens = (labels != self.ignore_index).float().sum().clamp(min=1) + return LossOutput(loss=loss, num_tokens=num_tokens) + + def _loss_from_logps(self, labels, logps): + mask = (labels != self.ignore_index).float() + per_token = -logps * logps.exp() if self.dft else -logps + if self.reduction == 'mean': + return LossOutput(loss=(per_token * mask).sum() / mask.sum().clamp(min=1), num_tokens=0) + return LossOutput(loss=(per_token * mask).sum(), num_tokens=mask.sum().clamp(min=1)) diff --git a/src/twinkle/notifier/base.py b/src/twinkle/notifier/base.py index a83903b5..6c138c99 100644 --- a/src/twinkle/notifier/base.py +++ b/src/twinkle/notifier/base.py @@ -66,7 +66,7 @@ def notify_exception(notifier: Notifier, context: str, exc: BaseException, name: if not _try_claim_notify_slot(exc, context, name): try: setattr(exc, '_twinkle_notified', True) - except Exception: # noqa: BLE001 + except Exception: # noqa pass return From 47a83204450a99182876d2093a7ed5059a5ba7cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 17:13:29 +0800 Subject: [PATCH 06/53] wip --- src/twinkle/loss/dpo.py | 44 +++++++++++++++------- src/twinkle/loss/gkd.py | 19 +++++++--- src/twinkle/loss/grpo.py | 31 ++++++---------- src/twinkle/loss/infonce.py | 74 ++++++++++++++++++++++--------------- 4 files changed, 102 insertions(+), 66 deletions(-) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index fe526ab4..d5301951 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -7,7 +7,7 @@ (https://arxiv.org/abs/2305.18290) """ from typing import TYPE_CHECKING, Dict, List, Optional, Union - +import math from twinkle.data_format import LossOutput from twinkle.loss.base import Loss from twinkle.utils.torch_utils import selective_log_softmax @@ -132,6 +132,12 @@ def __init__( **kwargs, ): super().__init__(ignore_index=ignore_index) + if loss_type not in ('sigmoid', 'hinge', 'ipo', 'kto_pair'): + raise ValueError(f'Unknown loss_type: {loss_type}') + if label_smoothing > 0 and loss_type != 'sigmoid': + raise ValueError( + f'label_smoothing > 0 is only defined for loss_type="sigmoid", ' + f'got loss_type="{loss_type}". Set label_smoothing=0.0 or switch to sigmoid.') self.beta = beta self.label_smoothing = label_smoothing self.loss_type = loss_type @@ -217,6 +223,11 @@ def _compute_dpo_loss( if self.loss_type == 'sigmoid': # Standard DPO loss: -log(sigmoid(beta * margin)) losses = -F.logsigmoid(logits) + # Apply label smoothing (only meaningful here: Bradley-Terry soft labels). + if self.label_smoothing > 0: + # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected + smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference + losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses elif self.loss_type == 'hinge': # Hinge loss variant losses = torch.relu(1 - logits) @@ -234,12 +245,6 @@ def _compute_dpo_loss( else: raise ValueError(f'Unknown loss_type: {self.loss_type}') - # Apply label smoothing if specified - if self.label_smoothing > 0: - # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected - smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference - losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses - return losses.mean() def __call__( @@ -321,7 +326,8 @@ def __call__( reference_chosen_logps = torch.zeros_like(policy_chosen_logps) reference_rejected_logps = torch.zeros_like(policy_rejected_logps) else: - return LossOutput(loss=torch.tensor(0.0, device=chosen_logps.device), num_tokens=0) + zero = (policy_chosen_logps.sum() + policy_rejected_logps.sum()) * 0.0 + return LossOutput(loss=zero, num_tokens=0) # Compute DPO loss dpo_loss = self._compute_dpo_loss( @@ -535,11 +541,23 @@ def __call__( # Odds ratio: log(odds_chosen / odds_rejected) # log_odds = log(p/(1-p)) = log(p) - log(1-p) - # Compute entirely in log-space to avoid exp() underflow: - # log(p) = avg_logps (already in log-space) - # log(1-p) = log1p(-exp(avg_logps)) (numerically stable via log1p) - log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps)) - log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps)) + # Compute log(1-p) = log(1 - exp(avg_logp)) numerically stably: + # - For x > -log(2): log(-expm1(x)) (avoids log(0) when p → 1) + # - For x ≤ -log(2): log1p(-exp(x)) (avoids cancellation when p → 0) + # ``avg_logp ∈ (-∞, 0]`` so the threshold partitions the safe regime. + log_two = math.log(2.0) + + def _log1mexp(x: 'torch.Tensor') -> 'torch.Tensor': + # Clamp at a tiny negative to keep both branches well-defined when p≈1. + x_safe = torch.clamp(x, max=-1e-7) + return torch.where( + x_safe > -log_two, + torch.log(-torch.expm1(x_safe)), + torch.log1p(-torch.exp(x_safe)), + ) + + log_odds_chosen = chosen_avg_logps - _log1mexp(chosen_avg_logps) + log_odds_rejected = rejected_avg_logps - _log1mexp(rejected_avg_logps) # ORPO odds ratio loss odds_ratio = log_odds_chosen - log_odds_rejected diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 3f7db4bf..7c198ad0 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -41,6 +41,10 @@ def __init__( chunk_size: int = 512, **kwargs, ): + if not (0.0 <= beta <= 1.0): + raise ValueError(f'beta must be in [0, 1], got {beta}') + if temperature <= 0: + raise ValueError(f'temperature must be > 0, got {temperature}') self.beta = beta self.temperature = temperature self.ignore_index = ignore_index @@ -94,6 +98,7 @@ def __call__( labels=labels, beta=self.beta, temperature=self.temperature, + ignore_index=self.ignore_index, chunk_size=self.chunk_size, topk=topk, teacher_topk_logprobs=teacher_topk_logprobs, @@ -108,6 +113,7 @@ def _generalized_jsd_loss( labels=None, beta: float = 0.5, temperature: float = 1.0, + ignore_index: int = -100, chunk_size: int = 512, topk: Optional[int] = None, teacher_topk_logprobs=None, @@ -164,7 +170,7 @@ def _generalized_jsd_loss( # ── Mask valid (response) tokens ────────────────────────────────────── if labels is not None: - mask = labels != -100 # ignore_index is always -100 per convention + mask = labels != ignore_index # Vocab-size mismatch (e.g. Qwen2.5-VL-3B vs 7B): pad the smaller side # so both distributions are defined over the same token set. stu_dim = student_logits.shape[-1] @@ -178,12 +184,15 @@ def _generalized_jsd_loss( student_logits = student_logits[mask] # [num_valid, vocab/topk] teacher_logits = teacher_logits[mask] num_valid = mask.sum() + # ``[mask]`` already created fresh storage, so in-place divide is safe + # and avoids an extra [num_valid, V] allocation. + student_logits.div_(temperature) + teacher_logits.div_(temperature) else: - student_logits = student_logits.view(-1, student_logits.size(-1)) - teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) + # Keep logits, may be an infer scenario + student_logits = student_logits.reshape(-1, student_logits.size(-1)) / temperature + teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) / temperature num_valid = student_logits.size(0) - student_logits.div_(temperature) - teacher_logits.div_(temperature) if num_valid == 0: return student_logits.new_zeros(()) diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index 4bb71216..781b2206 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -42,18 +42,6 @@ def __init__( self.require_entropy = entropy_coef > 0.0 self.ignore_index = ignore_index - def _compute_loss_mask(self, labels: 'torch.Tensor') -> 'torch.Tensor': - """ - Compute loss mask from labels. - - Args: - labels: [batch, seq_len] target token ids, -100 for ignored positions - - Returns: - mask: [batch, seq_len] float tensor, 1.0 for valid positions, 0.0 for ignored - """ - return (labels != self.ignore_index).float() - def _compute_log_importance_weights( self, per_token_logps: 'torch.Tensor', @@ -165,10 +153,13 @@ def _pad_and_align_to_batch( return data # Already aligned if data.dim() == 1: data = data.unsqueeze(1) - if data.shape[1] == 1: # Scalars - result = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device) - result[mask] = data[mask.any(dim=1).nonzero(as_tuple=True)[0].repeat_interleave(mask.sum(dim=1)), 0] - return result + if data.shape[1] == 1: + assert data.shape[0] == batch_size, ( + f'scalar broadcast expects data.shape[0]==batch_size, ' + f'got data.shape={tuple(data.shape)} mask.shape={(batch_size, seq_len)}') + fill = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device) + expanded = data.expand(batch_size, seq_len) + return torch.where(mask, expanded, fill) data = [data[i] for i in range(batch_size)] # To list # Handle list (scalars or sequences) @@ -276,10 +267,12 @@ def __call__( ) # GRPO loss is ill-defined without advantages (e.g. ref-logps-only forward, - # or eval/validation forwards). Return a zero loss so the forward still - # flows through cleanly and callers can harvest outputs['logps'] freely. + # or eval/validation forwards). Return a zero loss that still flows through + # autograd so DDP/FSDP do not see unused params, and callers can harvest + # outputs['logps'] freely. if advantages is None: - return LossOutput(loss=torch.zeros((), device=device, dtype=logps.dtype), num_tokens=0) + zero = logps.sum() * 0.0 + return LossOutput(loss=zero, num_tokens=0) advantages = self._pad_and_align_to_batch( advantages, diff --git a/src/twinkle/loss/infonce.py b/src/twinkle/loss/infonce.py index 68d14840..c356bd64 100644 --- a/src/twinkle/loss/infonce.py +++ b/src/twinkle/loss/infonce.py @@ -13,8 +13,6 @@ import numpy as np import torch import torch.distributed as dist -import torch.nn.functional as F -from enum import Enum from torch import nn from typing import Optional @@ -22,15 +20,6 @@ from .base import Loss -# Borrowed from sentence_transformers. -class SiameseDistanceMetric(Enum): - """Distance metrics available to the pairwise contrastive losses.""" - - EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa - MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa - COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa - - def _extract_sentences(outputs) -> torch.Tensor: """Return [B, D] sentence embeddings from postprocess_tensor_sp output. @@ -119,6 +108,11 @@ def __init__( process_group=None, **kwargs, ): + if mask_fake_negative and fake_neg_margin <= 0: + raise ValueError( + f'fake_neg_margin must be > 0 when mask_fake_negative=True, got {fake_neg_margin}. ' + 'A non-positive margin would mask out the positive itself or every above-positive ' + 'logit indiscriminately, collapsing the contrastive signal.') self.temperature = temperature self.use_batch = use_batch self.hard_negatives = hard_negatives @@ -129,7 +123,13 @@ def __init__( self.process_group = process_group def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor): - """All-gather embeddings & labels across DP ranks; only local shard keeps grad.""" + """All-gather embeddings & labels across DP ranks; only local shard keeps grad. + + NCCL ``all_gather`` requires every rank to send the *same* tensor size. Under + ``slice_dp`` dispatch the per-rank batch is uneven (``divmod`` splits), so we + pad each rank to the global max along dim-0, do an equal-sized all_gather, + then strip padding back. Only the local shard retains gradients. + """ if not (dist.is_available() and dist.is_initialized()): return sentences, labels world_size = dist.get_world_size(group=self.process_group) @@ -137,24 +137,40 @@ def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor): return sentences, labels rank = dist.get_rank(group=self.process_group) - # variable per-rank shapes require communicating shape first - local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) - shapes = [torch.empty_like(local_shape) for _ in range(world_size)] - dist.all_gather(shapes, local_shape, group=self.process_group) - all_sentences = [sentences.new_empty(shape.tolist()) for shape in shapes] - dist.all_gather(all_sentences, sentences.contiguous(), group=self.process_group) - - local_label_shape = labels.new_tensor(labels.shape, dtype=torch.long) - label_shapes = [torch.empty_like(local_label_shape) for _ in range(world_size)] - dist.all_gather(label_shapes, local_label_shape, group=self.process_group) - all_labels = [labels.new_empty(shape.tolist()) for shape in label_shapes] - dist.all_gather(all_labels, labels.contiguous(), group=self.process_group) - - # keep the local shard differentiable; detach others - all_sentences[rank] = sentences + # ``labels`` is a 1-D mask aligned to ``sentences`` along dim-0, so they + # share the same per-rank size. Gather sizes once and reuse for both. + assert sentences.shape[0] == labels.shape[0], ( + f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}') + local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long) + sizes = [torch.empty_like(local_n) for _ in range(world_size)] + dist.all_gather(sizes, local_n, group=self.process_group) + sizes_int = [int(s.item()) for s in sizes] + max_n = max(sizes_int) + + def _pad_gather(tensor: torch.Tensor): + if tensor.shape[0] < max_n: + pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:]) + padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0) + else: + padded = tensor + buffers = [torch.empty_like(padded) for _ in range(world_size)] + dist.all_gather(buffers, padded.contiguous(), group=self.process_group) + return buffers + + sent_buffers = _pad_gather(sentences) + label_buffers = _pad_gather(labels) + + # Strip padding; keep local shard differentiable, detach others. + all_sentences = [] + all_labels = [] for idx in range(world_size): - if idx != rank: - all_sentences[idx] = all_sentences[idx].detach() + n = sizes_int[idx] + if idx == rank: + all_sentences.append(sentences) + all_labels.append(labels) + else: + all_sentences.append(sent_buffers[idx][:n].detach()) + all_labels.append(label_buffers[idx][:n]) return torch.cat(all_sentences, dim=0), torch.cat(all_labels, dim=0) def __call__(self, inputs, outputs, **kwargs) -> LossOutput: From 18526e508eab5c9f62d9a747b942d7d218eb954f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 17:17:38 +0800 Subject: [PATCH 07/53] wip --- src/twinkle/metric/accuracy.py | 1 - src/twinkle/metric/embedding.py | 6 +++--- src/twinkle/metric/train_metric.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/twinkle/metric/accuracy.py b/src/twinkle/metric/accuracy.py index b3034c57..4dfb0119 100644 --- a/src/twinkle/metric/accuracy.py +++ b/src/twinkle/metric/accuracy.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np from typing import List, Union from ..data_format import InputFeature, ModelOutput diff --git a/src/twinkle/metric/embedding.py b/src/twinkle/metric/embedding.py index 9fb3aed8..ec67ac02 100644 --- a/src/twinkle/metric/embedding.py +++ b/src/twinkle/metric/embedding.py @@ -1,7 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import torch -import torch.distributed as dist -import torch.nn.functional as F from typing import List, Union from twinkle.data_format import InputFeature, ModelOutput @@ -32,6 +29,9 @@ def reset(self): self.grad_norm = 0.0 def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): + import torch + import torch.distributed as dist + import torch.nn.functional as F sentences = outputs.get('embeddings') if sentences is None: sentences = outputs.get('logits') diff --git a/src/twinkle/metric/train_metric.py b/src/twinkle/metric/train_metric.py index da82a878..8d785c38 100644 --- a/src/twinkle/metric/train_metric.py +++ b/src/twinkle/metric/train_metric.py @@ -2,7 +2,7 @@ import time from typing import List, Union -from ..data_format import InputFeature, ModelOutput +from twinkle.data_format import InputFeature, ModelOutput from .base import Metric From 34a3e9a39e2584552d258beca4f56057aa1be8da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 17:39:33 +0800 Subject: [PATCH 08/53] wip --- src/twinkle/metric/dpo.py | 6 +++++ src/twinkle/metric/embedding.py | 40 +++++++++++++++++++++------------ src/twinkle/metric/grpo.py | 20 ++++++++++++++++- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index b203d255..024cb047 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -131,6 +131,12 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M ref_outputs = kwargs.get('ref_outputs') if ref_outputs is not None: ref_logps = ref_outputs.get('logps') + if ref_logps is not None: + if isinstance(ref_logps, list): + if len(ref_logps) == 0: + ref_logps = None + else: + ref_logps = pad_and_stack_tensors(ref_logps) if ref_logps is not None: # Align ref_logps to match labels shape (handles different seq lengths) ref_logps = self._align_logps(ref_logps, labels.shape, labels.device, logps.dtype) diff --git a/src/twinkle/metric/embedding.py b/src/twinkle/metric/embedding.py index ec67ac02..8b368103 100644 --- a/src/twinkle/metric/embedding.py +++ b/src/twinkle/metric/embedding.py @@ -44,22 +44,34 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M inputs = [inputs] labels = torch.cat([inp['labels'].view(-1) for inp in inputs], dim=0) - # Gather embeddings and labels across DP for in-batch stats + # Gather embeddings and labels across DP for in-batch stats. + # NCCL ``all_gather`` requires every rank to send the same tensor size, + # but ``slice_dp`` dispatch (``divmod`` split) can leave per-rank dim-0 + # uneven. Pad to the global max along dim-0, gather, then strip padding. if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: world_size = dist.get_world_size() - local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) - shapes = [torch.empty_like(local_shape) for _ in range(world_size)] - dist.all_gather(shapes, local_shape) - all_sentences = [sentences.new_empty(s.tolist()) for s in shapes] - dist.all_gather(all_sentences, sentences.contiguous()) - sentences = torch.cat(all_sentences, dim=0) - - local_lshape = labels.new_tensor(labels.shape, dtype=torch.long) - lshapes = [torch.empty_like(local_lshape) for _ in range(world_size)] - dist.all_gather(lshapes, local_lshape) - all_labels = [labels.new_empty(s.tolist()) for s in lshapes] - dist.all_gather(all_labels, labels.contiguous()) - labels = torch.cat(all_labels, dim=0) + assert sentences.shape[0] == labels.shape[0], ( + f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}') + local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long) + sizes = [torch.empty_like(local_n) for _ in range(world_size)] + dist.all_gather(sizes, local_n) + sizes_int = [int(s.item()) for s in sizes] + max_n = max(sizes_int) + + def _pad_gather(tensor: 'torch.Tensor') -> 'List[torch.Tensor]': + if tensor.shape[0] < max_n: + pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:]) + padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0) + else: + padded = tensor + buffers = [torch.empty_like(padded) for _ in range(world_size)] + dist.all_gather(buffers, padded.contiguous()) + return buffers + + sent_buffers = _pad_gather(sentences) + label_buffers = _pad_gather(labels) + sentences = torch.cat([sent_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0) + labels = torch.cat([label_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0) anchor_idx = torch.nonzero(labels, as_tuple=False).squeeze(-1) if anchor_idx.numel() == 0: diff --git a/src/twinkle/metric/grpo.py b/src/twinkle/metric/grpo.py index 06e082ee..e2797b1e 100644 --- a/src/twinkle/metric/grpo.py +++ b/src/twinkle/metric/grpo.py @@ -3,9 +3,12 @@ from typing import Any, Dict, List, Optional, Union from twinkle.data_format import InputFeature, ModelOutput +from twinkle.utils import get_logger from twinkle.utils.transformers_utils import align_logps_to_mask from .base import Metric +logger = get_logger() + class GRPOMetric(Metric): @@ -254,6 +257,11 @@ def accumulate( if len(seq_lens) == 1: merged = torch.cat(label_tensors, dim=0) inputs_list = [{'labels': merged}] + else: + logger.warning( + f'GRPOMetric: logps is a single tensor but inputs_list has ' + f'{len(inputs_list)} mb with mismatched seq_lens={sorted(seq_lens)}. ' + f'Only mb[0] will be accumulated; check the model forward path.') flat_old: Optional[List] = None if old_logps is not None and isinstance(old_logps, (list, tuple)): @@ -284,7 +292,17 @@ def accumulate( # Uncommon: aligned global tensor. Only honour when it # exactly matches the single-mb shape; otherwise drop. import torch as _torch # noqa: F811 - old_slice = old_logps if (_torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape) else None + if _torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape: + old_slice = old_logps + else: + if mb_idx == 0: + # Warn once per accumulate call (not per mb) to avoid log spam. + old_shape = tuple(old_logps.shape) if _torch.is_tensor(old_logps) else 'unknown' + logger.warning( + f'GRPOMetric: old_logps shape {old_shape} does not match ' + f'logps_mb shape {tuple(logps_mb.shape)}; ratio/kl metrics will ' + f'be skipped for this step.') + old_slice = None else: old_slice = None From 420d0da101124a625577c417052bd70a3ad6f823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 18:36:09 +0800 Subject: [PATCH 09/53] wip --- src/twinkle/data_format/output.py | 2 ++ src/twinkle/model/megatron/megatron.py | 34 +++++++++++++------------- src/twinkle/processor/base.py | 1 - 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/twinkle/data_format/output.py b/src/twinkle/data_format/output.py index 763ef246..596252fb 100644 --- a/src/twinkle/data_format/output.py +++ b/src/twinkle/data_format/output.py @@ -20,11 +20,13 @@ class ModelOutput(TypedDict, total=False): loss: The loss calculated by the model. logps: The log-probabilities of correct tokens by the model. num_tokens: The token denominator associated with ``loss``. + embeddings: The embeddings output by the model, used be embedding task. """ logits: Optional[OutputType] loss: Optional[OutputType] logps: Optional[OutputType] num_tokens: Optional[OutputType] + embeddings: Optional[OutputType] class LossOutput(TypedDict, total=False): diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index a5ea3fc5..bd24bc02 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -420,7 +420,7 @@ def forward_step_func(data_iterator, model): embeddings = output_tensor elif labels is not None and is_last_pp: _loss_require_logps = getattr(_loss_instance, 'require_logps', True) - _loss_require_entropy = (hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy) + _loss_require_entropy = getattr(_loss_instance, 'require_entropy', True) _packed = batch.get('packed_seq_params') cu_seqlens_q = getattr(_packed, 'cu_seqlens_q', None) if _packed is not None else None if _loss_require_logps: @@ -446,7 +446,7 @@ def forward_step_func(data_iterator, model): _outputs = {'logps': logps} if entropies is not None: _outputs['entropies'] = entropies - if hasattr(_loss_instance, 'require_logits') and _loss_instance.require_logits: + if getattr(_loss_instance, 'require_logits', False): _outputs['logits'] = output_tensor batch, _outputs = processor.unpack_packed_sequences(batch, _outputs) logps = _outputs['logps'] @@ -1139,7 +1139,7 @@ def _load_mcore_optimizer( ) iteration = self._read_iteration(tracker_path) if iteration == 0: - logging.getLogger(__name__).warning(f'No checkpoint found in {checkpoint_dir}') + logger.warning(f'No checkpoint found in {checkpoint_dir}') return iter_dir = os.path.join(checkpoint_dir, f'iter_{iteration:07d}') @@ -1216,21 +1216,21 @@ def _load_mcore_optimizer( @staticmethod def _read_iteration(tracker_path: str) -> int: - if not os.path.exists(tracker_path): - return 0 - with open(tracker_path) as f: - iteration = int(f.read().strip()) + # All ranks must enter the all_reduce together; missing tracker on some + # ranks (e.g. NFS lag, partial mount) must NOT short-circuit, otherwise + # the remaining ranks hang at the collective. Treat missing as 0 and + # let MAX reduction recover the canonical iteration from any rank that + # successfully read the file. + iteration = 0 + if os.path.exists(tracker_path): + with open(tracker_path) as f: + iteration = int(f.read().strip()) if torch.distributed.is_initialized(): - iters_cuda = torch.tensor( - [iteration], - dtype=torch.long, - device='cuda', - ) - torch.distributed.all_reduce( - iters_cuda, - op=torch.distributed.ReduceOp.MAX, - ) - iteration = iters_cuda[0].item() + # Use Platform.get_local_device() to stay backend-agnostic + # (CUDA / NPU / MPS); 'cuda' would crash on NPU. + iters_dev = torch.tensor([iteration], dtype=torch.long, device=Platform.get_local_device()) + torch.distributed.all_reduce(iters_dev, op=torch.distributed.ReduceOp.MAX) + iteration = int(iters_dev[0].item()) return iteration def _merge_lora_adapters(self, adapter_name: str = 'default'): diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 8709d98a..d600ec2b 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -776,6 +776,5 @@ def postprocess_tensor_cp(self, tensor, cu_seqlens=None): if self.device_mesh.cp_world_size <= 1: return tensor from megatron.core import parallel_state as mpu - from twinkle.utils.torch_utils import gather_cp_load_balanced return gather_cp_load_balanced(tensor, mpu.get_context_parallel_group(), seq_dim=1, cu_seqlens=cu_seqlens) From 5f093e5aa4c7ec9cea21f3b6c72dad9b31fbbda2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 21:06:32 +0800 Subject: [PATCH 10/53] wip --- src/twinkle/infra/__init__.py | 2 +- src/twinkle/model/megatron/megatron.py | 12 +++++++---- .../model/megatron/multi_lora_megatron.py | 15 ++++++++----- src/twinkle/server/state/backend/factory.py | 5 ++--- src/twinkle/server/state/base.py | 4 ++-- src/twinkle/server/state/session_manager.py | 4 ++-- src/twinkle/server/telemetry/provider.py | 3 ++- src/twinkle/server/telemetry/worker_init.py | 5 +++-- src/twinkle/utils/platforms/base.py | 21 +++++++++++++++++++ src/twinkle/utils/platforms/gpu.py | 13 ++++++++++++ src/twinkle/utils/platforms/mps.py | 19 +++++++++++++++++ src/twinkle/utils/platforms/npu.py | 21 +++++++++++++++++++ 12 files changed, 104 insertions(+), 20 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 23fb1fda..1227584f 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -687,7 +687,7 @@ def __next__(_self): return decorator -def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callable] = 'slice', +def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp', 'last_pp_first'], Callable] = 'slice', execute: Literal['first', 'peer', 'all'] = 'all', collect: Union[Literal['none', 'flatten', 'mean', 'sum', 'first', 'last_pp'], Callable] = 'none', sync: bool = False, diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index bd24bc02..c7174df9 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -990,7 +990,9 @@ def _get_rng_state() -> 'ShardedObject': 'random_rng_state': random.getstate(), 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state(), + # Backend-agnostic device RNG (CUDA / NPU / MPS); key kept as + # 'cuda_rng_state' for backward compatibility with existing checkpoints. + 'cuda_rng_state': Platform.get_device_rng_state(), 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states(), } rng_state_list = [rng_state] @@ -1112,7 +1114,7 @@ def _save_mcore_optimizer( with open(tracker_path, 'w') as f: f.write(str(iteration)) - logging.getLogger(__name__).info(f'Saved mcore optimizer state at iteration {iteration} ' + logger.info(f'Saved mcore optimizer state at iteration {iteration} ' f'to {checkpoint_dir}') def _load_mcore_optimizer( @@ -1201,7 +1203,9 @@ def _load_mcore_optimizer( random.setstate(rng['random_rng_state']) np.random.set_state(rng['np_rng_state']) torch.set_rng_state(rng['torch_rng_state']) - torch.cuda.set_rng_state(rng['cuda_rng_state']) + # Backend-agnostic restore: tolerates ckpt produced on different backend + # (returns None) and avoids hard-coded torch.cuda which crashes on NPU. + Platform.set_device_rng_state(rng.get('cuda_rng_state')) tensor_parallel.get_cuda_rng_tracker().set_states(rng['rng_tracker_states'], ) # Restore iteration counter. @@ -1211,7 +1215,7 @@ def _load_mcore_optimizer( if dist.is_initialized(): dist.barrier() - logging.getLogger(__name__).info(f'Resumed from mcore checkpoint at iteration {iteration} ' + logger.info(f'Resumed from mcore checkpoint at iteration {iteration} ' f'from {checkpoint_dir}') @staticmethod diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 2dd6b7a5..9a7e7784 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -15,7 +15,7 @@ from transformers import AutoConfig, PretrainedConfig from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union -from twinkle import DeviceMesh, remote_class, remote_function, requires, template, torch_util +from twinkle import DeviceMesh, Platform, remote_class, remote_function, requires, template, torch_util from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation from twinkle.infra import collect_tensor_dict @@ -221,8 +221,11 @@ def _save_local_training_rng_state(): 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), } - if torch.cuda.is_available(): - rng_state['cuda_rng_state'] = torch.cuda.get_rng_state() + # Backend-agnostic device RNG capture (CUDA / NPU / MPS). Key is kept as + # 'cuda_rng_state' for backward compatibility with existing checkpoints. + device_rng = Platform.get_device_rng_state() + if device_rng is not None: + rng_state['cuda_rng_state'] = device_rng rng_state['rng_tracker_states'] = tensor_parallel.get_cuda_rng_tracker().get_states() return rng_state @@ -233,8 +236,10 @@ def _load_local_training_rng_state(rng_state): random.setstate(rng_state['random_rng_state']) np.random.set_state(rng_state['np_rng_state']) torch.set_rng_state(rng_state['torch_rng_state']) - if 'cuda_rng_state' in rng_state and torch.cuda.is_available(): - torch.cuda.set_rng_state(rng_state['cuda_rng_state']) + # Backend-agnostic device RNG restore: tolerates ckpt produced on different + # backend (key absent or None) and avoids hard-coded torch.cuda on NPU. + if 'cuda_rng_state' in rng_state: + Platform.set_device_rng_state(rng_state['cuda_rng_state']) tensor_parallel.get_cuda_rng_tracker().set_states(rng_state['rng_tracker_states']) def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kwargs): diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py index 326a6916..be24c782 100644 --- a/src/twinkle/server/state/backend/factory.py +++ b/src/twinkle/server/state/backend/factory.py @@ -1,12 +1,11 @@ """Backend factory for creating StateBackend instances based on configuration.""" from __future__ import annotations -import logging - from twinkle.server.config.persistence import PersistenceConfig +from twinkle.utils import get_logger from .base import StateBackend -logger = logging.getLogger(__name__) +logger = get_logger() def create_backend(config: PersistenceConfig | None = None) -> StateBackend: diff --git a/src/twinkle/server/state/base.py b/src/twinkle/server/state/base.py index d931ce33..8cb055ae 100644 --- a/src/twinkle/server/state/base.py +++ b/src/twinkle/server/state/base.py @@ -1,7 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations -import logging import time from abc import ABC, abstractmethod from datetime import datetime, timezone @@ -9,9 +8,10 @@ from typing import Generic, TypeVar from twinkle.server.state.backend.base import StateBackend +from twinkle.utils import get_logger T = TypeVar('T', bound=BaseModel) -logger = logging.getLogger(__name__) +logger = get_logger() class BaseManager(ABC, Generic[T]): diff --git a/src/twinkle/server/state/session_manager.py b/src/twinkle/server/state/session_manager.py index 442019ea..1bc901b0 100644 --- a/src/twinkle/server/state/session_manager.py +++ b/src/twinkle/server/state/session_manager.py @@ -2,14 +2,14 @@ from __future__ import annotations import functools -import logging import time +from twinkle.utils import get_logger from .backend.base import ConcurrencyError, StateBackend from .base import BaseManager from .models import SessionRecord -logger = logging.getLogger(__name__) +logger = get_logger() def _session_touch_transform(existing: dict | None, *, now: float) -> dict | None: diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py index 77212c75..059f301f 100644 --- a/src/twinkle/server/telemetry/provider.py +++ b/src/twinkle/server/telemetry/provider.py @@ -16,8 +16,9 @@ from typing import Any from twinkle.server.config.telemetry import TelemetryConfig +from twinkle.utils import get_logger -logger = logging.getLogger(__name__) +logger = get_logger() # Loggers belonging to the OTLP transport stack. Their own records must never # be routed back through the OTLP LoggingHandler: an exporter error logged diff --git a/src/twinkle/server/telemetry/worker_init.py b/src/twinkle/server/telemetry/worker_init.py index 997f2e14..40edc662 100644 --- a/src/twinkle/server/telemetry/worker_init.py +++ b/src/twinkle/server/telemetry/worker_init.py @@ -7,10 +7,11 @@ """ from __future__ import annotations -import logging import os -logger = logging.getLogger(__name__) +from twinkle.utils import get_logger + +logger = get_logger() _worker_initialized = False diff --git a/src/twinkle/utils/platforms/base.py b/src/twinkle/utils/platforms/base.py index 71c2bd18..483c725a 100644 --- a/src/twinkle/utils/platforms/base.py +++ b/src/twinkle/utils/platforms/base.py @@ -136,3 +136,24 @@ def device_backend(platform: str = None): def get_vllm_device_uuid(device_id: int = 0, platform=None) -> str: platform = Platform.get_platform(platform) return platform.get_vllm_device_uuid(device_id) + + @staticmethod + def get_device_rng_state(platform: str = None): + """Return device-specific RNG state (e.g. CUDA / NPU / MPS). + + Backend-agnostic replacement for hard-coded ``torch.cuda.get_rng_state()``. + Returns ``None`` when no accelerator is available, so callers can safely + skip persistence on CPU-only or unsupported devices. + """ + return Platform.get_platform(platform).get_device_rng_state() + + @staticmethod + def set_device_rng_state(state, *, platform: str = None) -> None: + """Restore device-specific RNG state. + + No-op when ``state`` is ``None`` (e.g. checkpoint produced on a different + backend) or when the current platform has no accelerator available. + """ + if state is None: + return + Platform.get_platform(platform).set_device_rng_state(state) diff --git a/src/twinkle/utils/platforms/gpu.py b/src/twinkle/utils/platforms/gpu.py index 0b99f885..0f213448 100644 --- a/src/twinkle/utils/platforms/gpu.py +++ b/src/twinkle/utils/platforms/gpu.py @@ -24,3 +24,16 @@ def device_backend(platform: str = None): def get_vllm_device_uuid(device_id: int = 0) -> str: from vllm.platforms import current_platform return current_platform.get_device_uuid(device_id) + + @staticmethod + def get_device_rng_state(): + import torch + if torch.cuda.is_available(): + return torch.cuda.get_rng_state() + return None + + @staticmethod + def set_device_rng_state(state) -> None: + import torch + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state) diff --git a/src/twinkle/utils/platforms/mps.py b/src/twinkle/utils/platforms/mps.py index e99abb0e..86ecf751 100644 --- a/src/twinkle/utils/platforms/mps.py +++ b/src/twinkle/utils/platforms/mps.py @@ -40,3 +40,22 @@ def device_backend(platform: str = None): @staticmethod def get_vllm_device_uuid(device_id: int = 0) -> str: raise NotImplementedError + + @staticmethod + def get_device_rng_state(): + import torch + if hasattr(torch, 'mps') and hasattr(torch.mps, 'get_rng_state'): + try: + return torch.mps.get_rng_state() + except Exception: # noqa: BLE001 + return None + return None + + @staticmethod + def set_device_rng_state(state) -> None: + import torch + if hasattr(torch, 'mps') and hasattr(torch.mps, 'set_rng_state'): + try: + torch.mps.set_rng_state(state) + except Exception: # noqa: BLE001 + pass diff --git a/src/twinkle/utils/platforms/npu.py b/src/twinkle/utils/platforms/npu.py index 89066b28..de15707f 100644 --- a/src/twinkle/utils/platforms/npu.py +++ b/src/twinkle/utils/platforms/npu.py @@ -133,3 +133,24 @@ def get_vllm_device_uuid(device_id: int = 0) -> str: visible = os.environ.get(Platform.visible_device_env()) raw = f'{socket.gethostname()}:{visible}:{device_id}' return hashlib.sha1(raw.encode('utf-8')).hexdigest()[:16] + + @staticmethod + def get_device_rng_state(): + import torch + try: + import torch_npu # noqa: F401 + except ImportError: + return None + if hasattr(torch, 'npu') and torch.npu.is_available(): + return torch.npu.get_rng_state() + return None + + @staticmethod + def set_device_rng_state(state) -> None: + import torch + try: + import torch_npu # noqa: F401 + except ImportError: + return + if hasattr(torch, 'npu') and torch.npu.is_available(): + torch.npu.set_rng_state(state) From 520f797c32677b4863981e3b8915569f594e1b63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 21:24:13 +0800 Subject: [PATCH 11/53] wip --- src/twinkle/model/megatron/megatron.py | 28 +++++++++++-------- .../model/megatron/multi_lora_megatron.py | 10 +++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index c7174df9..60b45f77 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1260,7 +1260,7 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non For distributed training: - All PP ranks participate in export (each has different layers) - - Only DP rank 0 actually writes to disk + - Only global rank 0 actually writes shared config files - Uses barrier for synchronization For LoRA training: @@ -1268,12 +1268,9 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non """ # Check if this is LoRA training is_peft_format = (adapter_name != _default_adapter_name) + is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0 - # Create output directory on rank 0 only - from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 - - if dp_rank == 0: + if is_global_zero: os.makedirs(output_dir, exist_ok=True) # Synchronize before saving @@ -1285,8 +1282,8 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non self.strategy.bridge.save_weights( model, output_dir, peft_format=is_peft_format, adapter_name=adapter_name, converter=lora_converter) - # Save config on rank 0 only - if dp_rank == 0: + # Save config on global rank 0 only (avoid concurrent writers). + if is_global_zero: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): config = model[0].peft_config[adapter_name] @@ -1295,11 +1292,13 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non model[0].peft_config[adapter_name].save_pretrained(output_dir) config.target_modules = target_modules + if dist.is_initialized(): + dist.barrier() + def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None): """Save in Megatron checkpoint format.""" + is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0 os.makedirs(output_dir, exist_ok=True) - from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 state_dict = self._get_trainable_parameters(adapter_name) cpu_state_dict = {} for k, v in state_dict.items(): @@ -1315,13 +1314,18 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert rank = dist.get_rank() if dist.is_initialized() else 0 checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt') torch.save(cpu_state_dict, checkpoint_path) - # Save config on rank 0 only + # Save shared config on global rank 0 only (avoid concurrent writers). model = self.strategy.unwrap_model(self.model) - if dp_rank == 0: + if is_global_zero: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): model[0].peft_config[adapter_name].save_pretrained(output_dir) + # Finalize barrier: ensure all ranks finish writing model_rank*.pt + # before the caller proceeds (e.g. uploading / loading the ckpt). + if dist.is_initialized(): + dist.barrier() + def _save_tokenizer(self, output_dir: str, **kwargs): from twinkle.utils import is_last_rank if not is_last_rank(): diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 9a7e7784..ae2eb4b2 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -256,6 +256,9 @@ def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kw torch.save(state_dict, self._rank_local_optimizer_path(checkpoint_dir)) + if dist.is_initialized(): + dist.barrier() + def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '', **kwargs): no_load_optim = kwargs.pop('no_load_optim', False) no_load_rng = kwargs.pop('no_load_rng', False) @@ -265,6 +268,13 @@ def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '' if not no_load_optim and optimizer_config is not None: if optimizer_config.optimizer is not None and 'optimizer' in state_dict: optimizer_config.optimizer.load_state_dict(state_dict['optimizer']) + device = Platform.get_local_device() + for group_state in optimizer_config.optimizer.state.values(): + if not isinstance(group_state, dict): + continue + for k, v in group_state.items(): + if isinstance(v, torch.Tensor): + group_state[k] = v.to(device) if optimizer_config.lr_scheduler is not None and 'opt_param_scheduler' in state_dict: optimizer_config.lr_scheduler.load_state_dict(state_dict['opt_param_scheduler']) if not no_load_rng and 'rng_state' in state_dict: From a79fb3b4b62f368c3c2a5f2c4e99fd583465f02e Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 15 Jun 2026 22:00:24 +0800 Subject: [PATCH 12/53] fix --- .../exp/embedding/build_thinking_rag_index.py | 66 +++-- .../exp/embedding/train_embedding_full_ddp.py | 242 +++++++++--------- 2 files changed, 170 insertions(+), 138 deletions(-) diff --git a/cookbook/exp/embedding/build_thinking_rag_index.py b/cookbook/exp/embedding/build_thinking_rag_index.py index cf4d1dd7..0299aa48 100644 --- a/cookbook/exp/embedding/build_thinking_rag_index.py +++ b/cookbook/exp/embedding/build_thinking_rag_index.py @@ -2,9 +2,10 @@ Pipeline (per row, batched): 1. Load (user_query, reasoning_content) pairs from ``dataset_think.get_dataset``. - 2. Compress query with ``FIXED_QUERY_NEED`` and cot with ``FIXED_QUERY_SKILL`` - using a Twinkle ``vLLMSampler`` (TP=4 across GPUs 0-3). Reuses the prompt - suite from ``cookbook/exp/condenser/make_condenser_dataset.py``. + 2. Compress query with ``RAG_QUERY_HINT`` and cot with ``RAG_THINKING_HINT`` + (a symmetric Problem/Skill/Knowledge schema defined in this file) using a + Twinkle ``vLLMSampler`` (TP=4 across GPUs 0-3). Reuses the system/user + wrappers from ``cookbook/exp/condenser/make_condenser_dataset.py``. 3. On condenser truncation (``stop_reason='length'`` or skeleton-incomplete output), fall back to an external OpenAI-compatible API. 4. Encode the condensed pair via the trained embedding model — Twinkle @@ -55,8 +56,6 @@ from make_condenser_dataset import ( # noqa: E402 (after sys.path tweak) COMPRESS_SYSTEM, COMPRESS_USER, - FIXED_QUERY_NEED, - FIXED_QUERY_SKILL, ) # Default dataset loader is the index-time corpus (broader retrieval profile); # pass --dataset-module dataset_think to fall back to the training mix. @@ -84,7 +83,7 @@ EMBED_MODEL_ID = os.environ.get( 'EMBED_MODEL_ID', - 'ms://twinkle-kit/Qwen3.5-4B-QA-emb', + '/mnt/workspace/yzhao/tastelikefeet/Qwen3.5-4B-QA-emb-v2', ) CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2') @@ -106,6 +105,31 @@ SIM_THRESHOLD = float(os.environ.get('SIM_THRESHOLD', 0.65)) MIN_TEXT_CHARS = int(os.environ.get('MIN_TEXT_CHARS', 256)) +# Hard-templated hints: the condenser SFT prior maps `Skill` to the legacy +# `Use when: / numbered steps / Output:` skeleton on long inputs; embedding the +# exact 4-line body template + explicit negative constraints is the only way to +# override it deterministically across query and cot sides. +RAG_QUERY_HINT = ( + 'Summarize this query for retrieval. ' + 'The body of ## Summary MUST follow this EXACT 4-line template — ' + 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n' + 'Topic: \n' + 'Problem: \n' + 'Skill: \n' + 'Knowledge: \n' + 'Then emit the mandatory ## More section as usual. ' + 'Topic must name the specific pattern, never generic labels.') +RAG_THINKING_HINT = ( + 'Summarize this reasoning trace for retrieval. ' + 'The body of ## Summary MUST follow this EXACT 4-line template — ' + 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n' + 'Topic: \n' + 'Problem: \n' + 'Skill: \n' + 'Knowledge: \n' + 'Then emit the mandatory ## More section as usual. ' + 'Topic must name the specific pattern, never generic labels.') + # OpenAI API fallback (used when vLLM truncates). COMPRESS_API_KEY = os.environ.get('COMPRESS_API_KEY', '') COMPRESS_BASE_URL = os.environ.get( @@ -128,12 +152,19 @@ # Small helpers # =========================================================================== -def _is_truncated_compression(text: str) -> bool: - """Detect structurally incomplete condenser output even when stop='stop'. +_LEGACY_USE_WHEN_RE = re.compile(r'(?im)^\s*Use when\s*:') +_SCHEMA_MARKERS = ('Problem:', 'Skill:', 'Knowledge:') + - The skeleton mandates a ``## Summary`` plus a ``## More`` bullet list whose - last line begins with ``-`` or ends with ``)`` (the ``(none)`` sentinel). - Anything short of that signals truncation; the API fallback is invoked. +def _is_truncated_compression(text: str) -> bool: + """Reject structurally incomplete OR schema-regressed condenser output. + + Triggers API fallback when the vLLM output: + * lacks ``## Summary`` / ``## More``, + * has an empty or unterminated ``## More`` bullet list, or + * regresses to the legacy ``Use when: / numbered-steps / Output:`` skeleton + instead of the mandated Problem/Skill/Knowledge 4-line body — the + dominant cot-side failure mode that drives sim < 0.45 drops. """ if not text or not text.strip(): return True @@ -145,6 +176,11 @@ def _is_truncated_compression(text: str) -> bool: last_line = after_more.splitlines()[-1].strip() if not (last_line.startswith('-') or last_line.endswith(')')): return True + summary_body = text.split('## Summary', 1)[1].split('## More', 1)[0] + if _LEGACY_USE_WHEN_RE.search(summary_body): + return True + if not all(marker in summary_body for marker in _SCHEMA_MARKERS): + return True return False @@ -494,11 +530,11 @@ def _flush(rows: List[Dict[str, Any]]) -> None: nonlocal n_kept, n_dropped_compress, n_dropped_sim if not rows: return - # Phase 1 — compress query (FIXED_QUERY_NEED) and cot (FIXED_QUERY_SKILL). + # Phase 1 — compress query (RAG_QUERY_HINT) and cot (RAG_THINKING_HINT). q_compressed = _resolve_compressed( - sampler, api, [r['query_raw'] for r in rows], FIXED_QUERY_NEED) + sampler, api, [r['query_raw'] for r in rows], RAG_QUERY_HINT) c_compressed = _resolve_compressed( - sampler, api, [r['cot_raw'] for r in rows], FIXED_QUERY_SKILL) + sampler, api, [r['cot_raw'] for r in rows], RAG_THINKING_HINT) kept_rows: List[Dict[str, Any]] = [] for r, q_cmp, c_cmp in zip(rows, q_compressed, c_compressed): if not q_cmp or not c_cmp: @@ -655,7 +691,7 @@ def _flush(rows: List[Dict[str, Any]]) -> None: if not rows: return compressed = _resolve_compressed( - sampler, api, [r['query_raw'] for r in rows], FIXED_QUERY_NEED) + sampler, api, [r['query_raw'] for r in rows], RAG_QUERY_HINT) useful = [(r, c) for r, c in zip(rows, compressed) if c] if not useful: return diff --git a/cookbook/exp/embedding/train_embedding_full_ddp.py b/cookbook/exp/embedding/train_embedding_full_ddp.py index bed1dbe9..9ac925b6 100644 --- a/cookbook/exp/embedding/train_embedding_full_ddp.py +++ b/cookbook/exp/embedding/train_embedding_full_ddp.py @@ -1,14 +1,12 @@ -"""LoRA embedding training with online condenser self-improvement. +"""LoRA embedding training with online compression via frozen vLLM condenser. Architecture (8 GPUs total): - Ranks 0-3 (``model``): Trainable embedding model with LoRA, InfoNCE loss. - - Ranks 4-5 (``condenser_sampler``): Frozen vLLM condenser for online compression. - - Ranks 6-7 (``condenser_model``): Trainable condenser with LoRA for self-improvement. + - Ranks 4-7 (``condenser_sampler``): Frozen vLLM condenser for online compression. -When the condenser sampler truncates (stop_reason='length'), an external OpenAI- -compatible API produces the correct compression. The failure is logged as SFT -training data. A background thread retrains the condenser on accumulated failures -mixed with condense_300K, then syncs weights back to the sampler. +When the condenser sampler truncates or regresses to the legacy schema, an +external OpenAI-compatible API produces the correct compression. The failure is +logged to failures.jsonl for offline SFT data regeneration. Launch: python cookbook/exp/train_embedding_lora_ddp.py @@ -19,6 +17,8 @@ import re import sys import threading +import time +from collections import deque from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Dict, List, Literal, Optional @@ -27,7 +27,6 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger -from twinkle.checkpoint_engine import CheckpointEngineManager from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.loss import InfonceLoss @@ -55,14 +54,14 @@ # -- GPU placement (8 total) -------------------------------------------------- MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -CONDENSER_SAMPLER_GPUS = int(os.environ.get('CONDENSER_SAMPLER_GPUS', 2)) -CONDENSER_MODEL_GPUS = int(os.environ.get('CONDENSER_MODEL_GPUS', 2)) -NUM_GPUS = MODEL_GPUS + CONDENSER_SAMPLER_GPUS + CONDENSER_MODEL_GPUS +CONDENSER_SAMPLER_GPUS = int(os.environ.get('CONDENSER_SAMPLER_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + CONDENSER_SAMPLER_GPUS # -- Embedding training hyper-params ------------------------------------------ EMB_MAX_LENGTH = 8192 HARD_NEGATIVES = None -TEMPERATURE = 0.03 +# 0.07 keeps gradient on diag pairs until cosine clears ~0.75; 0.03 saturated near 0.40. +TEMPERATURE = 0.07 BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32)) LEARNING_RATE = 1.5e-6 @@ -92,16 +91,15 @@ COMPRESS_TOP_P = 0.5 COMPRESS_MAX_MODEL_LEN = 32768 +# Prefetch depth: >1 lets next batch's vLLM/API run while current batch trains. +PREFETCH_WORKERS = int(os.environ.get('PREFETCH_WORKERS', 2)) + # -- OpenAI API fallback for truncated compressions --------------------------- COMPRESS_API_KEY = os.environ.get('COMPRESS_API_KEY', '') COMPRESS_BASE_URL = os.environ.get('COMPRESS_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1') COMPRESS_MODEL = os.environ.get('COMPRESS_MODEL', 'qwen3.7-max') - -# -- Condenser retraining knobs ----------------------------------------------- -CONDENSER_DATASET_ID = 'ms://twinkle-kit/condense_300K' -CONDENSER_RETRAIN_SAMPLES = 128 -CONDENSER_RETRAIN_EPOCHS = 3 -CONDENSER_RETRAIN_LR = 1e-5 +# Minimum gap between API calls (seconds); bounds dashscope qps under provider limits. +API_MIN_INTERVAL = float(os.environ.get('API_MIN_INTERVAL', 0.1)) # -- Output paths ------------------------------------------------------------- OUTPUT_DIR = f'./output/embedding_lora_{BACKEND}' @@ -209,6 +207,20 @@ _sample_counter = 0 _sample_counter_lock = threading.Lock() +# Serializes vLLM sample() across PREFETCH_WORKERS threads. +_sampler_lock = threading.Lock() + +_api_throttle_lock = threading.Lock() +_api_last_call = [0.0] + + +def _api_throttle(): + with _api_throttle_lock: + gap = time.monotonic() - _api_last_call[0] + if gap < API_MIN_INTERVAL: + time.sleep(API_MIN_INTERVAL - gap) + _api_last_call[0] = time.monotonic() + def _next_sample_id() -> int: global _sample_counter @@ -319,11 +331,40 @@ def save_checkpoint(model, name: str): # Compression prompt building # ============================================================================= +# Hard-templated hints: the condenser SFT prior maps `Skill` to the legacy +# `Use when: / numbered steps / Output:` skeleton on long inputs; embedding the +# exact 4-line body template + explicit negative constraints is the only way to +# override it deterministically across query and cot sides. EMBED_QUERY_Q = ( + 'Summarize this query for retrieval. ' + 'The body of ## Summary MUST follow this EXACT 4-line template — ' + 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n' + 'Topic: \n' + 'Problem: \n' + 'Skill: \n' + 'Knowledge: \n' + 'Then emit the mandatory ## More section as usual. ' + 'Topic must name the specific pattern, never generic labels.') +EMBED_QUERY_COT = ( + 'Summarize this reasoning trace for retrieval. ' + 'The body of ## Summary MUST follow this EXACT 4-line template — ' + 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n' + 'Topic: \n' + 'Problem: \n' + 'Skill: \n' + 'Knowledge: \n' + 'Then emit the mandatory ## More section as usual. ' + 'Topic must name the specific pattern, never generic labels.') + +# Legacy schema (Use when: / numbered steps / Output:) — mixed in 50/50 with the +# new schema to expose the embedder to schema-invariant semantic alignment. +# Both query and cot of the SAME pair always use the SAME schema; cross-schema +# anchors and positives would re-introduce the schema asymmetry we just fixed. +EMBED_QUERY_Q_LEGACY = ( 'What problem does this passage address, and what skill or method is needed? ' 'Topic must name the specific pattern, never generic labels. ' 'Compress into a retrieval-friendly need description.') -EMBED_QUERY_COT = ( +EMBED_QUERY_COT_LEGACY = ( 'Extract the reusable skill: trigger conditions, key steps, and expected output. ' 'Topic names the method/pattern; format as "Use when: ...", numbered steps, ' '"Output: ...". Compress into a standardized procedure for retrieval.') @@ -347,22 +388,25 @@ def _extract_query_cot(row: Dict[str, Any]): def _build_compress_prompts(rows: List[Dict[str, Any]]) -> tuple: """Build prompts for compressing both query and cot per row. - Returns (prompts, valid_indices, raw_pairs, prompt_queries, passthrough) where: + Returns (prompts, valid_indices, raw_pairs, prompt_queries, passthrough, schemas) + where: - prompts: flat-interleaved [query_0, cot_0, query_1, cot_1, ...]; ``None`` means passthrough (use raw text directly, do not call sampler) - valid_indices: which rows passed the min-length filter - raw_pairs: [(query, cot), ...] - prompt_queries: the query string used for each prompt (for failure logging) - passthrough: parallel to prompts; non-None text means "use this verbatim as qc" + - schemas: parallel to prompts; 'new' or 'legacy', drives validator branch """ prompts: List[Optional[Dict[str, Any]]] = [] valid_indices: List[int] = [] raw_pairs: List[tuple] = [] prompt_queries: List[str] = [] passthrough: List[Optional[str]] = [] + schemas: List[str] = [] # Conservative char budget: 32768 max_length - 8192 gen - ~2k prompt overhead = ~22k tokens. - # At worst-case 1.5 chars/token (CJK), that's ~33k chars. Use 60k as safe English-mix cap. - _MAX_COT_CHARS = 60_000 + # 30k cap bounds vLLM batch latency (vLLM batches by max prompt length). + _MAX_COT_CHARS = 30_000 for i, row in enumerate(rows): query, cot = _extract_query_cot(row) if not query or len(cot) < MIN_TEXT_CHARS: @@ -371,26 +415,32 @@ def _build_compress_prompts(rows: List[Dict[str, Any]]) -> tuple: continue valid_indices.append(i) raw_pairs.append((query, cot)) + # 50/50 schema mix; same schema for query+cot of one pair to keep alignment. + schema = 'legacy' if (i % 2 == 0) else 'new' + q_hint = EMBED_QUERY_Q_LEGACY if schema == 'legacy' else EMBED_QUERY_Q + c_hint = EMBED_QUERY_COT_LEGACY if schema == 'legacy' else EMBED_QUERY_COT # Short query bypasses condenser to avoid skeleton-induced hallucination. if len(query) < MIN_TEXT_CHARS: prompts.append(None) passthrough.append(query) else: - user = COMPRESS_USER.format(query=EMBED_QUERY_Q, text=query) + user = COMPRESS_USER.format(query=q_hint, text=query) prompts.append({'messages': [ {'role': 'system', 'content': COMPRESS_SYSTEM}, {'role': 'user', 'content': user}, ]}) passthrough.append(None) - prompt_queries.append(EMBED_QUERY_Q) - user = COMPRESS_USER.format(query=EMBED_QUERY_COT, text=cot) + prompt_queries.append(q_hint) + schemas.append(schema) + user = COMPRESS_USER.format(query=c_hint, text=cot) prompts.append({'messages': [ {'role': 'system', 'content': COMPRESS_SYSTEM}, {'role': 'user', 'content': user}, ]}) - prompt_queries.append(EMBED_QUERY_COT) + prompt_queries.append(c_hint) passthrough.append(None) - return prompts, valid_indices, raw_pairs, prompt_queries, passthrough + schemas.append(schema) + return prompts, valid_indices, raw_pairs, prompt_queries, passthrough, schemas def _get_first_feature(decoded_text: str, template: Template, role: str) -> Optional[Dict[str, Any]]: @@ -415,13 +465,24 @@ def _get_first_feature(decoded_text: str, template: Template, role: str) -> Opti # OpenAI API fallback # ============================================================================= -def _is_truncated_compression(text: str) -> bool: - """Detect structurally incomplete output that vLLM may report as stop_reason='stop'. +_LEGACY_USE_WHEN_RE = re.compile(r'(?im)^\s*Use when\s*:') +_SCHEMA_MARKERS = ('Problem:', 'Skill:', 'Knowledge:') + + +def _is_truncated_compression(text: str, schema: str = 'new') -> bool: + """Reject structurally incomplete OR schema-regressed condenser output. + + Triggers API fallback when the vLLM output: + * lacks ``## Summary`` / ``## More``, + * has an empty or unterminated ``## More`` bullet list, or + * (schema='new' only) regresses to the legacy ``Use when: / numbered-steps / + Output:`` skeleton instead of the mandated Problem/Skill/Knowledge 4-line + body — the dominant cot-side failure mode that drives sim < 0.45 drops on + the RAG index. - The condenser sometimes emits a chat-template token mid-skeleton (which we then - strip), so the visible text ends mid-sentence even though stop_reason!='length'. - The COMPRESS_SYSTEM skeleton mandates a `## More` section ending in a bullet list; - its absence is an unambiguous truncation signal. + For schema='legacy', body markers are intentionally NOT enforced: the legacy + template legitimately emits ``Use when:`` and the SFT prior already produces + that shape natively, so only structural completeness is checked. """ if not text or not text.strip(): return True @@ -433,11 +494,18 @@ def _is_truncated_compression(text: str) -> bool: last_line = after_more.splitlines()[-1].strip() if not (last_line.startswith('-') or last_line.endswith(')')): return True + if schema == 'new': + summary_body = text.split('## Summary', 1)[1].split('## More', 1)[0] + if _LEGACY_USE_WHEN_RE.search(summary_body): + return True + if not all(marker in summary_body for marker in _SCHEMA_MARKERS): + return True return False def _api_compress(api_client: OpenAIClient, prompt: Dict[str, Any]) -> Optional[str]: """Call external API to compress when vLLM truncates.""" + _api_throttle() trajectory = {'messages': prompt['messages']} # Cap max_tokens to leave ample prompt headroom inside the API model context. sp = SamplingParams(temperature=0.2, max_tokens=8192) @@ -456,61 +524,12 @@ def _api_compress(api_client: OpenAIClient, prompt: Dict[str, Any]) -> Optional[ return content -# ============================================================================= -# Condenser Retrainer (background thread) -# ============================================================================= - -class CondenserRetrainer: - """Async condenser self-improvement: retrains from failures, syncs to sampler.""" - - def __init__(self, condenser_model, ckpt_manager: CheckpointEngineManager, - condenser_sampler): - self._model = condenser_model - self._ckpt_manager = ckpt_manager - self._sampler = condenser_sampler - self._signal = threading.Event() - self._stop = threading.Event() - self._thread = threading.Thread(target=self._loop, daemon=True) - self._condense_300k_cache = None - self._retrain_count = 0 - # Prevents sample() and sync_weights() from running concurrently - self.sampler_lock = threading.Lock() - - def start(self): - self._thread.start() - - def stop(self): - self._stop.set() - self._signal.set() - self._thread.join(timeout=10) - - def notify_failure(self): - self._signal.set() - - def _loop(self): - while not self._stop.is_set(): - self._signal.wait(timeout=60) - if self._stop.is_set(): - break - if not self._signal.is_set(): - continue - self._signal.clear() - try: - self._retrain_and_sync() - except Exception as exc: - logger.error(f'[condenser_retrain] crashed: {exc}') - - def _retrain_and_sync(self): - # Retrain + sync temporarily disabled; failures.jsonl is written directly by _log_failure. - pass - - # ============================================================================= # Main training # ============================================================================= def train(): - # -------- Device groups (3 groups) ---------------------------------------- + # -------- Device groups (2 groups) ---------------------------------------- device_groups = [ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), @@ -518,15 +537,10 @@ def train(): DeviceGroup(name='condenser_sampler', ranks=list(range(MODEL_GPUS, MODEL_GPUS + CONDENSER_SAMPLER_GPUS)), device_type='GPU'), - DeviceGroup(name='condenser_model', - ranks=list(range(MODEL_GPUS + CONDENSER_SAMPLER_GPUS, NUM_GPUS)), - device_type='GPU'), ] model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) condenser_sampler_mesh = DeviceMesh.from_sizes( world_size=CONDENSER_SAMPLER_GPUS, dp_size=CONDENSER_SAMPLER_GPUS) - condenser_model_mesh = DeviceMesh.from_sizes( - world_size=CONDENSER_MODEL_GPUS, dp_size=1, fsdp_size=CONDENSER_MODEL_GPUS) twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) @@ -557,7 +571,7 @@ def train(): setup_optimizer(model, optimizer_steps) model.add_metric(EmbeddingMetric, is_training=True) - # -------- Condenser sampler (2 GPU, vLLM) -------------------------------- + # -------- Condenser sampler (4 GPU, vLLM) -------------------------------- emb_template = Template(model_id=MODEL_ID, max_length=EMB_MAX_LENGTH, enable_thinking=False) # Special tokens come from the condenser tokenizer because the leak we strip is in its decoded output. condenser_template = Template(model_id=CONDENSE_MODEL_ID, max_length=DATASET_MAX_TOKENS, @@ -582,24 +596,6 @@ def train(): num_samples=1, ) - # -------- Condenser model (2 GPU, trainable full-param) ------------------- - condenser_model = TransformersModel( - model_id=CONDENSE_MODEL_ID, - device_mesh=condenser_model_mesh, - remote_group='condenser_model', - ) - condenser_model.set_optimizer(optimizer_cls='AdamW', lr=CONDENSER_RETRAIN_LR) - - # -------- CheckpointEngineManager: condenser_model → condenser_sampler --- - condenser_ckpt_manager = CheckpointEngineManager( - model=condenser_model, sampler=condenser_sampler) - condenser_ckpt_manager.sync_weights() - - # -------- Background retrainer ------------------------------------------- - retrainer = CondenserRetrainer(condenser_model, condenser_ckpt_manager, - condenser_sampler) - retrainer.start() - # -------- OpenAI API client for fallback --------------------------------- api_client = OpenAIClient( model=COMPRESS_MODEL, @@ -629,7 +625,7 @@ def train(): # -------- Train loop ----------------------------------------------------- def _sample_batch(raw_batch): """Compress via vLLM sampler; fall back to API on truncation.""" - compress_prompts, valid_indices, raw_pairs, prompt_queries, passthrough = \ + compress_prompts, valid_indices, raw_pairs, prompt_queries, passthrough, schemas = \ _build_compress_prompts(raw_batch) if len(compress_prompts) < 4: return None @@ -639,7 +635,7 @@ def _sample_batch(raw_batch): sampler_pos = [ri for ri, p in enumerate(compress_prompts) if p is not None] if sampler_input: try: - with retrainer.sampler_lock: + with _sampler_lock: sampler_responses = condenser_sampler.sample(sampler_input, compress_params) except Exception as exc: logger.warning(f'[sampler] encode overflow in batch, falling back to API: {exc}') @@ -668,7 +664,7 @@ def _sample_batch(raw_batch): # Premature-EOS: model emits chat-template token mid-skeleton, vLLM reports # stop_reason='stop' but the stripped text is structurally incomplete. needs_fallback = (not seq or seq.stop_reason == 'length' - or _is_truncated_compression(text)) + or _is_truncated_compression(text, schemas[ri])) if not needs_fallback: decoded_texts.append(text) continue @@ -676,14 +672,13 @@ def _sample_batch(raw_batch): api_result = _api_compress(api_client, compress_prompts[ri]) # Skip logging when the API itself produced truncated output: an incomplete # gold answer would teach the condenser to imitate broken outputs. - if api_result and not _is_truncated_compression(api_result): + if api_result and not _is_truncated_compression(api_result, schemas[ri]): decoded_texts.append(api_result) pair_idx = ri // 2 q_raw, c_raw = raw_pairs[pair_idx] source_text = q_raw if ri % 2 == 0 else c_raw _log_failure(source_text, prompt_queries[ri], api_result, valid_indices[pair_idx]) - retrainer.notify_failure() else: decoded_texts.append('') @@ -711,7 +706,7 @@ def _sample_batch(raw_batch): _start_epoch = cur_step // _batches_per_epoch if cur_step > 0 else 0 _skip_batches_in_epoch = cur_step - _start_epoch * _batches_per_epoch if cur_step > 0 else 0 - prefetch_executor = ThreadPoolExecutor(max_workers=1) + prefetch_executor = ThreadPoolExecutor(max_workers=PREFETCH_WORKERS) for epoch in range(_start_epoch, NUM_EPOCHS): # Skip consumed samples for the resume epoch (shuffle order won't match # exactly, but the correct number of samples is skipped). @@ -720,14 +715,16 @@ def _sample_batch(raw_batch): batch_iter = iter(dataloader) # Reset skip after first resumed epoch _skip_batches_in_epoch = 0 - prefetch_future = None - first_batch = next(batch_iter, None) - if first_batch is not None: - prefetch_future = prefetch_executor.submit(_sample_batch, first_batch) + prefetch_pool: deque = deque() + for _ in range(PREFETCH_WORKERS): + nb = next(batch_iter, None) + if nb is None: + break + prefetch_pool.append(prefetch_executor.submit(_sample_batch, nb)) for raw_batch in batch_iter: - emb_features = prefetch_future.result() if prefetch_future else None - prefetch_future = prefetch_executor.submit(_sample_batch, raw_batch) + emb_features = prefetch_pool.popleft().result() if prefetch_pool else None + prefetch_pool.append(prefetch_executor.submit(_sample_batch, raw_batch)) if emb_features is None: continue @@ -753,16 +750,15 @@ def _sample_batch(raw_batch): if cur_step % SAVE_INTERVAL == 0: save_checkpoint(model, f'step_{cur_step}') - # # Drain last prefetched batch - # if prefetch_future is not None: - # emb_features = prefetch_future.result() + # # Drain remaining prefetched batches + # while prefetch_pool: + # emb_features = prefetch_pool.popleft().result() # if emb_features is not None: # model.forward_backward(inputs=emb_features, task='embedding') # model.clip_grad_and_step() # cur_step += 1 prefetch_executor.shutdown(wait=False) - retrainer.stop() save_checkpoint(model, 'last-checkpoint') From df8855691bf706d517642ffc898510f530719769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 22:02:48 +0800 Subject: [PATCH 13/53] wip --- src/twinkle/model/megatron/strategy/megatron.py | 1 - src/twinkle/model/transformers/strategy/accelerate.py | 9 +++------ src/twinkle/model/transformers/transformers.py | 9 ++++----- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 819014eb..1bd80902 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -48,7 +48,6 @@ def __init__( ddp_config: Dict[str, Any] = None, **kwargs, ): - import torch.distributed as dist from megatron.core import mpu self.device_mesh = device_mesh self.use_distributed_optimizer = use_distributed_optimizer diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index d0434991..018f4c49 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -34,10 +34,8 @@ def __init__( parallelism_config = self._parallelism_config_from_device_mesh(device_mesh) fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init) - kwargs_handlers = [] - kwargs_handlers.append( - InitProcessGroupKwargs( - timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))) + kwargs_handlers = [InitProcessGroupKwargs( + timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))] if ddp_config is not None: from accelerate import DistributedDataParallelKwargs ddp_config = DistributedDataParallelKwargs(**ddp_config) @@ -131,8 +129,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di return fsdp_plugin def wrap_model(self, model, *args): - result = self.accelerator.prepare(model, *args) - return result + return self.accelerator.prepare(model, *args) def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 61733d7d..cac37526 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -414,8 +414,8 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec inputs = optimizer_config.template.batch_encode(inputs) # noqa processor: InputProcessor = optimizer_config.processor loss_instance = optimizer_config.loss_instance - loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) - loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logits = getattr(loss_instance, 'require_logits', False) + loss_require_entropy = getattr(loss_instance, 'require_entropy', False) loss_require_logps = getattr(loss_instance, 'require_logps', True) assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding' inputs: Dict[str, Any] = processor( @@ -490,8 +490,8 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T processor: InputProcessor = optimizer_config.processor assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' loss_instance = optimizer_config.loss_instance - loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) - loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logits = getattr(loss_instance, 'require_logits', False) + loss_require_entropy = getattr(loss_instance, 'require_entropy', False) loss_require_logps = getattr(loss_instance, 'require_logps', True) inputs: Dict[str, Any] = processor( inputs, @@ -929,7 +929,6 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int if optimizer_config.cur_step % interval != 0: return model = self.strategy.unwrap_model(self.model) - processed_state_dict = {} save_kwargs = {} if adapter_name == _default_adapter_name: # Full model save From 28ee773e5e3bc0b0c9427a580731ead33c02ad43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Tue, 16 Jun 2026 14:20:17 +0800 Subject: [PATCH 14/53] wip --- cookbook/transformers/fsdp2.py | 12 +++++- src/twinkle/model/base.py | 3 +- src/twinkle/notifier/__init__.py | 1 + src/twinkle/notifier/base.py | 1 + src/twinkle/notifier/ding_notifier.py | 1 + src/twinkle/processor/base.py | 59 ++++++++++++++++++--------- 6 files changed, 55 insertions(+), 22 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index ad4c917f..ecad0220 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -2,6 +2,7 @@ from peft import LoraConfig from tqdm import tqdm +from torch.optim import Muon # PyTorch 2.9+; matrix-orthogonalized momentum optimizer. import twinkle from twinkle import DeviceMesh, get_device_placement, get_logger @@ -64,7 +65,16 @@ def train(): model.add_adapter_to_model( args.lora.adapter_name, lora_config, gradient_accumulation_steps=args.training.gradient_accumulation_steps) - model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) + # Muon optimizes 2D hidden-layer weight matrices via Newton-Schulz orthogonalization. + # In LoRA training the trainable params are exclusively lora_A / lora_B (both 2D), + # so Muon applies cleanly without an AdamW fallback for 1D params. + # ``adjust_lr_fn='match_rms_adamw'`` rescales the orthogonalized update so the same + # lr / weight_decay tuned for AdamW can be reused directly (Moonshot Muon recipe). + model.set_optimizer( + optimizer_cls=Muon, + lr=args.optimizer.learning_rate, + adjust_lr_fn='match_rms_adamw', + ) # Add LRScheduler for lora `default` model.set_lr_scheduler( diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index a4d4ea06..8ea00d69 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -1,8 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import ABC, abstractmethod -from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union from twinkle import Platform, torch_util from twinkle.data_format import InputFeature, ModelOutput diff --git a/src/twinkle/notifier/__init__.py b/src/twinkle/notifier/__init__.py index 329cb6f1..067db71a 100644 --- a/src/twinkle/notifier/__init__.py +++ b/src/twinkle/notifier/__init__.py @@ -1,2 +1,3 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. from .base import Notifier, notify_exception from .ding_notifier import DingNotifier diff --git a/src/twinkle/notifier/base.py b/src/twinkle/notifier/base.py index 6c138c99..6f50ca65 100644 --- a/src/twinkle/notifier/base.py +++ b/src/twinkle/notifier/base.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import os from typing import Dict, Optional diff --git a/src/twinkle/notifier/ding_notifier.py b/src/twinkle/notifier/ding_notifier.py index fe102d8a..fc535edd 100644 --- a/src/twinkle/notifier/ding_notifier.py +++ b/src/twinkle/notifier/ding_notifier.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import base64 import hashlib import hmac diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index d600ec2b..c505d91c 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -42,7 +42,6 @@ class InputProcessor: 'video_grid_thw': 0, 'input_features': 0.0, 'feature_attention_mask': 0, - 'mm_token_type_ids': 0, } # VLM fields to concatenate (not pad) in batch @@ -108,8 +107,12 @@ def to_tensor(_input): # so tensor ops like labels != ignore_index or .to(device) would fail without this. if isinstance(value, np.ndarray): value = torch.from_numpy(value) - elif (isinstance(value, list) and isinstance(value[0], - (int, float, np.number))) or key == 'position_ids': + elif isinstance(value, list) and len(value) > 0 and isinstance( + value[0], (int, float, np.number)): + value = torch.tensor(value) + elif key == 'position_ids' and not isinstance(value, torch.Tensor): + if value is None: + continue value = torch.tensor(value) elif (isinstance(value, list)) and key in ('completion_mask', 'mm_token_type_ids'): value = torch.tensor(value) @@ -284,7 +287,9 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso return input_tensor if cp_size > 1: - position_ids_f = position_ids.flatten() + pos_for_cu = position_ids[:1] if position_ids.dim() >= 2 and position_ids.shape[0] > 1 \ + else position_ids + position_ids_f = pos_for_cu.flatten() indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) cu_seqlens = torch.cat([ indices_q[position_ids_f == 0], @@ -354,8 +359,11 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) val = val.view(view_shape) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', - pin_memory=True).cuda(non_blocking=True) + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], + device=inputs.device, + dtype=torch.long, + ) val = val.index_select(dim, index) view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:]) new_inputs.append(val.view(view_shape)) @@ -402,17 +410,18 @@ def prepare_transformers_padding_free_patch(self, inputs: List[InputFeature], ** if not padding_free or bool(kwargs.get('enable_sp', False)): return inputs - from twinkle.patch import apply_patch - from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch - - apply_patch( - model, - GatedDeltaNetPaddingFreePatch, - hf_config=kwargs.get('hf_config'), - enable_sp=False, - ) if not getattr(model, '_twinkle_gdn_padding_free_patched', False): - return inputs + from twinkle.patch import apply_patch + from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch + + apply_patch( + model, + GatedDeltaNetPaddingFreePatch, + hf_config=kwargs.get('hf_config'), + enable_sp=False, + ) + if not getattr(model, '_twinkle_gdn_padding_free_patched', False): + return inputs for _inp in inputs: position_ids = _inp.get('position_ids') @@ -631,15 +640,27 @@ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeat output = {} _keys = [ 'input_ids', - 'input_embeddings', + 'inputs_embeds', 'attention_mask', 'position_ids', 'labels', 'completion_mask', + 'cu_seq_lens_q', + 'cu_seq_lens_k', + 'cu_seqlens_q', + 'cu_seqlens_kv', + 'max_length_q', + 'max_length_k', + 'packed_seq_params', ] + list(InputProcessor.VLM_CONCAT_FIELDS) for key in list(_input.keys()): - if key in _keys: - output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key] + if key not in _keys: + continue + value = _input[key] + if isinstance(value, torch.Tensor) or not isinstance(value, (list, np.ndarray)): + output[key] = value + else: + output[key] = np.array(value) results.append(InputFeature(**output)) return results From ffeba0b44cb734ba33493e8f4d3fb25fa181d546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Tue, 16 Jun 2026 14:31:21 +0800 Subject: [PATCH 15/53] wip --- .../sampler/vllm_sampler/vllm_sampler.py | 20 ------------- src/twinkle/template/tools/base.py | 29 +++++-------------- src/twinkle/template/tools/cline.py | 13 ++------- src/twinkle/template/tools/qwen.py | 3 -- .../preprocessor/message_normalizer.py | 12 ++++---- tests/template/test_tool_parsers.py | 22 -------------- 6 files changed, 16 insertions(+), 83 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 32dc1ca5..0433ef5a 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -1,24 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""vLLM-based sampler using VLLMEngine (AsyncLLM). - -Device Configuration: - vLLMSampler automatically detects the number of available GPUs from - CUDA_VISIBLE_DEVICES environment variable (set by twinkle's ResourceManager) - and configures vLLM's tensor_parallel_size accordingly. - - To use tensor parallelism, configure DeviceGroup with gpus_per_worker > 1: - - # DP2 with TP2 (4 GPUs total, 2 workers, each with 2 GPUs) - DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=2) - - # TP4 (4 GPUs, 1 worker with all 4 GPUs) - DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=4) - -Data Flow: - When multiple vLLMSampler workers exist (DP > 1): - - Data is dispatched via dispatch='slice_dp' (each worker gets a slice) - - Results are collected via collect='flatten' (merged into single list) -""" import asyncio import atexit import numpy as np diff --git a/src/twinkle/template/tools/base.py b/src/twinkle/template/tools/base.py index a6d7040e..35b63dc8 100644 --- a/src/twinkle/template/tools/base.py +++ b/src/twinkle/template/tools/base.py @@ -10,14 +10,6 @@ class ToolCallParser(ABC): open_marker: Optional[str] = None close_marker: Optional[str] = None - def matches_model(self, model_id: str) -> bool: - """Return True if this parser is the canonical choice for ``model_id``. - - Used for streaming where we must commit to a parser before any text - has arrived. Default False — parser is text-detection-only. - """ - return False - @abstractmethod def detect(self, text: str) -> bool: """Cheap pre-check: does ``text`` carry this format's markup?""" @@ -30,13 +22,14 @@ def parse(self, text: str) -> List[Dict[str, Any]]: def clean(self, text: str) -> str: """Strip parser-specific markup; return plain content text.""" - def detect_result(self, text: str) -> bool: - """Does ``text`` look like a tool-result message for this protocol?""" - return False + def extract_tool_result(self, text: str) -> Optional[str]: + """If ``text`` is a tool-result message of this protocol, return the + body with the protocol-specific prefix stripped; otherwise return ``None``. - def parse_result(self, text: str) -> str: - """Strip protocol-specific result prefix; return the raw tool output body.""" - return text + Default returns ``None`` — only protocols carrying their own tool-result + framing (e.g. Cline) need to override this. + """ + return None class ToolCallRegistry: @@ -56,14 +49,6 @@ def register(cls, parser: ToolCallParser) -> ToolCallParser: def parsers(cls) -> List[ToolCallParser]: return list(cls._parsers) - @classmethod - def select_for_model(cls, model_id: Optional[str]) -> Optional[ToolCallParser]: - mid = (model_id or '').lower() - for p in cls._parsers: - if p.matches_model(mid): - return p - return None - @classmethod def detect_first(cls, text: str) -> Optional[ToolCallParser]: if not text: diff --git a/src/twinkle/template/tools/cline.py b/src/twinkle/template/tools/cline.py index 2673e82e..8f36324c 100644 --- a/src/twinkle/template/tools/cline.py +++ b/src/twinkle/template/tools/cline.py @@ -21,7 +21,7 @@ from __future__ import annotations import re -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from .base import ToolCallParser @@ -99,10 +99,6 @@ class ClineParser(ToolCallParser): open_marker = None close_marker = None - def matches_model(self, model_id: str) -> bool: - # Cline is an app-level prompt protocol, not bound to any model family. - return False - def detect(self, text: str) -> bool: if not text or '<' not in text: return False @@ -153,9 +149,6 @@ def clean(self, text: str) -> str: out.append(text[last:]) return ''.join(out).rstrip() - def detect_result(self, text: str) -> bool: - return bool(_RESULT_RE.match(text or '')) - - def parse_result(self, text: str) -> str: + def extract_tool_result(self, text: str) -> Optional[str]: m = _RESULT_RE.match(text or '') - return text[m.end():] if m else text + return text[m.end():] if m else None diff --git a/src/twinkle/template/tools/qwen.py b/src/twinkle/template/tools/qwen.py index 12361b73..6713d570 100644 --- a/src/twinkle/template/tools/qwen.py +++ b/src/twinkle/template/tools/qwen.py @@ -16,9 +16,6 @@ class HermesQwenParser(ToolCallParser): _PARAMETER_RE = re.compile(r']+)>\s*([\s\S]*?)\s*') _STRIP_RE = re.compile(r'[\s\S]*?(?:|\Z)') - def matches_model(self, model_id: str) -> bool: - return 'qwen' in model_id - def detect(self, text: str) -> bool: return self.open_marker in text diff --git a/src/twinkle_agentic/preprocessor/message_normalizer.py b/src/twinkle_agentic/preprocessor/message_normalizer.py index a8606d8f..d3074a56 100644 --- a/src/twinkle_agentic/preprocessor/message_normalizer.py +++ b/src/twinkle_agentic/preprocessor/message_normalizer.py @@ -105,12 +105,12 @@ def _normalize_tool_calls(messages: List[Dict[str, Any]]) -> List[Dict[str, Any] nxt_text = msg_content_text(messages[j]) if not nxt_text: break - if parser.detect_result(nxt_text): - body = parser.parse_result(nxt_text) - elif tc_idx == 0 and len(tc_list) == 1: - body = nxt_text - else: - break + body = parser.extract_tool_result(nxt_text) + if body is None: + if tc_idx == 0 and len(tc_list) == 1: + body = nxt_text + else: + break out.append({ 'role': 'tool', 'content': body, diff --git a/tests/template/test_tool_parsers.py b/tests/template/test_tool_parsers.py index 41f6a3a4..9269ed1f 100644 --- a/tests/template/test_tool_parsers.py +++ b/tests/template/test_tool_parsers.py @@ -23,11 +23,6 @@ def test_detect(self): assert not self.p.detect('plain text') assert not self.p.detect('') - def test_matches_model(self): - assert self.p.matches_model('qwen2.5-7b') - assert self.p.matches_model('qwen3-32b') - assert not self.p.matches_model('llama-3.1-8b') - def test_parse_json_variant(self): text = '{"name": "get_weather", "arguments": {"city": "Paris"}}' out = self.p.parse(text) @@ -104,10 +99,6 @@ def test_no_block_marker(self): assert self.p.open_marker is None assert self.p.close_marker is None - def test_does_not_match_qwen_model(self): - assert not self.p.matches_model('qwen2.5') - assert not self.p.matches_model('llama-3') - def test_parse_single_action(self): text = 'Thought: search the web.\nAction: search[hello world]' out = self.p.parse(text) @@ -172,11 +163,6 @@ def test_no_marker(self): assert self.p.open_marker is None assert self.p.close_marker is None - def test_does_not_match_any_model_by_default(self): - # Cline is an app-level prompt protocol, not a model-family format. - assert not self.p.matches_model('qwen2.5') - assert not self.p.matches_model('claude-3') - def test_parse_single_arg(self): text = 'src/foo.py' out = self.p.parse(text) @@ -254,14 +240,6 @@ def test_no_parser_for_plain_text(self): assert ToolCallRegistry.detect_first('just some plain text') is None assert ToolCallRegistry.detect_first('') is None - def test_select_for_qwen_picks_hermes(self): - parser = ToolCallRegistry.select_for_model('qwen2.5-7b') - assert parser is not None and parser.name == 'hermes_qwen' - - def test_select_for_unknown_returns_none(self): - assert ToolCallRegistry.select_for_model('llama-3.1-8b') is None - assert ToolCallRegistry.select_for_model(None) is None - if __name__ == '__main__': pytest.main([__file__, '-v']) From a762f0f20e32e3ce709c8fb1c4e79a5135f5b8d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Tue, 16 Jun 2026 16:51:15 +0800 Subject: [PATCH 16/53] wip --- src/twinkle/template/__init__.py | 1 + src/twinkle/template/base.py | 75 ++++++++++++++++++----------- src/twinkle/template/deepseek_v4.py | 26 +++++----- src/twinkle/template/qwen3_5_vl.py | 3 ++ src/twinkle/template/utils.py | 62 ++++++++++++++++-------- 5 files changed, 104 insertions(+), 63 deletions(-) diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 6c4bdddd..168456ea 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -2,3 +2,4 @@ from .base import Template from .deepseek_v4 import DeepseekV4Template from .qwen3_5_vl import Qwen3_5Template +from .tools import ToolCallParser, ToolCallRegistry, ClineParser, HermesQwenParser, ReActParser, VCPParser diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 26c2e4f2..c1e8f069 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -32,6 +32,19 @@ class Template: video_placeholder: str = '