diff --git a/INSTALL_MEGATRON.sh b/INSTALL_MEGATRON.sh index 276598c50..951b0171d 100644 --- a/INSTALL_MEGATRON.sh +++ b/INSTALL_MEGATRON.sh @@ -5,6 +5,7 @@ set -e # Exit immediately on error export SETUPTOOLS_USE_DISTUTILS=local +export UV_INDEX_URL=${UV_INDEX_URL:-https://mirrors.aliyun.com/pypi/simple/} echo "==========================================" echo "Starting deep learning dependencies installation..." echo "==========================================" @@ -53,15 +54,15 @@ TORCH_CUDA_ARCH_LIST=$(get_cuda_arch "$GPU_NAME") export TORCH_CUDA_ARCH_LIST echo "Using CUDA architecture: $TORCH_CUDA_ARCH_LIST" -# Install latest base packages +# Install vllm 0.21.x (latest 0.2x uses CUDA 12 toolchain, avoids CUDA 13 CUTLASS conflicts) echo "" -echo "Installing peft, accelerate, transformers, modelscope..." -pip install --upgrade peft accelerate transformers "modelscope[framework]" --no-cache-dir +echo "Installing vllm 0.21..." +uv pip install "vllm>=0.21,<0.22" -# Install latest vllm +# Install latest base packages echo "" -echo "Installing latest vllm..." -pip install --upgrade vllm --no-cache-dir +echo "Installing peft, accelerate, transformers, modelscope..." +uv pip install --upgrade peft accelerate transformers "modelscope[framework]" # Get site-packages path and install transformer_engine and megatron_core echo "" @@ -69,26 +70,30 @@ echo "Installing transformer_engine and megatron_core..." SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") echo "Site-packages path: $SITE_PACKAGES" -CUDNN_PATH=$SITE_PACKAGES/nvidia/cudnn \ -CPLUS_INCLUDE_PATH=$SITE_PACKAGES/nvidia/cudnn/include \ -pip install --no-build-isolation "transformer_engine[pytorch]" --no-cache-dir +export CUDA_HOME=${SITE_PACKAGES}/nvidia/cu13 +export PATH=$CUDA_HOME/bin:$PATH +export CPATH=$CUDA_HOME/include:$CPATH +export LIBRARY_PATH=$CUDA_HOME/lib:$LIBRARY_PATH +export LD_LIBRARY_PATH=$CUDA_HOME/lib:$LD_LIBRARY_PATH +uv pip install transformer_engine_torch --no-build-isolation -pip install megatron_core mcore_bridge --no-cache-dir +uv pip install megatron_core mcore_bridge -# Install flash-attention (force local build) +# Install flash-attention +# Prefer prebuilt wheel; fall back to source build only if needed. echo "" -echo "Installing flash-attention (local build for $GPU_NAME)..." -TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" \ -MAX_JOBS=8 \ -FLASH_ATTENTION_FORCE_BUILD=TRUE \ -pip install flash-attn --no-build-isolation --no-cache-dir +echo "Installing flash-attention..." +export TORCH_CUDA_ARCH_LIST +export MAX_JOBS=8 +pip install flash-attn --no-cache-dir || \ + FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn --no-build-isolation --no-cache-dir -pip install flash-linear-attention -U --no-cache-dir +uv pip install flash-linear-attention --upgrade # Install numpy echo "" echo "Installing numpy==2.2 and deep_gemm..." -pip install numpy==2.2 --no-cache-dir +uv pip install numpy==2.2 # Verify installation echo "" diff --git a/README.md b/README.md index 4dd203cfb..4867358a3 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@

- English Documentation   |   中文文档   |   Twinkle Web   + English Documentation   |   中文文档   |   Twinkle Web  

## ✨ What is Twinkle? @@ -100,11 +100,11 @@ sh INSTALL_MEGATRON.sh | DPO multi-LoRA training | transformers | [Script](cookbook/rl/dpo_multi_lora.py) | | GKD on-policy distillation | megatron | [Script](cookbook/rl/gkd_on_policy.py) | | GKD off-policy distillation | megatron | [Script](cookbook/rl/gkd_off_policy.py) | -| Tinker client finetuning (self-host) | transformers | [Script](cookbook/client/tinker/self_host) | -| Tinker client finetuning (ModelScope) | transformers | [Script](cookbook/client/tinker/modelscope) | -| Twinkle client finetuning (self-host) | transformers | [Script](cookbook/client/twinkle/self_host) | -| Twinkle client finetuning (ModelScope) | transformers | [Script](cookbook/client/twinkle/modelscope) | -| Server startup scripts | transformers/megatron | [Script](cookbook/client/server) | +| Tinker client finetuning (self-host) | transformers | [Script](cookbook/server_mode/tinker/self_host) | +| Tinker client finetuning (ModelScope) | transformers | [Script](cookbook/server_mode/tinker/modelscope) | +| Twinkle client finetuning (self-host) | transformers | [Script](cookbook/server_mode/twinkle/self_host) | +| Twinkle client finetuning (ModelScope) | transformers | [Script](cookbook/server_mode/twinkle/modelscope) | +| Server startup scripts | transformers/megatron | [Script](cookbook/server_mode/server) | ## Changelog - 🎉2026-05-20 Support DeepSeek-V4-Flash and DeepSeek-V4-Pro models. @@ -122,7 +122,7 @@ sh INSTALL_MEGATRON.sh We are rolling out training service built atop Twinkle✨ on ModelScope. You may train via API endpoint `base_url=https://www.modelscope.cn/twinkle`. For more details, please refer to -our [documentation](docs/source_en/Usage%20Guide/Train-as-a-Service.md). +our [documentation](https://modelscope.github.io/twinkle-web/docs/usage-guide/train-as-a-service/). ## Supported Hardware @@ -177,7 +177,7 @@ supported on Twinkle✨ framework. ## Sample Code Below are some of the capabilities demonstrated in the example code. For a complete introduction to training capabilities, -please refer to [Quick Start](docs/source_en/Usage%20Guide/Quick-Start.md) and [cookbook](cookbook). +please refer to [Quick Start](https://modelscope.github.io/twinkle-web/docs/usage-guide/quick-start/) and [cookbook](cookbook). ### Train with Ray diff --git a/README_ZH.md b/README_ZH.md index 5d588b393..92923209c 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -19,7 +19,7 @@ by ModelScope & - 英文文档   |   中文文档   |   Twinkle 站点   + 英文文档   |   中文文档   |   Twinkle 站点  

## ✨ Twinkle 是什么? @@ -94,13 +94,13 @@ sh INSTALL_MEGATRON.sh | DPO 多 LoRA 训练 | transformers | [脚本](cookbook/rl/dpo_multi_lora.py) | | GKD 在线蒸馏 | megatron | [脚本](cookbook/rl/gkd_on_policy.py) | | GKD 离线蒸馏 | megatron | [脚本](cookbook/rl/gkd_off_policy.py) | -| Tinker 客户端微调(自部署) | transformers | [脚本](cookbook/client/tinker/self_host) | -| Tinker 客户端微调(ModelScope) | transformers | [脚本](cookbook/client/tinker/modelscope) | -| Twinkle 客户端微调(自部署) | transformers | [脚本](cookbook/client/twinkle/self_host) | -| Twinkle 客户端微调(ModelScope) | transformers | [脚本](cookbook/client/twinkle/modelscope) | -| 服务端启动脚本 | transformers/megatron | [脚本](cookbook/client/server) | +| Tinker 客户端微调(自部署) | transformers | [脚本](cookbook/server_mode/tinker/self_host) | +| Tinker 客户端微调(ModelScope) | transformers | [脚本](cookbook/server_mode/tinker/modelscope) | +| Twinkle 客户端微调(自部署) | transformers | [脚本](cookbook/server_mode/twinkle/self_host) | +| Twinkle 客户端微调(ModelScope) | transformers | [脚本](cookbook/server_mode/twinkle/modelscope) | +| 服务端启动脚本 | transformers/megatron | [脚本](cookbook/server_mode/server) | -Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Client等各场景下。其算法过程是外露的,非常便于修改和调试。完整的框架介绍请查看[快速开始](docs/source_zh/使用指引/快速开始.md) +Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Client等各场景下。其算法过程是外露的,非常便于修改和调试。完整的框架介绍请查看[快速开始](https://modelscope.github.io/twinkle-web/zh/docs/usage-guide/quick-start/) ## 更新日志 - 🎉2026-05-20 支持DeepSeek-V4-Flash and DeepSeek-V4-Pro系列模型。 @@ -116,7 +116,7 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl ## ModelScope 的训练服务 -我们正在 ModelScope 上推出基于 Twinkle✨ 构建的训练服务。你可以通过 API 端点 `base_url=https://www.modelscope.cn/twinkle` 进行训练。更多详情请参阅我们的[文档](docs/source_zh/使用指引/训练服务.md)。 +我们正在 ModelScope 上推出基于 Twinkle✨ 构建的训练服务。你可以通过 API 端点 `base_url=https://www.modelscope.cn/twinkle` 进行训练。更多详情请参阅我们的[文档](https://modelscope.github.io/twinkle-web/zh/docs/usage-guide/train-as-a-service/)。 ## 支持的硬件 @@ -166,7 +166,7 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl ## 示例代码 -下面列出了示例代码的一部分能力。完整的训练能力介绍请参考[快速开始](docs/source_zh/使用指引/快速开始.md)以及[cookbook](cookbook)。 +下面列出了示例代码的一部分能力。完整的训练能力介绍请参考[快速开始](https://modelscope.github.io/twinkle-web/zh/docs/usage-guide/quick-start/)以及[cookbook](cookbook)。 ### 使用 Ray 训练 diff --git a/cookbook/exp/condenser/untested/eval_condensed_compressed.sh b/cookbook/exp/condenser/untested/eval_condensed_compressed.sh index 5567a1a3b..833b446fa 100755 --- a/cookbook/exp/condenser/untested/eval_condensed_compressed.sh +++ b/cookbook/exp/condenser/untested/eval_condensed_compressed.sh @@ -3,14 +3,14 @@ # Identical --dataset / --limit / --model_id as eval_condensed_native.sh for an A/B comparison. set -euo pipefail -DATASET="${DATASET:-/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl}" -MODEL_ID="${MODEL_ID:-ms://Qwen/Qwen3.5-4B}" -CONDENSER_LORA="${CONDENSER_LORA:-ms://twinkle-kit/Qwen3.5-4B-Condenser}" -LIMIT="${LIMIT:-500}" -NUM_GPUS="${NUM_GPUS:-4}" -OUT_DIR="${OUT_DIR:-eval_out}" +DATASET="/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl" +MODEL_ID="ms://Qwen/Qwen3.5-4B" +CONDENSER_LORA="ms://twinkle-kit/Qwen3.5-4B-Condenser" +LIMIT="500" +NUM_GPUS="4" +OUT_DIR="eval_out" -CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ python cookbook/exp/eval_condensed.py \ --mode condensed \ --dataset_format musique \ diff --git a/cookbook/exp/condenser/untested/eval_condensed_native.sh b/cookbook/exp/condenser/untested/eval_condensed_native.sh index 0849e9378..176c767b6 100755 --- a/cookbook/exp/condenser/untested/eval_condensed_native.sh +++ b/cookbook/exp/condenser/untested/eval_condensed_native.sh @@ -3,13 +3,13 @@ # Compare against eval_condensed_compressed.sh on identical --dataset / --limit / --model_id. set -euo pipefail -DATASET="${DATASET:-/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl}" -MODEL_ID="${MODEL_ID:-ms://Qwen/Qwen3.5-4B}" -LIMIT="${LIMIT:-500}" -NUM_GPUS="${NUM_GPUS:-4}" -OUT_DIR="${OUT_DIR:-eval_out}" +DATASET="/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl" +MODEL_ID="ms://Qwen/Qwen3.5-4B" +LIMIT="500" +NUM_GPUS="4" +OUT_DIR="eval_out" -CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ python cookbook/exp/eval_condensed.py \ --mode native \ --dataset_format musique \ diff --git a/cookbook/exp/embedding/build_thinking_rag_index.py b/cookbook/exp/embedding/build_thinking_rag_index.py new file mode 100644 index 000000000..d228a597a --- /dev/null +++ b/cookbook/exp/embedding/build_thinking_rag_index.py @@ -0,0 +1,935 @@ +"""Build a thinking-trace RAG index from condensed (query, cot) pairs. + +Pipeline (per row, batched): + 1. Load (user_query, reasoning_content) pairs from ``dataset_think.get_dataset``. + 2. Compress query with ``RAG_QUERY_HINT`` and cot with ``RAG_THINKING_HINT`` + (a symmetric Problem/Skill/Knowledge schema defined in this file) using a + Twinkle ``vLLMSampler`` (TP=4 across GPUs 0-3). Reuses the system/user + wrappers from ``cookbook/exp/condenser/make_condenser_dataset.py``. + 3. On condenser truncation (``stop_reason='length'`` or skeleton-incomplete + output), fall back to an external OpenAI-compatible API. + 4. Encode the condensed pair via the trained embedding model — Twinkle + ``TransformersModel`` on the ``emb_model`` device group (DP=4 across GPUs + 4-7) using ``forward_only(task='embedding')``, the same code path as + training. + 5. Compute cosine similarity for each (query, thinking) pair, drop pairs with + ``sim < SIM_THRESHOLD``, and insert kept rows into LanceDB. The vector + column carries the **positive (compressed-skill)** embedding so a search + keyed by an anchor-encoded query retrieves the matching thinking trace. + 6. Each row stores the **raw thinking** alongside its embedding, so a hit + in the index can directly surface the original CoT. + +Eval mode (``--mode eval`` or ``--mode both``): + * Self-recall test — encode a sample of dataset queries (whose corresponding + rows are already in the index) as anchors and report recall@1/5/10 plus + a per-source breakdown. + +Architecture (8 GPUs): + * GPU 0-3: vLLM condenser (tensor-parallel, ``DeviceGroup name='sampler'``) + * GPU 4-7: TransformersModel embedding (data-parallel, ``DeviceGroup name='emb_model'``) + * Single ``twinkle.initialize(mode='ray', ...)`` call wires both groups. + +Launch examples: + python build_thinking_rag_index.py --mode build --total 500000 + python build_thinking_rag_index.py --mode eval --eval-size 1000 + python build_thinking_rag_index.py --mode both --total 200000 --eval-size 500 +""" +import argparse +import json +import os +import re +import sys +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import tqdm + +# --------------------------------------------------------------------------- +# Compress prompts — MUST match train_embedding_full_ddp.py exactly. +# --------------------------------------------------------------------------- +_HERE = Path(__file__).resolve().parent +sys.path.insert(0, str(_HERE)) + +COMPRESS_SYSTEM = """\ +You are a compression and summary assistant. For the (query, source) pair, emit a Markdown \ +answer with TWO sections, designed to pair with the `extract_compressed` tool: \ +the reader absorbs `## Summary` directly, then calls `extract_compressed` \ +on any topic-key listed under `## More` to recover its \ +fuller content. + + `## Summary` \u2014 extreme-density text the reader reads directly. + `## More` \u2014 a topic index whose keys are valid arguments \ +to `extract_compressed` for recovering material not captured inline. + +Together the two sections must form a COMPLETE, NON-DISTORTING inventory of the \ +source for the query \u2014 nothing essential lost, nothing implied that the source \ +does not support. NO preamble, NO meta-commentary, NO code fences wrapping the \ +whole output. + +Output skeleton: + +## Summary +Topic: + + +## More +- : +- ... + +Format selection for the inline body (pick the MOST COMPACT form per query, mix \ +when helpful): +- Interface / signature \u2192 code notation directly: `func(a:int)->str` +- Factual / entity \u2192 telegraphic prose; drop function words; \":\" for \"is\", \",\" \ +for \"has\" +- Skill / how-to / usage \u2192 lead with `Use when: `; numbered telegraphic \ +steps `1.do X 2.then Y`; close with `Output: ` when relevant +- Procedural \u2192 numbered short steps +- Analytical / design \u2192 hierarchical bullets with abbreviations + +`## Summary` rules: +1. TOPIC LINE \u2014 line 1 is ALWAYS `Topic: `, even when the \ +query is narrow. Anchors both the reader and the tool. +2. DENSITY \u2014 every token in the body carries query-relevant signal; cut filler. +3. PRIMARY-COMPLETE \u2014 never silently drop a fact essential to answering the \ +query. Anything cut for length MUST appear as a key under \ +`## More`. +4. NON-MISLEADING \u2014 phrasing must not let the reader infer anything the source \ +does not support; partial truths that mislead are worse than honest omissions \ +flagged in the index. +5. SELF-CONTAINED \u2014 the reader can act on the answer without re-opening the source. +6. FAITHFUL \u2014 only content the source supports; no fabrication, no extrapolation. +7. LANGUAGE \u2014 match the source language. +8. NO outer code fences around the whole answer; no meta-commentary. + +`## More` rules (MANDATORY \u2014 this section is never omitted): +1. FORMAT \u2014 each bullet is `- : `: + \u2022 topic-key \u2014 short, unambiguous, grounded in source vocabulary so the \ +`extract_compressed` tool can locate the aspect (e.g. `decorators`, \ +`error handling`, `pitfalls`). + \u2022 hint \u2014 tells WHAT the reader gains by expanding (concrete numbers, code \ +listings, secondary cases, edge details, related context, \u2026); do NOT restate \ +the inline answer. +2. CRITERION \u2014 each bullet names an aspect that EXISTS in the source but is \ +NOT fully captured inline. Material that genuinely fits inline without \ +distortion MUST NOT be duplicated here. +3. FAITHFUL \u2014 hints must be grounded in the source; never speculate or invent. +4. ORDER \u2014 by relevance to the query, then by importance. +5. EMPTY CASE \u2014 if the source is so short / single-purpose that everything \ +fits inline, write a single line `- (none)`. + +Now begin.\ +""" + +COMPRESS_USER = ( + 'Downstream model will read your compressed block to decide whether to ' + 'expand it. Compress faithfully: preserve the passage topic + core facts. ' + 'Do NOT invent facts. Do NOT drop major facts. Do NOT write meta-commentary ' + 'about the Query (never write "Query info: absent", "no X mention", etc.); ' + 'if the passage does not address the Query, still summarize the passage. ' + 'CRITICAL LANGUAGE RULE: detect the dominant language of the Passage ' + '(NOT the Query, NOT this instruction) and write the ENTIRE output in that ' + 'same language; English passage \u2192 English output, Chinese passage \u2192 ' + 'Chinese output, Japanese passage \u2192 Japanese output. NEVER translate, ' + 'NEVER mix languages, NEVER copy these instructions into the output.\n\n' + '## Query (ordering hint only \u2014 still summarize the whole passage)\n{query}\n\n' + '## Passage\n{text}') + +# Default dataset loader is the index-time corpus (broader retrieval profile); +# pass --dataset-module dataset_think to fall back to the training mix. +from dataset_index import get_dataset as _default_get_dataset # noqa: E402 + +_GET_DATASET = _default_get_dataset + +import twinkle # noqa: E402 +from twinkle import DeviceGroup, DeviceMesh, get_logger # noqa: E402 +from twinkle.data_format import SamplingParams as TwinkleSamplingParams # noqa: E402 +from twinkle.loss import InfonceLoss # noqa: E402 +from twinkle.model import TransformersModel # noqa: E402 +from twinkle.processor import InputProcessor # noqa: E402 +from twinkle.sampler import vLLMSampler # noqa: E402 +from twinkle.template import Qwen3_5Template # noqa: E402 +from twinkle.utils.parallel import PosixFileLock # noqa: E402 +from twinkle_agentic.protocol.openai import OpenAI as OpenAIClient # noqa: E402 + +logger = get_logger() + + +# =========================================================================== +# Config (most fields overridable via CLI / env) +# =========================================================================== + +EMBED_MODEL_ID = os.environ.get( + 'EMBED_MODEL_ID', + 'output/embedding_lora_transformers/step_4000', +) +CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2') + +# Twinkle device topology: TP=4 sampler on 0-3, DP=4 embedding on 4-7. +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +EMB_GPUS = int(os.environ.get('EMB_GPUS', 4)) +NUM_GPUS = SAMPLER_GPUS + EMB_GPUS + +# vLLM engine sizing. +CONDENSE_GPU_MEM = float(os.environ.get('CONDENSE_GPU_MEM', 0.85)) +CONDENSE_MAX_MODEL_LEN = int(os.environ.get('CONDENSE_MAX_MODEL_LEN', 32768)) +CONDENSE_MAX_TOKENS = int(os.environ.get('CONDENSE_MAX_TOKENS', 8192)) +COMPRESS_TEMPERATURE = float(os.environ.get('COMPRESS_TEMPERATURE', 0.2)) +COMPRESS_TOP_P = float(os.environ.get('COMPRESS_TOP_P', 0.5)) + +# Embedding sizing. +EMBED_MAX_LENGTH = int(os.environ.get('EMBED_MAX_LENGTH', 8192)) + +SIM_THRESHOLD = float(os.environ.get('SIM_THRESHOLD', 0.65)) +MIN_TEXT_CHARS = int(os.environ.get('MIN_TEXT_CHARS', 256)) + +# Hard-templated hints: the condenser SFT prior maps `Skill` to the legacy +# `Use when: / numbered steps / Output:` skeleton on long inputs; embedding the +# exact 4-line body template + explicit negative constraints is the only way to +# override it deterministically across query and cot sides. +RAG_QUERY_HINT = ( + 'Summarize this query for retrieval. ' + 'The body of ## Summary MUST follow this EXACT 4-line template — ' + 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n' + 'Topic: \n' + 'Problem: \n' + 'Skill: \n' + 'Knowledge: \n' + 'Then emit the mandatory ## More section as usual. ' + 'Topic must name the specific pattern, never generic labels.') +RAG_THINKING_HINT = ( + 'Summarize this reasoning trace for retrieval. ' + 'The body of ## Summary MUST follow this EXACT 4-line template — ' + 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n' + 'Topic: \n' + 'Problem: \n' + 'Skill: \n' + 'Knowledge: \n' + 'Then emit the mandatory ## More section as usual. ' + 'Topic must name the specific pattern, never generic labels.') + +# OpenAI API fallback (used when vLLM truncates). +COMPRESS_API_KEY = os.environ.get('COMPRESS_API_KEY', '') +COMPRESS_BASE_URL = os.environ.get( + 'COMPRESS_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1') +COMPRESS_API_MODEL = os.environ.get('COMPRESS_API_MODEL', 'qwen3.7-max') + +# Source → coarse domain (for filtered eval). +DOMAIN_MAP = { + 'CodeX-2M-Thinking': 'code', + 'OpenThoughts3-1.2M': 'reasoning', + 'LIMO-v2': 'math', + 'Chinese-DeepSeek-R1-Distill-data-110k': 'reasoning_zh', + 'Opus-4.6-Reasoning-3000x-filtered': 'reasoning', + 'claude-opus-4.6-10000x': 'mixed', + 'angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k': 'mixed', +} + + +# =========================================================================== +# Small helpers +# =========================================================================== + +_LEGACY_USE_WHEN_RE = re.compile(r'(?im)^\s*Use when\s*:') +_SCHEMA_MARKERS = ('Problem:', 'Skill:', 'Knowledge:') + + +def _is_truncated_compression(text: str) -> bool: + """Reject structurally incomplete OR schema-regressed condenser output. + + Triggers API fallback when the vLLM output: + * lacks ``## Summary`` / ``## More``, + * has an empty or unterminated ``## More`` bullet list, or + * regresses to the legacy ``Use when: / numbered-steps / Output:`` skeleton + instead of the mandated Problem/Skill/Knowledge 4-line body — the + dominant cot-side failure mode that drives sim < 0.45 drops. + """ + if not text or not text.strip(): + return True + if '## More' not in text or '## Summary' not in text: + return True + after_more = text.split('## More', 1)[1].strip() + if not after_more: + return True + last_line = after_more.splitlines()[-1].strip() + if not (last_line.startswith('-') or last_line.endswith(')')): + return True + summary_body = text.split('## Summary', 1)[1].split('## More', 1)[0] + if _LEGACY_USE_WHEN_RE.search(summary_body): + return True + if not all(marker in summary_body for marker in _SCHEMA_MARKERS): + return True + return False + + +def _strip_outer_codefence(text: str) -> str: + m = re.match(r'^```[a-zA-Z]*\n(.*?)\n```\s*$', text, re.DOTALL) + if m: + return m.group(1).strip() + return text.strip() + + +def _wrap_anchor(text: str) -> List[Dict[str, str]]: + """Anchor-side message wrapping (must match training).""" + return [ + {'role': 'user', 'content': text}, + {'role': 'assistant', 'content': 'Match the correct response here.'}, + ] + + +def _wrap_positive(text: str) -> List[Dict[str, str]]: + """Positive-side message wrapping (must match training).""" + return [ + {'role': 'user', 'content': 'Match the correct query here.'}, + {'role': 'assistant', 'content': text}, + ] + + +def _short(text: str, n: int = 96) -> str: + text = (text or '').replace('\n', ' ').strip() + return text[:n] + ('…' if len(text) > n else '') + + +def _detect_lang(text: str) -> str: + if not text: + return 'unknown' + cjk = sum(1 for ch in text[:512] if '\u4e00' <= ch <= '\u9fff') + return 'zh' if cjk >= 8 else 'en' + + +def _build_compress_messages(text: str, query: str) -> List[Dict[str, str]]: + return [ + {'role': 'system', 'content': COMPRESS_SYSTEM}, + {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)}, + ] + + +# =========================================================================== +# Twinkle component wrappers +# =========================================================================== + +def initialize_twinkle() -> Tuple[DeviceMesh, DeviceMesh]: + """Wire two device groups (sampler / emb_model) and return their meshes.""" + device_groups = [ + DeviceGroup( + name='sampler', + ranks=list(range(SAMPLER_GPUS)), + device_type='GPU', + gpus_per_worker=SAMPLER_GPUS, # TP=4 → one worker spans all 4 GPUs + ), + DeviceGroup( + name='emb_model', + ranks=list(range(SAMPLER_GPUS, NUM_GPUS)), + device_type='GPU', + ), + ] + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, tp_size=SAMPLER_GPUS) + emb_mesh = DeviceMesh.from_sizes(world_size=EMB_GPUS, dp_size=EMB_GPUS) + twinkle.initialize( + mode='ray', + nproc_per_node=NUM_GPUS, + groups=device_groups, + lazy_collect=False, + ) + return sampler_mesh, emb_mesh + + +def build_sampler(sampler_mesh: DeviceMesh) -> vLLMSampler: + sampler = vLLMSampler( + model_id=CONDENSE_MODEL_ID, + engine_args={ + 'gpu_memory_utilization': CONDENSE_GPU_MEM, + 'max_model_len': CONDENSE_MAX_MODEL_LEN, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template( + 'Qwen3_5Template', + model_id=CONDENSE_MODEL_ID, + enable_thinking=False, + max_length=CONDENSE_MAX_MODEL_LEN, + ) + return sampler + + +def build_emb_model(emb_mesh: DeviceMesh) -> Tuple[TransformersModel, Qwen3_5Template]: + model = TransformersModel( + model_id=EMBED_MODEL_ID, + device_mesh=emb_mesh, + remote_group='emb_model', + ) + model.set_processor(InputProcessor) + # InfonceLoss is required by the framework even though forward_only does + # not actually invoke it; matches the training-time configuration. + model.set_loss(InfonceLoss, temperature=0.03, use_batch=True) + # Qwen3.5-specific subclass applies orphan- chat-template patches. + template = Qwen3_5Template( + model_id=EMBED_MODEL_ID, + max_length=EMBED_MAX_LENGTH, + truncation_strategy='delete', + enable_thinking=False, + ) + return model, template + + +# =========================================================================== +# Compression helpers (vLLMSampler) + API fallback +# =========================================================================== + +def _vllm_compress(sampler: vLLMSampler, texts: List[str], query_hint: str + ) -> List[Tuple[str, str]]: + """Compress ``texts`` via the sampler; return ``(decoded, stop_reason)``.""" + if not texts: + return [] + prompts = [{'messages': _build_compress_messages(t, query_hint)} for t in texts] + params = TwinkleSamplingParams( + max_tokens=CONDENSE_MAX_TOKENS, + temperature=COMPRESS_TEMPERATURE, + top_p=COMPRESS_TOP_P, + num_samples=1, + ) + responses = sampler.sample(prompts, params) + results: List[Tuple[str, str]] = [] + for resp in responses: + seq = resp.sequences[0] if resp and resp.sequences else None + if seq is None: + results.append(('', 'error')) + continue + text = seq.decoded or '' + # Strip any leaked chat-template special tokens like ``<|im_end|>``. + text = re.sub(r'<\|[^|]+\|>', '', text).rstrip() + text = _strip_outer_codefence(text) + results.append((text, seq.stop_reason or 'stop')) + return results + + +def _api_compress(api: OpenAIClient, messages: List[Dict[str, str]]) -> Optional[str]: + sp = TwinkleSamplingParams(temperature=COMPRESS_TEMPERATURE, max_tokens=CONDENSE_MAX_TOKENS) + try: + reply = api({'messages': messages}, sp, extra_body={'enable_thinking': False}) + except Exception as exc: # noqa: BLE001 — broad catch is intentional + sys.stderr.write(f'[api_fallback] error: {exc}\n') + return None + content = (reply.get('content') or '').strip() + if not content: + return None + return _strip_outer_codefence(content) + + +def _resolve_compressed(sampler: vLLMSampler, api: Optional[OpenAIClient], + texts: List[str], query_hint: str) -> List[Optional[str]]: + """Run vLLM batch; replace truncations / skeleton-incomplete with API output.""" + pairs = _vllm_compress(sampler, texts, query_hint) + results: List[Optional[str]] = [] + for (text, stop), src_text in zip(pairs, texts): + if stop != 'length' and not _is_truncated_compression(text): + results.append(text) + continue + if api is None: + results.append(None) + continue + api_text = _api_compress(api, _build_compress_messages(src_text, query_hint)) + if api_text is None or _is_truncated_compression(api_text): + results.append(None) + else: + results.append(api_text) + return results + + +# =========================================================================== +# Embedding helpers (TransformersModel.forward_only(task='embedding')) +# =========================================================================== + +def _build_features(template: Qwen3_5Template, texts: List[str], role: str + ) -> List[Dict[str, Any]]: + """Wrap each text into the role-specific anchor / positive feature dict.""" + features: List[Dict[str, Any]] = [] + for text in texts: + if not text or not text.strip(): + # Pad with a single space so positional alignment holds against + # the input list — the caller filters out empty-text rows upstream. + text = ' ' + if role == 'anchor': + feat = template.encode({'messages': _wrap_anchor(text)}) + feat['labels'] = [1] + else: + feat = template.encode({'messages': _wrap_positive(text)}) + feat['labels'] = [0] + features.append(feat) + return features + + +def get_embeddings(model: TransformersModel, template: Qwen3_5Template, + texts: List[str], role: str) -> np.ndarray: + """Return ``[N, H]`` float32 L2-normalised embeddings for ``texts``. + + Inputs are padded up to a multiple of ``EMB_GPUS`` and sliced back to the + original ``N``: the dispatch layer (``_dispatch_args``) starves any rank + whose chunk lands beyond ``len(texts)``, so a single forward of fewer than + ``EMB_GPUS`` items (e.g. the probe) would otherwise raise + ``Batch too small for {EMB_GPUS} workers``. + """ + if not texts: + return np.zeros((0,), dtype=np.float32) + n = len(texts) + pad_n = (-n) % EMB_GPUS + padded = list(texts) + [' '] * pad_n if pad_n else list(texts) + features = _build_features(template, padded, role) + out = model.forward_only(inputs=features, task='embedding', return_logits=True) + emb = out['embeddings'] + if isinstance(emb, torch.Tensor): + emb = emb.detach().to(torch.float32).cpu().numpy() + emb = np.asarray(emb, dtype=np.float32) + return emb[:n] if pad_n else emb + + +def _probe_hidden_size(model: TransformersModel, template: Qwen3_5Template) -> int: + """One-shot warmup forward to read out the embedding dimension.""" + emb = get_embeddings(model, template, ['probe'], role='anchor') + if emb.ndim != 2 or emb.shape[0] == 0: + raise RuntimeError(f'unexpected embedding shape from probe: {emb.shape}') + return int(emb.shape[1]) + + +# =========================================================================== +# LanceDB I/O +# =========================================================================== + +def _make_arrow_schema(hidden_size: int): + import pyarrow as pa + return pa.schema([ + pa.field('id', pa.string()), + pa.field('vector', pa.list_(pa.float32(), hidden_size)), + pa.field('thinking_raw', pa.string()), + pa.field('query_raw', pa.string()), + pa.field('cot_compressed', pa.string()), + pa.field('query_compressed', pa.string()), + pa.field('source', pa.string()), + pa.field('domain', pa.string()), + pa.field('language', pa.string()), + pa.field('sim', pa.float32()), + ]) + + +def _open_or_create_table(db_path: str, table_name: str, hidden_size: int, + mode: str): + """Open an existing table for append/eval, or create a fresh one.""" + import lancedb + db = lancedb.connect(db_path) + schema = _make_arrow_schema(hidden_size) + if table_name in db.table_names(): + if mode == 'overwrite': + db.drop_table(table_name) + tbl = db.create_table(table_name, schema=schema, mode='overwrite') + else: + tbl = db.open_table(table_name) + else: + tbl = db.create_table(table_name, schema=schema, mode='create') + return db, tbl + + +def _existing_ids(table) -> set: + try: + col = table.to_pandas(columns=['id']) + return set(col['id'].astype(str).tolist()) + except Exception: # noqa: BLE001 + return set() + + +# =========================================================================== +# Build pipeline +# =========================================================================== + +def _stream_corpus(total: Optional[int], load_from_cache_file: bool, + max_rows: int = 0) -> Iterator[Dict[str, Any]]: + ds = _GET_DATASET(total=total, load_from_cache_file=load_from_cache_file) + n_full = len(ds) + cap = max_rows if (max_rows and max_rows < n_full) else n_full + sys.stderr.write(f'[corpus] get_dataset: {n_full} rows' + + (f' → yielding first {cap}\n' if cap < n_full else '\n')) + for i, row in enumerate(ds): + if i >= cap: + break + yield row + + +def _extract_query_cot(row: Dict[str, Any]) -> Tuple[str, str]: + user_query, cot = '', '' + for m in row.get('messages') or []: + if not isinstance(m, dict): + continue + role = m.get('role') or '' + if role == 'user' and not user_query: + user_query = (m.get('content') or '').strip() + elif role == 'assistant': + cot = (m.get('reasoning_content') or '').strip() + break + return user_query, cot + + +def _log_miss(misses_path: str, lock: PosixFileLock, record: Dict[str, Any]) -> None: + line = json.dumps(record, ensure_ascii=False, default=str) + '\n' + with lock: + with open(misses_path, 'a', encoding='utf-8') as fh: + fh.write(line) + + +def build_index(args: argparse.Namespace, + sampler: vLLMSampler, + emb_model: TransformersModel, + emb_template: Qwen3_5Template, + api: Optional[OpenAIClient]) -> None: + # ---- Probe embedding dimension ----------------------------------------- + sys.stderr.write('[build] probing embedding hidden size...\n') + hidden_size = _probe_hidden_size(emb_model, emb_template) + sys.stderr.write(f'[build] hidden_size={hidden_size}\n') + + # ---- LanceDB ------------------------------------------------------------ + db, tbl = _open_or_create_table( + args.db_path, args.table, hidden_size, + mode='overwrite' if args.overwrite else 'append', + ) + indexed = _existing_ids(tbl) if not args.overwrite else set() + sys.stderr.write(f'[build] table "{args.table}" — {len(indexed)} existing rows.\n') + + misses_path = args.misses_log or (str(Path(args.db_path) / f'{args.table}.misses.jsonl')) + Path(misses_path).parent.mkdir(parents=True, exist_ok=True) + misses_lock = PosixFileLock(misses_path + '.lock') + + # ---- Streaming loop ----------------------------------------------------- + n_seen = n_kept = n_dropped_short = n_dropped_compress = n_dropped_sim = 0 + n_dropped_dup = 0 + pbar = tqdm(desc='index', unit='row', dynamic_ncols=True) + + batch: List[Dict[str, Any]] = [] + + def _flush(rows: List[Dict[str, Any]]) -> None: + nonlocal n_kept, n_dropped_compress, n_dropped_sim + if not rows: + return + # Phase 1 — compress query (RAG_QUERY_HINT) and cot (RAG_THINKING_HINT). + # Short queries bypass condenser (passthrough) — matches training behaviour. + long_q_indices = [i for i, r in enumerate(rows) if len(r['query_raw']) >= MIN_TEXT_CHARS] + q_compressed: List[Optional[str]] = [None] * len(rows) + for i, r in enumerate(rows): + if len(r['query_raw']) < MIN_TEXT_CHARS: + q_compressed[i] = r['query_raw'] + if long_q_indices: + long_results = _resolve_compressed( + sampler, api, [rows[i]['query_raw'] for i in long_q_indices], RAG_QUERY_HINT) + for idx, res in zip(long_q_indices, long_results): + q_compressed[idx] = res + c_compressed = _resolve_compressed( + sampler, api, [r['cot_raw'] for r in rows], RAG_THINKING_HINT) + kept_rows: List[Dict[str, Any]] = [] + for r, q_cmp, c_cmp in zip(rows, q_compressed, c_compressed): + if not q_cmp or not c_cmp: + n_dropped_compress += 1 + _log_miss(misses_path, misses_lock, { + 'id': r['id'], 'source': r['source'], 'reason': 'compress_fail', + 'query_raw_head': _short(r['query_raw'], 200), + 'cot_raw_head': _short(r['cot_raw'], 200), + }) + continue + r['query_compressed'] = q_cmp + r['cot_compressed'] = c_cmp + kept_rows.append(r) + if not kept_rows: + return + # Phase 2 — encode anchor (compressed query) + positive (compressed cot). + anchor_emb = get_embeddings( + emb_model, emb_template, [r['query_compressed'] for r in kept_rows], role='anchor') + positive_emb = get_embeddings( + emb_model, emb_template, [r['cot_compressed'] for r in kept_rows], role='positive') + sims = (anchor_emb * positive_emb).sum(axis=1).astype(np.float32) + # Phase 3 — sim filter + LanceDB insert. + to_insert: List[Dict[str, Any]] = [] + for idx, (r, sim_val) in enumerate(zip(kept_rows, sims)): + tag = 'KEEP' if sim_val >= SIM_THRESHOLD else 'DROP' + print(f'[{tag} sim={sim_val:.4f}] {r["source"][:24]} ' + f'q={_short(r["query_raw"], 60)!r} ' + f'cot={_short(r["cot_raw"], 60)!r}', flush=True) + if sim_val < SIM_THRESHOLD: + n_dropped_sim += 1 + _log_miss(misses_path, misses_lock, { + 'id': r['id'], 'source': r['source'], 'reason': 'sim_low', + 'sim': float(sim_val), + 'query_raw': r['query_raw'], + 'cot_raw': r['cot_raw'], + 'query_compressed': r['query_compressed'], + 'cot_compressed': r['cot_compressed'], + }) + continue + to_insert.append({ + 'id': r['id'], + 'vector': positive_emb[idx].tolist(), + 'thinking_raw': r['cot_raw'], + 'query_raw': r['query_raw'], + 'cot_compressed': r['cot_compressed'], + 'query_compressed': r['query_compressed'], + 'source': r['source'], + 'domain': DOMAIN_MAP.get(r['source'], 'mixed'), + 'language': _detect_lang(r['cot_raw']), + 'sim': float(sim_val), + }) + if to_insert: + tbl.add(to_insert) + n_kept += len(to_insert) + indexed.update(r['id'] for r in to_insert) + + try: + for row in _stream_corpus(total=args.total, load_from_cache_file=not args.no_cache, + max_rows=args.max_rows): + n_seen += 1 + if args.limit and n_kept >= args.limit: + break + rid = row.get('id') or '' + if not rid: + continue + if rid in indexed: + n_dropped_dup += 1 + continue + user_query, cot = _extract_query_cot(row) + if not user_query or len(cot) < MIN_TEXT_CHARS: + n_dropped_short += 1 + continue + batch.append({ + 'id': rid, + 'source': row.get('source') or 'unknown', + 'query_raw': user_query, + 'cot_raw': cot, + }) + if len(batch) >= args.batch_size: + _flush(batch) + batch.clear() + pbar.set_postfix(kept=n_kept, sim_drop=n_dropped_sim, + cmp_drop=n_dropped_compress, refresh=False) + pbar.update(1) + if batch: + _flush(batch) + batch.clear() + finally: + pbar.close() + + sys.stderr.write( + f'[build] seen={n_seen} kept={n_kept} sim_drop={n_dropped_sim} ' + f'cmp_drop={n_dropped_compress} short_drop={n_dropped_short} ' + f'dup_skip={n_dropped_dup}\n') + + # ---- Build vector index for fast retrieval ------------------------------ + if n_kept >= 64 and not args.skip_index: + sys.stderr.write('[build] creating IVF_PQ index (metric=dot)...\n') + n_partitions = max(8, min(256, n_kept // 1000 + 1)) + try: + tbl.create_index( + metric='dot', + vector_column_name='vector', + num_partitions=n_partitions, + num_sub_vectors=16, + index_type='IVF_PQ', + replace=True, + ) + except Exception as exc: # noqa: BLE001 + sys.stderr.write(f'[build] index build failed: {exc} ' + '(table is still queryable via brute-force scan)\n') + sys.stderr.write(f'[build] done. table rows={tbl.count_rows()}\n') + + +# =========================================================================== +# Eval pipeline (self-recall on indexed rows) +# =========================================================================== + +def eval_recall(args: argparse.Namespace, + sampler: vLLMSampler, + emb_model: TransformersModel, + emb_template: Qwen3_5Template, + api: Optional[OpenAIClient]) -> None: + """Probe each gold query against the index; report recall@k. + + Self-recall semantics: only rows whose ``id`` is already present in the + index are probed. The corresponding ``cot``-keyed vector must be retrieved + by encoding the **raw user query** through the condenser → embedder + pipeline (anchor side). The match is correct iff the retrieved row's + ``id`` equals the probe row's ``id``. + """ + import lancedb + db = lancedb.connect(args.db_path) + if args.table not in db.table_names(): + raise SystemExit(f'[eval] table "{args.table}" does not exist in {args.db_path}') + tbl = db.open_table(args.table) + indexed_ids = _existing_ids(tbl) + sys.stderr.write(f'[eval] table rows={tbl.count_rows()} indexed_ids={len(indexed_ids)}\n') + if not indexed_ids: + sys.stderr.write('[eval] empty index — nothing to evaluate.\n') + return + + ks = sorted({1, 5, 10, args.top_k}) + hits = {k: 0 for k in ks} + per_source_hits: Dict[str, Dict[int, int]] = {} + per_source_total: Dict[str, int] = {} + probed = 0 + + pbar = tqdm(desc='eval', unit='probe', dynamic_ncols=True) + batch_rows: List[Dict[str, Any]] = [] + + def _flush(rows: List[Dict[str, Any]]) -> None: + nonlocal probed + if not rows: + return + compressed = _resolve_compressed( + sampler, api, [r['query_raw'] for r in rows], RAG_QUERY_HINT) + useful = [(r, c) for r, c in zip(rows, compressed) if c] + if not useful: + return + anchor_emb = get_embeddings( + emb_model, emb_template, [c for _, c in useful], role='anchor') + for (r, _), vec in zip(useful, anchor_emb): + res = ( + tbl.search(vec.astype(np.float32).tolist()) + .metric('dot') + .limit(max(ks)) + .select(['id', 'source']) + .to_list() + ) + hit_ids = [item['id'] for item in res] + try: + rank = hit_ids.index(r['id']) + except ValueError: + rank = -1 + for k in ks: + if 0 <= rank < k: + hits[k] += 1 + per_source_hits.setdefault(r['source'], {kk: 0 for kk in ks})[k] += 1 + per_source_total[r['source']] = per_source_total.get(r['source'], 0) + 1 + per_source_hits.setdefault(r['source'], {kk: 0 for kk in ks}) + probed += 1 + pbar.update(len(useful)) + + try: + for row in _stream_corpus(total=args.total, load_from_cache_file=not args.no_cache, + max_rows=args.max_rows): + if probed + len(batch_rows) >= args.eval_size: + break + rid = row.get('id') or '' + if not rid or rid not in indexed_ids: + continue + user_query, _ = _extract_query_cot(row) + if not user_query or len(user_query) < MIN_TEXT_CHARS: + continue + batch_rows.append({ + 'id': rid, + 'source': row.get('source') or 'unknown', + 'query_raw': user_query, + }) + if len(batch_rows) >= args.batch_size: + _flush(batch_rows) + batch_rows.clear() + if batch_rows: + _flush(batch_rows) + finally: + pbar.close() + + if probed == 0: + sys.stderr.write( + '[eval] no probed rows — index empty, queries too short, or ' + 'corpus exhausted before eval-size?\n') + return + + print('\n=== Recall @ k (self-recall, gold present in index) ===') + print(f'probed = {probed}') + for k in ks: + print(f' recall@{k:<3} = {hits[k]/probed:.4f} ({hits[k]}/{probed})') + + print('\n=== Per-source recall@10 ===') + for src in sorted(per_source_total): + tot = per_source_total[src] + h10 = per_source_hits.get(src, {}).get(10, 0) + print(f' {src:<48s} {h10/tot:.4f} ({h10}/{tot})') + + +# =========================================================================== +# CLI +# =========================================================================== + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument('--mode', choices=['build', 'eval', 'both'], default='build') + p.add_argument('--db-path', default='./output/thinking_rag/lance.db', + help='LanceDB on-disk directory (persisted across runs).') + p.add_argument('--table', default='thinking_traces', + help='LanceDB table name within --db-path.') + p.add_argument('--total', type=int, default=0, + help='Total dataset rows to scale corpus to (0 = base sizes from the loader module).') + p.add_argument('--dataset-module', default='dataset_index', + choices=['dataset_index', 'dataset_think'], + help='Which loader to use: dataset_index (RAG profile) or ' + 'dataset_think (training mix).') + p.add_argument('--limit', type=int, default=0, + help='Stop building once this many rows are kept (0 = no cap).') + p.add_argument('--max-rows', type=int, default=0, + help='Truncate corpus to this many rows AFTER get_dataset (0 = no cap). ' + 'Use this instead of --total to avoid invalidating the dataset cache.') + p.add_argument('--batch-size', type=int, default=64, + help='Rows per condense+encode batch.') + p.add_argument('--no-cache', action='store_true', + help='Disable load_from_cache_file in dataset_think.get_dataset.') + p.add_argument('--overwrite', action='store_true', + help='Drop the table before build and start fresh.') + p.add_argument('--skip-index', action='store_true', + help='Skip IVF_PQ index build at the end (debug).') + p.add_argument('--misses-log', default='', + help='Path for filtered-row JSONL log (defaults to /.misses.jsonl).') + + # eval-only + p.add_argument('--eval-size', type=int, default=500, + help='Number of probes for self-recall evaluation.') + p.add_argument('--top-k', type=int, default=10, + help='Largest k to report. Smaller ks (1, 5) are always reported.') + + return p.parse_args() + + +def main() -> None: + args = parse_args() + Path(args.db_path).mkdir(parents=True, exist_ok=True) + + global _GET_DATASET + if args.dataset_module == 'dataset_think': + from dataset_think import get_dataset as _swap + _GET_DATASET = _swap + sys.stderr.write(f'[main] dataset loader: {args.dataset_module}\n') + + # Build/eval both depend on the same Twinkle stack — initialize once. + sampler_mesh, emb_mesh = initialize_twinkle() + sys.stderr.write(f'[main] twinkle initialized: ' + f'sampler ranks 0-{SAMPLER_GPUS - 1} (TP={SAMPLER_GPUS}), ' + f'emb_model ranks {SAMPLER_GPUS}-{NUM_GPUS - 1} (DP={EMB_GPUS}).\n') + + sys.stderr.write('[main] starting vLLM condenser sampler...\n') + sampler = build_sampler(sampler_mesh) + sys.stderr.write('[main] starting embedding TransformersModel...\n') + emb_model, emb_template = build_emb_model(emb_mesh) + + api: Optional[OpenAIClient] = None + if COMPRESS_API_KEY: + api = OpenAIClient( + model=COMPRESS_API_MODEL, + api_key=COMPRESS_API_KEY, + base_url=COMPRESS_BASE_URL, + ) + else: + sys.stderr.write( + '[main] WARNING: COMPRESS_API_KEY unset — truncated rows will be dropped.\n') + + if args.mode in ('build', 'both'): + build_index(args, sampler, emb_model, emb_template, api) + if args.mode in ('eval', 'both'): + eval_recall(args, sampler, emb_model, emb_template, api) + + +if __name__ == '__main__': + main() diff --git a/cookbook/exp/embedding/dataset_index.py b/cookbook/exp/embedding/dataset_index.py new file mode 100644 index 000000000..7d2905a59 --- /dev/null +++ b/cookbook/exp/embedding/dataset_index.py @@ -0,0 +1,718 @@ +"""RAG-index corpus loader — abstract reasoning skills + textbook-style methods. + +Distinct from training-time ``dataset_think.py``. Optimizes for **abstraction +density**, not raw coverage: every row should encode a transferable method, +theorem, or solution pattern that downstream queries can retrieve as a +"use-when-X-do-Y" recipe. + +Single-table design (``thinking_traces``); EMBED_QUERY_COT condense step in +``build_thinking_rag_index`` homogenizes thinking-style and textbook-style +content into the same retrieval form, so dual-table is unnecessary. The +``source`` field carries the original dataset name for eval-time +domain-bucket diagnostics. + +Output schema matches ``dataset_think.get_dataset()``: ``{id, source, messages}`` +with ``messages[1].reasoning_content`` carrying the CoT. + +Mix (≈3.6M rows base, 10 datasets): + Math thinking 23% — OpenMathReasoning + OpenR1-Math-220k + s1K-1.1 + Code thinking 19% — OpenCodeReasoning-2 + codeforces-cots + Cross-domain R1 39% — Bespoke-Stratos + dolphin-r1 + reasoning-v1-20m + + natural_reasoning + Textbook synth 17% — cosmopedia v1 (auto_math_text, chunked by H2) + Olympiad solutions <1% — Omni-MATH + +Dropped: camel-ai/{physics,chemistry,biology} (zip-only, no parquet/jsonl) and +swift/stack-exchange-paired (dataset_infos.json/data layout mismatch); the +textbook-density gap is covered by a larger cosmopedia slice. + +Textbook processors synthesize a question from the chapter heading and place +the explanatory body into the ``cot`` field — embedding+condense reads +``query | cot`` so the textbook prose becomes a retrievable method. + +Field extraction is defensive: each processor tries multiple plausible column +names and silently drops rows that miss a usable signal. Inspect +``dropped_index.jsonl`` after the first run to verify field-name guesses. +""" +import re +from typing import Any, Dict, List, Optional + +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.preprocessor import Preprocessor + +from dataset_think import _THINK_RE, _hash_id, _register, ToMessagesProcessor + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# Sky-T1 / Bespoke-Stratos custom markers (used in place of ). +_BOT_RE = re.compile( + r'<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>', re.DOTALL) +_BOS_RE = re.compile( + r'<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>', re.DOTALL) + +# H2 heading split for cosmopedia-style markdown chunks. +_H2_RE = re.compile(r'^##\s+(.+?)\s*$', re.MULTILINE) + + +def _split_think(text: str) -> tuple: + """Return ``(cot, response)``; cot empty if no ```` block found.""" + if not text: + return '', '' + m = _THINK_RE.search(text) + if not m: + return '', text.strip() + return m.group(1).strip(), text[m.end():].strip() + + +def _split_sky_t1(text: str) -> tuple: + """Return ``(cot, response)`` for Sky-T1 / Bespoke-Stratos marker format.""" + if not text: + return '', '' + bot = _BOT_RE.search(text) + bos = _BOS_RE.search(text) + cot = bot.group(1).strip() if bot else '' + sol = bos.group(1).strip() if bos else '' + return cot, sol + + +def _from_messages(messages: Any) -> tuple: + """Pull (first_user, first_assistant) from OpenAI/ShareGPT-style list.""" + if not isinstance(messages, list): + return '', '' + query, assistant = '', '' + for msg in messages: + if not isinstance(msg, dict): + continue + role = msg.get('role') or msg.get('from') or '' + content = msg.get('content') or msg.get('value') or '' + if not isinstance(content, str): + continue + if role in ('user', 'human') and not query: + query = content.strip() + elif role in ('assistant', 'gpt') and not assistant: + assistant = content.strip() + break + return query, assistant + + +def _chunk_by_h2(text: str, min_chars: int = 200, max_chars: int = 6000): + """Split markdown text on ``## `` headings; yield ``(title, body)`` pairs.""" + if not text: + return + matches = list(_H2_RE.finditer(text)) + if not matches: + head = text.strip()[:80].splitlines()[0] if text.strip() else '' + body = text.strip() + if head and min_chars <= len(body) <= max_chars: + yield head, body + return + for i, m in enumerate(matches): + title = m.group(1).strip() + start = m.end() + end = matches[i + 1].start() if i + 1 < len(matches) else len(text) + body = text[start:end].strip() + if min_chars <= len(body) <= max_chars and title: + yield title, body + + +# =========================================================================== +# Math thinking +# =========================================================================== + +OPEN_MATH_REASONING_REPO = 'ms://AI-ModelScope/OpenMathReasoning' + + +class OpenMathReasoningProcessor(Preprocessor): + """OpenMathReasoning → ``{id, source, query, cot, response}``. + + Schema: ``problem``, ``generated_solution`` (R1 trace with ````), + ``expected_answer``. The ``cot`` *split* (not column) is the long-CoT + portion — TIR/genselect/additional_problems sit in sibling splits and + are filtered at load time, not row-level. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('problem') or row.get('question') or '').strip() + assistant = (row.get('generated_solution') or row.get('solution') + or row.get('output') or '').strip() + if not query or not assistant: + continue + cot, response = _split_think(assistant) + if not cot: + continue + if not response: + response = (row.get('expected_answer') or row.get('answer') or '').strip() + if not response: + continue + out.append({ + 'id': _hash_id('open_math_reasoning', f'{query}\n{response}'), + 'source': 'OpenMathReasoning', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +OPEN_R1_MATH_REPO = 'ms://open-r1/OpenR1-Math-220k' + + +class OpenR1MathProcessor(Preprocessor): + """OpenR1-Math-220k → ``{id, source, query, cot, response}``. + + Schema: ``problem``, ``solution``, ``answer``, ``generations`` (list of + R1 traces), ``correctness_math_verify`` (parallel bool list). Pick the + first generation whose math-verify passed; fall back to ``solution``. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('problem') or row.get('question') or '').strip() + if not query: + continue + assistant = '' + gens = row.get('generations') + verifies = row.get('correctness_math_verify') + if isinstance(gens, list): + if isinstance(verifies, list) and len(verifies) == len(gens): + for g, v in zip(gens, verifies): + if v and isinstance(g, str) and g.strip(): + assistant = g.strip() + break + if not assistant: + for g in gens: + if isinstance(g, str) and g.strip(): + assistant = g.strip() + break + if not assistant: + assistant = (row.get('solution') or '').strip() + if not assistant: + continue + cot, response = _split_think(assistant) + if not cot: + continue + if not response: + response = (row.get('answer') or '').strip() + if not response: + continue + out.append({ + 'id': _hash_id('open_r1_math', f'{query}\n{response}'), + 'source': 'OpenR1-Math-220k', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +S1K_REPO = 'ms://simplescaling/s1K-1.1' + + +class S1KProcessor(Preprocessor): + """s1K-1.1 → ``{id, source, query, cot, response}``. + + Schema: ``question`` + ``deepseek_thinking_trajectory`` (or + ``thinking_trajectories`` legacy) + ``deepseek_attempt`` (final answer). + Hand-curated peak-abstraction set, kept whole. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('question') or row.get('problem') or '').strip() + thinking = (row.get('deepseek_thinking_trajectory') + or row.get('thinking_trajectories') + or row.get('thinking') or '') + if isinstance(thinking, list): + thinking = '\n\n'.join(t for t in thinking if isinstance(t, str)) + cot = (thinking or '').strip() + response = (row.get('deepseek_attempt') or row.get('attempt') + or row.get('answer') or row.get('solution') or '').strip() + if not query or not cot or not response: + continue + out.append({ + 'id': _hash_id('s1k', f'{query}\n{response}'), + 'source': 's1K-1.1', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# =========================================================================== +# Code thinking +# =========================================================================== + +OPEN_CODE_REASONING_REPO = 'ms://nv-community/OpenCodeReasoning-2' + + +class OpenCodeReasoning2Processor(Preprocessor): + """OpenCodeReasoning-2 → ``{id, source, query, cot, response}``. + + Schema: ``input``/``problem``, plus per-model R1-style trace columns + (``r1_generation``, ``qwq_generation``, etc.). Prefer the ``r1`` trace; + fall back to ``solution``. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('input') or row.get('problem') + or row.get('question') or '').strip() + # OCR-2 'python' split ships dirty rows where question is literally '-'; + # the real prompt is buried in r1_generation and not recoverable here. + if not query or query == '-': + continue + assistant = (row.get('r1_generation') or row.get('reasoning_content') + or row.get('solution') or row.get('output') or '').strip() + if not assistant: + continue + cot, response = _split_think(assistant) + if not cot: + continue + if not response: + response = (row.get('expected_solution') or row.get('answer') or '').strip() + if not response: + continue + out.append({ + 'id': _hash_id('opencode_reasoning2', f'{query}\n{response}'), + 'source': 'OpenCodeReasoning-2', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +CODEFORCES_COTS_REPO = 'ms://open-r1/codeforces-cots' + + +class CodeforcesCotsProcessor(Preprocessor): + """codeforces-cots → ``{id, source, query, cot, response}``. + + Schema: ``description``/``problem``, ``generation``/``solution`` (R1 + trace with ```` + final code). Algorithmic patterns at high + abstraction density. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('description') or row.get('problem') + or row.get('input') or row.get('question') or '').strip() + assistant = (row.get('generation') or row.get('solution') + or row.get('output') or '').strip() + if not query or not assistant: + continue + cot, response = _split_think(assistant) + if not cot or not response: + continue + out.append({ + 'id': _hash_id('codeforces_cots', f'{query}\n{response}'), + 'source': 'codeforces-cots', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# =========================================================================== +# Cross-domain R1 +# =========================================================================== + +BESPOKE_STRATOS_REPO = 'ms://bespokelabs/Bespoke-Stratos-17k' + + +class BespokeStratosProcessor(Preprocessor): + """Bespoke-Stratos-17k → ``{id, source, query, cot, response}``. + + Schema: ``conversations`` (ShareGPT). Assistant content uses Sky-T1 + markers ``<|begin_of_thought|>...<|end_of_thought|>`` then + ``<|begin_of_solution|>...<|end_of_solution|>``. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query, assistant = _from_messages( + row.get('conversations') or row.get('messages')) + if not query or not assistant: + continue + cot, response = _split_sky_t1(assistant) + if not cot: + cot, response = _split_think(assistant) + if not cot or not response: + continue + out.append({ + 'id': _hash_id('bespoke_stratos', f'{query}\n{response}'), + 'source': 'Bespoke-Stratos-17k', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +DOLPHIN_R1_REPO = 'ms://AI-ModelScope/dolphin-r1' + + +class DolphinR1Processor(Preprocessor): + """dolphin-r1 → ``{id, source, query, cot, response}``. + + Schema (reasoning-deepseek subset): ``messages=[system, user]`` (no + assistant turn) + flat ``reasoning`` (CoT) + ``answer`` (final response) + + ``model``. Pull the user turn as query, ``reasoning``/``answer`` as + cot/response. Fallback to embedded ```` for legacy rows. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + msgs = row.get('messages') or row.get('conversations') + query = '' + if isinstance(msgs, list): + for msg in msgs: + if not isinstance(msg, dict): + continue + role = msg.get('role') or msg.get('from') or '' + content = msg.get('content') or msg.get('value') or '' + if role in ('user', 'human') and isinstance(content, str): + query = content.strip() + cot = (row.get('reasoning') or row.get('reasoning_content') or '').strip() + response = (row.get('answer') or '').strip() + if (not cot or not response) and isinstance(msgs, list): + _, assistant = _from_messages(msgs) + if assistant: + c2, r2 = _split_think(assistant) + if c2: + cot = cot or c2 + response = response or r2 or assistant + if not query or not cot or not response: + continue + out.append({ + 'id': _hash_id('dolphin_r1', f'{query}\n{response}'), + 'source': 'dolphin-r1', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +GLAIVE_REASONING_REPO = 'ms://glaiveai/reasoning-v1-20m' + + +class GlaiveReasoningProcessor(Preprocessor): + """reasoning-v1-20m → ``{id, source, query, cot, response}``. + + Schema: ``prompt``, ``response`` (R1 trace with ```` + answer). + Largest cross-domain corpus in the mix; downsample aggressively. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('prompt') or row.get('question') + or row.get('input') or '').strip() + assistant = (row.get('response') or row.get('output') + or row.get('answer') or '').strip() + if not query or not assistant: + continue + cot, response = _split_think(assistant) + if not cot or not response: + continue + out.append({ + 'id': _hash_id('glaive_reasoning', f'{query}\n{response}'), + 'source': 'reasoning-v1-20m', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +NATURAL_REASONING_REPO = 'ms://facebook/natural_reasoning' + + +class NaturalReasoningProcessor(Preprocessor): + """natural_reasoning → ``{id, source, query, cot, response}``. + + Schema: ``question`` + ``reference_answer`` + ``responses=[{response_model, + response}]``. The ``response`` field itself is the step-by-step CoT + (``## Step 1...## Step 2...``); there is no separate ``reasoning`` key. + Map ``responses[i].response`` → cot, ``reference_answer`` → response. + Rows with empty ``reference_answer`` (~18% per README) are dropped. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('question') or '').strip() + if not query: + continue + cot = '' + responses = row.get('responses') + if isinstance(responses, list): + for r in responses: + if not isinstance(r, dict): + continue + txt = (r.get('response') or r.get('reasoning') + or r.get('thinking') or r.get('answer') or '').strip() + if txt: + cot = txt + break + if not cot: + cot = (row.get('reasoning') or row.get('thinking') + or row.get('response') or '').strip() + response = (row.get('reference_answer') or row.get('answer') or '').strip() + if not cot or not response: + continue + out.append({ + 'id': _hash_id('natural_reasoning', f'{query}\n{response}'), + 'source': 'natural_reasoning', + 'query': query, + 'cot': cot, + 'response': response, + }) + return self.map_row_to_col(out) + + +# =========================================================================== +# Textbook-style — synthesize query from chapter heading; body → cot +# =========================================================================== + +COSMOPEDIA_REPO = 'ms://HuggingFaceTB/cosmopedia' + +class CosmopediaProcessor(Preprocessor): + """cosmopedia v1 → ``{id, source, query, cot, response}``. + + Schema: ``prompt`` (writing instruction), ``text`` (full chapter body), + ``format``/``audience``/``seed_data``. The subset is selected at load + time (``subset_name='auto_math_text'`` — densest math-textbook slice); + H2 chunking inside each row yields synthetic queries + (``Explain {heading}``) with the body placed into ``cot``. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + text = (row.get('text') or row.get('content') or '').strip() + if not text: + continue + for title, body in _chunk_by_h2(text): + # Heading-only "Explain: X" was 1-2 tokens and impossible to align + # with full-section cot. Promote the section's lead paragraph into + # the query so anchor carries real semantic content. + parts = body.split('\n\n', 1) + first_para = parts[0].strip() + rest = parts[1].strip() if len(parts) > 1 else '' + if len(first_para) < 256 or len(rest) < 256: + continue + query = f'{title}\n\n{first_para}' if title else first_para + out.append({ + 'id': _hash_id('cosmopedia', f'{title}\n{first_para[:200]}'), + 'source': 'cosmopedia-v1', + 'query': query, + 'cot': rest, + 'response': '', + }) + return self.map_row_to_col(out) + + +OMNI_MATH_REPO = 'ms://AI-ModelScope/Omni-MATH' + + +class OmniMathProcessor(Preprocessor): + """Omni-MATH → ``{id, source, query, cot, response}``. + + Schema: ``problem``, ``solution`` (full proof), ``answer``, ``domain``, + ``difficulty``. Olympiad-grade derivations — solution body → cot, + answer → response. + """ + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + out: List[Dict[str, Any]] = [] + for row in rows: + query = (row.get('problem') or row.get('question') or '').strip() + solution = (row.get('solution') or '').strip() + answer = (row.get('answer') or row.get('expected_answer') or '').strip() + if not query or not solution: + continue + out.append({ + 'id': _hash_id('omni_math', f'{query}\n{solution[:200]}'), + 'source': 'Omni-MATH', + 'query': query, + 'cot': solution, + 'response': answer, + }) + return self.map_row_to_col(out) + + +# =========================================================================== +# Mix configuration — base sizes target ≈3.6M total rows +# =========================================================================== + +_BASE_SIZES = { + 'open_math_reasoning': 600_000, + 'open_r1_math': 220_000, + 's1k': 1_000, + 'opencode_reasoning2': 500_000, + 'codeforces_cots': 200_000, + 'bespoke_stratos': 17_000, + 'dolphin_r1': 400_000, + 'glaive_reasoning': 800_000, + 'natural_reasoning': 200_000, + 'cosmopedia': 700_000, + 'omni_math': 4_000, +} + + +def _scaled_sizes(total: Optional[int]) -> Dict[str, int]: + if total is None or total <= 0: + return dict(_BASE_SIZES) + scale = total / sum(_BASE_SIZES.values()) + return {k: max(1, int(round(v * scale))) for k, v in _BASE_SIZES.items()} + + +def _build_dataset(total: Optional[int] = None, + load_from_cache_file: bool = True) -> Dataset: + sizes = _scaled_sizes(total) + dataset = Dataset() + + _register(dataset, OpenMathReasoningProcessor, + DatasetMeta(dataset_id=OPEN_MATH_REASONING_REPO, split='cot', + data_slice=range(sizes['open_math_reasoning'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, OpenR1MathProcessor, + DatasetMeta(dataset_id=OPEN_R1_MATH_REPO, split='train', + data_slice=range(sizes['open_r1_math'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, S1KProcessor, + DatasetMeta(dataset_id=S1K_REPO, split='train'), + load_from_cache_file=load_from_cache_file) + + _register(dataset, OpenCodeReasoning2Processor, + DatasetMeta(dataset_id=OPEN_CODE_REASONING_REPO, + subset_name='train', split='python', + data_slice=range(sizes['opencode_reasoning2'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, CodeforcesCotsProcessor, + DatasetMeta(dataset_id=CODEFORCES_COTS_REPO, + subset_name='solutions_w_editorials_decontaminated', + split='train', + data_slice=range(sizes['codeforces_cots'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, BespokeStratosProcessor, + DatasetMeta(dataset_id=BESPOKE_STRATOS_REPO, split='train'), + load_from_cache_file=load_from_cache_file) + + _register(dataset, DolphinR1Processor, + DatasetMeta(dataset_id=DOLPHIN_R1_REPO, + subset_name='reasoning-deepseek', split='train', + data_slice=range(sizes['dolphin_r1'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, GlaiveReasoningProcessor, + DatasetMeta(dataset_id=GLAIVE_REASONING_REPO, split='train', + data_slice=range(sizes['glaive_reasoning'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, NaturalReasoningProcessor, + DatasetMeta(dataset_id=NATURAL_REASONING_REPO, split='train', + data_slice=range(sizes['natural_reasoning'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, CosmopediaProcessor, + DatasetMeta(dataset_id=COSMOPEDIA_REPO, + subset_name='auto_math_text', split='train', + data_slice=range(sizes['cosmopedia'])), + load_from_cache_file=load_from_cache_file) + + _register(dataset, OmniMathProcessor, + DatasetMeta(dataset_id=OMNI_MATH_REPO, split='test'), + load_from_cache_file=load_from_cache_file) + + dataset.mix_dataset(False) + # Mix is concatenated in registration order; shuffle so the streaming + # consumer sees all sources interleaved instead of 600k OpenMathReasoning + # rows before it ever reaches code/textbook splits. + dataset.dataset = dataset.dataset.shuffle(seed=42) + return dataset + + +def get_dataset(total: Optional[int] = None, + dropped_log: Optional[str] = None, + load_from_cache_file: bool = True) -> Dataset: + """Build, convert to messages, and quality-filter the RAG-index corpus. + + Mirrors ``dataset_think.get_dataset``: identical signature + output + schema so ``build_thinking_rag_index`` consumes both modules unchanged. + """ + from twinkle_agentic.preprocessor import ( + DeadLoopFilter, + FixUnicodeFilter, + HardFilter, + MessageSanityFilter, + QualityPreprocessor, + RefuseFilter, + RemoveRepeatSentencesFilter, + TokenNumFilter, + TokenSoupFilter, + ) + + dataset = _build_dataset(total=total, load_from_cache_file=load_from_cache_file) + # Drop trivially-short queries (e.g. one-line math problems, OmniMath stubs) + # before message conversion — anchor side needs enough tokens to embed meaningfully. + dataset.dataset = dataset.dataset.filter( + lambda x: len((x.get('query') or '').strip()) >= 100, + num_proc=32, load_from_cache_file=load_from_cache_file) + dataset.map(ToMessagesProcessor(), remove_columns=['query', 'cot', 'response'], + load_from_cache_file=load_from_cache_file) + qp = QualityPreprocessor( + pipeline=[ + HardFilter(), + RefuseFilter(), + DeadLoopFilter(), + TokenSoupFilter(), + MessageSanityFilter(min_turns=1, max_msg_chars=200000), + FixUnicodeFilter(), + RemoveRepeatSentencesFilter(), + TokenNumFilter(max_num=32768), + ], + dropped_log_path=dropped_log or '', + ) + dataset.map(qp, num_proc=32, load_from_cache_file=load_from_cache_file) + return dataset + + +if __name__ == '__main__': + import os + dropped_log = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'dropped_index.jsonl') + if os.path.exists(dropped_log): + os.remove(dropped_log) + dataset = get_dataset(load_from_cache_file=False) + print(len(dataset)) diff --git a/cookbook/exp/embedding/train_embedding_full_ddp.py b/cookbook/exp/embedding/train_embedding_full_ddp.py index 492e29aae..5db9f786c 100644 --- a/cookbook/exp/embedding/train_embedding_full_ddp.py +++ b/cookbook/exp/embedding/train_embedding_full_ddp.py @@ -1,14 +1,12 @@ -"""LoRA embedding training with online condenser self-improvement. +"""LoRA embedding training with online compression via frozen vLLM condenser. Architecture (8 GPUs total): - Ranks 0-3 (``model``): Trainable embedding model with LoRA, InfoNCE loss. - - Ranks 4-5 (``condenser_sampler``): Frozen vLLM condenser for online compression. - - Ranks 6-7 (``condenser_model``): Trainable condenser with LoRA for self-improvement. + - Ranks 4-7 (``condenser_sampler``): Frozen vLLM condenser for online compression. -When the condenser sampler truncates (stop_reason='length'), an external OpenAI- -compatible API produces the correct compression. The failure is logged as SFT -training data. A background thread retrains the condenser on accumulated failures -mixed with condense_300K, then syncs weights back to the sampler. +When the condenser sampler truncates or regresses to the legacy schema, an +external OpenAI-compatible API produces the correct compression. The failure is +logged to failures.jsonl for offline SFT data regeneration. Launch: python cookbook/exp/train_embedding_lora_ddp.py @@ -19,6 +17,7 @@ import re import sys import threading +import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Dict, List, Literal, Optional @@ -27,7 +26,6 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger -from twinkle.checkpoint_engine import CheckpointEngineManager from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.loss import InfonceLoss @@ -35,12 +33,13 @@ from twinkle.model import TransformersModel from twinkle.processor import InputProcessor from twinkle.sampler import vLLMSampler -from twinkle.template import Template +from twinkle.template import Qwen3_5Template, Template from twinkle.utils.parallel import PosixFileLock from twinkle_agentic.protocol.openai import OpenAI as OpenAIClient sys.path.insert(0, str(Path(__file__).resolve().parent)) -from dataset_think import get_dataset # noqa: E402 +from dataset_think import get_dataset as get_dataset_think # noqa: E402 +from dataset_index import get_dataset as get_dataset_index # noqa: E402 logger = get_logger() @@ -54,29 +53,33 @@ # -- GPU placement (8 total) -------------------------------------------------- MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -CONDENSER_SAMPLER_GPUS = int(os.environ.get('CONDENSER_SAMPLER_GPUS', 2)) -CONDENSER_MODEL_GPUS = int(os.environ.get('CONDENSER_MODEL_GPUS', 2)) -NUM_GPUS = MODEL_GPUS + CONDENSER_SAMPLER_GPUS + CONDENSER_MODEL_GPUS +CONDENSER_SAMPLER_GPUS = int(os.environ.get('CONDENSER_SAMPLER_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + CONDENSER_SAMPLER_GPUS # -- Embedding training hyper-params ------------------------------------------ EMB_MAX_LENGTH = 8192 HARD_NEGATIVES = None -TEMPERATURE = 0.03 +# 0.07 keeps gradient on diag pairs until cosine clears ~0.75; 0.03 saturated near 0.40. +TEMPERATURE = 0.07 BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32)) -LEARNING_RATE = 1.5e-6 +LEARNING_RATE = 1e-5 GRADIENT_ACCUMULATION_STEPS = 1 LOG_INTERVAL = 2 -SAVE_INTERVAL = 4000 -NUM_EPOCHS = 2 +SAVE_INTERVAL = 2000 +NUM_EPOCHS = 1 TOTAL_SAMPLES: Optional[int] = None +# Post-build caps on each loader (None = no cap). Applied via .select() before mix. +THINK_CAP: Optional[int] = 400_000 +INDEX_CAP: Optional[int] = 400_000 +MIX_SHUFFLE_SEED = 42 # -- Resume from checkpoint --------------------------------------------------- -RESUME_CHECKPOINT = os.environ.get( - 'RESUME_CHECKPOINT', - './output/embedding_lora_transformers/step_16000') -RESUME_STEP = int(os.environ.get('RESUME_STEP', 16000)) +# Empty by default — build_model falls back to MODEL_ID (the published emb model). +# Set both to point at a local in-progress run only when resuming the *same* schedule. +RESUME_CHECKPOINT = os.environ.get('RESUME_CHECKPOINT', '') +RESUME_STEP = int(os.environ.get('RESUME_STEP', 0)) # -- Online-compression knobs ------------------------------------------------- # Below this length, condenser fabricates content for open-ended short prompts; @@ -87,16 +90,18 @@ COMPRESS_TOP_P = 0.5 COMPRESS_MAX_MODEL_LEN = 32768 +# How many BATCH_SIZE chunks to fetch and compress in one vLLM call. +PREFETCH_BATCH_MULTIPLIER = int(os.environ.get('PREFETCH_BATCH_MULTIPLIER', 8)) + # -- OpenAI API fallback for truncated compressions --------------------------- COMPRESS_API_KEY = os.environ.get('COMPRESS_API_KEY', '') COMPRESS_BASE_URL = os.environ.get('COMPRESS_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1') COMPRESS_MODEL = os.environ.get('COMPRESS_MODEL', 'qwen3.7-max') - -# -- Condenser retraining knobs ----------------------------------------------- -CONDENSER_DATASET_ID = 'ms://twinkle-kit/condense_300K' -CONDENSER_RETRAIN_SAMPLES = 128 -CONDENSER_RETRAIN_EPOCHS = 3 -CONDENSER_RETRAIN_LR = 1e-5 +# Minimum gap between API calls (seconds); bounds dashscope qps under provider limits. +API_MIN_INTERVAL = float(os.environ.get('API_MIN_INTERVAL', 0.1)) +API_CONCURRENCY = int(os.environ.get('API_CONCURRENCY', 8)) +# vLLM sampler timeout (seconds); if a sample() call exceeds this, fall back to API. +SAMPLER_TIMEOUT = float(os.environ.get('SAMPLER_TIMEOUT', 300)) # -- Output paths ------------------------------------------------------------- OUTPUT_DIR = f'./output/embedding_lora_{BACKEND}' @@ -204,6 +209,17 @@ _sample_counter = 0 _sample_counter_lock = threading.Lock() +_api_throttle_lock = threading.Lock() +_api_last_call = [0.0] + + +def _api_throttle(): + with _api_throttle_lock: + gap = time.monotonic() - _api_last_call[0] + if gap < API_MIN_INTERVAL: + time.sleep(API_MIN_INTERVAL - gap) + _api_last_call[0] = time.monotonic() + def _next_sample_id() -> int: global _sample_counter @@ -314,11 +330,40 @@ def save_checkpoint(model, name: str): # Compression prompt building # ============================================================================= +# Hard-templated hints: the condenser SFT prior maps `Skill` to the legacy +# `Use when: / numbered steps / Output:` skeleton on long inputs; embedding the +# exact 4-line body template + explicit negative constraints is the only way to +# override it deterministically across query and cot sides. EMBED_QUERY_Q = ( + 'Summarize this query for retrieval. ' + 'The body of ## Summary MUST follow this EXACT 4-line template — ' + 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n' + 'Topic: \n' + 'Problem: \n' + 'Skill: \n' + 'Knowledge: \n' + 'Then emit the mandatory ## More section as usual. ' + 'Topic must name the specific pattern, never generic labels.') +EMBED_QUERY_COT = ( + 'Summarize this reasoning trace for retrieval. ' + 'The body of ## Summary MUST follow this EXACT 4-line template — ' + 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n' + 'Topic: \n' + 'Problem: \n' + 'Skill: \n' + 'Knowledge: \n' + 'Then emit the mandatory ## More section as usual. ' + 'Topic must name the specific pattern, never generic labels.') + +# Legacy schema (Use when: / numbered steps / Output:) — mixed in 50/50 with the +# new schema to expose the embedder to schema-invariant semantic alignment. +# Both query and cot of the SAME pair always use the SAME schema; cross-schema +# anchors and positives would re-introduce the schema asymmetry we just fixed. +EMBED_QUERY_Q_LEGACY = ( 'What problem does this passage address, and what skill or method is needed? ' 'Topic must name the specific pattern, never generic labels. ' 'Compress into a retrieval-friendly need description.') -EMBED_QUERY_COT = ( +EMBED_QUERY_COT_LEGACY = ( 'Extract the reusable skill: trigger conditions, key steps, and expected output. ' 'Topic names the method/pattern; format as "Use when: ...", numbered steps, ' '"Output: ...". Compress into a standardized procedure for retrieval.') @@ -342,45 +387,59 @@ def _extract_query_cot(row: Dict[str, Any]): def _build_compress_prompts(rows: List[Dict[str, Any]]) -> tuple: """Build prompts for compressing both query and cot per row. - Returns (prompts, valid_indices, raw_pairs, prompt_queries, passthrough) where: + Returns (prompts, valid_indices, raw_pairs, prompt_queries, passthrough, schemas) + where: - prompts: flat-interleaved [query_0, cot_0, query_1, cot_1, ...]; ``None`` means passthrough (use raw text directly, do not call sampler) - valid_indices: which rows passed the min-length filter - raw_pairs: [(query, cot), ...] - prompt_queries: the query string used for each prompt (for failure logging) - passthrough: parallel to prompts; non-None text means "use this verbatim as qc" + - schemas: parallel to prompts; 'new' or 'legacy', drives validator branch """ prompts: List[Optional[Dict[str, Any]]] = [] valid_indices: List[int] = [] raw_pairs: List[tuple] = [] prompt_queries: List[str] = [] passthrough: List[Optional[str]] = [] + schemas: List[str] = [] + # Conservative char budget: 32768 max_length - 8192 gen - ~2k prompt overhead = ~22k tokens. + # 30k cap bounds vLLM batch latency (vLLM batches by max prompt length). + _MAX_COT_CHARS = 30_000 for i, row in enumerate(rows): query, cot = _extract_query_cot(row) if not query or len(cot) < MIN_TEXT_CHARS: continue + if len(cot) > _MAX_COT_CHARS: + continue valid_indices.append(i) raw_pairs.append((query, cot)) + # 50/50 schema mix; same schema for query+cot of one pair to keep alignment. + schema = 'legacy' if (i % 2 == 0) else 'new' + q_hint = EMBED_QUERY_Q_LEGACY if schema == 'legacy' else EMBED_QUERY_Q + c_hint = EMBED_QUERY_COT_LEGACY if schema == 'legacy' else EMBED_QUERY_COT # Short query bypasses condenser to avoid skeleton-induced hallucination. if len(query) < MIN_TEXT_CHARS: prompts.append(None) passthrough.append(query) else: - user = COMPRESS_USER.format(query=EMBED_QUERY_Q, text=query) + user = COMPRESS_USER.format(query=q_hint, text=query) prompts.append({'messages': [ {'role': 'system', 'content': COMPRESS_SYSTEM}, {'role': 'user', 'content': user}, ]}) passthrough.append(None) - prompt_queries.append(EMBED_QUERY_Q) - user = COMPRESS_USER.format(query=EMBED_QUERY_COT, text=cot) + prompt_queries.append(q_hint) + schemas.append(schema) + user = COMPRESS_USER.format(query=c_hint, text=cot) prompts.append({'messages': [ {'role': 'system', 'content': COMPRESS_SYSTEM}, {'role': 'user', 'content': user}, ]}) - prompt_queries.append(EMBED_QUERY_COT) + prompt_queries.append(c_hint) passthrough.append(None) - return prompts, valid_indices, raw_pairs, prompt_queries, passthrough + schemas.append(schema) + return prompts, valid_indices, raw_pairs, prompt_queries, passthrough, schemas def _get_first_feature(decoded_text: str, template: Template, role: str) -> Optional[Dict[str, Any]]: @@ -405,13 +464,24 @@ def _get_first_feature(decoded_text: str, template: Template, role: str) -> Opti # OpenAI API fallback # ============================================================================= -def _is_truncated_compression(text: str) -> bool: - """Detect structurally incomplete output that vLLM may report as stop_reason='stop'. +_LEGACY_USE_WHEN_RE = re.compile(r'(?im)^\s*Use when\s*:') +_SCHEMA_MARKERS = ('Problem:', 'Skill:', 'Knowledge:') + + +def _is_truncated_compression(text: str, schema: str = 'new') -> bool: + """Reject structurally incomplete OR schema-regressed condenser output. - The condenser sometimes emits a chat-template token mid-skeleton (which we then - strip), so the visible text ends mid-sentence even though stop_reason!='length'. - The COMPRESS_SYSTEM skeleton mandates a `## More` section ending in a bullet list; - its absence is an unambiguous truncation signal. + Triggers API fallback when the vLLM output: + * lacks ``## Summary`` / ``## More``, + * has an empty or unterminated ``## More`` bullet list, or + * (schema='new' only) regresses to the legacy ``Use when: / numbered-steps / + Output:`` skeleton instead of the mandated Problem/Skill/Knowledge 4-line + body — the dominant cot-side failure mode that drives sim < 0.45 drops on + the RAG index. + + For schema='legacy', body markers are intentionally NOT enforced: the legacy + template legitimately emits ``Use when:`` and the SFT prior already produces + that shape natively, so only structural completeness is checked. """ if not text or not text.strip(): return True @@ -423,11 +493,18 @@ def _is_truncated_compression(text: str) -> bool: last_line = after_more.splitlines()[-1].strip() if not (last_line.startswith('-') or last_line.endswith(')')): return True + if schema == 'new': + summary_body = text.split('## Summary', 1)[1].split('## More', 1)[0] + if _LEGACY_USE_WHEN_RE.search(summary_body): + return True + if not all(marker in summary_body for marker in _SCHEMA_MARKERS): + return True return False def _api_compress(api_client: OpenAIClient, prompt: Dict[str, Any]) -> Optional[str]: """Call external API to compress when vLLM truncates.""" + _api_throttle() trajectory = {'messages': prompt['messages']} # Cap max_tokens to leave ample prompt headroom inside the API model context. sp = SamplingParams(temperature=0.2, max_tokens=8192) @@ -446,61 +523,12 @@ def _api_compress(api_client: OpenAIClient, prompt: Dict[str, Any]) -> Optional[ return content -# ============================================================================= -# Condenser Retrainer (background thread) -# ============================================================================= - -class CondenserRetrainer: - """Async condenser self-improvement: retrains from failures, syncs to sampler.""" - - def __init__(self, condenser_model, ckpt_manager: CheckpointEngineManager, - condenser_sampler): - self._model = condenser_model - self._ckpt_manager = ckpt_manager - self._sampler = condenser_sampler - self._signal = threading.Event() - self._stop = threading.Event() - self._thread = threading.Thread(target=self._loop, daemon=True) - self._condense_300k_cache = None - self._retrain_count = 0 - # Prevents sample() and sync_weights() from running concurrently - self.sampler_lock = threading.Lock() - - def start(self): - self._thread.start() - - def stop(self): - self._stop.set() - self._signal.set() - self._thread.join(timeout=10) - - def notify_failure(self): - self._signal.set() - - def _loop(self): - while not self._stop.is_set(): - self._signal.wait(timeout=60) - if self._stop.is_set(): - break - if not self._signal.is_set(): - continue - self._signal.clear() - try: - self._retrain_and_sync() - except Exception as exc: - logger.error(f'[condenser_retrain] crashed: {exc}') - - def _retrain_and_sync(self): - # Retrain + sync temporarily disabled; failures.jsonl is written directly by _log_failure. - pass - - # ============================================================================= # Main training # ============================================================================= def train(): - # -------- Device groups (3 groups) ---------------------------------------- + # -------- Device groups (2 groups) ---------------------------------------- device_groups = [ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), @@ -508,22 +536,31 @@ def train(): DeviceGroup(name='condenser_sampler', ranks=list(range(MODEL_GPUS, MODEL_GPUS + CONDENSER_SAMPLER_GPUS)), device_type='GPU'), - DeviceGroup(name='condenser_model', - ranks=list(range(MODEL_GPUS + CONDENSER_SAMPLER_GPUS, NUM_GPUS)), - device_type='GPU'), ] model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) condenser_sampler_mesh = DeviceMesh.from_sizes( world_size=CONDENSER_SAMPLER_GPUS, dp_size=CONDENSER_SAMPLER_GPUS) - condenser_model_mesh = DeviceMesh.from_sizes( - world_size=CONDENSER_MODEL_GPUS, dp_size=1, fsdp_size=CONDENSER_MODEL_GPUS) twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) # -------- Data ----------------------------------------------------------- - dataset = get_dataset(total=TOTAL_SAMPLES, load_from_cache_file=True) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) - total_forward_steps = len(dataloader) * NUM_EPOCHS + dataset = get_dataset_think(total=TOTAL_SAMPLES, load_from_cache_file=True) + if THINK_CAP and len(dataset.dataset) > THINK_CAP: + dataset.dataset = dataset.dataset.select(range(THINK_CAP)) + if INDEX_CAP != 0: + from datasets import concatenate_datasets + ds_index = get_dataset_index(total=None, load_from_cache_file=True) + if INDEX_CAP and len(ds_index.dataset) > INDEX_CAP: + ds_index.dataset = ds_index.dataset.select(range(INDEX_CAP)) + n_think = len(dataset.dataset) + n_index = len(ds_index.dataset) + # Both loaders emit identical {id, source, messages} schema post-QP. + dataset.dataset = concatenate_datasets( + [dataset.dataset, ds_index.dataset]).shuffle(seed=MIX_SHUFFLE_SEED) + logger.info(f'[mix] think={n_think} + index={n_index} → total={len(dataset.dataset)}') + _mega_batch_size = BATCH_SIZE * PREFETCH_BATCH_MULTIPLIER + dataloader = DataLoader(dataset=dataset, batch_size=_mega_batch_size, shuffle=True) + total_forward_steps = len(dataloader) * PREFETCH_BATCH_MULTIPLIER * NUM_EPOCHS optimizer_steps = total_forward_steps // GRADIENT_ACCUMULATION_STEPS # -------- Embedding model (4 GPU) ---------------------------------------- @@ -534,10 +571,10 @@ def train(): setup_optimizer(model, optimizer_steps) model.add_metric(EmbeddingMetric, is_training=True) - # -------- Condenser sampler (2 GPU, vLLM) -------------------------------- - emb_template = Template(model_id=MODEL_ID, max_length=EMB_MAX_LENGTH, enable_thinking=False) + # -------- Condenser sampler (4 GPU, vLLM) -------------------------------- + emb_template = Qwen3_5Template(model_id=MODEL_ID, max_length=EMB_MAX_LENGTH, enable_thinking=False) # Special tokens come from the condenser tokenizer because the leak we strip is in its decoded output. - condenser_template = Template(model_id=CONDENSE_MODEL_ID, max_length=DATASET_MAX_TOKENS, + condenser_template = Qwen3_5Template(model_id=CONDENSE_MODEL_ID, max_length=DATASET_MAX_TOKENS, enable_thinking=False) _special_tokens = set(condenser_template.tokenizer.all_special_tokens) condenser_sampler = vLLMSampler( @@ -559,23 +596,32 @@ def train(): num_samples=1, ) - # -------- Condenser model (2 GPU, trainable full-param) ------------------- - condenser_model = TransformersModel( - model_id=CONDENSE_MODEL_ID, - device_mesh=condenser_model_mesh, - remote_group='condenser_model', - ) - condenser_model.set_optimizer(optimizer_cls='AdamW', lr=CONDENSER_RETRAIN_LR) - - # -------- CheckpointEngineManager: condenser_model → condenser_sampler --- - condenser_ckpt_manager = CheckpointEngineManager( - model=condenser_model, sampler=condenser_sampler) - condenser_ckpt_manager.sync_weights() - - # -------- Background retrainer ------------------------------------------- - retrainer = CondenserRetrainer(condenser_model, condenser_ckpt_manager, - condenser_sampler) - retrainer.start() + condenser_sampler._ray_get_timeout = SAMPLER_TIMEOUT + _sampler_epoch = 0 + + def _rebuild_sampler(): + """Kill stuck actors and recreate the vLLM sampler from scratch.""" + nonlocal condenser_sampler, _sampler_epoch + import ray + for actor in getattr(condenser_sampler, '_actors', []): + try: + ray.kill(actor, no_restart=True) + except Exception: + pass + logger.warning('[sampler] killed stuck actors, recreating sampler \u2026') + new = vLLMSampler( + model_id=CONDENSE_MODEL_ID, + engine_args={'gpu_memory_utilization': 0.8, 'max_model_len': COMPRESS_MAX_MODEL_LEN}, + device_mesh=condenser_sampler_mesh, + remote_group='condenser_sampler', + ) + new.set_template( + TEMPLATE_NAME, model_id=CONDENSE_MODEL_ID, enable_thinking=False, + truncation_strategy='delete', max_length=DATASET_MAX_TOKENS) + new._ray_get_timeout = SAMPLER_TIMEOUT + condenser_sampler = new + _sampler_epoch += 1 + logger.warning('[sampler] sampler rebuilt successfully') # -------- OpenAI API client for fallback --------------------------------- api_client = OpenAIClient( @@ -606,28 +652,41 @@ def train(): # -------- Train loop ----------------------------------------------------- def _sample_batch(raw_batch): """Compress via vLLM sampler; fall back to API on truncation.""" - compress_prompts, valid_indices, raw_pairs, prompt_queries, passthrough = \ + _t_enter = time.monotonic() + compress_prompts, valid_indices, raw_pairs, prompt_queries, passthrough, schemas = \ _build_compress_prompts(raw_batch) - if not compress_prompts: + _t_build = time.monotonic() + if len(compress_prompts) < 4: return None # Only submit non-passthrough prompts to the sampler. sampler_input = [p for p in compress_prompts if p is not None] sampler_pos = [ri for ri, p in enumerate(compress_prompts) if p is not None] if sampler_input: - with retrainer.sampler_lock: + try: sampler_responses = condenser_sampler.sample(sampler_input, compress_params) + except Exception as exc: + logger.warning(f'[sampler] error \u2192 API fallback: {exc}') + sampler_responses = [None] * len(sampler_input) + if 'Timeout' in type(exc).__name__: + try: + _rebuild_sampler() + except Exception as re_exc: + logger.error(f'[sampler] rebuild failed: {re_exc}') else: sampler_responses = [] + _t_sample = time.monotonic() + responses = [None] * len(compress_prompts) for resp, pos in zip(sampler_responses, sampler_pos): responses[pos] = resp # Extract decoded texts; detect truncations and fall back to API - decoded_texts: List[str] = [] + decoded_texts: List[Optional[str]] = [None] * len(compress_prompts) + fallback_indices: List[int] = [] for ri in range(len(compress_prompts)): if passthrough[ri] is not None: - decoded_texts.append(passthrough[ri]) + decoded_texts[ri] = passthrough[ri] continue resp = responses[ri] seq = resp.sequences[0] if resp and resp.sequences else None @@ -638,27 +697,33 @@ def _sample_batch(raw_batch): text = text.replace(tok, '') text = text.rstrip() - # Premature-EOS: model emits chat-template token mid-skeleton, vLLM reports - # stop_reason='stop' but the stripped text is structurally incomplete. needs_fallback = (not seq or seq.stop_reason == 'length' - or _is_truncated_compression(text)) + or _is_truncated_compression(text, schemas[ri])) if not needs_fallback: - decoded_texts.append(text) - continue - - api_result = _api_compress(api_client, compress_prompts[ri]) - # Skip logging when the API itself produced truncated output: an incomplete - # gold answer would teach the condenser to imitate broken outputs. - if api_result and not _is_truncated_compression(api_result): - decoded_texts.append(api_result) - pair_idx = ri // 2 - q_raw, c_raw = raw_pairs[pair_idx] - source_text = q_raw if ri % 2 == 0 else c_raw - _log_failure(source_text, prompt_queries[ri], api_result, - valid_indices[pair_idx]) - retrainer.notify_failure() + decoded_texts[ri] = text else: - decoded_texts.append('') + fallback_indices.append(ri) + + _api_calls = len(fallback_indices) + if fallback_indices: + from concurrent.futures import as_completed + api_futures = {} + with ThreadPoolExecutor(max_workers=API_CONCURRENCY) as api_pool: + for ri in fallback_indices: + api_futures[api_pool.submit(_api_compress, api_client, compress_prompts[ri])] = ri + for fut in as_completed(api_futures): + ri = api_futures[fut] + api_result = fut.result() + if api_result and not _is_truncated_compression(api_result, schemas[ri]): + decoded_texts[ri] = api_result + pair_idx = ri // 2 + q_raw, c_raw = raw_pairs[pair_idx] + source_text = q_raw if ri % 2 == 0 else c_raw + _log_failure(source_text, prompt_queries[ri], api_result, + valid_indices[pair_idx]) + else: + decoded_texts[ri] = '' + _t_api = time.monotonic() # Build embedding features from decoded texts emb_features: List[Dict[str, Any]] = [] @@ -673,69 +738,96 @@ def _sample_batch(raw_batch): if feat_q and feat_c: emb_features.append(feat_q) emb_features.append(feat_c) + _t_feat = time.monotonic() - if len(emb_features) < 4: - return None - return emb_features + logger.info( + f'[prefetch] prompts={len(sampler_input)} api={_api_calls} feats={len(emb_features)} | ' + f'build={_t_build - _t_enter:.1f}s ' + f'vllm={_t_sample - _t_build:.1f}s ' + f'api={_t_api - _t_sample:.1f}s feat={_t_feat - _t_api:.1f}s ' + f'total={_t_feat - _t_enter:.1f}s') + + _target = BATCH_SIZE * 2 + minibatches = [emb_features[i:i + _target] for i in range(0, len(emb_features), _target)] + minibatches = [mb for mb in minibatches if len(mb) >= 4] + return minibatches if minibatches else None cur_step = RESUME_STEP - # Compute which epoch and how many batches to skip within that epoch _batches_per_epoch = len(dataloader) - _start_epoch = cur_step // _batches_per_epoch if cur_step > 0 else 0 - _skip_batches_in_epoch = cur_step - _start_epoch * _batches_per_epoch if cur_step > 0 else 0 + _steps_per_mega = PREFETCH_BATCH_MULTIPLIER + _start_epoch = cur_step // (_batches_per_epoch * _steps_per_mega) if cur_step > 0 else 0 + _skip_batches_in_epoch = max(0, cur_step // _steps_per_mega - _start_epoch * _batches_per_epoch) + + _ema_prefetch = 0.0 + _ema_train = 0.0 + _ema_alpha = 0.1 prefetch_executor = ThreadPoolExecutor(max_workers=1) for epoch in range(_start_epoch, NUM_EPOCHS): - # Skip consumed samples for the resume epoch (shuffle order won't match - # exactly, but the correct number of samples is skipped). if _skip_batches_in_epoch > 0: - dataloader.skip_consumed_samples(_skip_batches_in_epoch * BATCH_SIZE) + dataloader.skip_consumed_samples(_skip_batches_in_epoch * _mega_batch_size) batch_iter = iter(dataloader) - # Reset skip after first resumed epoch _skip_batches_in_epoch = 0 - prefetch_future = None - first_batch = next(batch_iter, None) - if first_batch is not None: - prefetch_future = prefetch_executor.submit(_sample_batch, first_batch) - for raw_batch in batch_iter: - emb_features = prefetch_future.result() if prefetch_future else None - prefetch_future = prefetch_executor.submit(_sample_batch, raw_batch) + first = next(batch_iter, None) + future = prefetch_executor.submit(_sample_batch, first) if first else None - if emb_features is None: + for raw_mega_batch in batch_iter: + t0 = time.monotonic() + minibatches = future.result() if future else None + t_prefetch = time.monotonic() - t0 + future = prefetch_executor.submit(_sample_batch, raw_mega_batch) + + if not minibatches: continue - model.forward_backward(inputs=emb_features, task='embedding') - model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) - cur_step += 1 - - if cur_step % LOG_INTERVAL == 0: - metric = model.calculate_metric(is_training=True) - logger.info( - f'Epoch {epoch} Step {cur_step}/{total_forward_steps}, metric: {metric}') - log_dict = {} - for k, v in metric.items(): - if not v: - continue - try: - log_dict[k] = float(v) - except (ValueError, TypeError): - pass - log_dict['epoch'] = epoch - swanlab.log(log_dict, step=cur_step) - if cur_step % SAVE_INTERVAL == 0: - save_checkpoint(model, f'step_{cur_step}') - - # # Drain last prefetched batch - # if prefetch_future is not None: - # emb_features = prefetch_future.result() - # if emb_features is not None: - # model.forward_backward(inputs=emb_features, task='embedding') - # model.clip_grad_and_step() - # cur_step += 1 + for mb in minibatches: + t1 = time.monotonic() + model.forward_backward(inputs=mb, task='embedding') + model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + t_train = time.monotonic() - t1 + cur_step += 1 + + _ema_prefetch = _ema_alpha * t_prefetch + (1 - _ema_alpha) * _ema_prefetch if cur_step > RESUME_STEP + 1 else t_prefetch + _ema_train = _ema_alpha * t_train + (1 - _ema_alpha) * _ema_train if cur_step > RESUME_STEP + 1 else t_train + + if cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + _bottleneck = 'PREFETCH' if _ema_prefetch > _ema_train else 'TRAIN' + logger.info( + f'Epoch {epoch} Step {cur_step}/{total_forward_steps}, metric: {metric} | ' + f'prefetch={t_prefetch:.1f}s(ema {_ema_prefetch:.1f}) ' + f'train={t_train:.1f}s(ema {_ema_train:.1f}) ' + f'bottleneck={_bottleneck}') + log_dict = {} + for k, v in metric.items(): + if not v: + continue + try: + log_dict[k] = float(v) + except (ValueError, TypeError): + pass + log_dict['epoch'] = epoch + log_dict['prefetch_sec'] = round(t_prefetch, 2) + log_dict['train_sec'] = round(t_train, 2) + swanlab.log(log_dict, step=cur_step) + if cur_step % SAVE_INTERVAL == 0: + save_checkpoint(model, f'step_{cur_step}') + t_prefetch = 0.0 + + # Drain final mega-batch + if future: + minibatches = future.result() + future = None + if minibatches: + for mb in minibatches: + model.forward_backward(inputs=mb, task='embedding') + model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + cur_step += 1 + if cur_step % SAVE_INTERVAL == 0: + save_checkpoint(model, f'step_{cur_step}') prefetch_executor.shutdown(wait=False) - retrainer.stop() save_checkpoint(model, 'last-checkpoint') diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index 650cf67b6..13c5ccfb1 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -5,42 +5,26 @@ import twinkle from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import MegatronModel from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() +args = CLI.from_args() -MODEL_ID = 'ms://Qwen/Qwen3.5-4B' -DATASET_ID = 'ms://swift/self-cognition' -TEMPLATE_NAME = 'Qwen3_5Template' -MODEL_NAME = 'twinkle大模型' -MODEL_AUTHOR = 'ModelScope社区' -DP_SIZE = 2 -TP_SIZE = 2 -PP_SIZE = 2 -BATCH_SIZE = 16 -LEARNING_RATE = 1e-4 -LOG_INTERVAL = 5 -EVAL_INTERVAL = 20 -EVAL_SAMPLES = 100 -TRAIN_SAMPLES = 1000 - -OUTPUT_DIR = './output/megatron_tp' -RESUME_FROM_CHECKPOINT = None -RESUME_ONLY_MODEL = False -IGNORE_DATA_SKIP = False -ADAPTER_NAME = 'default' - -device_mesh = DeviceMesh.from_sizes(dp_size=DP_SIZE, tp_size=TP_SIZE, pp_size=PP_SIZE) -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +device_mesh = DeviceMesh.from_sizes(dp_size=args.infra.dp_size, tp_size=args.infra.tp_size, pp_size=args.infra.pp_size) +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) def build_dataset(num_samples: int) -> Dataset: - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) - dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(num_samples))) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) + dataset.map(SelfCognitionProcessor( + args.extra.get('model_name', 'twinkle大模型'), + args.extra.get('model_author', 'ModelScope社区'), + )) dataset.encode() return dataset @@ -48,42 +32,45 @@ def build_dataset(num_samples: int) -> Dataset: def save_checkpoint(model: MegatronModel, checkpoint_name: str, dataloader: DataLoader): model.save( checkpoint_name, - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, + output_dir=args.training.output_dir, + adapter_name=args.lora.adapter_name, + save_optimizer=args.checkpoint.save_optimizer, consumed_train_samples=dataloader.get_state()['consumed_train_samples'], ) def evaluate(model): - dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + eval_samples = args.training.eval_samples or 100 + dataloader = DataLoader(dataset=build_dataset(eval_samples), batch_size=args.training.batch_size) for batch in tqdm(dataloader): model.forward_only(inputs=batch) return model.calculate_metric(is_training=False) def train(): - dataset = build_dataset(TRAIN_SAMPLES) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + train_samples = args.training.train_samples or 1000 + dataset = build_dataset(train_samples) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) - model = MegatronModel(model_id=MODEL_ID) + model = MegatronModel(model_id=args.model.model_id) - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + lora_config = LoraConfig(**args.get_lora_args()) # Comment this to use full-parameter training - model.add_adapter_to_model(ADAPTER_NAME, lora_config) - model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE) - model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) + model.add_adapter_to_model(args.lora.adapter_name, lora_config) + model.set_optimizer(optimizer_cls='default', lr=args.optimizer.learning_rate) + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=args.scheduler.num_warmup_steps, + lr_decay_steps=len(dataloader)) start_step = 0 - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + if args.training.resume_from_checkpoint: + checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve() kwargs = {} - if ADAPTER_NAME: - kwargs['adapter_name'] = ADAPTER_NAME + if args.lora.adapter_name: + kwargs['adapter_name'] = args.lora.adapter_name progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - if not IGNORE_DATA_SKIP: + str(checkpoint_path), resume_only_model=args.training.resume_only_model, **kwargs) + if not args.training.ignore_data_skip: dataloader.resume_from_checkpoint(progress['consumed_train_samples']) start_step = progress['cur_step'] @@ -92,14 +79,15 @@ def train(): logger.info(f'Total steps: {len(dataloader)}') best_loss = float('inf') + eval_interval = args.training.eval_interval or 20 for step, batch in enumerate(dataloader, start=start_step): model.forward_backward(inputs=batch) model.clip_grad_and_step() - if step % LOG_INTERVAL == 0: + if step % args.training.log_interval == 0: metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % EVAL_INTERVAL == 0: + if step > 0 and step % eval_interval == 0: metrics = evaluate(model) logger.info(f'Eval metric: {metrics}') metrics['step'] = step diff --git a/cookbook/megatron/tp.sh b/cookbook/megatron/tp.sh index 5516130e3..789c54379 100644 --- a/cookbook/megatron/tp.sh +++ b/cookbook/megatron/tp.sh @@ -1 +1,23 @@ -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp.py +#!/usr/bin/env bash +set -euo pipefail + +# Megatron TP + LoRA training. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash tp.sh --model-id ms://Qwen/Qwen3.5-4B --tp-size 4 + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --nproc_per_node=8 tp.py \ + --model-id ms://Qwen/Qwen3.5-4B \ + --dataset-id ms://swift/self-cognition \ + --template-cls Qwen3_5Template \ + --dp-size 4 \ + --tp-size 2 \ + --batch-size 8 \ + --lr 1e-4 \ + --train-samples 1000 \ + --log-interval 10 \ + --eval-interval 20 \ + --output-dir ./output/megatron_tp \ + --model-name twinkle大模型 \ + --model-author ModelScope社区 \ + "$@" diff --git a/cookbook/megatron/tp_moe.py b/cookbook/megatron/tp_moe.py index a13b0e58a..11e2c7d84 100644 --- a/cookbook/megatron/tp_moe.py +++ b/cookbook/megatron/tp_moe.py @@ -1,29 +1,38 @@ -import os from peft import LoraConfig from tqdm import tqdm import twinkle from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import MegatronModel from twinkle.preprocessor import SelfCognitionProcessor +logger = get_logger() +args = CLI.from_args() + # Construct a device_mesh, tp=pp=ep=dp=2 -device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2, ep_size=2, sequence_parallel=True) +device_mesh = DeviceMesh.from_sizes( + dp_size=args.infra.dp_size, tp_size=args.infra.tp_size, + pp_size=args.infra.pp_size, ep_size=args.infra.ep_size, + sequence_parallel=args.infra.sequence_parallel, +) # use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - -logger = get_logger() +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) def eval(model): - # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-35B-A3B') - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + # Eval samples + eval_samples = args.training.eval_samples or 100 + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(eval_samples))) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) + dataset.map(SelfCognitionProcessor( + args.extra.get('model_name', 'twinkle大模型'), + args.extra.get('model_author', 'ModelScope社区'), + )) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=16) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) for step, batch in tqdm(enumerate(dataloader)): model.forward_only(inputs=batch) metrics = model.calculate_metric(is_training=False) @@ -31,44 +40,49 @@ def eval(model): def train(): - # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) + # Training samples + train_samples = args.training.train_samples or 1000 + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(train_samples))) # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-35B-A3B') + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.map(SelfCognitionProcessor( + args.extra.get('model_name', 'twinkle大模型'), + args.extra.get('model_author', 'ModelScope社区'), + )) # Encode dataset dataset.encode() - # Global batch size = 1, dp_size = 1 - dataloader = DataLoader(dataset=dataset, batch_size=16) + # Global batch size + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) # Use a MegatronModel - model = MegatronModel(model_id='ms://Qwen/Qwen3.5-35B-A3B') + model = MegatronModel(model_id=args.model.model_id) - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + lora_config = LoraConfig(**args.get_lora_args()) - # Add a lora to model, with name `default` + # Add a lora to model, with name from args # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config) - # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='default', lr=1e-4) - # Add LRScheduler for lora `default` - model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) + model.add_adapter_to_model(args.lora.adapter_name, lora_config) + # Add Optimizer + model.set_optimizer(optimizer_cls='default', lr=args.optimizer.learning_rate) + # Add LRScheduler + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=args.scheduler.num_warmup_steps, + lr_decay_steps=len(dataloader)) logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') loss_metric = 99.0 - # lora: 23G * 8 + eval_interval = args.training.eval_interval or 20 for step, batch in enumerate(dataloader): # Do forward and backward model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % 5 == 0: + if step % args.training.log_interval == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 20 == 0: + if step > 0 and step % eval_interval == 0: metrics = eval(model) logger.info(f'Eval metric: {metrics}') metrics['step'] = step diff --git a/cookbook/megatron/tp_moe.sh b/cookbook/megatron/tp_moe.sh index 58e586464..7f6a2d06b 100644 --- a/cookbook/megatron/tp_moe.sh +++ b/cookbook/megatron/tp_moe.sh @@ -1 +1,25 @@ -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp_moe.py +#!/usr/bin/env bash +set -euo pipefail + +# Megatron TP + MoE + LoRA training. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash tp_moe.sh --model-id ms://Qwen/Qwen3.5-30B-A3B --tp-size 4 + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --nproc_per_node=8 tp_moe.py \ + --model-id ms://Qwen/Qwen3.5-30B-A3B \ + --dataset-id ms://swift/self-cognition \ + --template-cls Qwen3_5Template \ + --dp-size 2 \ + --tp-size 2 \ + --pp-size 2 \ + --ep-size 2 \ + --sequence-parallel \ + --batch-size 8 \ + --lr 1e-4 \ + --train-samples 1000 \ + --log-interval 10 \ + --eval-interval 20 \ + --model-name twinkle大模型 \ + --model-author ModelScope社区 \ + "$@" diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py index 4dc508506..2edaf54e9 100644 --- a/cookbook/mm/fsdp2.py +++ b/cookbook/mm/fsdp2.py @@ -3,18 +3,20 @@ import twinkle from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.data_format import Trajectory, Message from twinkle.dataloader import DataLoader from twinkle.dataset import LazyDataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import Preprocessor -# Construct a device_mesh, fsdp=2 -device_mesh = DeviceMesh.from_sizes(fsdp_size=2) -# use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - logger = get_logger() +args = CLI.from_args() + +# Construct a device_mesh +device_mesh = DeviceMesh.from_sizes(fsdp_size=args.infra.fsdp_size, dp_size=args.infra.dp_size) +# use torchrun mode +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) class LatexOCRProcessor(Preprocessor): @@ -35,12 +37,13 @@ def preprocess(self, row) -> Trajectory: def eval(model): - # 100 Samples - dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') + # Eval samples + eval_samples = args.training.eval_samples or 100 + dataset = LazyDataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(eval_samples))) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) dataset.map(LatexOCRProcessor) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=8) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) for step, batch in tqdm(enumerate(dataloader)): model.forward_only(inputs=batch) model.calculate_loss() @@ -49,54 +52,56 @@ def eval(model): def train(): - # 2000 samples - dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000))) + # Training samples + train_samples = args.training.train_samples or 2000 + dataset = LazyDataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(train_samples))) # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B', max_length=1024) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id, max_length=args.template.max_length) # Preprocess the dataset to standard format dataset.map(LatexOCRProcessor) # Encode dataset dataset.encode() - # Global batch size = 4, for GPUs, so 2 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=4) + # Global batch size + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) # Use a TransformersModel - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', model_cls=Qwen3_5ForConditionalGeneration) + model = TransformersModel(model_id=args.model.model_id, model_cls=args.model.model_cls) model.model._no_split_modules = {'Qwen3_5DecoderLayer'} - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + lora_config = LoraConfig(**args.get_lora_args()) - # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) - # Add Optimizer for lora `default` - model.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) - # Add LRScheduler for lora `default` + # Add a lora to model + model.add_adapter_to_model(args.lora.adapter_name, lora_config, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) + # Add Optimizer + model.set_template(args.template.template_cls, model_id=args.model.model_id) + model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) + # Add LRScheduler model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + scheduler_cls=args.scheduler.scheduler_cls, num_warmup_steps=args.scheduler.num_warmup_steps, + num_training_steps=len(dataloader)) logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') loss_metric = 99.0 + eval_interval = args.training.eval_interval or 200 for step, batch in enumerate(dataloader): # Do forward and backward model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % 20 == 0: + if step % args.training.log_interval == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 200 == 0: + if step > 0 and step % eval_interval == 0: metrics = eval(model) logger.info(f'Eval metric: {metrics}') metrics['step'] = step if loss_metric > float(metrics['loss']): model.save(f'checkpoint-{step}') loss_metric = float(metrics['loss']) - model.save(f'last-checkpoint') + model.save('last-checkpoint') if __name__ == '__main__': diff --git a/cookbook/mm/fsdp2.sh b/cookbook/mm/fsdp2.sh index 46e9f27f6..2e0bed3d5 100644 --- a/cookbook/mm/fsdp2.sh +++ b/cookbook/mm/fsdp2.sh @@ -1 +1,21 @@ -CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 fsdp2.py +#!/usr/bin/env bash +set -euo pipefail + +# Multi-modal FSDP2 + LoRA training (LaTeX OCR). +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash fsdp2.sh --model-id ms://Qwen/Qwen2.5-VL-3B-Instruct --batch-size 4 + +CUDA_VISIBLE_DEVICES=0,1 \ + torchrun --nproc_per_node=2 fsdp2.py \ + --model-id ms://Qwen/Qwen2.5-VL-3B-Instruct \ + --dataset-id ms://AI-ModelScope/LaTeX_OCR \ + --template-cls Qwen2_5VLTemplate \ + --dp-size 2 \ + --batch-size 2 \ + --lr 1e-4 \ + --gradient-accumulation-steps 4 \ + --train-samples 2000 \ + --eval-samples 100 \ + --eval-interval 200 \ + --log-interval 10 \ + "$@" diff --git a/cookbook/mm/fsdp2_gemma4_12b_mm.py b/cookbook/mm/fsdp2_gemma4_12b_mm.py index c21932b33..62e26d776 100644 --- a/cookbook/mm/fsdp2_gemma4_12b_mm.py +++ b/cookbook/mm/fsdp2_gemma4_12b_mm.py @@ -1,4 +1,3 @@ -import os from peft import LoraConfig from tqdm import tqdm from transformers import AutoConfig @@ -8,35 +7,29 @@ import twinkle from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel -# from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor logger = get_logger() +args = CLI.from_args() ########## Construct a device_mesh ########## device_mesh = DeviceMesh.from_sizes( - # fsdp_size=2, - # dp_size=1, - # ep_size=2, + fsdp_size=args.infra.fsdp_size, + dp_size=args.infra.dp_size, + ep_size=args.infra.ep_size, device_type=Platform.get_platform().device_prefix(), ) # use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) ########## hyperparameters ########## -IGNORE_MISMATCHED_SIZES = True -# MODEL_PATH = 'ms://google/gemma-4-26b-a4b' -MODEL_PATH = 'ms://google/gemma-4-12b' -DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR' -TRAIN_LEN = 2000 -BATCH_SIZE = 4 -METRIC_STEP = 10 -SAVE_STEP = 10 +IGNORE_MISMATCHED_SIZES = args.extra.get('ignore_mismatched_sizes', True) ### reduce model layers for debug -TEXT_NUM_LAYERS = 8 # gemma-4-12b text_config.num_hidden_layers=48 +TEXT_NUM_LAYERS = args.extra.get('text_num_layers', None) from twinkle.preprocessor import Preprocessor from twinkle.data_format import Message, Trajectory @@ -79,24 +72,21 @@ def train(): 'messages': List(sub_msg_feat) }) ### prepare dataset and dataloader - dataset = Dataset(features=writer_features, dataset_meta=DatasetMeta(DATASET_PATH, subset_name='default', data_slice=range(TRAIN_LEN))) + train_samples = args.training.train_samples or 2000 + dataset = Dataset(features=writer_features, dataset_meta=DatasetMeta( + args.dataset.dataset_id, subset_name=args.dataset.subset_name, data_slice=range(train_samples))) # Set template to prepare encoding - dataset.set_template('Template', model_id=MODEL_PATH) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) # Preprocess the dataset to standard format - # dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.map(preprocess_func=LatexOCRProcessor) # Encode dataset dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) config, kwargs = AutoConfig.from_pretrained( - MODEL_PATH, + args.model.model_id, trust_remote_code=True, return_unused_kwargs=True, - # code_revision=code_revision, - # _commit_hash=commit_hash, - # **hub_kwargs, - # **kwargs, ) if isinstance(config, Gemma4UnifiedConfig): # 减层 @@ -111,10 +101,10 @@ def train(): from transformers import AutoModelForMultimodalLM model = TransformersModel( model_cls=AutoModelForMultimodalLM, - model_id=MODEL_PATH, + model_id=args.model.model_id, config=config, device_mesh=device_mesh, - strategy='accelerate', # native_fsdp、 accelerate + strategy=args.model.strategy, ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, fsdp_config={ 'reshard_after_forward': True, @@ -126,46 +116,46 @@ def train(): }, ) - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + lora_config = LoraConfig(**args.get_lora_args()) - # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) - # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + # Add a lora to model + model.add_adapter_to_model(args.lora.adapter_name, lora_config, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) + # Add Optimizer + model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) - # Add LRScheduler for lora `default` + # Add LRScheduler model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + scheduler_cls=args.scheduler.scheduler_cls, num_warmup_steps=args.scheduler.num_warmup_steps, + num_training_steps=len(dataloader)) logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') best_eval_loss = float('inf') - # lora: 8G * 8 - # full: 18G * 8 ### eval dataset and dataloader - EVAL_LENGTH = 100 - eval_dataset = Dataset(features=writer_features, dataset_meta=DatasetMeta(DATASET_PATH, subset_name='default', data_slice=range(EVAL_LENGTH))) - eval_dataset.set_template('Template', model_id=MODEL_PATH) - # eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + eval_samples = args.training.eval_samples or 100 + eval_dataset = Dataset(features=writer_features, dataset_meta=DatasetMeta( + args.dataset.dataset_id, subset_name=args.dataset.subset_name, data_slice=range(eval_samples))) + eval_dataset.set_template(args.template.template_cls, model_id=args.model.model_id) eval_dataset.map(preprocess_func=LatexOCRProcessor) eval_dataset.encode() - eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8) + eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=args.training.batch_size) + save_step = args.training.save_steps for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)): # Do forward and backward model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % METRIC_STEP == 0: + if step % args.training.log_interval == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}') - if step % SAVE_STEP == 0: + if step % save_step == 0: metrics = evaluate(model, eval_dataloader) metrics['step'] = step if float(metrics['loss']) < best_eval_loss: diff --git a/cookbook/mm/fsdp2_gemma4_mm.py b/cookbook/mm/fsdp2_gemma4_mm.py index 778051874..5c756cfbe 100644 --- a/cookbook/mm/fsdp2_gemma4_mm.py +++ b/cookbook/mm/fsdp2_gemma4_mm.py @@ -1,4 +1,3 @@ -import os from peft import LoraConfig from tqdm import tqdm from transformers import AutoConfig @@ -8,35 +7,30 @@ import twinkle from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel -# from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor logger = get_logger() +args = CLI.from_args() ########## Construct a device_mesh ########## device_mesh = DeviceMesh.from_sizes( - # fsdp_size=2, - # dp_size=1, - # ep_size=2, + fsdp_size=args.infra.fsdp_size, + dp_size=args.infra.dp_size, + ep_size=args.infra.ep_size, device_type=Platform.get_platform().device_prefix(), ) # use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) ########## hyperparameters ########## -IGNORE_MISMATCHED_SIZES = True -MODEL_PATH = 'ms://google/gemma-4-26b-a4b' -DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR' -TRAIN_LEN = 2000 -BATCH_SIZE = 4 -METRIC_STEP = 10 -SAVE_STEP = 10 +IGNORE_MISMATCHED_SIZES = args.extra.get('ignore_mismatched_sizes', True) ### reduce model layers for debug -TEXT_NUM_LAYERS = 3 -VISION_NUM_LAYERS = 3 +TEXT_NUM_LAYERS = args.extra.get('text_num_layers', None) +VISION_NUM_LAYERS = args.extra.get('vision_num_layers', None) from twinkle.preprocessor import Preprocessor @@ -67,24 +61,20 @@ def eval(model, eval_dataloader): def train(): ### prepare dataset and dataloader - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_LEN))) + train_samples = args.training.train_samples or 2000 + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(train_samples))) # Set template to prepare encoding - dataset.set_template('Template', model_id=MODEL_PATH) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) # Preprocess the dataset to standard format - # dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.map(preprocess_func=LatexOCRProcessor) # Encode dataset dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) config, kwargs = AutoConfig.from_pretrained( - MODEL_PATH, + args.model.model_id, trust_remote_code=True, return_unused_kwargs=True, - # code_revision=code_revision, - # _commit_hash=commit_hash, - # **hub_kwargs, - # **kwargs, ) if isinstance(config, Gemma4Config): # 减层 @@ -101,10 +91,10 @@ def train(): # Use a TransformersModel model = TransformersModel( - model_id=MODEL_PATH, + model_id=args.model.model_id, config=config, device_mesh=device_mesh, - strategy='accelerate', # native_fsdp、 accelerate + strategy=args.model.strategy, ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, fsdp_config={ 'reshard_after_forward': True, @@ -116,46 +106,45 @@ def train(): }, ) - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + lora_config = LoraConfig(**args.get_lora_args()) - # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) - # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + # Add a lora to model + model.add_adapter_to_model(args.lora.adapter_name, lora_config, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) + # Add Optimizer + model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) - # Add LRScheduler for lora `default` + # Add LRScheduler model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + scheduler_cls=args.scheduler.scheduler_cls, num_warmup_steps=args.scheduler.num_warmup_steps, + num_training_steps=len(dataloader)) logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') best_eval_loss = float('inf') - # lora: 8G * 8 - # full: 18G * 8 ### eval dataset and dataloader - EVAL_LENGTH = 100 - eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH))) - eval_dataset.set_template('Template', model_id=MODEL_PATH) - # eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + eval_samples = args.training.eval_samples or 100 + eval_dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(eval_samples))) + eval_dataset.set_template(args.template.template_cls, model_id=args.model.model_id) eval_dataset.map(preprocess_func=LatexOCRProcessor) eval_dataset.encode() - eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8) + eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=args.training.batch_size) + save_step = args.training.save_steps for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)): # Do forward and backward model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % METRIC_STEP == 0: + if step % args.training.log_interval == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}') - if step % SAVE_STEP == 0: + if step % save_step == 0: metrics = eval(model, eval_dataloader) metrics['step'] = step if float(metrics['loss']) < best_eval_loss: diff --git a/cookbook/mm/fsdp2_gemma4_mm.sh b/cookbook/mm/fsdp2_gemma4_mm.sh index c67113d8f..82d9ef1d5 100644 --- a/cookbook/mm/fsdp2_gemma4_mm.sh +++ b/cookbook/mm/fsdp2_gemma4_mm.sh @@ -1,3 +1,21 @@ -export CUDA_VISIBLE_DEVICES=0,1 +#!/usr/bin/env bash +set -euo pipefail -torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py +# Multi-modal FSDP2 + LoRA training for Gemma4 (LaTeX OCR). +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash fsdp2_gemma4_mm.sh --model-id ms://google/gemma-4-4b-it --batch-size 4 + +CUDA_VISIBLE_DEVICES=0,1 \ + torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py \ + --model-id ms://google/gemma-4-12b-it \ + --dataset-id ms://AI-ModelScope/LaTeX_OCR \ + --template-cls Gemma4Template \ + --dp-size 2 \ + --batch-size 2 \ + --lr 1e-4 \ + --gradient-accumulation-steps 4 \ + --train-samples 2000 \ + --eval-samples 100 \ + --log-interval 10 \ + --save-steps 200 \ + "$@" diff --git a/cookbook/rl/dpo_full.py b/cookbook/rl/dpo/dpo_full.py similarity index 92% rename from cookbook/rl/dpo_full.py rename to cookbook/rl/dpo/dpo_full.py index 8610b986f..afb3f6155 100644 --- a/cookbook/rl/dpo_full.py +++ b/cookbook/rl/dpo/dpo_full.py @@ -49,6 +49,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.data_format import Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -58,25 +59,26 @@ from twinkle.processor import InputProcessor logger = get_logger() +args = CLI.from_args() # ── Configuration ───────────────────────────────────────────────────────────── -USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0)) -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') +USE_MEGATRON = args.model.strategy != 'native_fsdp' +MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3-4B' +DATASET_ID = args.dataset.dataset_id or 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) +MODEL_GPUS = args.infra.model_gpus or 4 +REF_MODEL_GPUS = args.infra.ref_model_gpus or 4 NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) -LEARNING_RATE = float(os.environ.get('LR', 1e-5)) -DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) -SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization -LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) -MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) -SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') +BATCH_SIZE = args.training.batch_size or 8 +GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 2 +LEARNING_RATE = args.optimizer.learning_rate or 1e-5 +DPO_BETA = args.loss.beta +SFT_WEIGHT = args.loss.sft_weight +LOSS_TYPE = args.loss.loss_type +SAVE_STEPS = args.training.save_steps or 100 +MAX_LENGTH = args.template.max_length +SYSTEM_PROMPT = args.template.default_system or 'You are a helpful assistant.' def create_dpo_dataset(): diff --git a/cookbook/rl/dpo/dpo_full.sh b/cookbook/rl/dpo/dpo_full.sh new file mode 100644 index 000000000..cffba898b --- /dev/null +++ b/cookbook/rl/dpo/dpo_full.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +# DPO Full-Parameter Training via Ray. +# Uses separate policy and reference model GPU groups. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash dpo_full.sh --model-id ms://Qwen/Qwen3-8B --beta 0.05 + +python dpo_full.py \ + --model-id ms://Qwen/Qwen3-4B \ + --dataset-id ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ + --model-gpus 4 \ + --ref-model-gpus 4 \ + --batch-size 8 \ + --gradient-accumulation-steps 2 \ + --lr 1e-5 \ + --beta 0.1 \ + --sft-weight 1.0 \ + --loss-type sigmoid \ + --max-length 2048 \ + --save-steps 100 \ + "$@" diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo/dpo_lora.py similarity index 91% rename from cookbook/rl/dpo_lora.py rename to cookbook/rl/dpo/dpo_lora.py index c7ec3147c..868de1521 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo/dpo_lora.py @@ -48,6 +48,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.data_format import Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -57,24 +58,25 @@ from twinkle.processor import InputProcessor logger = get_logger() +args = CLI.from_args() # ── Configuration ───────────────────────────────────────────────────────────── -USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0)) -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') - -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) - -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) -LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4) -DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) -SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization -LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) -MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) -ADAPTER_NAME = 'default' -SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') +USE_MEGATRON = args.model.strategy != 'native_fsdp' +MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3-4B' +DATASET_ID = args.dataset.dataset_id or 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' + +MODEL_GPUS = args.infra.model_gpus or 8 + +BATCH_SIZE = args.training.batch_size or 8 +GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 2 +LEARNING_RATE = args.optimizer.learning_rate or 1e-4 +DPO_BETA = args.loss.beta +SFT_WEIGHT = args.loss.sft_weight +LOSS_TYPE = args.loss.loss_type +SAVE_STEPS = args.training.save_steps or 100 +MAX_LENGTH = args.template.max_length +ADAPTER_NAME = args.lora.adapter_name or 'default' +SYSTEM_PROMPT = args.template.default_system or 'You are a helpful assistant.' def create_dpo_dataset(): diff --git a/cookbook/rl/dpo/dpo_lora.sh b/cookbook/rl/dpo/dpo_lora.sh new file mode 100644 index 000000000..7af42b6dc --- /dev/null +++ b/cookbook/rl/dpo/dpo_lora.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +# DPO LoRA Training via Ray (single GPU group). +# Uses base model (disable_lora=True) as reference model. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash dpo_lora.sh --model-id ms://Qwen/Qwen3-8B --lr 5e-5 + +python dpo_lora.py \ + --model-id ms://Qwen/Qwen3-4B \ + --dataset-id ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ + --model-gpus 8 \ + --batch-size 8 \ + --gradient-accumulation-steps 2 \ + --lr 1e-4 \ + --beta 0.1 \ + --sft-weight 1.0 \ + --loss-type sigmoid \ + --max-length 2048 \ + --save-steps 100 \ + --adapter-name default \ + "$@" diff --git a/cookbook/rl/dpo_multi_lora.py b/cookbook/rl/dpo/dpo_multi_lora.py similarity index 91% rename from cookbook/rl/dpo_multi_lora.py rename to cookbook/rl/dpo/dpo_multi_lora.py index 7c09bf61f..0a43322fc 100644 --- a/cookbook/rl/dpo_multi_lora.py +++ b/cookbook/rl/dpo/dpo_multi_lora.py @@ -48,6 +48,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.data_format import Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -57,23 +58,24 @@ from twinkle.processor import InputProcessor logger = get_logger() +args = CLI.from_args() # ── Configuration ───────────────────────────────────────────────────────────── -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') - -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2)) - -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) -LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4) -DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) -SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization -LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) -MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) -ADAPTER_NAME = 'default_0' -SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') +MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = args.dataset.dataset_id or 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' + +MODEL_GPUS = args.infra.model_gpus or 2 + +BATCH_SIZE = args.training.batch_size or 8 +GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 2 +LEARNING_RATE = args.optimizer.learning_rate or 1e-4 +DPO_BETA = args.loss.beta +SFT_WEIGHT = args.loss.sft_weight +LOSS_TYPE = args.loss.loss_type +SAVE_STEPS = args.training.save_steps or 100 +MAX_LENGTH = args.template.max_length +ADAPTER_NAME = args.lora.adapter_name or 'default_0' +SYSTEM_PROMPT = args.template.default_system or 'You are a helpful assistant.' def create_dpo_dataset(): diff --git a/cookbook/rl/dpo/dpo_multi_lora.sh b/cookbook/rl/dpo/dpo_multi_lora.sh new file mode 100644 index 000000000..0652b95f2 --- /dev/null +++ b/cookbook/rl/dpo/dpo_multi_lora.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +# DPO MultiLoRA Training via Ray (Megatron backend). +# Uses base model (disable_lora=True) as reference model. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash dpo_multi_lora.sh --model-id ms://Qwen/Qwen3.5-4B --lr 5e-5 + +python dpo_multi_lora.py \ + --model-id ms://Qwen/Qwen3.5-4B \ + --dataset-id ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ + --model-gpus 2 \ + --batch-size 8 \ + --gradient-accumulation-steps 2 \ + --lr 1e-4 \ + --beta 0.1 \ + --sft-weight 1.0 \ + --loss-type sigmoid \ + --max-length 2048 \ + --save-steps 100 \ + --adapter-name default_0 \ + "$@" diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd/gkd_off_policy.py similarity index 94% rename from cookbook/rl/gkd_off_policy.py rename to cookbook/rl/gkd/gkd_off_policy.py index 204e90f92..bdf992463 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd/gkd_off_policy.py @@ -45,6 +45,7 @@ import twinkle from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -55,23 +56,24 @@ from twinkle.template import Template logger = get_logger() +args = CLI.from_args() # ── Configuration ───────────────────────────────────────────────────────────── -STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') -TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') +STUDENT_MODEL_ID = args.rl.student_model_id or 'ms://Qwen/Qwen3-0.6B' +TEACHER_MODEL_ID = args.rl.teacher_model_id or 'ms://Qwen/Qwen3-8B' -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +MODEL_GPUS = args.infra.model_gpus or 4 +SAMPLER_GPUS = args.infra.sampler_gpus or 4 NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -LEARNING_RATE = float(os.environ.get('LR', 5e-5)) +BATCH_SIZE = args.training.batch_size or 16 +MAX_STEPS = args.training.max_steps or 1000 +LEARNING_RATE = args.optimizer.learning_rate or 5e-5 -GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) -GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) -GKD_TOPK = int(os.environ.get('GKD_TOPK', 64)) -ADAPTER_NAME = 'default' +GKD_BETA = args.rl.gkd_beta +GKD_TEMPERATURE = args.rl.gkd_temperature +GKD_TOPK = args.rl.gkd_topk +ADAPTER_NAME = args.lora.adapter_name or 'default' SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem step by step and put ' 'your final answer within #### ') diff --git a/cookbook/rl/gkd/gkd_off_policy.sh b/cookbook/rl/gkd/gkd_off_policy.sh new file mode 100644 index 000000000..262542062 --- /dev/null +++ b/cookbook/rl/gkd/gkd_off_policy.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +# GKD Off-Policy Distillation via Ray. +# Teacher vLLM computes prompt logprobs on existing dataset responses. +# Student Megatron model learns to match teacher's token distribution. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash gkd_off_policy.sh --student-model-id ms://Qwen/Qwen3-1.7B --gkd-beta 0.3 + +python gkd_off_policy.py \ + --student-model-id ms://Qwen/Qwen3-0.6B \ + --teacher-model-id ms://Qwen/Qwen3-8B \ + --model-gpus 4 \ + --sampler-gpus 4 \ + --batch-size 16 \ + --max-steps 1000 \ + --lr 5e-5 \ + --gkd-beta 0.5 \ + --gkd-temperature 1.0 \ + --gkd-topk 64 \ + --adapter-name default \ + "$@" diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd/gkd_on_policy.py similarity index 94% rename from cookbook/rl/gkd_on_policy.py rename to cookbook/rl/gkd/gkd_on_policy.py index 2675d0358..1ddc9d89e 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd/gkd_on_policy.py @@ -51,6 +51,7 @@ import twinkle from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.cli import CLI from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.dataset import DatasetMeta, LazyDataset @@ -60,26 +61,27 @@ from twinkle.sampler import vLLMSampler logger = get_logger() +args = CLI.from_args() # ── Configuration ───────────────────────────────────────────────────────────── -STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3.5-4B') -TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3.5-9B') -USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) +STUDENT_MODEL_ID = args.rl.student_model_id or 'ms://Qwen/Qwen3.5-4B' +TEACHER_MODEL_ID = args.rl.teacher_model_id or 'ms://Qwen/Qwen3.5-9B' +USE_MEGATRON = args.model.strategy != 'native_fsdp' -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) +MODEL_GPUS = args.infra.model_gpus or 4 +SAMPLER_GPUS = args.infra.sampler_gpus or 2 NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -LEARNING_RATE = float(os.environ.get('LR', 5e-5)) -N_SAMPLES = int(os.environ.get('N_SAMPLES', 1)) +MAX_NEW_TOKENS = args.sampling.max_tokens or 2048 +BATCH_SIZE = args.training.batch_size or 4 +MAX_STEPS = args.training.max_steps or 1000 +LEARNING_RATE = args.optimizer.learning_rate or 5e-5 +N_SAMPLES = args.sampling.num_samples -GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) -GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) -GKD_TOPK = int(os.environ.get('GKD_TOPK', 64)) -ADAPTER_NAME = 'default' +GKD_BETA = args.rl.gkd_beta +GKD_TEMPERATURE = args.rl.gkd_temperature +GKD_TOPK = args.rl.gkd_topk +ADAPTER_NAME = args.lora.adapter_name or 'default' # OlympiadBench subsets SUBSETS = [ diff --git a/cookbook/rl/gkd/gkd_on_policy.sh b/cookbook/rl/gkd/gkd_on_policy.sh new file mode 100644 index 000000000..ed37e5fb5 --- /dev/null +++ b/cookbook/rl/gkd/gkd_on_policy.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +set -euo pipefail + +# GKD On-Policy Multimodal Distillation via Ray. +# Student generates on-policy, teacher provides top-k prompt logprobs, +# student trains to match teacher's distribution. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash gkd_on_policy.sh --student-model-id ms://Qwen/Qwen3.5-4B --max-steps 500 + +python gkd_on_policy.py \ + --student-model-id ms://Qwen/Qwen3.5-4B \ + --teacher-model-id ms://Qwen/Qwen3.5-9B \ + --model-gpus 4 \ + --sampler-gpus 2 \ + --batch-size 4 \ + --max-steps 1000 \ + --max-tokens 2048 \ + --lr 5e-5 \ + --num-samples 1 \ + --gkd-beta 0.5 \ + --gkd-temperature 1.0 \ + --gkd-topk 64 \ + --adapter-name default \ + "$@" diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo/grpo.py similarity index 88% rename from cookbook/rl/grpo.py rename to cookbook/rl/grpo/grpo.py index dd16e7f07..e50402383 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo/grpo.py @@ -7,6 +7,7 @@ from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger from twinkle.advantage import GRPOAdvantage from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.cli import CLI from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -19,24 +20,25 @@ from twinkle.preprocessor.llm import GSM8KProcessor logger = get_logger() +args = CLI.from_args() -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') -USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0'))) +MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B' +USE_MEGATRON = args.model.strategy != 'native_fsdp' -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS',4)) +MODEL_GPUS = args.infra.model_gpus or 4 +SAMPLER_GPUS = args.infra.sampler_gpus or 4 NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS -NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 1e-5)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size -MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) # global completion-level mini-batch-size -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) -ADAPTER_NAME = 'default' -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50)) +NUM_GENERATIONS = args.rl.num_generations or 8 +MAX_NEW_TOKENS = args.sampling.max_tokens or 4096 +LEARNING_RATE = args.optimizer.learning_rate or 1e-5 +MAX_STEPS = args.training.max_steps or 200 +BATCH_SIZE = args.training.batch_size or 8 +MINI_BATCH_SIZE = args.training.mini_batch_size or 8 +MICRO_BATCH_SIZE = args.training.micro_batch_size or 2 +GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1 +ADAPTER_NAME = args.lora.adapter_name or 'default' +SAVE_STEPS = args.training.save_steps or 50 def create_gsm8k_dataset(): dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) diff --git a/cookbook/rl/grpo/grpo.sh b/cookbook/rl/grpo/grpo.sh new file mode 100644 index 000000000..b6feb02fa --- /dev/null +++ b/cookbook/rl/grpo/grpo.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +# GRPO training on GSM8K via Ray. +# Model + vLLM sampler on separate GPU groups. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash grpo.sh --model-id ms://Qwen/Qwen3.5-4B --max-steps 500 + +python grpo.py \ + --model-id ms://Qwen/Qwen3.5-4B \ + --model-gpus 4 \ + --sampler-gpus 4 \ + --num-generations 8 \ + --max-tokens 4096 \ + --batch-size 8 \ + --mini-batch-size 8 \ + --micro-batch-size 2 \ + --max-steps 200 \ + --lr 1e-5 \ + --save-steps 50 \ + --adapter-name default \ + "$@" diff --git a/cookbook/rl/grpo_mm.py b/cookbook/rl/grpo/grpo_mm.py similarity index 92% rename from cookbook/rl/grpo_mm.py rename to cookbook/rl/grpo/grpo_mm.py index 1f89c7a91..f7de43ca5 100644 --- a/cookbook/rl/grpo_mm.py +++ b/cookbook/rl/grpo/grpo_mm.py @@ -14,6 +14,7 @@ from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger from twinkle.advantage import GRPOAdvantage from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.cli import CLI from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.dataset import DatasetMeta, LazyDataset @@ -28,27 +29,28 @@ from twinkle.sampler import vLLMSampler logger = get_logger() +args = CLI.from_args() # Model configuration -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') -USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) +MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B' +USE_MEGATRON = args.model.strategy != 'native_fsdp' # GPU configuration -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +MODEL_GPUS = args.infra.model_gpus or 4 +SAMPLER_GPUS = args.infra.sampler_gpus or 4 NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS # Training hyperparameters -NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 1e-5)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) -MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) -ADAPTER_NAME = 'default' -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50)) +NUM_GENERATIONS = args.rl.num_generations or 8 +MAX_NEW_TOKENS = args.sampling.max_tokens or 4096 +LEARNING_RATE = args.optimizer.learning_rate or 1e-5 +MAX_STEPS = args.training.max_steps or 1000 +BATCH_SIZE = args.training.batch_size or 4 +MINI_BATCH_SIZE = args.training.mini_batch_size or 4 +MICRO_BATCH_SIZE = args.training.micro_batch_size or 1 +GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1 +ADAPTER_NAME = args.lora.adapter_name or 'default' +SAVE_STEPS = args.training.save_steps or 50 # Dataset configuration SUBSETS = [ diff --git a/cookbook/rl/grpo/grpo_mm.sh b/cookbook/rl/grpo/grpo_mm.sh new file mode 100644 index 000000000..b5ca2fda3 --- /dev/null +++ b/cookbook/rl/grpo/grpo_mm.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +# GRPO Multimodal training on OlympiadBench via Ray. +# Supports multimodal math/physics problems (Chinese CEE). +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash grpo_mm.sh --model-id ms://Qwen/Qwen3.5-4B --max-steps 500 + +python grpo_mm.py \ + --model-id ms://Qwen/Qwen3.5-4B \ + --model-gpus 4 \ + --sampler-gpus 4 \ + --num-generations 8 \ + --max-tokens 4096 \ + --batch-size 4 \ + --mini-batch-size 4 \ + --micro-batch-size 1 \ + --max-steps 1000 \ + --lr 1e-5 \ + --save-steps 50 \ + --adapter-name default \ + "$@" diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/grpo/short_math_grpo.py similarity index 91% rename from cookbook/rl/short_math_grpo.py rename to cookbook/rl/grpo/short_math_grpo.py index 5e107b0ae..91fcd7669 100644 --- a/cookbook/rl/short_math_grpo.py +++ b/cookbook/rl/grpo/short_math_grpo.py @@ -14,6 +14,7 @@ from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger from twinkle.advantage import GRPOAdvantage from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.cli import CLI from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -26,26 +27,27 @@ from twinkle.preprocessor.llm import GSM8KProcessor logger = get_logger() +args = CLI.from_args() # ========== Configuration ========== -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') -USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) +MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B' +USE_MEGATRON = args.model.strategy != 'native_fsdp' -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +MODEL_GPUS = args.infra.model_gpus or 4 +SAMPLER_GPUS = args.infra.sampler_gpus or 4 NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS -NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 1e-5)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) -MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) -ADAPTER_NAME = 'default' -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) -LORA_RANK = int(os.environ.get('LORA_RANK', 16)) +NUM_GENERATIONS = args.rl.num_generations or 8 +MAX_NEW_TOKENS = args.sampling.max_tokens or 4096 +LEARNING_RATE = args.optimizer.learning_rate or 1e-5 +MAX_STEPS = args.training.max_steps or 1000 +BATCH_SIZE = args.training.batch_size or 8 +MINI_BATCH_SIZE = args.training.mini_batch_size or 8 +MICRO_BATCH_SIZE = args.training.micro_batch_size or 2 +GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1 +ADAPTER_NAME = args.lora.adapter_name or 'default' +SAVE_STEPS = args.training.save_steps or 1000 +LORA_RANK = args.lora.lora_r or 16 SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' 'and put your final answer within \\boxed{}.') diff --git a/cookbook/rl/grpo/short_math_grpo.sh b/cookbook/rl/grpo/short_math_grpo.sh new file mode 100644 index 000000000..033507dc8 --- /dev/null +++ b/cookbook/rl/grpo/short_math_grpo.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +set -euo pipefail + +# GRPO Short Math Reasoning on GSM8K via Ray. +# Uses short reasoning format: shorter thinking gets higher brevity reward. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash short_math_grpo.sh --model-id ms://Qwen/Qwen3.5-4B --max-steps 500 + +python short_math_grpo.py \ + --model-id ms://Qwen/Qwen3.5-4B \ + --model-gpus 4 \ + --sampler-gpus 4 \ + --num-generations 8 \ + --max-tokens 4096 \ + --batch-size 8 \ + --mini-batch-size 8 \ + --micro-batch-size 2 \ + --max-steps 1000 \ + --lr 1e-5 \ + --lora-r 16 \ + --save-steps 1000 \ + --adapter-name default \ + "$@" diff --git a/cookbook/rl/short_math_grpo_moe.py b/cookbook/rl/grpo/short_math_grpo_moe.py similarity index 90% rename from cookbook/rl/short_math_grpo_moe.py rename to cookbook/rl/grpo/short_math_grpo_moe.py index 9d870eacb..f19747282 100644 --- a/cookbook/rl/short_math_grpo_moe.py +++ b/cookbook/rl/grpo/short_math_grpo_moe.py @@ -14,6 +14,7 @@ from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger from twinkle.advantage import GRPOAdvantage from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.cli import CLI from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -26,31 +27,32 @@ from twinkle.preprocessor.llm import GSM8KProcessor logger = get_logger() +args = CLI.from_args() # ========== Configuration ========== -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') -USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) +MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.6-35B-A3B' +USE_MEGATRON = args.model.strategy != 'native_fsdp' -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -MODEL_EP = int(os.environ.get('MODEL_EP', 2)) -MODEL_TP = int(os.environ.get('MODEL_TP', 2)) -MODEL_PP = int(os.environ.get('MODEL_PP', 2)) +MODEL_GPUS = args.infra.model_gpus or 4 +MODEL_EP = args.infra.ep_size or 2 +MODEL_TP = args.infra.tp_size or 2 +MODEL_PP = args.infra.pp_size or 2 -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) -SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2)) +SAMPLER_GPUS = args.infra.sampler_gpus or 2 +SAMPLER_TP = args.sampler.tensor_parallel_size or 2 NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS -NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 5e-5)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) -MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) -ADAPTER_NAME = 'default' -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) -LORA_RANK = int(os.environ.get('LORA_RANK', 16)) +NUM_GENERATIONS = args.rl.num_generations or 8 +MAX_NEW_TOKENS = args.sampling.max_tokens or 4096 +LEARNING_RATE = args.optimizer.learning_rate or 5e-5 +MAX_STEPS = args.training.max_steps or 1000 +BATCH_SIZE = args.training.batch_size or 4 +MINI_BATCH_SIZE = args.training.mini_batch_size or 4 +MICRO_BATCH_SIZE = args.training.micro_batch_size or 1 +GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1 +ADAPTER_NAME = args.lora.adapter_name or 'default' +SAVE_STEPS = args.training.save_steps or 1000 +LORA_RANK = args.lora.lora_r or 16 SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' 'and put your final answer within \\boxed{}.') diff --git a/cookbook/rl/grpo/short_math_grpo_moe.sh b/cookbook/rl/grpo/short_math_grpo_moe.sh new file mode 100644 index 000000000..00369610c --- /dev/null +++ b/cookbook/rl/grpo/short_math_grpo_moe.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -euo pipefail + +# GRPO Short Math MoE on GSM8K via Ray. +# Uses Megatron MoE model with TP+EP+PP parallelism. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash short_math_grpo_moe.sh --model-id ms://Qwen/Qwen3.6-35B-A3B --max-steps 500 + +python short_math_grpo_moe.py \ + --model-id ms://Qwen/Qwen3.6-35B-A3B \ + --model-gpus 4 \ + --sampler-gpus 2 \ + --ep-size 2 \ + --tp-size 2 \ + --pp-size 2 \ + --tensor-parallel-size 2 \ + --num-generations 8 \ + --max-tokens 4096 \ + --batch-size 4 \ + --mini-batch-size 4 \ + --micro-batch-size 1 \ + --max-steps 1000 \ + --lr 5e-5 \ + --lora-r 16 \ + --save-steps 1000 \ + --adapter-name default \ + "$@" diff --git a/cookbook/rl/short_math_grpo_multi_lora.py b/cookbook/rl/grpo/short_math_grpo_multi_lora.py similarity index 92% rename from cookbook/rl/short_math_grpo_multi_lora.py rename to cookbook/rl/grpo/short_math_grpo_multi_lora.py index 9dad8df30..cff4bb4b9 100644 --- a/cookbook/rl/short_math_grpo_multi_lora.py +++ b/cookbook/rl/grpo/short_math_grpo_multi_lora.py @@ -21,6 +21,7 @@ import twinkle from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger from twinkle.advantage import GRPOAdvantage +from twinkle.cli import CLI from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -33,28 +34,29 @@ from twinkle.preprocessor.llm import GSM8KProcessor logger = get_logger() +args = CLI.from_args() # ========== Configuration ========== -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') +MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.6-35B-A3B' -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) -SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2)) +MODEL_GPUS = args.infra.model_gpus or 4 +SAMPLER_GPUS = args.infra.sampler_gpus or 2 +SAMPLER_TP = args.sampler.tensor_parallel_size or 2 NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS -NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 5e-5)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) -MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) -ADAPTER_NAME = 'default_0' -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) -LORA_RANK = int(os.environ.get('LORA_RANK', 16)) -LORA_SYNC_DIR = os.environ.get('LORA_SYNC_DIR', 'output/lora_sync') +NUM_GENERATIONS = args.rl.num_generations or 8 +MAX_NEW_TOKENS = args.sampling.max_tokens or 4096 +LEARNING_RATE = args.optimizer.learning_rate or 5e-5 +MAX_STEPS = args.training.max_steps or 1000 +BATCH_SIZE = args.training.batch_size or 4 +MINI_BATCH_SIZE = args.training.mini_batch_size or 4 +MICRO_BATCH_SIZE = args.training.micro_batch_size or 1 +GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1 +ADAPTER_NAME = args.lora.adapter_name or 'default_0' +SAVE_STEPS = args.training.save_steps or 1000 +LORA_RANK = args.lora.lora_r or 16 +LORA_SYNC_DIR = args.checkpoint.lora_sync_dir or 'output/lora_sync' SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' 'and put your final answer within \\boxed{}.') diff --git a/cookbook/rl/grpo/short_math_grpo_multi_lora.sh b/cookbook/rl/grpo/short_math_grpo_multi_lora.sh new file mode 100644 index 000000000..a465250c8 --- /dev/null +++ b/cookbook/rl/grpo/short_math_grpo_multi_lora.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +set -euo pipefail + +# GRPO Short Math MultiLoRA on GSM8K via Ray. +# Uses MultiLoraMegatronModel with filesystem-based LoRA sync to vLLM. +# Model: Qwen3.6-35B-A3B (MoE) with tp=2, ep=2, pp=2. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash short_math_grpo_multi_lora.sh --model-id ms://Qwen/Qwen3.6-35B-A3B --max-steps 500 + +python short_math_grpo_multi_lora.py \ + --model-id ms://Qwen/Qwen3.6-35B-A3B \ + --model-gpus 4 \ + --sampler-gpus 2 \ + --tensor-parallel-size 2 \ + --num-generations 8 \ + --max-tokens 4096 \ + --batch-size 4 \ + --mini-batch-size 4 \ + --micro-batch-size 1 \ + --max-steps 1000 \ + --lr 5e-5 \ + --lora-r 16 \ + --save-steps 1000 \ + --adapter-name default_0 \ + --lora-sync-dir output/lora_sync \ + "$@" diff --git a/cookbook/rl/multi_turn/multi_turn_grpo.py b/cookbook/rl/multi_turn/multi_turn_grpo.py new file mode 100644 index 000000000..6661bfcdc --- /dev/null +++ b/cookbook/rl/multi_turn/multi_turn_grpo.py @@ -0,0 +1,457 @@ +"""Multi-turn GRPO training with EnvPool (integrated environment pool). + +Demonstrates how to train an LLM agent via GRPO on interactive environments +(e.g. Blackjack) using EnvPool and Twinkle's MultiTurnRollout. + +EnvPool is deployed as a @remote_class component — either: + - With remote_group='env': runs on a dedicated CPU DeviceGroup (isolated) + - Without remote_group: runs locally in the driver (zero RPC overhead) + +The agent interacts with environments through tool calls: + 1. EnvPool manages N env instances; each trajectory maps to one slot. + 2. MultiTurnRollout drives the multi-turn loop: model generates tool calls, + EnvTool dispatches them to env.step(), observations are fed back. + 3. Episode reward is extracted after rollout completes. + 4. GRPO advantages are computed across the batch and used for policy update. + +Usage: + # No need to start a separate server — environments are instantiated + # directly inside the EnvPool worker: + # python multi_turn_grpo.py + # + # To run envs on a dedicated CPU worker (isolated): + # ENV_REMOTE=1 python multi_turn_grpo.py + +References: + - OpenEnv GRPO Blackjack: https://github.com/huggingface/OpenEnv/tree/main/examples/grpo_blackjack + - cookbook/rl/grpo/short_math_grpo.py (single-turn GRPO template) +""" +import os +from typing import Any, Dict, List, Tuple + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.cli import CLI +from twinkle.data_format import SamplingParams +from twinkle.metric import CompletionRewardMetric +from twinkle.model import TransformersModel +from twinkle.processor import InputProcessor +from twinkle.sampler import vLLMSampler +from twinkle.template import Qwen3_5Template +from twinkle_agentic.envs import EnvPool, EnvPoolAdapter, EnvTool +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout +from twinkle_agentic.tools.tool_manager import ToolManager + +logger = get_logger() +args = CLI.from_args() + +# ========== Configuration ========== +MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B' +USE_MEGATRON = False + +MODEL_GPUS = args.infra.model_gpus or 4 +SAMPLER_GPUS = args.infra.sampler_gpus or 4 +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +NUM_GENERATIONS = args.rl.num_generations or 8 +MAX_NEW_TOKENS = args.sampling.max_tokens or 2048 +LEARNING_RATE = args.optimizer.learning_rate or 1e-5 +MAX_STEPS = args.training.max_steps or 1000 +BATCH_SIZE = args.training.batch_size or 4 +MINI_BATCH_SIZE = args.training.mini_batch_size or 8 +MICRO_BATCH_SIZE = args.training.micro_batch_size or 2 +GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1 +ADAPTER_NAME = args.lora.adapter_name or 'default' +SAVE_STEPS = args.training.save_steps or 500 +LORA_RANK = args.lora.lora_r or 16 +MAX_TURNS = int(os.environ.get('MAX_TURNS', '6')) + +# Environment configuration +# ENV_CLS: import path to the environment class (no server needed) +ENV_CLS = os.environ.get('ENV_CLS', 'blackjack_env:BlackjackEnv') +# ENV_REMOTE: set to '1' to deploy envs on a dedicated CPU DeviceGroup +ENV_REMOTE = os.environ.get('ENV_REMOTE', '0') == '1' +# Pool size = total trajectories per batch +ENV_POOL_SIZE = int(os.environ.get('ENV_POOL_SIZE', '0')) # 0 = auto + +# ========== Tool Schema (Blackjack example) ========== +# Define tools the model can use in the environment. +# For blackjack: a single "play" tool with hit/stand actions. +# Override TOOL_SCHEMA for different environments. +BLACKJACK_TOOL_SCHEMA = [ + { + 'type': 'function', + 'function': { + 'name': 'play', + 'description': 'Take an action in the blackjack game.', + 'parameters': { + 'type': 'object', + 'properties': { + 'action': { + 'type': 'string', + 'enum': ['hit', 'stand'], + 'description': 'The action to take: "hit" to draw a card, "stand" to keep current hand.', + } + }, + 'required': ['action'], + }, + }, + } +] + +TOOL_SCHEMA = BLACKJACK_TOOL_SCHEMA + +# Action name → OpenSpiel action_id mapping for blackjack. +# OpenSpiel blackjack: 0 = HIT, 1 = STAND +BLACKJACK_ACTION_MAP = {'hit': 0, 'stand': 1} + + +def blackjack_action_mapper(tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Map tool calls to OpenSpielAction format. + + Converts play(action='hit') → {action_id: 0, game_name: 'blackjack'} + """ + action_str = str(arguments.get('action', 'stand')).lower().strip() + action_id = BLACKJACK_ACTION_MAP.get(action_str, 1) # default STAND + return {'action_id': action_id, 'game_name': 'blackjack'} + + +SYSTEM_PROMPT = """You are a skilled blackjack player. You will be told your current hand and the dealer's visible card. + +Your goal is to win the game by getting as close to 21 as possible without going over. + +Strategy guidelines: +- Hit if your hand total is below 12 +- Consider the dealer's visible card when deciding +- Stand if you have 17 or higher +- Be cautious with hard hands (no ace counted as 11) + +Use the `play` tool to take actions. Always reason briefly before acting.""" + + +# ========== Environment Setup ========== +def prepare_trajectories( + env_pool: EnvPool, + n_trajectories: int, + tool_schema: List[Dict], + system_prompt: str, + action_mapper=None, +) -> Tuple[List[Dict[str, Any]], List[ToolManager], List[List[EnvTool]]]: + """Reset environments via EnvPool and build initial trajectories. + + For each trajectory: + 1. Get an EnvPoolAdapter (standard Env interface) from the pool + 2. Reset the env slot to get initial observation + 3. Build a trajectory dict with system + user messages and tools + + Args: + env_pool: The EnvPool instance managing all environments. + n_trajectories: Total number of trajectories to create. + tool_schema: Tool definitions for the environment. + system_prompt: System prompt for the agent. + action_mapper: Optional callable to transform actions. + + Returns: + Tuple of (trajectories, tool_managers, env_tools_list). + """ + # Get per-trajectory adapters from the pool + adapters = env_pool.get_adapters( + n=n_trajectories, + tool_schema=tool_schema, + action_mapper=action_mapper, + ) + + trajectories = [] + tool_managers = [] + env_tools_list = [] + + for adapter in adapters: + # Reset env slot to start a new episode + initial_result = adapter.reset() + initial_obs = initial_result.observation + + # Create EnvTool and ToolManager for this trajectory + env_tools = EnvTool.from_env(adapter) + tm = ToolManager(env_tools) + + # Build trajectory with initial observation as user message + traj = { + 'messages': [ + {'role': 'system', 'content': system_prompt}, + {'role': 'user', 'content': initial_obs}, + ], + 'tools': tool_schema, + } + + trajectories.append(traj) + tool_managers.append(tm) + env_tools_list.append(env_tools) + + return trajectories, tool_managers, env_tools_list + + +def extract_rewards(env_tools_list: List[List[EnvTool]]) -> List[float]: + """Extract episode rewards from EnvTool instances after rollout. + + Each EnvTool tracks the cumulative episode reward from env.step() calls. + """ + rewards = [] + for env_tools in env_tools_list: + if env_tools: + reward = env_tools[0].episode_reward + else: + reward = 0.0 + rewards.append(reward) + return rewards + + +# ========== Main ========== +def main(): + # Determine pool size + n_trajectories = BATCH_SIZE * NUM_GENERATIONS + pool_size = ENV_POOL_SIZE if ENV_POOL_SIZE > 0 else n_trajectories + + # Device groups: model + sampler + (optionally) env + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] + + if ENV_REMOTE: + # Add a CPU-only DeviceGroup for env pool (1 CPU process, colocated on same node) + device_groups.append( + DeviceGroup(name='env', ranks=1, device_type='CPU'), + ) + + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + + lora_config = LoraConfig( + target_modules='all-linear', + r=LORA_RANK, + lora_alpha=LORA_RANK * 2, + lora_dropout=0.05, + ) + + if USE_MEGATRON: + from twinkle.model.megatron import MegatronModel + model = MegatronModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + mixed_precision='bf16', + variable_seq_lengths=True, + ) + else: + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + if USE_MEGATRON: + model.set_optimizer('default', lr=LEARNING_RATE) + model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) + else: + model.set_optimizer('AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + + model.set_loss('GRPOLoss', epsilon=0.2) + model.set_processor(InputProcessor, padding_free=True) + model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) + + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'gpu_memory_utilization': 0.8, + 'max_model_len': 8192, + 'max_lora_rank': 32, + 'enable_lora': True, + 'enable_tower_connector_lora': True, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) + + # ========== EnvPool: environment instances managed by Twinkle ========== + env_pool_kwargs = dict( + env_cls=ENV_CLS, + pool_size=pool_size, + ) + if ENV_REMOTE: + # Deploy on dedicated CPU DeviceGroup + env_mesh = DeviceMesh.from_sizes(world_size=1, dp_size=1) + env_pool_kwargs['remote_group'] = 'env' + env_pool_kwargs['device_mesh'] = env_mesh + # else: runs locally in driver (zero RPC overhead) + + env_pool = EnvPool(**env_pool_kwargs) + logger.info(f'EnvPool created: env_cls={ENV_CLS}, pool_size={pool_size}, ' + f'remote={ENV_REMOTE}') + + # Local template for MultiTurnRollout bridge computation + rollout_template = Qwen3_5Template(MODEL_ID, max_length=8192, enable_thinking=False) + rollout_template.truncation_strategy = 'delete' + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) + + # MultiTurnRollout: tool_manager is optional at construction time; + # the actual per-trajectory ToolManagers are provided at call time. + sampling_params = SamplingParams( + max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, + temperature=1.0, top_p=0.95, + ) + rollout = MultiTurnRollout( + sampler=sampler, + template=rollout_template, + sampling_params=sampling_params, + max_turns=MAX_TURNS, + ) + + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + + optim_step = 0 + logger.info('Starting multi-turn GRPO training with EnvPool') + logger.info(f'ENV_CLS={ENV_CLS}, MAX_TURNS={MAX_TURNS}, NUM_GENERATIONS={NUM_GENERATIONS}') + logger.info(get_device_placement()) + + while optim_step < MAX_STEPS: + metrics.reset() + + # Total trajectories per batch: BATCH_SIZE * NUM_GENERATIONS + # Each trajectory is an independent game episode. + n_traj = BATCH_SIZE * NUM_GENERATIONS + + # 1. Prepare environments and initial trajectories + logger.info(f'[Step {optim_step}] Resetting {n_traj} environments...') + expand_prompts, tool_managers, env_tools_list = prepare_trajectories( + env_pool=env_pool, + n_trajectories=n_traj, + tool_schema=TOOL_SCHEMA, + system_prompt=SYSTEM_PROMPT, + action_mapper=blackjack_action_mapper, + ) + + # 2. Sync model weights to sampler + ckpt_manager.sync_weights(merge_and_sync=False) + sampler.reset_prefix_cache() + + # 3. Run multi-turn rollout with per-trajectory ToolManagers + all_trajectories: List[Dict[str, Any]] = rollout( + expand_prompts, + tool_manager=tool_managers, + ) + + # 4. Extract rewards and logprobs + env_rewards = extract_rewards(env_tools_list) + + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] + n_turns_per_rollout: List[int] = [] + + for traj in all_trajectories: + logprobs = traj.get('logprobs') or [] + old_logps = [lp[0][1] for lp in logprobs] if logprobs else [] + all_old_logps.append(old_logps) + # Completion length = number of trainable tokens (labels != -100) + labels = traj.get('labels') or [] + comp_len = sum(1 for l in labels if l != -100) + all_completion_lengths.append(comp_len) + n_turns_per_rollout.append(int(traj.get('turns') or 0)) + + # 5. Compute advantages (group-relative within NUM_GENERATIONS) + total_rewards = env_rewards + advantages = advantage_fn( + total_rewards, num_generations=NUM_GENERATIONS, scale='group', + ).tolist() + + # 6. Log metrics + metrics.accumulate( + completion_lengths=all_completion_lengths, + rewards={'total': total_rewards}, + ) + + avg_reward = sum(total_rewards) / len(total_rewards) if total_rewards else 0.0 + avg_turns = sum(n_turns_per_rollout) / len(n_turns_per_rollout) if n_turns_per_rollout else 0.0 + logger.info(f'[Step {optim_step}] avg_reward={avg_reward:.3f}, avg_turns={avg_turns:.1f}') + + # 7. Forward-backward with mini-batches + # Filter out oversized/truncated trajectories (strategy='delete'), + # keep only those with valid completions and ensure >= MODEL_GPUS inputs. + all_input_data: List[Dict[str, Any]] = [] + filtered_old_logps: List[List[float]] = [] + filtered_advantages: List[float] = [] + max_len = rollout_template.max_length or float('inf') + for i, traj in enumerate(all_trajectories): + traj_len = len(traj.get('input_ids') or traj.get('labels') or []) + comp_len = sum(1 for l in (traj.get('labels') or []) if l != -100) + if traj_len > max_len or comp_len == 0: + continue + all_input_data.append(traj) + filtered_old_logps.append(all_old_logps[i]) + filtered_advantages.append(advantages[i]) + + if len(all_input_data) < MODEL_GPUS: + logger.warning(f'[Step {optim_step}] Only {len(all_input_data)} valid trajectories ' + f'after filtering (need >= {MODEL_GPUS}), skipping this batch.') + continue + + all_old_logps = filtered_old_logps + advantages = filtered_advantages + total_completions = len(all_input_data) + logger.info(f'[Step {optim_step}] {total_completions}/{n_traj} trajectories ' + f'passed length filter (max_len={max_len})') + + for mb_start in range(0, total_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) + mb_inputs = all_input_data[mb_start:mb_end] + mb_old_logps = all_old_logps[mb_start:mb_end] + mb_advantages = advantages[mb_start:mb_end] + + # Print trajectory lengths before forward_backward + traj_lengths = [] + for idx, traj in enumerate(mb_inputs): + labels = traj.get('labels') or traj.get('input_ids') or [] + traj_lengths.append(len(labels)) + logger.info(f'[Step {optim_step}] mini-batch [{mb_start}:{mb_end}] ' + f'n_inputs={len(mb_inputs)}, dp_world={MODEL_GPUS}, ' + f'traj_lengths={traj_lengths}') + + model.forward_backward( + inputs=mb_inputs, + old_logps=mb_old_logps, + advantages=mb_advantages, + micro_batch_size=MICRO_BATCH_SIZE, + ) + model.clip_grad_and_step() + optim_step += 1 + + if optim_step >= MAX_STEPS: + break + if optim_step % SAVE_STEPS == 0: + model.save(f'multi-turn-grpo-checkpoint-{optim_step}') + + # 8. Log step summary + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric(is_training=True)) + log_dict['avg_turns'] = avg_turns + log_dict['avg_reward'] = avg_reward + metrics.reset() + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') + + # Cleanup + env_pool.close() + logger.info(f'Training completed. optim_steps={optim_step}') + model.save('multi-turn-grpo-final') + + +if __name__ == '__main__': + main() diff --git a/cookbook/sample/emb_sample.py b/cookbook/sample/emb_sample.py index 6d7e4c599..da27a8155 100644 --- a/cookbook/sample/emb_sample.py +++ b/cookbook/sample/emb_sample.py @@ -12,15 +12,15 @@ python cookbook/sample/emb_sample.py EMB_MODEL=./output/embedding_lora_transformers/step_16000 python cookbook/sample/emb_sample.py """ -import os import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import torch import torch.nn.functional as F import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.data_format import SamplingParams from twinkle.loss import InfonceLoss from twinkle.model import TransformersModel @@ -29,12 +29,13 @@ from twinkle.template import Template logger = get_logger() +args = CLI.from_args() # -- Config ------------------------------------------------------------------- -CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2') -EMB_MODEL_ID = os.environ.get('EMB_MODEL', 'ms://twinkle-kit/Qwen3.5-4B-QA-emb') -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 1)) -EMB_GPUS = int(os.environ.get('EMB_GPUS', 1)) +CONDENSE_MODEL_ID = args.extra.get('condense_model_id', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2') +EMB_MODEL_ID = args.extra.get('emb_model_id', 'ms://twinkle-kit/Qwen3.5-4B-QA-emb') +SAMPLER_GPUS = args.infra.sampler_gpus or 1 +EMB_GPUS = int(args.extra.get('emb_gpus', 1)) EMB_MAX_LENGTH = 8192 # -- Prompts (aligned with train_embedding_full_ddp.py) ----------------------- diff --git a/cookbook/sample/sample.py b/cookbook/sample/sample.py index b56460ea1..8cd452b8f 100644 --- a/cookbook/sample/sample.py +++ b/cookbook/sample/sample.py @@ -18,19 +18,19 @@ MODEL_ID=/path/to/model LORA_PATH=/path/to/adapter SAMPLER_GPUS=1 python sample.py """ -import os from typing import List, Dict, Any import twinkle from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.data_format import SamplingParams from twinkle.sampler import vLLMSampler logger = get_logger() +args = CLI.from_args() -MODEL_ID = os.environ.get('MODEL_ID', 'Qwen/Qwen3.5-4B') -LORA_PATH = os.environ.get('LORA_PATH', '/path/to/lora') -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 1)) +SAMPLER_GPUS = args.infra.sampler_gpus or 1 +LORA_PATH = args.lora.lora_path or '/path/to/lora' def build_prompts() -> List[Dict[str, Any]]: @@ -67,7 +67,7 @@ def main(): # ── 2. Create vLLMSampler with LoRA enabled ──────────────────────── sampler = vLLMSampler( - model_id=MODEL_ID, + model_id=args.model.model_id, engine_args={ 'gpu_memory_utilization': 0.7, 'max_model_len': 4096, @@ -79,7 +79,7 @@ def main(): device_mesh=sampler_mesh, remote_group='sampler', ) - sampler.set_template('Qwen3_5Template', model_id=MODEL_ID) + sampler.set_template('Qwen3_5Template', model_id=args.model.model_id) logger.info(get_device_placement()) # ── 3. Configure sampling parameters ──────────────────────────────── @@ -92,7 +92,7 @@ def main(): # ── 4. Run inference ──────────────────────────────────────────────── prompts = build_prompts() - logger.info(f'Sampling {len(prompts)} prompts with model {MODEL_ID} ...') + logger.info(f'Sampling {len(prompts)} prompts with model {args.model.model_id} ...') responses = sampler.sample(prompts, sampling_params, adapter_path=LORA_PATH) diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py index af72efa10..79daafd38 100644 --- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py +++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py @@ -4,7 +4,6 @@ Run on 8 GPUs: torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py """ -import os from pathlib import Path from peft import LoraConfig @@ -12,45 +11,31 @@ import twinkle from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() +args = CLI.from_args() -MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') -TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) -GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) -LOG_INTERVAL = GRAD_ACCUM_STEPS -LR = float(os.environ.get('LR', '1e-4')) -MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) -LORA_R = int(os.environ.get('LORA_R', '8')) -LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) -ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' -OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output_dsv4') -RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None -RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' -IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' -ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') -NUM_GPUS = int(os.environ.get('NUM_GPUS', '8')) +ENABLE_EP = args.extra.get('enable_ep', True) device_mesh = DeviceMesh.from_sizes( - fsdp_size=NUM_GPUS, - dp_size=1, - ep_size=NUM_GPUS, + fsdp_size=args.infra.fsdp_size, + dp_size=args.infra.dp_size, + ep_size=args.infra.ep_size, device_type=Platform.get_platform().device_prefix(), ) -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) def _build_lora_config(enable_ep: bool): if enable_ep: return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, + r=args.lora.lora_r, + lora_alpha=args.lora.lora_alpha, target_modules='all-linear', exclude_modules=['o_a_proj'], target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], @@ -60,8 +45,8 @@ def _build_lora_config(enable_ep: bool): # during forward. That is not stable with plain FSDP2, so non-EP mode uses # regular module LoRA and does not train expert parameters. return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, + r=args.lora.lora_r, + lora_alpha=args.lora.lora_alpha, exclude_modules=['o_a_proj'], target_modules='all-linear', ) @@ -70,31 +55,34 @@ def _build_lora_config(enable_ep: bool): def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): return model.save( name=checkpoint_name, - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, + output_dir=args.training.output_dir, + adapter_name=args.lora.adapter_name, + save_optimizer=args.checkpoint.save_optimizer, consumed_train_samples=dataloader.get_state()['consumed_train_samples'], ) def train(): - config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + config = AutoConfig.from_pretrained(args.model.model_id, trust_remote_code=True) text_config = getattr(config, 'text_config', config) if hasattr(text_config, 'use_cache'): text_config.use_cache = False - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) - dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id)) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) + dataset.map(SelfCognitionProcessor( + args.extra.get('model_name', 'twinkle'), + args.extra.get('model_author', 'ModelScope'), + )) dataset.encode(batched=True) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size, device_mesh=device_mesh) model = TransformersModel( - model_id=MODEL_ID, + model_id=args.model.model_id, config=config, device_mesh=device_mesh, strategy='native_fsdp', - memory_efficient_init=True, + memory_efficient_init=args.model.memory_efficient_init, fsdp_config={ 'expert_parallel': { 'enabled': ENABLE_EP, @@ -104,38 +92,41 @@ def train(): }, ) lora_cfg = _build_lora_config(ENABLE_EP) - model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - model.set_optimizer('AdamW', lr=LR, foreach=False) + model.add_adapter_to_model(args.lora.adapter_name, lora_cfg, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) + model.set_optimizer(args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate, foreach=False) model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, + scheduler_cls=args.scheduler.scheduler_cls, + num_warmup_steps=args.scheduler.num_warmup_steps, num_training_steps=len(dataloader), ) - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + if args.training.resume_from_checkpoint: + checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve() kwargs = {} - if ADAPTER_NAME: - kwargs['adapter_name'] = ADAPTER_NAME + if args.lora.adapter_name: + kwargs['adapter_name'] = args.lora.adapter_name progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - if not IGNORE_DATA_SKIP: + str(checkpoint_path), resume_only_model=args.training.resume_only_model, **kwargs) + if not args.training.ignore_data_skip: dataloader.resume_from_checkpoint(progress['consumed_train_samples']) logger.info(get_device_placement()) logger.info(model.get_train_configs()) logger.info( - f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' - f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') + f'Total steps: {len(dataloader)}, batch_size={args.training.batch_size}, ' + f'grad_accum={args.training.gradient_accumulation_steps}, ' + f'enable_ep={ENABLE_EP}, output_dir={args.training.output_dir}') - optimizer_group = model.optimizer_group[ADAPTER_NAME] + optimizer_group = model.optimizer_group[args.lora.adapter_name] for batch in dataloader: if callable(batch): batch = batch() model.forward_backward(inputs=batch) - model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.clip_grad_and_step(max_grad_norm=args.optimizer.max_grad_norm, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) cur_step = optimizer_group.cur_step - if cur_step > 0 and cur_step % LOG_INTERVAL == 0: + if cur_step > 0 and cur_step % args.training.log_interval == 0: metric = model.calculate_metric(is_training=True) if callable(metric): metric = metric() diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh index b4e3d9ffb..f2b01ff6e 100644 --- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh +++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh @@ -4,13 +4,18 @@ set -euo pipefail # EP + FSDP2 + LoRA training for DeepSeek-V4. # ENABLE_EP=1 trains expert LoRA with target_parameters. # ENABLE_EP=0 runs plain FSDP2 LoRA and does not train expert parameters. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash ep_fsdp2_lora_deepseek_v4.sh --batch-size 8 --lr 5e-5 -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}" -export NPROC_PER_NODE="${NPROC_PER_NODE:-8}" -export ENABLE_EP="${ENABLE_EP:-1}" -export BATCH_SIZE="${BATCH_SIZE:-4}" -export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}" -export OUTPUT_DIR="${OUTPUT_DIR:-./output_dsv4}" - -torchrun --nproc-per-node="${NPROC_PER_NODE}" \ - cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --nproc-per-node=8 \ + cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py \ + --model-id ms://deepseek-ai/DeepSeek-V3-0324 \ + --dataset-id ms://swift/self-cognition \ + --dp-size 4 \ + --ep-size 2 \ + --batch-size 4 \ + --gradient-accumulation-steps 4 \ + --output-dir ./output_dsv4 \ + --enable-ep 1 \ + "$@" diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh index 7344474e5..47a7ebfc4 100644 --- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh +++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh @@ -1,26 +1,39 @@ +#!/usr/bin/env bash +set -euo pipefail # `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights. # Convert the checkpoint before training by following: # https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87 # Install `transformers==5.8.0` before running this cookbook. +# All training config passed as CLI flags. Override at invocation. -export DSV4_MODEL_ID="ms://deepseek-ai/DeepSeek-V4-Flash-bf16" -export DATASET_ID="ms://swift/self-cognition" -# The following environment variables are required for multi-node training. Adjust the values according to your cluster setup. -export GLOO_SOCKET_IFNAME="eth0" # Use ifconfig to check the network interface name +# Multi-node networking config — adjust to your cluster setup. +export GLOO_SOCKET_IFNAME="eth0" export HCCL_SOCKET_IFNAME="eth0" export HCCL_EXEC_TIMEOUT=1200 export HCCL_CONNECT_TIMEOUT=1200 -export NNODES=4 -export NUM_GPUS=64 -export MASTER_ADDR="node0" # Replace with the IP address or hostname of the master node -export MASTER_PORT=29500 # Replace with an open port on the master node export HCCL_IF_BASE_PORT=20000 -torchrun --nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=16 \ - --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT ep_fsdp2_lora_deepseek_v4.py +NNODES=4 +MASTER_ADDR=node0 +MASTER_PORT=29500 +NPROC_PER_NODE=16 -# NODE_RANK=0 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh -# NODE_RANK=1 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh -# NODE_RANK=2 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh -# NODE_RANK=3 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh +torchrun --nnodes=$NNODES --node_rank=${NODE_RANK:?"NODE_RANK must be set"} \ + --nproc_per_node=$NPROC_PER_NODE \ + --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ + ep_fsdp2_lora_deepseek_v4.py \ + --model-id ms://deepseek-ai/DeepSeek-V4-Flash-bf16 \ + --dataset-id ms://swift/self-cognition \ + --dp-size 4 \ + --ep-size 2 \ + --batch-size 4 \ + --gradient-accumulation-steps 4 \ + --output-dir ./output_dsv4_multinode \ + --enable-ep 1 \ + "$@" + +# NODE_RANK=0 bash ep_fsdp2_lora_deepseek_v4_multinode.sh +# NODE_RANK=1 bash ep_fsdp2_lora_deepseek_v4_multinode.sh +# NODE_RANK=2 bash ep_fsdp2_lora_deepseek_v4_multinode.sh +# NODE_RANK=3 bash ep_fsdp2_lora_deepseek_v4_multinode.sh diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py index a9f90111d..5d5088182 100644 --- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py @@ -4,7 +4,6 @@ Run on 8 GPUs: torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py """ -import os from pathlib import Path from peft import LoraConfig @@ -12,6 +11,7 @@ import twinkle from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.cli import CLI from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel @@ -20,38 +20,24 @@ from twinkle.kernel import kernelize_model logger = get_logger() +args = CLI.from_args() -MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') -TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template') -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) -GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) -LOG_INTERVAL = GRAD_ACCUM_STEPS -LR = float(os.environ.get('LR', '1e-4')) -MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) -LORA_R = int(os.environ.get('LORA_R', '8')) -LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) -ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' -OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') -RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None -RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' -IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' -ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') +ENABLE_EP = args.extra.get('enable_ep', True) device_mesh = DeviceMesh.from_sizes( - fsdp_size=8, - dp_size=1, - ep_size=8, + fsdp_size=args.infra.fsdp_size, + dp_size=args.infra.dp_size, + ep_size=args.infra.ep_size, device_type=Platform.get_platform().device_prefix(), ) -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) def _build_lora_config(enable_ep: bool): if enable_ep: return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, + r=args.lora.lora_r, + lora_alpha=args.lora.lora_alpha, target_modules='all-linear', target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], ) @@ -60,8 +46,8 @@ def _build_lora_config(enable_ep: bool): # during forward. That is not stable with plain FSDP2, so non-EP mode uses # regular module LoRA and does not train expert parameters. return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, + r=args.lora.lora_r, + lora_alpha=args.lora.lora_alpha, target_modules='all-linear', ) @@ -69,30 +55,33 @@ def _build_lora_config(enable_ep: bool): def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): return model.save( name=checkpoint_name, - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, + output_dir=args.training.output_dir, + adapter_name=args.lora.adapter_name, + save_optimizer=args.checkpoint.save_optimizer, consumed_train_samples=dataloader.get_state()['consumed_train_samples'], ) def train(): - config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + config = AutoConfig.from_pretrained(args.model.model_id, trust_remote_code=True) text_config = getattr(config, 'text_config', config) if hasattr(text_config, 'use_cache'): text_config.use_cache = False - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id)) try: - dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) except ValueError: - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) + dataset.set_template('Qwen3_5Template', model_id=args.model.model_id) + dataset.map(SelfCognitionProcessor( + args.extra.get('model_name', 'twinkle'), + args.extra.get('model_author', 'ModelScope'), + )) dataset.encode(batched=True) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size, device_mesh=device_mesh) model = TransformersModel( - model_id=MODEL_ID, + model_id=args.model.model_id, config=config, device_mesh=device_mesh, strategy='native_fsdp', @@ -108,38 +97,41 @@ def train(): if Torch.is_npu_available(): model = kernelize_model(model, mode='train', device='npu') lora_cfg = _build_lora_config(ENABLE_EP) - model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - model.set_optimizer('AdamW', lr=LR, foreach=False) + model.add_adapter_to_model(args.lora.adapter_name, lora_cfg, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) + model.set_optimizer(args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate, foreach=False) model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, + scheduler_cls=args.scheduler.scheduler_cls, + num_warmup_steps=args.scheduler.num_warmup_steps, num_training_steps=len(dataloader), ) - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + if args.training.resume_from_checkpoint: + checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve() kwargs = {} - if ADAPTER_NAME: - kwargs['adapter_name'] = ADAPTER_NAME + if args.lora.adapter_name: + kwargs['adapter_name'] = args.lora.adapter_name progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - if not IGNORE_DATA_SKIP: + str(checkpoint_path), resume_only_model=args.training.resume_only_model, **kwargs) + if not args.training.ignore_data_skip: dataloader.resume_from_checkpoint(progress['consumed_train_samples']) logger.info(get_device_placement()) logger.info(model.get_train_configs()) logger.info( - f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' - f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') + f'Total steps: {len(dataloader)}, batch_size={args.training.batch_size}, ' + f'grad_accum={args.training.gradient_accumulation_steps}, ' + f'enable_ep={ENABLE_EP}, output_dir={args.training.output_dir}') - optimizer_group = model.optimizer_group[ADAPTER_NAME] + optimizer_group = model.optimizer_group[args.lora.adapter_name] for batch in dataloader: if callable(batch): batch = batch() model.forward_backward(inputs=batch) - model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.clip_grad_and_step(max_grad_norm=args.optimizer.max_grad_norm, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) cur_step = optimizer_group.cur_step - if cur_step > 0 and cur_step % LOG_INTERVAL == 0: + if cur_step > 0 and cur_step % args.training.log_interval == 0: metric = model.calculate_metric(is_training=True) if callable(metric): metric = metric() diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh index 8f1813e4f..5132d9d0b 100644 --- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh +++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh @@ -2,15 +2,19 @@ set -euo pipefail # EP + FSDP2 + LoRA training for Qwen3.5-MoE. -# ENABLE_EP=1 trains expert LoRA with target_parameters. -# ENABLE_EP=0 runs plain FSDP2 LoRA and does not train expert parameters. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash ep_fsdp2_lora_qwen3_5_moe.sh --batch-size 8 --lr 5e-5 -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}" -export NPROC_PER_NODE="${NPROC_PER_NODE:-8}" -export ENABLE_EP="${ENABLE_EP:-1}" -export BATCH_SIZE="${BATCH_SIZE:-4}" -export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}" -export OUTPUT_DIR="${OUTPUT_DIR:-./output_qwen3_5_moe}" - -torchrun --nproc-per-node="${NPROC_PER_NODE}" \ - cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --nproc-per-node=8 \ + cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py \ + --model-id ms://Qwen/Qwen3.5-30B-A3B \ + --dataset-id ms://swift/self-cognition \ + --template-cls Qwen3_5Template \ + --dp-size 4 \ + --ep-size 2 \ + --batch-size 4 \ + --gradient-accumulation-steps 4 \ + --output-dir ./output_qwen3_5_moe \ + --enable-ep 1 \ + "$@" diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index ad4c917f9..a3b4da645 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -2,6 +2,7 @@ from peft import LoraConfig from tqdm import tqdm +from torch.optim import Muon # PyTorch 2.9+; matrix-orthogonalized momentum optimizer. import twinkle from twinkle import DeviceMesh, get_device_placement, get_logger @@ -51,7 +52,7 @@ def evaluate(model): def train(): - train_samples = int(args.extra.get('train_samples', 1000)) + train_samples = args.training.train_samples or 1000 dataset = build_dataset(train_samples) dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) model = TransformersModel(model_id=args.model.model_id) @@ -64,7 +65,16 @@ def train(): model.add_adapter_to_model( args.lora.adapter_name, lora_config, gradient_accumulation_steps=args.training.gradient_accumulation_steps) - model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) + # Muon optimizes 2D hidden-layer weight matrices via Newton-Schulz orthogonalization. + # In LoRA training the trainable params are exclusively lora_A / lora_B (both 2D), + # so Muon applies cleanly without an AdamW fallback for 1D params. + # ``adjust_lr_fn='match_rms_adamw'`` rescales the orthogonalized update so the same + # lr / weight_decay tuned for AdamW can be reused directly (Moonshot Muon recipe). + model.set_optimizer( + optimizer_cls=Muon, + lr=args.optimizer.learning_rate, + adjust_lr_fn='match_rms_adamw', + ) # Add LRScheduler for lora `default` model.set_lr_scheduler( diff --git a/cookbook/transformers/fsdp2.sh b/cookbook/transformers/fsdp2.sh index bbe269629..c372fbf2e 100644 --- a/cookbook/transformers/fsdp2.sh +++ b/cookbook/transformers/fsdp2.sh @@ -2,7 +2,7 @@ # All training config passed as CLI flags. Override at invocation, e.g.: # bash fsdp2.sh --batch-size 16 --lr 5e-5 -CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ torchrun --nproc_per_node=8 fsdp2.py \ --model-id ms://Qwen/Qwen3.5-4B \ --dataset-id ms://swift/self-cognition \ diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index a6fd0bdcb..56f22c801 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -1,9 +1,9 @@ -import numpy as np from functools import partial from peft import LoraConfig import twinkle -from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger +from twinkle import DeviceMesh, get_logger +from twinkle.cli import CLI from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel @@ -12,48 +12,44 @@ from twinkle.kernel import kernelize_model logger = get_logger() -MODEL_ID = 'ms://Qwen/Qwen3.5-4B' -DATASETS = 'ms://swift/self-cognition' - -device_group = [DeviceGroup( - name='default', - ranks=[0, 1, 2, 3], - device_type=Platform.get_platform().device_prefix(), -)] +args = CLI.from_args() # FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2. # In Transformers route, ulysses_size is the total sequence-parallel degree. -device_mesh = DeviceMesh( - device_type=Platform.get_platform().device_prefix(), - mesh=np.arange(4).reshape(2, 2), - mesh_dim_names=('dp', 'fsdp'), - ulysses_size=2, +device_mesh = DeviceMesh.from_sizes( + dp_size=args.infra.dp_size, + fsdp_size=args.infra.fsdp_size, + ulysses_size=args.infra.ulysses_size, ) twinkle.initialize( - mode='local', - nproc_per_node=4, + mode=args.infra.mode, global_device_mesh=device_mesh, - lazy_collect=False, + lazy_collect=args.infra.lazy_collect, ) def eval(model): + eval_samples = args.training.eval_samples or 100 dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=range(100)), - batch_size=4, + dataset=partial(create_dataset, data_slice=range(eval_samples)), + batch_size=args.training.batch_size, device_mesh=device_mesh, ) for _, batch in enumerate(dataloader): - model.forward_only(inputs=batch, adapter_name='default') - model.calculate_loss(adapter_name='default') - return model.calculate_metric(is_training=False, adapter_name='default') + model.forward_only(inputs=batch, adapter_name=args.lora.adapter_name) + model.calculate_loss(adapter_name=args.lora.adapter_name) + return model.calculate_metric(is_training=False, adapter_name=args.lora.adapter_name) def create_dataset(data_slice=None): - dataset = Dataset(dataset_meta=DatasetMeta(DATASETS, data_slice=range(500))) - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队')) + train_samples = args.training.train_samples or 500 + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=data_slice or range(train_samples))) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) + dataset.map(SelfCognitionProcessor( + args.extra.get('model_name', 'twinkle模型'), + args.extra.get('model_author', 'twinkle团队'), + )) dataset.encode(batched=True) return dataset @@ -61,36 +57,38 @@ def create_dataset(data_slice=None): def train(): dataloader = DataLoader( dataset=partial(create_dataset, data_slice=None), - batch_size=8, + batch_size=args.training.batch_size, device_mesh=device_mesh, ) model = TransformersModel( - model_id=MODEL_ID, + model_id=args.model.model_id, device_mesh=device_mesh, - strategy='native_fsdp', + strategy=args.model.strategy, ) # npu patch if Torch.is_npu_available(): model = kernelize_model(model, mode='train', device='npu') - lora_config = LoraConfig(target_modules='all-linear') - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) - model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') + lora_config = LoraConfig(**args.get_lora_args()) + model.add_adapter_to_model(args.lora.adapter_name, lora_config, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) + model.set_optimizer(args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate, + adapter_name=args.lora.adapter_name) model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, + scheduler_cls=args.scheduler.scheduler_cls, + num_warmup_steps=args.scheduler.num_warmup_steps, num_training_steps=len(dataloader), - adapter_name='default', + adapter_name=args.lora.adapter_name, ) - logger.info(model.get_train_configs(adapter_name='default')) + logger.info(model.get_train_configs(adapter_name=args.lora.adapter_name)) logger.info(f'Total steps: {len(dataloader)}') for step, batch in enumerate(dataloader): - model.forward_backward(inputs=batch, adapter_name='default') - model.clip_grad_and_step(adapter_name='default') - if step % 20 == 0: - metric = model.calculate_metric(is_training=True, adapter_name='default') + model.forward_backward(inputs=batch, adapter_name=args.lora.adapter_name) + model.clip_grad_and_step(adapter_name=args.lora.adapter_name) + if step % args.training.log_interval == 0: + metric = model.calculate_metric(is_training=True, adapter_name=args.lora.adapter_name) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') model.save('last-checkpoint', interval=1) diff --git a/cookbook/transformers/sp_fsdp_dense.sh b/cookbook/transformers/sp_fsdp_dense.sh index 2a8bcf08b..841561fe4 100644 --- a/cookbook/transformers/sp_fsdp_dense.sh +++ b/cookbook/transformers/sp_fsdp_dense.sh @@ -1,11 +1,24 @@ -#!/bin/bash -# To enable Transformers sequence parallelism, please set ulysses_size > 1. -# ulysses_size is interpreted as the total sequence-parallel degree. -# device_mesh = DeviceMesh( -# device_type="cuda", -# mesh=np.arange(4).reshape(2, 2), -# mesh_dim_names=("dp", "fsdp"), -# ulysses_size=2, -# ) -# -CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 sp_fsdp_dense.py +#!/usr/bin/env bash +set -euo pipefail + +# FSDP + Sequence Parallelism training. +# To enable Transformers sequence parallelism, set ulysses-size > 1. +# All training config passed as CLI flags. Override at invocation, e.g.: +# bash sp_fsdp_dense.sh --model-id ms://Qwen/Qwen3.5-4B --ulysses-size 4 + +CUDA_VISIBLE_DEVICES=0,1,2,3 \ + torchrun --nproc_per_node=4 sp_fsdp_dense.py \ + --model-id ms://Qwen/Qwen3.5-4B \ + --dataset-id ms://swift/self-cognition \ + --template-cls Qwen3_5Template \ + --dp-size 2 \ + --fsdp-size 2 \ + --ulysses-size 2 \ + --batch-size 4 \ + --lr 1e-4 \ + --gradient-accumulation-steps 2 \ + --train-samples 500 \ + --log-interval 10 \ + --model-name twinkle模型 \ + --model-author twinkle团队 \ + "$@" diff --git a/docs/source_en/Components/Agentic/Envs.md b/docs/source_en/Components/Agentic/Envs.md new file mode 100644 index 000000000..3e2e90b3a --- /dev/null +++ b/docs/source_en/Components/Agentic/Envs.md @@ -0,0 +1,183 @@ +# Environments (Envs) + +The Envs module provides an RL execution environment abstraction for agentic training. Environments can participate in multi-turn rollouts interactively or evaluate completed trajectories in batch. + +## Env Base Class + +```python +from twinkle_agentic.envs.base import Env, StepResult + +class Env(ABC): + + def reset(self, trajectory=None) -> StepResult: + """Reset for a new episode.""" + + @abstractmethod + def step(self, tool_name: str, arguments: dict) -> StepResult: + """Execute a single action, return observation + reward + done.""" + + def tools(self) -> List[ToolInfo]: + """Return tool definitions available in this environment.""" + + def evaluate(self, trajectories, **kwargs) -> List[float]: + """Batch-evaluate completed trajectories, return rewards.""" + + def close(self) -> None: + """Release resources.""" +``` + +### StepResult + +```python +@dataclass +class StepResult: + observation: str = '' # Environment observation after the action + reward: float = 0.0 # Scalar reward for this step + done: bool = False # Whether the episode is terminated + info: Dict[str, Any] = field(default_factory=dict) # Extra metadata +``` + +### Two Usage Modes + +1. **Interactive mode** (multi-turn rollout) — step-by-step execution: + +```python +env = MyEnv() +env.reset(trajectory) +result = env.step('search', {'query': 'Python'}) +# ... repeat until result.done +``` + +2. **Batch evaluation mode** — evaluate completed trajectories: + +```python +rewards = env.evaluate(completed_trajectories) +``` + +## EnvTool + +`EnvTool` wraps an `Env` as a `Tool`, bridging the environment with `ToolManager` and `MultiTurnRollout`. + +```python +from twinkle_agentic.envs.env_tool import EnvTool +from twinkle_agentic.tools.tool_manager import ToolManager + +env = MyEnv() + +# Create one EnvTool per tool defined in the environment +env_tools = EnvTool.from_env(env) + +# Register into ToolManager +manager = ToolManager(env_tools) +``` + +### Key Features + +| Feature | Description | +|---------|-------------| +| `from_env(env)` | Factory: creates one `EnvTool` per tool in `env.tools()`. | +| `last_result` | Stores the most recent `StepResult` for inspection. | +| `done` | Property: whether the last step terminated the episode. | +| `episode_reward` | Property: cumulative reward from `info['episode_reward']`. | + +### Manual Construction + +```python +env_tool = EnvTool( + env=my_env, + tool_name='execute_code', + description='Execute Python code in a sandbox.', + parameters={ + 'type': 'object', + 'properties': { + 'code': {'type': 'string', 'description': 'Python code to execute.'}, + }, + 'required': ['code'], + }, +) +``` + +## OpenEnv + +`OpenEnv` adapts an [OpenEnv](https://github.com/OpenEnv) WebSocket-based environment server as a synchronous Twinkle `Env`. + +```python +from twinkle_agentic.envs.openenv import OpenEnv + +env = OpenEnv( + base_url='http://localhost:8000', + env_cls='coding_env.CodingEnv', # Optional typed client + env_kwargs={'message_timeout_s': 30}, + tool_schema=[...], # Optional tool definitions +) +``` + +### Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `base_url` | `str` | URL of the running OpenEnv server. | +| `env_cls` | `str` or class | Dotted import path or class for a typed client. `None` uses `GenericEnvClient`. | +| `env_kwargs` | `Dict` | Extra kwargs for the client constructor. | +| `tool_schema` | `List[ToolInfo]` | Tool definitions exposed via `tools()`. | +| `action_mapper` | `Callable` | Custom function to map `(tool_name, args)` to the action dict sent to the server. | + +### Usage with Rollout + +```python +from twinkle_agentic.envs.openenv import OpenEnv +from twinkle_agentic.envs.env_tool import EnvTool +from twinkle_agentic.tools.tool_manager import ToolManager +from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout + +# Set up environment +env = OpenEnv(base_url='http://localhost:8000', tool_schema=[...]) +env.reset() + +# Bridge to ToolManager +env_tools = EnvTool.from_env(env) +manager = ToolManager(env_tools) + +# Use in rollout +rollout = APIMultiTurnRollout(api=api, tool_manager=manager, max_turns=10) +results = rollout(trajectories) +``` + +### Implementing a Custom Environment + +```python +from twinkle_agentic.envs.base import Env, StepResult + +class CodeExecutionEnv(Env): + + def reset(self, trajectory=None): + self._sandbox = create_sandbox() + return StepResult(observation='Sandbox ready.') + + def step(self, tool_name, arguments): + code = arguments.get('code', '') + output = self._sandbox.run(code) + return StepResult( + observation=output, + reward=1.0 if 'error' not in output.lower() else 0.0, + done=False, + ) + + def tools(self): + return [{ + 'type': 'function', + 'function': { + 'name': 'execute_code', + 'description': 'Run Python code.', + 'parameters': { + 'type': 'object', + 'properties': { + 'code': {'type': 'string'}, + }, + }, + }, + }] + + def close(self): + self._sandbox.cleanup() +``` diff --git a/docs/source_en/Components/Agentic/Multi-Turn-Tool-Usage.md b/docs/source_en/Components/Agentic/Multi-Turn-Tool-Usage.md new file mode 100644 index 000000000..06a962a33 --- /dev/null +++ b/docs/source_en/Components/Agentic/Multi-Turn-Tool-Usage.md @@ -0,0 +1,205 @@ +# Multi-Turn Tool Usage Guide + +This guide shows how to set up and run multi-turn agentic rollouts with tool use in Twinkle. + +## Architecture Overview + +The agentic rollout pipeline consists of four key components: + +- **Tool** — implements a specific capability (search, code execution, etc.) +- **ToolManager** — registers tools and dispatches LLM tool calls +- **Env** (optional) — RL environment that exposes tools via `EnvTool` +- **Rollout** — drives the multi-turn conversation loop + +## Quick Start: API-based Rollout + +The simplest way to run a multi-turn tool-use rollout using an OpenAI-compatible API: + +```python +from twinkle_agentic.protocol.openai import OpenAI +from twinkle_agentic.tools.base import Tool +from twinkle_agentic.tools.tool_manager import ToolManager +from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout +from twinkle.data_format.sampling import SamplingParams + +# 1. Define tools +class WeatherTool(Tool): + def __call__(self, tool_name, arguments): + city = arguments.get('city', 'unknown') + return f'The weather in {city} is sunny, 25°C.' + + def tool_info(self): + return { + 'type': 'function', + 'function': { + 'name': 'get_weather', + 'description': 'Get the current weather for a city.', + 'parameters': { + 'type': 'object', + 'properties': { + 'city': {'type': 'string', 'description': 'City name.'}, + }, + 'required': ['city'], + }, + }, + } + +# 2. Set up ToolManager +manager = ToolManager([WeatherTool()]) + +# 3. Create API client +api = OpenAI(model='qwen3.5-32b', base_url='http://localhost:8000/v1') + +# 4. Create rollout +rollout = APIMultiTurnRollout( + api=api, + tool_manager=manager, + sampling_params=SamplingParams(temperature=0.7, max_tokens=2048), + max_turns=6, + concurrency=8, +) + +# 5. Prepare trajectories +trajectories = [ + { + 'messages': [ + {'role': 'user', 'content': "What's the weather like in Beijing?"}, + ], + }, +] + +# 6. Run rollout +results = rollout(trajectories) +for r in results: + print(f"Turns: {r['turns']}, Stop: {r['stop_reason']}") + for msg in r['messages']: + print(f" [{msg['role']}] {msg.get('content', '')[:100]}") +``` + +## Training Integration: vLLM-based Rollout + +For RLHF training, use `MultiTurnRollout` which produces `input_ids` and `labels`: + +```python +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout +from twinkle.data_format.sampling import SamplingParams + +rollout = MultiTurnRollout( + sampler=vllm_sampler, # vLLMSampler instance + template=template, # Chat template + tool_manager=manager, + sampling_params=SamplingParams(temperature=0.7, max_tokens=4096), + max_turns=6, + max_trajectory_tokens=8192, + trace_dir='rollout_traces/', +) + +# In GRPO training loop +results = rollout(batch_trajectories) +# results contain input_ids, labels, logprobs for training +``` + +## Using Environments as Tools + +Bridge an RL environment into the tool pipeline: + +```python +from twinkle_agentic.envs.base import Env, StepResult +from twinkle_agentic.envs.env_tool import EnvTool +from twinkle_agentic.tools.tool_manager import ToolManager + +# Define environment +class CodeEnv(Env): + def step(self, tool_name, arguments): + code = arguments.get('code', '') + # Execute code in sandbox + result = execute_in_sandbox(code) + return StepResult(observation=result, reward=1.0, done=False) + + def tools(self): + return [{ + 'type': 'function', + 'function': { + 'name': 'run_python', + 'description': 'Execute Python code.', + 'parameters': { + 'type': 'object', + 'properties': { + 'code': {'type': 'string'}, + }, + 'required': ['code'], + }, + }, + }] + +# Bridge Env -> Tool -> ToolManager +env = CodeEnv() +env_tools = EnvTool.from_env(env) +manager = ToolManager(env_tools) + +# Use manager in rollout as usual +rollout = APIMultiTurnRollout(api=api, tool_manager=manager, max_turns=10) +``` + +## Using OpenEnv Environments + +Connect to a remote OpenEnv WebSocket server: + +```python +from twinkle_agentic.envs.openenv import OpenEnv +from twinkle_agentic.envs.env_tool import EnvTool + +env = OpenEnv( + base_url='http://localhost:8000', + env_cls='coding_env.CodingEnv', + tool_schema=[{ + 'type': 'function', + 'function': { + 'name': 'submit', + 'description': 'Submit code solution.', + 'parameters': { + 'type': 'object', + 'properties': { + 'code': {'type': 'string'}, + }, + }, + }, + }], +) + +env.reset() +env_tools = EnvTool.from_env(env) +manager = ToolManager(env_tools) +``` + +## Per-Trajectory Tool Managers + +For scenarios where each trajectory needs its own tool set (e.g., trajectory-bound state): + +```python +# Create per-trajectory managers +managers = [] +for traj in trajectories: + env = create_env_for(traj) + env_tools = EnvTool.from_env(env) + managers.append(ToolManager(env_tools)) + +# Pass as a list (aligned 1:1 with trajectories) +results = rollout(trajectories, tool_manager=managers) +``` + +## Trace Debugging + +Both rollout implementations support trace dumps for debugging: + +```python +rollout = APIMultiTurnRollout( + api=api, + tool_manager=manager, + trace_dir='traces/', + trace_callback=lambda t: t['turns'] > 1, # Only store multi-turn + success_callback=lambda t: t.get('stop_reason') == 'stop', +) +``` + +Trace files are saved as `{step}-{ok|fail}-{id}.json` with the full conversation and metadata. diff --git a/docs/source_en/Components/Agentic/Preprocessor.md b/docs/source_en/Components/Agentic/Preprocessor.md new file mode 100644 index 000000000..b646739c0 --- /dev/null +++ b/docs/source_en/Components/Agentic/Preprocessor.md @@ -0,0 +1,189 @@ +# Agentic Preprocessor + +The agentic preprocessor module provides a pipeline-based data quality filtering framework for multi-turn conversation datasets. It is designed for cleaning and filtering training data before RLHF / agentic fine-tuning. + +## QualityPreprocessor + +`QualityPreprocessor` is a thin pipeline runner that accepts a list of filter callables and runs them in sequence. Each step receives a list of rows, returns `(kept, dropped)`, and the pipeline logs per-step statistics. + +```python +from twinkle_agentic.preprocessor import QualityPreprocessor, HardFilter, DeadLoopFilter + +pipeline = [ + HardFilter(min_user_chars=10), + DeadLoopFilter(), +] +preprocessor = QualityPreprocessor(pipeline, dropped_log_path='dropped.jsonl') + +# rows is a dict of columns (Dataset.map format) +cleaned = preprocessor(rows) +``` + +### Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipeline` | `List[Callable]` | Ordered list of filter steps. Each step takes `List[Dict]` and returns `(kept, dropped)`. | +| `dropped_log_path` | `str` | Optional JSONL file path for logging dropped rows with step name and reason. | + +## Built-in Filters + +### HardFilter + +Rule-based filter that removes trivially bad rows using deterministic rules. Supports multi-language detection (EN/ZH/JA/KO). + +```python +from twinkle_agentic.preprocessor import HardFilter + +f = HardFilter( + min_user_chars=10, # Min chars for non-CJK user query + min_user_chars_cjk=6, # Min chars for CJK user query + min_assistant_chars_2turn=80, # Min assistant reply length (2-turn) + min_thinking_chars=200, # Min thinking chain length to exempt + system_deny_keywords=['hack', 'exploit'], + max_chars_per_round=50000, + max_total_chars=200000, + max_rounds=50, +) +``` + +**Drop reasons:** `trivial_single_turn`, `shallow_reply`, `all_empty_assistant`, `system_deny_keyword`, `round_too_long`, `total_too_long`, `too_many_rounds` + +### DeadLoopFilter + +Detects assistant messages exhibiting hesitation or dead-loop patterns — repetitive self-corrections, cascading corrections, and high n-gram repetition. + +```python +from twinkle_agentic.preprocessor import DeadLoopFilter + +f = DeadLoopFilter( + hesitation_density_threshold=7.0, # Markers per 1000 chars (response) + cascade_threshold=5, # Cascade markers in window + cascade_window=800, # Window size in chars + repetition_threshold=0.45, # N-gram repetition ratio + think_hesitation_density_threshold=15.0, # Laxer for blocks + think_repetition_threshold=0.65, +) +``` + +Uses separate threshold profiles for `` reasoning blocks (laxer, free to ramble) and visible response (stricter). + +### DedupFilter + +Global longest-wins deduplication. The signature is derived from the first real user turn (head+tail) and the first assistant reply. + +```python +from twinkle_agentic.preprocessor import DedupFilter + +f = DedupFilter(prefix_chars=100, asst_chars=100) +kept, dropped = f(all_rows) # Must see entire dataset in one call +``` + +> **Note:** `DedupFilter` requires the full dataset in a single call. Do **not** place it inside `QualityPreprocessor` (which processes per-batch). Run it separately before or after the pipeline. + +### RefuseFilter + +Detects self-referential refusals in the first assistant reply (e.g., "I cannot help with that"). Multi-language pattern matching (EN/ZH/JA/KO). + +```python +from twinkle_agentic.preprocessor import RefuseFilter + +f = RefuseFilter(check_window=600) # Only check first N chars +``` + +### TokenSoupFilter + +Detects garbled / token-soup output by checking for replacement characters, control characters, private-use Unicode, leaked special tokens, single-character repetition, and script chaos. + +```python +from twinkle_agentic.preprocessor import TokenSoupFilter + +f = TokenSoupFilter( + replacement_char_ratio=0.02, + special_token_count=20, + script_chaos_threshold=0.55, +) +``` + +### PIIPresidioFilter + +Multi-language PII detection and rewriting using Microsoft Presidio + spaCy NER + Faker. Detects and replaces personal identifiable information (names, emails, phone numbers, addresses, etc.). + +```python +from twinkle_agentic.preprocessor import PIIPresidioFilter + +f = PIIPresidioFilter(languages=['en', 'zh']) +``` + +### IntentClassifier + +Heuristic intent classifier that tags each row with detected intents. Pluggable detector pipeline. + +```python +from twinkle_agentic.preprocessor import IntentClassifier + +classifier = IntentClassifier() +``` + +**Intent categories:** `tool_call`, `code`, `math`, `complex_logic`, `reasoning`, `user_dissatisfaction`, `other` + +### ScoreFilter + +Pluggable scorer-based filter with built-in scorers for character-level metrics, semantic similarity, and code execution. + +```python +from twinkle_agentic.preprocessor import ScoreFilter + +f = ScoreFilter() +``` + +**Built-in scorers:** `ChrMinScorer`, `SIFDScorer`, `PassNScorer`, `ParaphraseScorer` + +### ModelFilter + +Filters rows by model ID whitelist. + +```python +from twinkle_agentic.preprocessor import ModelFilter + +f = ModelFilter(allowed_models=['qwen3.5-4b', 'qwen3.5-32b']) +``` + +### MessageNormalizer + +Three-pass message normalization: heartbeat stripping, tool-call rewriting, and consecutive same-role message merging. + +```python +from twinkle_agentic.preprocessor import MessageNormalizer + +normalizer = MessageNormalizer() +``` + +## Complete Pipeline Example + +```python +from twinkle_agentic.preprocessor import ( + QualityPreprocessor, + HardFilter, + DeadLoopFilter, + RefuseFilter, + TokenSoupFilter, + MessageNormalizer, + DedupFilter, +) + +# Step 1: Global dedup (must run on full dataset) +dedup = DedupFilter() +rows, _ = dedup(all_rows) + +# Step 2: Per-batch pipeline +pipeline = [ + HardFilter(min_user_chars=10, max_rounds=30), + DeadLoopFilter(), + RefuseFilter(), + TokenSoupFilter(), + MessageNormalizer(), +] +preprocessor = QualityPreprocessor(pipeline, dropped_log_path='dropped.jsonl') +cleaned = preprocessor(rows) +``` diff --git a/docs/source_en/Components/Agentic/Protocol.md b/docs/source_en/Components/Agentic/Protocol.md new file mode 100644 index 000000000..a44c5c3a2 --- /dev/null +++ b/docs/source_en/Components/Agentic/Protocol.md @@ -0,0 +1,91 @@ +# Protocol + +The Protocol module provides an abstract LLM API client interface and its OpenAI-compatible implementation. It bridges Twinkle's `Trajectory` / `SamplingParams` data types with external LLM inference services. + +## API Base Class + +```python +from abc import ABC, abstractmethod +from twinkle.data_format import Trajectory +from twinkle.data_format.message import Message +from twinkle.data_format.sampling import SamplingParams + +class API(ABC): + """Abstract LLM API client: Trajectory + SamplingParams -> assistant Message(s).""" + + @abstractmethod + def __call__( + self, + trajectory: Trajectory, + sampling_params: SamplingParams, + **kwargs, + ) -> Union[Message, List[Message]]: + raise NotImplementedError() +``` + +The `API` class defines a simple contract: given a conversation trajectory and sampling parameters, return one or more assistant messages. + +## OpenAI + +`OpenAI` is the built-in implementation that works with any endpoint speaking the `/v1/chat/completions` protocol (OpenAI, Azure OpenAI, vLLM, SGLang, Ollama, etc.). + +```python +from twinkle_agentic.protocol.openai import OpenAI + +api = OpenAI( + model='qwen3.5-32b', + base_url='http://localhost:8000/v1', + api_key='EMPTY', +) +``` + +### Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `model` | `str` | Model name to pass in the API request. | +| `api_key` | `str` | API key. Defaults to the `OPENAI_API_KEY` environment variable. | +| `base_url` | `str` | Base URL of the API endpoint (e.g. `http://localhost:8000/v1`). | +| `client_kwargs` | `Dict` | Extra keyword arguments forwarded to the `openai.OpenAI` client constructor. | + +### Usage + +```python +from twinkle.data_format import Trajectory +from twinkle.data_format.sampling import SamplingParams + +trajectory = { + 'messages': [ + {'role': 'user', 'content': 'What is the capital of France?'}, + ] +} + +sp = SamplingParams(temperature=0.7, max_tokens=512) +reply = api(trajectory, sp) +# reply is a Message dict: {'role': 'assistant', 'content': '...'} +``` + +### Features + +- **Tool calls**: Automatically maps `trajectory['tools']` to the API request and parses structured `tool_calls` from the response. +- **Reasoning content**: Preserves `reasoning_content` from models that support it (e.g., o1-style reasoning). +- **Finish reason**: Surfaces `finish_reason` on the returned message so multi-turn drivers can detect length-cap truncation. +- **Multi-sample**: When `sampling_params.num_samples > 1`, returns a list of messages (one per choice). + +### Custom API Client + +To integrate a non-OpenAI API, subclass `API`: + +```python +from twinkle_agentic.protocol.base import API + +class MyCustomAPI(API): + + def __call__(self, trajectory, sampling_params, **kwargs): + # Call your custom endpoint + response = my_llm_client.chat( + messages=trajectory['messages'], + temperature=sampling_params.temperature, + ) + return {'role': 'assistant', 'content': response.text} +``` diff --git a/docs/source_en/Components/Agentic/Rollout.md b/docs/source_en/Components/Agentic/Rollout.md new file mode 100644 index 000000000..94b143454 --- /dev/null +++ b/docs/source_en/Components/Agentic/Rollout.md @@ -0,0 +1,140 @@ +# Multi-Turn Rollout + +The Rollout module provides multi-turn conversation rollout engines for agentic RLHF training. Two implementations are available: `MultiTurnRollout` for batched vLLM sampling and `APIMultiTurnRollout` for OpenAI-compatible API endpoints. + +## Rollout Base Class + +```python +from abc import ABC, abstractmethod +from twinkle.data_format import Trajectory + +class Rollout(ABC): + + @abstractmethod + def __call__(self, trajectories: List[Trajectory], **kwargs) -> List[Trajectory]: + raise NotImplementedError() +``` + +All rollouts accept a list of trajectories and return the same number of trajectories with additional fields (`messages`, `turns`, `stop_reason`, `truncated`). + +## MultiTurnRollout + +Batched multi-turn rollout engine that uses a vLLM sampler for generation. All active trajectories are sampled in a single batched call per turn for maximum throughput. + +### Per-turn Loop + +1. Encode each trajectory into an `InputFeature` with a generation prompt +2. Batch `sampler.sample(active_pifs)` — all live trajectories in parallel +3. Check termination: `stop_reason == 'length'`, no tool calls, or max turns reached +4. Dispatch tools via `ToolManager`, append tool responses +5. Compute bridge tokens (tool turns + generation prompt) with `labels = -100` +6. Repeat until all trajectories are done + +```python +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout +from twinkle_agentic.tools.tool_manager import ToolManager +from twinkle.data_format.sampling import SamplingParams + +rollout = MultiTurnRollout( + sampler=vllm_sampler, + template=template, + tool_manager=tool_manager, + sampling_params=SamplingParams(temperature=0.7, max_tokens=4096), + max_turns=6, + max_trajectory_tokens=8192, + trace_dir='rollout_traces/', +) + +# Run rollout +results = rollout(trajectories) +``` + +### Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `sampler` | Sampler | vLLM sampler instance for batched generation. | +| `template` | `Template` | Chat template for encoding/decoding. | +| `tool_manager` | `ToolManager` | Tool dispatcher. Can also be passed per-call. | +| `sampling_params` | `SamplingParams` | Default sampling parameters. | +| `max_turns` | `int` | Maximum number of turns per trajectory (default: 6). | +| `max_trajectory_tokens` | `int` | Max total token length; exceeding truncates the trajectory. | +| `trace_dir` | `str` | Directory for per-trajectory JSON trace dumps. | +| `trace_callback` | `Callable` | Decides whether to store a trajectory trace. | +| `success_callback` | `Callable` | Decides filename prefix (`ok-` vs `fail-`). | + +### Output Fields + +Each output trajectory dict includes: + +| Field | Type | Description | +|-------|------|-------------| +| `messages` | `List[Dict]` | Full conversation including tool turns. | +| `input_ids` | `List[int]` | Token IDs of the full sequence. | +| `labels` | `List[int]` | Training labels (`-100` for non-trainable tokens). | +| `turns` | `int` | Number of turns performed. | +| `stop_reason` | `str` | `'stop'` / `'length'` | +| `truncated` | `bool` | Whether the trajectory was truncated. | +| `logprobs` | `List` | Per-token log probabilities (if available). | + +### Ray Remote Support + +`MultiTurnRollout` is decorated with `@remote_class()`, enabling transparent deployment as a Ray actor: + +```python +# The rollout can run as a Ray remote actor +rollout_actor = MultiTurnRollout.remote(sampler=sampler, template=template, ...) +results = ray.get(rollout_actor.__call__.remote(trajectories)) +``` + +## APIMultiTurnRollout + +Multi-turn rollout over an OpenAI-compatible chat-completions API. Each trajectory runs independently in a thread pool for network concurrency. + +```python +from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout +from twinkle_agentic.protocol.openai import OpenAI + +api = OpenAI(model='qwen3.5-32b', base_url='http://localhost:8000/v1') + +rollout = APIMultiTurnRollout( + api=api, + tool_manager=tool_manager, + sampling_params=SamplingParams(temperature=0.7), + max_turns=6, + concurrency=8, + trace_dir='api_traces/', +) + +results = rollout(trajectories) +``` + +### Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `api` | `OpenAI` | OpenAI-compatible API client. | +| `tool_manager` | `ToolManager` | Tool dispatcher (single or per-trajectory list). | +| `sampling_params` | `SamplingParams` | Default sampling parameters. | +| `max_turns` | `int` | Maximum turns per trajectory (default: 6). | +| `concurrency` | `int` | Thread pool size for parallel API calls (default: 8). | +| `extra_body` | `Dict` | Extra fields to include in API requests. | +| `trace_dir` | `str` | Directory for trace dumps. | + +### Stop Reasons + +| Reason | Description | +|--------|-------------| +| `stop` | Assistant responded without tool calls (natural end). | +| `length` | API returned `finish_reason='length'` (token limit). | +| `max_turns` | Reached `max_turns` limit. | +| `api_error` | API call or tool execution raised an exception. | + +## Choosing Between Rollouts + +| Feature | MultiTurnRollout | APIMultiTurnRollout | +|---------|-----------------|---------------------| +| **Backend** | vLLM sampler (local GPU) | OpenAI-compatible API | +| **Training integration** | Produces `input_ids` / `labels` for GRPO | Messages only (for data collection) | +| **Batching** | GPU-level batch parallelism | Network-level thread concurrency | +| **Use case** | Online RLHF training loop | Offline data generation / evaluation | diff --git a/docs/source_en/Components/Agentic/Tools.md b/docs/source_en/Components/Agentic/Tools.md new file mode 100644 index 000000000..e53c7371f --- /dev/null +++ b/docs/source_en/Components/Agentic/Tools.md @@ -0,0 +1,119 @@ +# Tools & ToolManager + +The Tools module provides an abstract tool interface and a central tool dispatcher (`ToolManager`) for agentic multi-turn rollouts. Tools follow the OpenAI function-calling schema for seamless integration with LLM tool-use capabilities. + +## Tool Base Class + +```python +from abc import ABC, abstractmethod +from twinkle.data_format import Tool as ToolInfo + +class Tool(ABC): + + @abstractmethod + def __call__(self, tool_name: str, arguments: Dict[str, Any]) -> str: + """Execute the tool and return a string result.""" + raise NotImplementedError + + @abstractmethod + def tool_info(self) -> ToolInfo: + """Return OpenAI-compatible tool schema.""" + raise NotImplementedError +``` + +### Implementing a Custom Tool + +```python +from twinkle_agentic.tools.base import Tool + +class SearchTool(Tool): + + def __call__(self, tool_name: str, arguments: dict) -> str: + query = arguments.get('query', '') + # Perform search logic + return f'Search results for: {query}' + + def tool_info(self): + return { + 'type': 'function', + 'function': { + 'name': 'search', + 'description': 'Search the web for information.', + 'parameters': { + 'type': 'object', + 'properties': { + 'query': { + 'type': 'string', + 'description': 'The search query.', + }, + }, + 'required': ['query'], + }, + }, + } +``` + +## ToolManager + +`ToolManager` is a registry and dispatcher for tools. It resolves tool calls from the LLM's structured output and routes them to the correct tool implementation. + +```python +from twinkle_agentic.tools.tool_manager import ToolManager + +# Initialize with a list of Tool instances +manager = ToolManager([search_tool, calculator_tool]) + +# Or with a dict +manager = ToolManager({'search': search_tool, 'calc': calculator_tool}) + +# Or register dynamically +manager = ToolManager() +manager.register(search_tool) +manager.register(calculator_tool) +``` + +### Key Methods + +| Method | Description | +|--------|-------------| +| `register(tool)` | Register a tool (name extracted from `tool_info()`). | +| `unregister(name)` | Remove a tool by name. | +| `names()` | List all registered tool names. | +| `copy()` | Create a shallow copy of the manager. | +| `tool_infos()` | Return a list of all tool schemas (for API requests). | +| `__call__(tool_call)` | Dispatch a tool call and return the result string. | + +### Dispatching Tool Calls + +`ToolManager` accepts OpenAI-shaped tool call dicts: + +```python +tool_call = { + 'id': 'call_1', + 'type': 'function', + 'function': { + 'name': 'search', + 'arguments': '{"query": "Python tutorials"}', + }, +} + +result = manager(tool_call) +# result: 'Search results for: Python tutorials' +``` + +**Error handling:** If the tool name is unknown, arguments are invalid JSON, or the tool raises an exception, `ToolManager` returns a descriptive error string instead of raising — this keeps the rollout loop running. + +### Integration with Rollout + +```python +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout + +rollout = MultiTurnRollout( + sampler=sampler, + template=template, + tool_manager=manager, # Pass manager to rollout + max_turns=6, +) +``` + +The rollout engine calls `manager(tool_call)` for each tool call generated by the model, and appends the result as a `{'role': 'tool', 'content': result}` message. diff --git a/docs/source_en/Components/Agentic/index.rst b/docs/source_en/Components/Agentic/index.rst new file mode 100644 index 000000000..802034366 --- /dev/null +++ b/docs/source_en/Components/Agentic/index.rst @@ -0,0 +1,11 @@ +Agentic +=============== +.. toctree:: + :maxdepth: 1 + + Preprocessor.md + Protocol.md + Rollout.md + Tools.md + Envs.md + Multi-Turn-Tool-Usage.md diff --git a/docs/source_en/Components/CLI/CLI.md b/docs/source_en/Components/CLI/CLI.md new file mode 100644 index 000000000..9efe3c894 --- /dev/null +++ b/docs/source_en/Components/CLI/CLI.md @@ -0,0 +1,134 @@ +# CLI + +The CLI module provides a unified configuration system for Twinkle training scripts. It merges multiple configuration sources (environment variables, `.env` files, YAML configs, and command-line arguments) into a single `Args` dataclass with typed argument groups. + +## Resolution Order + +Configuration sources are applied in order (later wins): + +1. **Dataclass defaults** — sensible out-of-the-box values +2. **`.env` file** — project-local overrides +3. **Environment variables** — `TWINKLE_` prefix or bare keys +4. **YAML config file** — `--config path/to/config.yaml` +5. **CLI overrides** — `--key value` (highest priority) + +All keys are case-insensitive and dash/underscore equivalent. + +## Quick Start + +```python +from twinkle.cli import CLI + +args = CLI.from_args() + +# Access typed groups +print(args.model.model_id) +print(args.training.max_steps) +print(args.optimizer.learning_rate) + +# Or get dictionaries for component construction +model_kwargs = args.get_model_args() +optimizer_kwargs = args.get_optimizer_args() +``` + +## Argument Groups + +| Group | Class | Key Parameters | +|:------|:------|:---------------| +| model | `ModelArgs` | `model_id`, `mixed_precision`, `strategy`, `gradient_checkpointing` | +| lora | `LoraArgs` | `use_lora`, `lora_r`, `lora_alpha`, `lora_target_modules` | +| dataset | `DatasetArgs` | `dataset_id`, `subset_name`, `split`, `streaming` | +| template | `TemplateArgs` | `template_cls`, `max_length`, `truncation_strategy`, `enable_thinking` | +| training | `TrainingArgs` | `max_steps`, `batch_size`, `micro_batch_size`, `output_dir`, `save_steps` | +| optimizer | `OptimizerArgs` | `optimizer_cls`, `learning_rate`, `weight_decay`, `max_grad_norm` | +| scheduler | `SchedulerArgs` | `scheduler_cls`, `num_warmup_steps`, `t_max` | +| loss | `LossArgs` | `loss_cls`, `epsilon`, `beta`, `sft_weight` | +| sampler | `SamplerArgs` | `sampler_type`, `gpu_memory_utilization`, `tensor_parallel_size` | +| sampling | `SamplingArgs` | `max_tokens`, `temperature`, `top_k`, `top_p`, `num_samples` | +| infra | `InfraArgs` | `mode`, `nproc_per_node`, `model_gpus`, `sampler_gpus`, `dp_size` | +| server | `ServerArgs` | `config`, `host`, `port`, `ray_namespace` | +| rl | `RLArgs` | `num_generations`, `advantage_type`, `reward_fns` | +| checkpoint | `CheckpointArgs` | `save_optimizer`, `merge_and_sync`, `platform` | + +## YAML Configuration + +```yaml +# config.yaml +model_id: ms://Qwen/Qwen3.5-4B +mixed_precision: bf16 +strategy: accelerate + +use_lora: true +lora_r: 16 +lora_alpha: 32 + +dataset_id: ms://swift/self-cognition +max_length: 4096 + +batch_size: 8 +micro_batch_size: 2 +max_steps: 200 +learning_rate: 1e-5 + +mode: ray +nproc_per_node: 8 +model_gpus: 4 +sampler_gpus: 4 +``` + +## Command-Line Usage + +```bash +# Use with YAML config +python train.py --config config.yaml + +# Override specific values +python train.py --config config.yaml --learning_rate 5e-6 --max_steps 500 + +# Boolean flags +python train.py --use_lora --no_gradient_checkpointing + +# Without config file (all from CLI) +python train.py --model_id ms://Qwen/Qwen3.5-4B --batch_size 4 +``` + +## Environment Variables + +```bash +# TWINKLE_ prefix +export TWINKLE_MODEL_ID=ms://Qwen/Qwen3.5-4B +export TWINKLE_LEARNING_RATE=1e-5 + +# Or bare keys (when recognized) +export MODEL_ID=ms://Qwen/Qwen3.5-4B +``` + +## Field Aliases + +Some fields support aliases for convenience: + +- `learning_rate` ↔ `lr` +- `nproc_per_node` ↔ `num_gpus` +- `max_tokens` ↔ `max_new_tokens` +- `use_megatron=true` → `strategy=native_fsdp` + +## Custom Config Sources + +You can extend the CLI with custom configuration sources: + +```python +from twinkle.cli.cli import ConfigSource, Args, ConfigResolver + +class RemoteConfigSource(ConfigSource): + def __init__(self, url: str): + self.url = url + + def load(self) -> dict: + import requests + return requests.get(self.url).json() + +# Apply custom source +args = Args() +resolver = ConfigResolver(args) +resolver.apply(RemoteConfigSource('http://config-server/my-config').load()) +``` diff --git "a/docs/source_zh/\347\273\204\344\273\266/Gym/index.rst" b/docs/source_en/Components/CLI/index.rst similarity index 76% rename from "docs/source_zh/\347\273\204\344\273\266/Gym/index.rst" rename to docs/source_en/Components/CLI/index.rst index 85d941b97..cf59fa766 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/Gym/index.rst" +++ b/docs/source_en/Components/CLI/index.rst @@ -1,6 +1,6 @@ -Gym +CLI =============== .. toctree:: :maxdepth: 1 - Gym.md + CLI.md diff --git a/docs/source_en/Components/Gym/Gym.md b/docs/source_en/Components/Gym/Gym.md deleted file mode 100644 index 4db355b8a..000000000 --- a/docs/source_en/Components/Gym/Gym.md +++ /dev/null @@ -1,26 +0,0 @@ -# Gym - -The Gym component provides an interface for reinforcement learning environments in Twinkle. - -```python -from twinkle.gym import Gym - -class CustomGym(Gym): - - def step(self, trajectories, **kwargs): - """ - Execute one RL step: evaluate trajectories and return rewards. - - Args: - trajectories: Model-generated trajectories to evaluate - **kwargs: Additional arguments - - Returns: - Reward values for each trajectory - """ - ... -``` - -The Gym abstraction allows you to plug in custom RL environments that interact with the training loop. It decouples reward computation and environment interaction from the core training logic. - -> Gym is typically used in on-policy RL training where the environment needs to provide feedback on model-generated outputs. diff --git a/docs/source_en/Components/Loss/InfoNCELoss.md b/docs/source_en/Components/Loss/InfoNCELoss.md new file mode 100644 index 000000000..1c6cada5d --- /dev/null +++ b/docs/source_en/Components/Loss/InfoNCELoss.md @@ -0,0 +1,68 @@ +# InfoNCE Loss + +The `InfonceLoss` implements contrastive learning with in-batch negatives and optional cross-rank gathering. It is designed for embedding/retrieval model training. + +## Usage + +```python +from twinkle.loss import InfonceLoss + +loss_fn = InfonceLoss( + temperature=0.1, + use_batch=True, # Enable in-batch negatives + hard_negatives=7, # Fix negative count per sample + mask_fake_negative=True, # Mask false negatives + fake_neg_margin=0.1, # Margin for false negative detection +) + +model.set_loss(loss_fn) +``` + +## Input Format + +Each sample is laid out as `anchor(1) + positive(1) + negatives(n)` in a flat embedding tensor. The `inputs['labels']` is a 1-D mask where `1` marks the start of each group. + +``` +embeddings: [a0, p0, n0_1, n0_2, a1, p1, n1_1, n1_2, ...] +labels: [ 1, 0, 0, 0, 1, 0, 0, 0, ...] +``` + +## Parameters + +| Parameter | Type | Default | Description | +|:----------|:-----|:--------|:------------| +| `temperature` | float | 0.1 | Logit scaling factor | +| `use_batch` | bool | True | Use cross-sample in-batch negatives | +| `hard_negatives` | int | None | Fix per-sample negative count (truncate/upsample) | +| `mask_fake_negative` | bool | False | Mask logits > positive + margin | +| `fake_neg_margin` | float | 0.1 | Threshold for false negative masking | +| `include_qq` | bool | False | Add query-query similarity block | +| `include_dd` | bool | False | Add doc-doc similarity block | + +## Cross-Rank Gathering + +When `use_batch=True` and distributed training is active, embeddings are gathered from all DP ranks to maximize in-batch negative diversity. Only the local shard retains gradients. + +## Similarity Blocks + +The loss supports three similarity blocks for comprehensive contrastive learning: + +- **Q→D (default)**: Query to all documents — primary contrastive signal +- **Q→Q** (`include_qq=True`): Query to all other queries — prevents query collapse +- **D→D** (`include_dd=True`): Document to all other documents — Qwen3-Embedding style + +## Example: Embedding Training + +```python +from twinkle.loss import InfonceLoss +from twinkle.metric import EmbeddingMetric + +# Configure model for embedding +model.set_loss(InfonceLoss(temperature=0.05, use_batch=True, include_qq=True)) +model.set_metric(EmbeddingMetric(device_mesh=mesh, process_group=pg)) + +# Training loop +for batch in dataloader: + model.forward_backward(batch) + model.clip_grad_and_step() +``` diff --git a/docs/source_en/Components/Loss/index.rst b/docs/source_en/Components/Loss/index.rst index dceaf20f0..c6f3cb2c9 100644 --- a/docs/source_en/Components/Loss/index.rst +++ b/docs/source_en/Components/Loss/index.rst @@ -3,6 +3,19 @@ Loss .. toctree:: :maxdepth: 1 + CrossEntropy.md + ChunkedCrossEntropy.md + DPOLoss.md + GKDLoss.md + GRPOLoss.md + InfoNCELoss.md + MSELoss.md + Building-Loss.md +Loss +=============== +.. toctree:: + :maxdepth: 1 + CrossEntropy.md ChunkedCrossEntropy.md DPOLoss.md diff --git a/docs/source_en/Components/Metrics/EmbeddingMetric.md b/docs/source_en/Components/Metrics/EmbeddingMetric.md new file mode 100644 index 000000000..e8cea06ba --- /dev/null +++ b/docs/source_en/Components/Metrics/EmbeddingMetric.md @@ -0,0 +1,42 @@ +# EmbeddingMetric + +The `EmbeddingMetric` tracks embedding quality during contrastive (InfoNCE) training. It reports anchor-positive cosine similarity statistics and in-batch negative similarity. + +## Usage + +```python +from twinkle.metric import EmbeddingMetric + +metric = EmbeddingMetric(device_mesh=device_mesh, process_group=process_group) + +# During training +metric.accumulate(inputs, outputs) + +# At log interval +results = metric.calculate() +# results: { +# 'pos_sim': '0.8523', # Mean anchor-positive cosine similarity +# 'pos_sim_min': '0.7102', # Min across batch +# 'pos_sim_max': '0.9451', # Max across batch +# 'neg_sim': '0.2134', # Mean anchor-negative similarity +# 'loss': '0.3412', # Average InfoNCE loss +# 'grad_norm': '1.234567', # Gradient norm +# } +``` + +## Reported Metrics + +| Metric | Description | +|:-------|:------------| +| `pos_sim` | Mean cosine similarity between anchors and their positives | +| `pos_sim_min` | Minimum anchor-positive similarity in the batch | +| `pos_sim_max` | Maximum anchor-positive similarity in the batch | +| `neg_sim` | Mean similarity between anchors and other positives (in-batch negatives) | +| `loss` | Average contrastive loss value | +| `grad_norm` | Gradient norm (passed via kwargs) | + +## Cross-Rank Gathering + +`EmbeddingMetric` performs an `all_gather` to compute similarity statistics across all DP ranks, providing a global view of embedding quality even under data-parallel training. + +> This metric pairs with `InfonceLoss` for embedding/retrieval training tasks. diff --git a/docs/source_en/Components/Metrics/GRPOMetric.md b/docs/source_en/Components/Metrics/GRPOMetric.md new file mode 100644 index 000000000..cd4f11551 --- /dev/null +++ b/docs/source_en/Components/Metrics/GRPOMetric.md @@ -0,0 +1,66 @@ +# GRPOMetric + +The `GRPOMetric` tracks policy optimization diagnostics during GRPO training, including KL divergence, clipping rates, entropy, and log-probability statistics. + +## Usage + +```python +from twinkle.metric import GRPOMetric + +metric = GRPOMetric( + device_mesh=device_mesh, + process_group=process_group, + epsilon=0.2, # PPO clip range + temperature=1.0, # Sampling temperature for logp rescaling + top_k_kl=10, # Track top-K high-KL tokens per step +) + +# During training loop +metric.accumulate(inputs, outputs, old_logps=old_logps, advantages=advantages) + +# At log interval +results = metric.calculate() +# results: { +# 'train/policy_confidence': 0.85, +# 'train/mean_new_logp': -1.23, +# 'train/mean_old_logp': -1.30, +# 'train/logp_diff_mean': 0.07, +# 'train/approx_kl': 0.003, +# 'train/token_kl_max': 0.15, +# 'train/entropy': 2.1, +# 'train/clip_ratio': 0.02, +# 'train/clip_ratio_low': 0.01, +# 'train/clip_ratio_high': 0.01, +# } +``` + +## Reported Metrics + +| Metric | Description | +|:-------|:------------| +| `train/policy_confidence` | exp(mean_new_logp) — higher means model is more confident | +| `train/mean_new_logp` | Average log-probability of generated tokens under current policy | +| `train/mean_old_logp` | Average log-probability under reference policy | +| `train/logp_diff_mean` | Mean (new - old) log-probability difference | +| `train/approx_kl` | Schulman K3 estimator of KL(old \|\| new) | +| `train/token_kl_max` | Maximum per-token KL across all ranks | +| `train/token_ratio_max` | Maximum importance weight across all ranks | +| `train/entropy` | Average token-level entropy | +| `train/clip_ratio` | Fraction of tokens clipped (low + high) | +| `train/clip_ratio_low` | Fraction clipped below (ratio < 1-ε, negative advantage) | +| `train/clip_ratio_high` | Fraction clipped above (ratio > 1+ε, positive advantage) | + +## Variants + +- **`GSPOMetric`** — Computes clip rate at sequence level (geometric-mean ratio per sequence) +- **`CISPOMetric`** — Unconditional clip rate (not gated by advantage sign) + +## Parameters + +| Parameter | Type | Default | Description | +|:----------|:-----|:--------|:------------| +| `epsilon` | float | 0.2 | Lower clip boundary | +| `epsilon_high` | float | None | Upper clip boundary (defaults to epsilon) | +| `temperature` | float | 1.0 | Rescale logps to T=1 before computing KL | +| `top_k_kl` | int | 0 | If > 0, record top-K high-KL token details | +| `ignore_index` | int | -100 | Label value to mask out | diff --git a/docs/source_en/Components/Metrics/index.rst b/docs/source_en/Components/Metrics/index.rst index 5d50e183b..68215482d 100644 --- a/docs/source_en/Components/Metrics/index.rst +++ b/docs/source_en/Components/Metrics/index.rst @@ -8,4 +8,6 @@ Metrics Accuracy.md CompletionRewardMetric.md DPOMetric.md + GRPOMetric.md + EmbeddingMetric.md Building-Metrics.md diff --git a/docs/source_en/Components/Model/MultiLoraTransformersModel.md b/docs/source_en/Components/Model/MultiLoraTransformersModel.md index c196f900b..0c78a6886 100644 --- a/docs/source_en/Components/Model/MultiLoraTransformersModel.md +++ b/docs/source_en/Components/Model/MultiLoraTransformersModel.md @@ -30,3 +30,48 @@ The reason for the existence of max_loras and max_r parameters is that Twinkle's Because of this, the user's r must be less than or equal to the max_r configuration. During actual training, only part of the lora's rank will be used in the calculation. MultiLoraTransformersModel supports the `@remote_class` annotation and supports device_mesh, which means it can run in Ray workers. + +## Tenant Lifecycle + +Under the hood, `MultiLoraTransformersModel` uses the `MultiLora` manager to handle tenant LoRA slots. The key APIs: + +### acquire_lora + +Claim an available LoRA slot for a tenant: + +```python +adapter_name = model.multi_lora.acquire_lora('tenant_a', LoraConfig(r=16, lora_alpha=32)) +``` + +- Raises `RuntimeError` if all slots are in use or `config.r > max_r` + +### release_lora + +Release a tenant's LoRA slot, resetting weights to initial state: + +```python +model.multi_lora.release_lora('tenant_a') +``` + +### Context Manager + +Use `adapter()` for scoped activation: + +```python +with model.multi_lora.adapter('tenant_a') as name: + output = model.forward(inputs) +``` + +### LoraTenant + +Each slot is tracked as a `LoraTenant` dataclass: + +```python +@dataclass +class LoraTenant: + index: int # Slot index (0..max_loras-1) + adapter_name: str # Internal name (e.g. "lora_0") + config: LoraConfig # Pre-allocated config (max_r) + tenant_adapter_name: str # User-facing tenant name (None if free) + tenant_config: LoraConfig # Tenant's actual config (None if free) +``` diff --git a/docs/source_en/Components/Model/SupportedModels.md b/docs/source_en/Components/Model/SupportedModels.md new file mode 100644 index 000000000..7cd9e8b4d --- /dev/null +++ b/docs/source_en/Components/Model/SupportedModels.md @@ -0,0 +1,79 @@ +# Supported Models + +Twinkle supports any model compatible with HuggingFace Transformers or Megatron-LM. Below is a curated list of models tested with Twinkle. + +## Language Models + +| Model Family | Model IDs | Parameters | Features | +|:-------------|:----------|:-----------|:---------| +| Qwen 3.5 | `Qwen/Qwen3.5-0.6B` ~ `Qwen/Qwen3.5-235B-A22B` | 0.6B–235B | MoE, Thinking mode | +| Qwen 2.5 | `Qwen/Qwen2.5-0.5B` ~ `Qwen/Qwen2.5-72B` | 0.5B–72B | Dense | +| DeepSeek V4 | `deepseek-ai/DeepSeek-V4` | 685B MoE | Custom DSML encoding | +| DeepSeek R1 | `deepseek-ai/DeepSeek-R1` | 685B MoE | Reasoning | +| LLaMA 3 | `meta-llama/Llama-3.3-70B-Instruct` | 8B–70B | Dense | +| Mistral | `mistralai/Mistral-7B-v0.3` | 7B | Dense | +| Yi | `01-ai/Yi-1.5-34B` | 6B–34B | Dense | +| GLM-4 | `THUDM/glm-4-9b-chat` | 9B | Dense | +| InternLM 2.5 | `internlm/internlm2_5-7b-chat` | 7B–20B | Dense | + +## Vision-Language Models + +| Model Family | Model IDs | Features | +|:-------------|:----------|:---------| +| Qwen 3.5 VL | `Qwen/Qwen3.5-VL-3B` ~ `Qwen/Qwen3.5-VL-72B` | Image, Video | +| Qwen 2.5 VL | `Qwen/Qwen2.5-VL-7B-Instruct` | Image, Video | +| InternVL 2.5 | `OpenGVLab/InternVL2_5-8B` | Image | + +## Embedding Models + +| Model Family | Model IDs | Training Method | +|:-------------|:----------|:----------------| +| Qwen3 Embedding | `Qwen/Qwen3-Embedding-0.6B` | InfoNCE contrastive | +| GTE | `thenlper/gte-large-zh` | InfoNCE contrastive | + +## Model Loading + +Models can be loaded from ModelScope or HuggingFace: + +```python +from twinkle.model import TransformersModel + +# From ModelScope (ms:// prefix) +model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + +# From HuggingFace (hf:// prefix) +model = TransformersModel(model_id='hf://meta-llama/Llama-3.3-70B-Instruct') + +# Local path +model = TransformersModel(model_id='/path/to/model') +``` + +## Framework Support + +| Framework | Class | Use Case | +|:----------|:------|:---------| +| Transformers | `TransformersModel` | General training (SFT, RLHF, DPO) | +| Transformers + Multi-LoRA | `MultiLoraTransformersModel` | Multi-tenant training | +| Megatron-LM | `MegatronModel` | Large-scale distributed pre-training | +| Megatron + Multi-LoRA | `MultiLoraMegatronModel` | Large-scale multi-tenant | + +## Precision Support + +| Mode | Description | +|:-----|:------------| +| `bf16` | BFloat16 mixed precision (recommended for A100/H100) | +| `fp16` | Float16 mixed precision (for older GPUs) | +| `fp8` | FP8 precision (H100 with Transformer Engine) | +| `no` | Full precision (debugging only) | + +## Parallelism Strategies + +| Strategy | Config Key | Description | +|:---------|:-----------|:------------| +| FSDP | `strategy=accelerate` | Accelerate-managed FSDP (default) | +| Native FSDP | `strategy=native_fsdp` | PyTorch native FSDP | +| Tensor Parallel | `tp_size` | Split layers across GPUs | +| Pipeline Parallel | `pp_size` | Split model stages | +| Data Parallel | `dp_size` | Replicate model, split data | +| Sequence Parallel | `sequence_parallel` | Split long sequences | +| Expert Parallel | `ep_size` | MoE expert distribution | diff --git a/docs/source_en/Components/Model/index.rst b/docs/source_en/Components/Model/index.rst index e0648f00f..4802cd0d3 100644 --- a/docs/source_en/Components/Model/index.rst +++ b/docs/source_en/Components/Model/index.rst @@ -8,3 +8,14 @@ Model MultiLoraTransformersModel.md MegatronModel.md MultiLoraMegatronModel.md + SupportedModels.md +Model +=============== +.. toctree:: + :maxdepth: 1 + + TwinkleModel.md + TransformersModel.md + MultiLoraTransformersModel.md + MegatronModel.md + MultiLoraMegatronModel.md diff --git a/docs/source_en/Components/Notifier/Notifier.md b/docs/source_en/Components/Notifier/Notifier.md new file mode 100644 index 000000000..c4e14b511 --- /dev/null +++ b/docs/source_en/Components/Notifier/Notifier.md @@ -0,0 +1,93 @@ +# Notifier + +The Notifier component provides a pluggable notification system for sending alerts during training. When exceptions occur or training events need attention, notifiers deliver messages to external channels (e.g., DingTalk webhooks). + +## Base Interface + +```python +from twinkle.notifier import Notifier + +class Notifier: + def __call__(self, message: str): + """Send a notification message.""" + ... + + def to_dict(self) -> dict: + """Serialize for checkpoint/restore.""" + ... + + @classmethod + def from_dict(cls, data: dict) -> Notifier: + """Restore from serialized form.""" + ... +``` + +## DingNotifier + +Sends notifications to DingTalk (钉钉) custom robot webhooks. + +```python +from twinkle.notifier import DingNotifier + +notifier = DingNotifier( + ding_url='https://oapi.dingtalk.com/robot/send?access_token=xxx', + secret='SECxxxxxxx', # Optional: for signed robots + timeout=5.0, +) + +# Send a message +notifier("### Training Complete\n\n- Steps: 1000\n- Loss: 0.25") +``` + +**Parameters:** +- `ding_url`: Full DingTalk webhook URL with access token +- `secret`: Optional signing secret for signed-robot mode +- `timeout`: HTTP request timeout in seconds (default: 5.0) + +Messages are sent as DingTalk **Markdown** format. The first heading line is extracted as the chat preview title. + +## Exception Notifications + +Twinkle provides automatic exception notification with deduplication: + +```python +from twinkle.notifier.base import notify_exception + +# Automatically sends formatted exception info +# Only one rank sends per unique exception (prevents flooding) +try: + model.forward_backward(batch) +except Exception as e: + notify_exception(notifier, context='forward_backward', exc=e, name='sft_train') +``` + +The notification includes: +- Exception type and message +- Full traceback +- Runtime metadata (rank, PID, hostname) +- Deduplication: only one notification per unique exception across all ranks + +## Custom Notifier + +Create custom notifiers by subclassing `Notifier`: + +```python +from twinkle.notifier import Notifier + +class SlackNotifier(Notifier): + def __init__(self, webhook_url: str): + self.webhook_url = webhook_url + + def __call__(self, message: str): + import requests + requests.post(self.webhook_url, json={'text': message}) + + def to_dict(self): + return {'class': 'SlackNotifier', 'webhook_url': self.webhook_url} + + @classmethod + def _from_dict_impl(cls, data): + return cls(webhook_url=data['webhook_url']) +``` + +> Notifiers are registered automatically via `__init_subclass__`, so `Notifier.from_dict()` can restore any subclass by name. diff --git a/docs/source_en/Components/Notifier/index.rst b/docs/source_en/Components/Notifier/index.rst new file mode 100644 index 000000000..ff82117d4 --- /dev/null +++ b/docs/source_en/Components/Notifier/index.rst @@ -0,0 +1,6 @@ +Notifier +=============== +.. toctree:: + :maxdepth: 1 + + Notifier.md diff --git a/docs/source_en/Components/TUI/Auto-Research.md b/docs/source_en/Components/TUI/Auto-Research.md new file mode 100644 index 000000000..bd27a9705 --- /dev/null +++ b/docs/source_en/Components/TUI/Auto-Research.md @@ -0,0 +1,313 @@ +# Auto-Research (TUI) + +Twinkle TUI is a terminal-based intelligent training assistant that lets you **control, monitor, and debug ML training through natural language**. It combines a chat-driven AI agent with real-time metrics visualization, log streaming, and an automated health monitor that can detect and fix training failures autonomously. + +## Architecture Overview + +``` +┌──────────────────────────────────────────────────────────┐ +│ TwinkleTUI (Textual App) │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ StatusBar: state / run_id / model / step / progress │ │ +│ ├──────────────────────┬───────────────────────────────┤ │ +│ │ MetricsPanel │ LogPanel │ │ +│ │ (ASCII chart) │ (scrolling logs) │ │ +│ ├──────────────────────┤ │ │ +│ │ ChatPanel │ │ │ +│ │ (user <-> agent) │ │ │ +│ └──────────────────────┴───────────────────────────────┘ │ +│ │ +│ Background Services: │ +│ AgentLoop ─── LLM tool-calling loop │ +│ TrainingMonitor ─── periodic health check & auto-fix │ +│ MetricsPoller ─── incremental metrics reading │ +│ LogsPoller ─── incremental log tailing │ +│ SkillsLoader ─── async plugin loading │ +└──────────────────────────────────────────────────────────┘ +``` + +## Installation & Launch + +TUI is part of the `twinkle-client` package: + +```bash +pip install twinkle-client +``` + +### Command-Line Usage + +```bash +# Basic launch (uses default local Ollama endpoint) +twinkle-tui + +# Specify LLM backend +twinkle-tui --llm-base-url http://localhost:11434/v1 --llm-model qwen3.5 + +# Attach to an existing training run +twinkle-tui --run-id my-grpo-run + +# Use a remote API (e.g., OpenAI-compatible) +twinkle-tui --llm-base-url https://api.example.com/v1 --llm-api-key sk-xxx --llm-model gpt-4o + +# Enable debug logging +twinkle-tui --verbose +``` + +Or run as a Python module: + +```bash +python -m twinkle_client.tui +``` + +### CLI Options + +| Option | Env Var | Default | Description | +|--------|---------|---------|-------------| +| `--run-id`, `-r` | `TWINKLE_TUI_RUN_ID` | None | Attach to an existing training run | +| `--llm-base-url` | `TWINKLE_LLM_BASE_URL` | `http://localhost:11434/v1` | LLM API base URL | +| `--llm-model` | `TWINKLE_LLM_MODEL` | `qwen3.5` | LLM model name | +| `--llm-api-key` | `TWINKLE_LLM_API_KEY` | `not-needed` | LLM API key | +| `--verbose`, `-v` | `TWINKLE_TUI_VERBOSE` | `False` | Enable DEBUG logging | +| `--version`, `-V` | — | — | Show version and exit | + +### Keyboard Shortcuts + +| Key | Action | +|-----|--------| +| `q` | Quit | +| `Ctrl+P` | Toggle metrics panel | +| `Ctrl+L` | Clear logs | + +## Chat Agent + +The core of TUI is an **LLM-powered tool-calling agent** (`AgentLoop`) that processes natural language commands through an OpenAI-compatible API. The agent maintains conversation history with automatic pruning (last 50 messages) and supports up to 10 tool-calling rounds per interaction. + +### What You Can Say + +**Training lifecycle:** +- *"List my training runs"* +- *"Start a new GRPO training with Qwen3.5-4B on gsm8k"* +- *"Pause the current run"* +- *"Resume training"* +- *"Stop training"* + +**Server management:** +- *"Start the server with Qwen3.5-4B and a Qwen3.5-72B sampler on 2 GPUs"* +- *"Shut down the server"* +- *"How many GPUs are available?"* + +**Monitoring & analysis:** +- *"How is the training going?"* +- *"Show me the reward-related metrics"* +- *"Zoom into steps 100-200"* +- *"Reset the chart view"* + +**Search:** +- *"Search for math datasets"* +- *"Find Qwen models on ModelScope"* + +### Available Tools + +The agent has access to 13 built-in tools: + +| Tool | Description | +|------|-------------| +| `list_training_runs` | List all training runs | +| `get_training_status` | Get detailed status and recent metrics | +| `start_server` | Start Ray cluster + Twinkle Server (idempotent) | +| `shutdown_server` | Shut down server and release GPU resources | +| `start_training` | Create and launch a new training run | +| `select_run` | Switch monitoring to a different run | +| `pause_training` | Pause training (SIGKILL, server retains state) | +| `resume_training` | Resume by re-launching the client script | +| `stop_training` | Stop training (SIGTERM, saves checkpoint) | +| `update_script` | Update training script with version archiving | +| `list_supported_models` | Query server for available models | +| `search_datasets` | Search ModelScope for datasets | +| `search_models` | Search ModelScope for models | +| `zoom_metrics` | Adjust metrics chart view range | +| `select_metrics` | Choose which metrics to display (max 4) | +| `get_cluster_info` | Get GPU/cluster resource info | + +### Server Startup + +The `start_server` tool automates a multi-step pipeline: + +1. **GPU detection** — `nvidia-smi` hardware scan +2. **GPU allocation** — partition GPUs between training model and samplers +3. **Config generation** — auto-create `server_config.yaml` +4. **Ray cluster startup** — multi-node GPU partitioning with isolated `CUDA_VISIBLE_DEVICES` +5. **Server launch** — start Twinkle Server as background process +6. **Health check** — poll `/api/v1/healthz` until ready + +Multi-model topology is supported: 1 training model + N sampler/teacher models. + +### Skills System + +TUI supports extensible skill plugins loaded from three sources: + +1. **Bundled skills** — shipped inside `twinkle_client/skills/bundled/` +2. **User-local skills** — `~/.cache/twinkle/tui/skills/local/` +3. **Community skills** — fetched from ModelScope (best-effort, 10s timeout) + +Skills are loaded asynchronously after startup and injected into the agent's system prompt. The agent is usable immediately even before skills finish loading. + +## Training Monitor (Auto-Fix) + +The `TrainingMonitor` is a background service that runs every **30 seconds**, collecting all available signals about the current training run and feeding them to the LLM for analysis. + +### Collected Signals + +- **Process status**: alive / dead / unknown +- **output.log tail**: last 1500 chars (prioritizes tracebacks) +- **Metrics**: recent entries + first-half vs second-half trend analysis +- **Stall duration**: seconds since last metric was produced +- **Current train.py**: full script source (for accurate fixes) + +### Decision Framework + +The LLM classifies each check into one of three actions: + +| Decision | When | Action | +|----------|------|--------| +| **LGTM** | Training progressing normally | No action | +| **WARNING** | Loss plateau, reward hacking, KL explosion, etc. | Relay observation to user | +| **FIX** | Script crashed, process dead with traceback | Auto-fix and restart | + +### Auto-Fix Pipeline + +When a FIX is needed: + +1. LLM outputs diagnosis + complete fixed script +2. Monitor archives the old `train.py` as `train_v{N}.py` +3. Writes the fixed script as the new `train.py` +4. Re-launches training via `resume_training` +5. Resets stall tracking for the new attempt + +Safety guardrails: +- Max **3 auto-fix attempts** per run (prevents infinite retry loops) +- Fix attempts are tracked per `run_id` +- Snapshot deduplication avoids re-analyzing unchanged states + +## File-Based Connection + +TUI communicates with training processes through the local filesystem: + +``` +~/.cache/twinkle/{run_id}/ +├── meta.json — run metadata (model_id, config, status, pid) +├── metrics.jsonl — one JSON object per step (incremental) +├── output.log — combined stdout+stderr from training +├── train.py — current active training script +└── train_v{N}.py — archived previous script versions +``` + +### Training Control Model + +In Server Mode, the Twinkle Server retains all model/optimizer state in GPU memory: + +- **Pause** = kill client process (SIGKILL) — server state preserved +- **Resume** = re-launch client script — seamlessly continues training +- **Stop** = SIGTERM — triggers checkpoint saving then exits +- **Shut down server** = releases GPU resources, **destroys** model state + +## TrainingRuntime (Script Integration) + +Training scripts use `TrainingRuntime` to integrate with TUI: + +```python +from twinkle_client.tui.runtime import TrainingRuntime + +rt = TrainingRuntime(run_id='my-grpo-run') +rt.start(model_id='Qwen/Qwen3.5-4B', config={'lr': 1e-5}) +rt.register_graceful_shutdown(model, dataloader) + +for step, batch in enumerate(dataloader): + # ... training logic ... + rt.log_metrics(step=step, loss=loss, reward=reward, grad_norm=gn, lr=lr) + rt.log(f'Completed step {step}, loss={loss:.4f}') + +rt.finish() +``` + +### Key Methods + +| Method | Description | +|--------|-------------| +| `start(model_id, config, script_path)` | Initialize run directory and metadata | +| `log_metrics(**kwargs)` | Write metrics entry to `metrics.jsonl` | +| `log(message)` | Print log message (captured as `output.log`) | +| `get_resume_info()` | Get `last_step` for resuming from checkpoint | +| `finish(status)` | Mark training as finished, close files | +| `register_graceful_shutdown(model, dataloader)` | Register SIGTERM handler that saves checkpoint | + +### Resume Support + +`TrainingRuntime` automatically saves training progress to `meta.json` (throttled to every 5 seconds). Scripts can use `get_resume_info()` to resume from the last saved step: + +```python +rt = TrainingRuntime(run_id='my-run') +resume = rt.get_resume_info() +global_step = resume['last_step'] + +if global_step > 0: + dataloader.skip_consumed_samples(global_step * BATCH_SIZE) + print(f'Resuming from step {global_step}') +``` + +### Graceful Shutdown + +When `register_graceful_shutdown()` is called, a SIGTERM handler is installed that: + +1. Saves model checkpoint (LoRA weights + optimizer state) +2. Saves dataloader position (`consumed_train_samples`) +3. Logs the checkpoint path +4. Marks training as `stopped` and exits + +## UI Panels + +### StatusBar + +Displays current training state at the top of the screen: + +- Training state icon (🚀 Training / ⏸ Paused / ✅ Done / ❌ Error) +- Run ID +- Model name +- Current step +- Progress bar with percentage + +### MetricsPanel + +Real-time ASCII chart rendered with `plotext`: + +- Plots up to 4 metrics simultaneously +- Supports zoom (by step range and y-axis range) +- Auto-selects first 3 available metrics if no selection +- Hint bar shows hidden metrics that can be switched via agent +- Retains up to 2000 data points + +### LogPanel + +Scrolling log viewer: + +- Strips ANSI escape sequences for clean display +- Hard-wraps long lines to prevent overflow +- Handles `\r` carriage returns from progress bars +- Retains last 500 lines + +### ChatPanel + +Interactive chat interface: + +- User input with streaming agent responses +- Throttled token flushing (80ms) for smooth display +- Stream reset on tool-call detection +- Supports Rich markup formatting + +## Logging + +All TUI logs are written to `./tui.log` (current working directory): + +- Rotated at 5MB with 3 backups +- **No console output** — avoids corrupting Textual's alt-screen buffer +- Use `--verbose` for DEBUG level logging diff --git a/docs/source_en/Components/TUI/SkillProvider.md b/docs/source_en/Components/TUI/SkillProvider.md new file mode 100644 index 000000000..d008cf978 --- /dev/null +++ b/docs/source_en/Components/TUI/SkillProvider.md @@ -0,0 +1,71 @@ +# SkillProvider + +The skill system allows Twinkle's TUI agent to dynamically load specialized knowledge from external sources (Git repos, APIs, local files) and inject them into the LLM's system prompt. + +## Architecture + +| Class | Role | +|-------|------| +| **Skill** | Dataclass holding a single skill's name, content, and source | +| **SkillProvider** | Abstract base class for fetching skills from a source | +| **SkillManager** | Orchestrates multiple providers, aggregates skills for prompt injection | + +## Skill Dataclass + +```python +@dataclasses.dataclass +class Skill: + name: str # Short identifier (typically filename without extension) + content: str # Full markdown content + source: str # Provider name + relative path for traceability +``` + +## Creating a Custom Provider + +Subclass `SkillProvider` and implement `name` and `fetch()`: + +```python +from twinkle_client.skills.base import SkillProvider + +class MySkillProvider(SkillProvider): + + @property + def name(self) -> str: + return 'my-skills' + + async def fetch(self) -> None: + # Download/clone skill files to self.cache_dir + # e.g., git clone, API download, file copy + ... +``` + +The default `load_skills()` scans `self.cache_dir` for `.md` files (skipping README, LICENSE, etc.) and returns `Skill` objects. + +## SkillManager + +```python +from twinkle_client.skills.manager import SkillManager + +manager = SkillManager() +manager.register(my_provider) +manager.register(another_provider) + +# Fetch and load all skills +skills = await manager.load_all() + +# Format for LLM system prompt injection +prompt_section = manager.format_for_prompt() +``` + +### Key Methods + +| Method | Description | +|--------|-------------| +| `register(provider)` | Add a skill provider | +| `load_all()` | Fetch + load from all providers | +| `format_for_prompt()` | Render skills as formatted text for system prompt | +| `get_skill_names()` | List names of loaded skills | + +## Cache Directory + +By default, skills are cached at `~/.cache/twinkle/tui/skills//`. Override by passing `cache_dir` to the provider constructor. diff --git a/docs/source_en/Components/TUI/index.rst b/docs/source_en/Components/TUI/index.rst new file mode 100644 index 000000000..29cdad073 --- /dev/null +++ b/docs/source_en/Components/TUI/index.rst @@ -0,0 +1,7 @@ +TUI +=============== +.. toctree:: + :maxdepth: 1 + + Auto-Research.md + SkillProvider.md diff --git a/docs/source_en/Components/Task Processor/GRPOProcessor.md b/docs/source_en/Components/Task Processor/GRPOProcessor.md deleted file mode 100644 index adff73c45..000000000 --- a/docs/source_en/Components/Task Processor/GRPOProcessor.md +++ /dev/null @@ -1,19 +0,0 @@ -# GRPOLossProcessor - -GRPOLossProcessor is a task processor wrapper designed for GRPO reinforcement learning training. It extends InputProcessor with GRPO-specific data preparation. - -```python -from twinkle.processor import GRPOLossProcessor - -processor = GRPOLossProcessor( - device_mesh=..., - padding_free=False, - framework='transformers', -) - -model.set_processor(processor) -``` - -GRPOLossProcessor wraps the base `InputProcessor` and adds handling for GRPO-specific fields such as advantages, old log-probabilities, and reference log-probabilities that are required by the GRPO loss function. - -> For standard SFT tasks, use `InputProcessor` directly. Use `GRPOLossProcessor` when your training loop involves GRPO or its variants. diff --git a/docs/source_en/Components/Task Processor/index.rst b/docs/source_en/Components/Task Processor/index.rst index 1f20fdbca..1e9d600a4 100644 --- a/docs/source_en/Components/Task Processor/index.rst +++ b/docs/source_en/Components/Task Processor/index.rst @@ -4,4 +4,3 @@ Task Processor :maxdepth: 1 InputProcessor.md - GRPOProcessor.md diff --git a/docs/source_en/Components/Template/DeepSeekV4Template.md b/docs/source_en/Components/Template/DeepSeekV4Template.md new file mode 100644 index 000000000..bbd74928e --- /dev/null +++ b/docs/source_en/Components/Template/DeepSeekV4Template.md @@ -0,0 +1,56 @@ +# DeepSeek-V4 Template + +The `DeepseekV4Template` provides native support for DeepSeek V4's custom chat template encoding, including its unique thinking mode, tool-call protocol, and multi-token special tokens. + +## Usage + +```python +from twinkle.template import DeepseekV4Template + +template = DeepseekV4Template( + model_id='deepseek-ai/DeepSeek-V4', + enable_thinking=True, +) +``` + +## Features + +- **Custom tokenizer wrapper**: Overrides `apply_chat_template` with DeepSeek V4's encoding protocol +- **Thinking mode**: Supports `thinking` / `chat` modes with configurable reasoning effort +- **Tool calls**: Native DSML (DeepSeek Markup Language) tool-call encoding +- **Multi-token EOS**: Handles DeepSeek V4's multi-character special tokens + +## Thinking Modes + +```python +# Enable deep thinking (reasoning mode) +template = DeepseekV4Template(model_id='...', enable_thinking=True) + +# Control reasoning effort +# 'max' or 'high' enables extended reasoning budget +template.encode(messages, reasoning_effort='max') +``` + +## Tool Call Support + +DeepSeek V4 uses its own DSML protocol for structured function calling: + +```python +messages = [ + {'role': 'user', 'content': 'What is the weather in Shanghai?'}, +] +tools = [ + {'type': 'function', 'function': {'name': 'get_weather', 'parameters': {...}}} +] + +features = template.encode(messages, tools=tools) +``` + +## Key Differences from Base Template + +| Feature | Base Template | DeepseekV4Template | +|:--------|:-------------|:-------------------| +| Chat template | HuggingFace native | Custom DSML encoding | +| Thinking | `` tags | Native thinking mode toggle | +| Tool calls | Hermes/Qwen format | DSML tool blocks | +| EOS handling | Single token | Multi-token special markers | diff --git a/docs/source_en/Components/Template/Template.md b/docs/source_en/Components/Template/Template.md index 60962a33a..b9124412a 100644 --- a/docs/source_en/Components/Template/Template.md +++ b/docs/source_en/Components/Template/Template.md @@ -2,6 +2,64 @@ The template is a key component for converting Trajectory to InputFeature. +```python +class Template: + + def __init__(self, + model_id: str, + use_chat_template: bool = True, + max_length: Optional[int] = 8192, + truncation_strategy: Literal['raise', 'left', 'right', 'split', 'delete'] = 'raise', + default_system: Optional[str] = None): + ... + + def batch_encode(self, trajectories: Union[Dict[str, Any], List[Trajectory]]) -> List[InputFeature]: + # Batch encode samples + ... + + def check(self, trajectory: Trajectory) -> Optional[Trajectory]: + # Encode one sample and return the original sample + # Generally used to check data reasonableness in RL algorithms like GRPO + ... + + def batch_check(self, trajectories: List[Trajectory]) -> List[Optional[Trajectory]]: + # Batch check samples + ... + + def decode(self, token_ids: List[int], **kwargs) -> str: + # Decode sample + ... + + def batch_decode(self, token_ids: List[List[int]], **kwargs) -> List[str]: + # Batch decode samples + ... +``` + +- model_id: Model id containing tokenizer or processor +- use_chat_template: Whether to use chat_template. If not used, it is generally a pre-training scenario +- max_length: Maximum length of a single sample +- truncation_strategy: How to handle the sample if it exceeds the maximum length + - raise: Throw an exception. Generally used for very precise dataset scenarios + - left: Remove tokens on the left to conform to max_length + - right: Remove tokens on the right to conform to max_length + - split: Split the oversized sample into multiple max_length chunks (not supported for multimodal, LazyDataset, or IterablePackingDataset) + - delete: Drop the entire sample if it exceeds max_length +- default_system: If the dataset does not have a system message, use the default system + +> Template does not support using functions as replacements because it needs to support many functions internally. If you need to write a new Template, please inherit the `Template` class. +> Generally speaking, using the Template base class is sufficient for pure text models. In the base class, we use tokenizer.apply_chat_template to encode the model, which is universal for general pure text models. + +# Template mapping + +Currently, the model-template mapping is simple: + +- Template class: Supported in all pure text LLMs. +- DeepseekV4Template class: For DeepSeek V4, rewrites the chat template encoding logic, `encode_messages` is built into twinkle. +- Qwen3_5Template class: For Qwen3.5 MLLMs. +# Template + +The template is a key component for converting Trajectory to InputFeature. + ```python class Template: diff --git a/docs/source_en/Components/Template/ToolCallParsers.md b/docs/source_en/Components/Template/ToolCallParsers.md new file mode 100644 index 000000000..8d4e3f988 --- /dev/null +++ b/docs/source_en/Components/Template/ToolCallParsers.md @@ -0,0 +1,98 @@ +# Tool Call Parsers + +Twinkle's template system includes a modular tool-call parsing framework for training models with function calling capabilities. + +## Architecture + +``` +ToolCallRegistry +├── HermesQwenParser — Hermes/Qwen style ... +├── ReActParser — ReAct Thought/Action/Observation +├── ClineParser — Cline XML-based tool calls +└── VCPParser — VCP protocol +``` + +## ToolCallParser Interface + +```python +from twinkle.template.tools import ToolCallParser + +class ToolCallParser(ABC): + name: str = '' + open_marker: str | None = None + close_marker: str | None = None + + def detect(self, text: str) -> bool: + """Check if text contains this format's markup.""" + ... + + def parse(self, text: str) -> List[Dict[str, Any]]: + """Extract tool calls in OpenAI format.""" + ... + + def clean(self, text: str) -> str: + """Strip markup, return plain content.""" + ... +``` + +## ToolCallRegistry + +The registry auto-discovers parsers and routes detection: + +```python +from twinkle.template.tools import ToolCallRegistry + +# Detect which format a completion uses +parser = ToolCallRegistry.detect_first(completion_text) +if parser: + tool_calls = parser.parse(completion_text) + clean_text = parser.clean(completion_text) +``` + +## Built-in Parsers + +### HermesQwenParser + +Parses Hermes/Qwen-style function calls: + +```xml + +{"name": "get_weather", "arguments": {"city": "Shanghai"}} + +``` + +### ReActParser + +Parses ReAct-style reasoning traces: + +``` +Thought: I need to check the weather +Action: get_weather +Action Input: {"city": "Shanghai"} +Observation: ... +``` + +### ClineParser + +Parses Cline XML-based tool invocations with structured parameters. + +### VCPParser + +Parses VCP (Visual Code Protocol) tool calls. + +## Usage in Training + +Tool call parsers integrate with the Template during preprocessing: + +```python +from twinkle.template import Template + +template = Template( + model_id='ms://Qwen/Qwen3.5-4B', + enable_thinking=True, +) + +# Template automatically uses ToolCallRegistry for +# tool-call aware tokenization during encoding +features = template.encode(messages, tools=tool_definitions) +``` diff --git a/docs/source_en/Components/Template/index.rst b/docs/source_en/Components/Template/index.rst index cd5fddb42..c3d125f04 100644 --- a/docs/source_en/Components/Template/index.rst +++ b/docs/source_en/Components/Template/index.rst @@ -4,3 +4,11 @@ Template :maxdepth: 1 Template.md + DeepSeekV4Template.md + ToolCallParsers.md +Template +=============== +.. toctree:: + :maxdepth: 1 + + Template.md diff --git a/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md index 169adb86e..b5f55a575 100644 --- a/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md +++ b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md @@ -40,6 +40,24 @@ class DeviceMesh: It is recommended to use `from_sizes` to construct it. +### Parameter Reference + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `world_size` | Total number of processes | 1 | +| `dp_size` | Data parallel degree | 1 | +| `fsdp_size` | Fully Sharded Data Parallel degree | None | +| `tp_size` | Tensor parallel degree | None | +| `pp_size` | Pipeline parallel degree | None | +| `ulysses_size` | Ulysses sequence parallel degree | None | +| `cp_size` | Context parallel degree | None | +| `ep_size` | Expert parallel degree (for MoE models) | None | +| `etp_size` | Expert tensor parallel degree | None | +| `ep_fsdp_size` | FSDP degree within each EP group | None | +| `vpp_size` | Virtual pipeline parallel degree | None | +| `device_type` | Device type (`cuda`, `npu`, etc.) | `cuda` | +| `sequence_parallel` | Enable Megatron-style sequence parallel | False | + Let's give an example: ```python diff --git a/docs/source_en/Components/Training Middleware/Expert-Parallel.md b/docs/source_en/Components/Training Middleware/Expert-Parallel.md new file mode 100644 index 000000000..edd0ec03d --- /dev/null +++ b/docs/source_en/Components/Training Middleware/Expert-Parallel.md @@ -0,0 +1,74 @@ +# Expert Parallel (EP) + +Expert Parallel distributes Mixture-of-Experts (MoE) model experts across multiple GPUs, allowing each rank to hold a subset of experts. This reduces per-GPU memory and enables training of large MoE models. + +## Overview + +| Concept | Description | +|---------|-------------| +| **ExpertParallelConfig** | Configuration dataclass controlling EP behavior | +| **apply_expert_parallel()** | Entry point that shards experts and patches forward | +| **shard_experts()** | Evenly splits experts across EP ranks | +| **patch_forward()** | Replaces MoE block forward with EP-aware all-to-all communication | + +## Configuration + +```python +from twinkle.model.transformers.moe.expert_parallel import ExpertParallelConfig + +config = ExpertParallelConfig( + enabled=True, # Enable expert parallel + router_dtype='fp32', # Router computation dtype: 'fp32', 'bf16', 'fp16' + keep_router_logits=True, # Return router logits alongside hidden states + ignore_shared_experts=False,# Skip shared expert computation (e.g. DeepSeek) + ep_size=None, # EP world size (consumed by TransformersModel) +) +``` + +## Usage with DeviceMesh + +EP is activated by setting `ep_size` in `DeviceMesh.from_sizes()`. The framework automatically calls `apply_expert_parallel()` during model initialization. + +```python +from twinkle.utils import DeviceMesh + +# 8 GPUs: 2-way EP × 4-way data parallel +device_mesh = DeviceMesh.from_sizes( + world_size=8, + dp_size=4, + ep_size=2, +) +``` + +For combined EP + FSDP sharding on the expert parameters: + +```python +# 8 GPUs: 2-way EP with FSDP within each EP group +device_mesh = DeviceMesh.from_sizes( + world_size=8, + dp_size=2, + ep_size=2, + ep_fsdp_size=2, +) +``` + +## Communication Pattern + +The EP forward pass follows a 4-stage pipeline: + +1. **Preprocess** — compute per-expert token counts and split sizes +2. **Token Pre-All2All** — permute tokens by expert assignment, then all-to-all exchange across EP ranks +3. **Expert Compute** — each rank runs its local experts on received tokens +4. **Token Post-All2All** — all-to-all exchange results back, unpermute and apply routing weights + +``` +Input tokens → Router → [preprocess] → [pre_all2all] → [local experts] → [post_all2all] → Output +``` + +## Requirements + +- `num_experts` must be divisible by `ep_size` +- `torch.distributed` must be initialized +- MoE blocks must define a `gate`/`router` module and `experts` (either `nn.ModuleList` or tensor-style `gate_up_proj`/`down_proj`) +- Both ModuleList-style and tensor-style (fused) experts are supported +- Shared experts (e.g. DeepSeek MoE) are handled automatically unless `ignore_shared_experts=True` diff --git a/docs/source_en/Components/Training Middleware/Padding-Free.md b/docs/source_en/Components/Training Middleware/Padding-Free.md new file mode 100644 index 000000000..44dd7ba14 --- /dev/null +++ b/docs/source_en/Components/Training Middleware/Padding-Free.md @@ -0,0 +1,52 @@ +# Padding-Free Training + +Padding-free (also called "packing") training eliminates wasted computation on padding tokens by concatenating multiple sequences into a single packed batch. Twinkle supports padding-free training for both standard attention and Qwen3.5's GatedDeltaNet linear attention. + +## How It Works + +Instead of padding all sequences to `max_length`, padding-free packs multiple sequences into one row and uses `position_ids` to track sequence boundaries. This avoids wasted FLOPs on padding tokens. + +``` +Standard: [tok tok tok PAD PAD PAD] [tok tok PAD PAD PAD PAD] +Packed: [tok tok tok tok tok ...] ← no padding waste +``` + +## Usage + +Padding-free is enabled via `PackingDataset` or `IterablePackingDataset`: + +```python +from twinkle.dataset import PackingDataset + +dataset = PackingDataset( + dataset=base_dataset, + max_length=8192, +) +``` + +The dataset automatically packs sequences and generates correct `position_ids` with resets at sequence boundaries. + +## GatedDeltaNet Patch (Qwen3.5) + +Qwen3.5 uses a hybrid architecture mixing standard attention with GatedDeltaNet linear attention. The native GatedDeltaNet implementation does not reset its linear-attention state at packed sequence boundaries. + +`GatedDeltaNetPaddingFreePatch` fixes this by: + +1. Patching `Qwen3_5DecoderLayer.forward` to pass `cu_seq_lens_q` (cumulative sequence lengths) to linear attention layers +2. Patching `Qwen3_5GatedDeltaNet.forward` to use flash-linear-attention kernels (`causal_conv1d`, `chunk_gated_delta_rule`) with `cu_seqlens` support + +The patch is applied automatically when padding-free is detected on Qwen3.5 models. + +### Requirements + +- `flash-linear-attention` package must be installed +- Only needed for Qwen3.5 models with GatedDeltaNet layers +- When sequence parallel is enabled, a separate `Qwen3_5GatedDeltaNetUlyssesPatch` is used instead + +## Attention Backend Requirements + +| Attention Backend | Padding-Free Support | +|-------------------|---------------------| +| FlashAttention2 | Fully supported | +| SDPA | Supported (incompatible with sequence parallel) | +| Eager | Not supported | diff --git a/docs/source_en/Components/Training Middleware/Sequence-Parallel.md b/docs/source_en/Components/Training Middleware/Sequence-Parallel.md new file mode 100644 index 000000000..d08b01d67 --- /dev/null +++ b/docs/source_en/Components/Training Middleware/Sequence-Parallel.md @@ -0,0 +1,68 @@ +# Sequence Parallel (SP) + +Sequence Parallel splits long sequences across multiple GPUs along the sequence dimension, enabling training with sequence lengths that exceed single-GPU memory. Twinkle implements Ulysses-style sequence parallel with optional derived ring attention. + +## Overview + +| Concept | Description | +|---------|-------------| +| **SequenceParallelConfig** | Configuration dataclass for SP | +| **SequenceParallelStrategy** | Strategy class that wraps SP lifecycle | +| **SequenceParallel** | Core implementation handling pad/split/gather | + +## Configuration + +```python +from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelConfig + +config = SequenceParallelConfig( + enabled=True, # Enable sequence parallel + ulysses_size=None, # Ulysses SP degree (auto-derived from DeviceMesh if None) + gather_logits=True, # Gather logits after forward for loss computation +) +``` + +## Usage with DeviceMesh + +SP is activated by setting `ulysses_size` in `DeviceMesh.from_sizes()`: + +```python +from twinkle.utils import DeviceMesh + +# 8 GPUs: 4-way Ulysses SP × 2-way data parallel +device_mesh = DeviceMesh.from_sizes( + world_size=8, + dp_size=2, + ulysses_size=4, +) +``` + +## How It Works + +1. **Pad** — input sequences are padded to a length divisible by SP world size +2. **Split** — padded inputs are evenly split across SP ranks along the sequence dimension +3. **Distributed Attention** — FlashAttention2 is patched to perform Ulysses all-to-all communication before/after attention computation +4. **Gather** — after forward, logits are gathered back to full sequence length for loss computation + +## Supported Attention Backends + +| Backend | Status | +|---------|--------| +| FlashAttention2 | Fully supported (including packed/padding-free sequences) | +| SDPA | Supported (non-packed batches only) | +| Derived Ring Attention | Supported with FlashAttention2 only (`rp_world_size > 1`) | + +## Qwen3.5 Linear Attention + +SP automatically detects Qwen3.5 GatedDeltaNet linear attention layers and applies the `Qwen3_5GatedDeltaNetUlyssesPatch` for correct sequence-parallel behavior on hybrid attention architectures. + +## MoE Auxiliary Loss + +For MoE models, SP automatically installs a forward hook that gathers router logits across SP ranks before auxiliary loss computation, ensuring correct load-balancing signals. + +## Key Constraints + +- `num_key_value_heads` must be divisible by `ulysses_size` (for Ulysses) or use ring attention fallback +- Packed/padding-free batches require FlashAttention2 +- Derived ring attention requires `batch_size == 1` (packed format) +- `torch.distributed` must be initialized diff --git a/docs/source_en/Components/Training Middleware/TwinkleClient.md b/docs/source_en/Components/Training Middleware/TwinkleClient.md new file mode 100644 index 000000000..18a1437db --- /dev/null +++ b/docs/source_en/Components/Training Middleware/TwinkleClient.md @@ -0,0 +1,81 @@ +# TwinkleClient + +`TwinkleClient` is the Python client for interacting with the Twinkle REST API. It manages sessions, training runs, and checkpoints. + +## Initialization + +```python +from twinkle_client.manager import TwinkleClient + +client = TwinkleClient( + base_url='http://localhost:8000', # Or TWINKLE_SERVER_URL env var + api_key='your-api-key', # Or TWINKLE_SERVER_TOKEN env var + route_prefix='/twinkle', # API route prefix + session_heartbeat_interval=10, # Heartbeat interval in seconds + session_metadata={'user': 'alice'}, # Optional session metadata +) +``` + +On init, the client: +1. Sets `base_url` and `api_key` into shared context (used by all client objects) +2. Creates a server-side session +3. Starts a background heartbeat thread to keep the session alive + +## Health Check + +```python +is_healthy = client.health_check() # Returns True/False +capabilities = client.get_server_capabilities() # Supported models +``` + +## Training Runs + +```python +# List runs +runs = client.list_training_runs(limit=20, offset=0) + +# List with pagination cursor +runs, cursor = client.list_training_runs_with_cursor(limit=20) + +# Get specific run +run = client.get_training_run(run_id='run_abc123') + +# Find by base model +qwen_runs = client.find_training_run_by_model('Qwen/Qwen3.5-4B') +``` + +## Checkpoints + +```python +# List checkpoints for a run +checkpoints = client.list_checkpoints(run_id='run_abc123') + +# Get checkpoint path +parsed = client.get_checkpoint_path(run_id, checkpoint_id) +# parsed.path → filesystem path +# parsed.twinkle_path → twinkle:// URI + +# Get latest checkpoint (useful for resume training) +latest_path = client.get_latest_checkpoint_path(run_id) + +# Delete checkpoint +client.delete_checkpoint(run_id, checkpoint_id) +``` + +## Capacity & Weights Info + +```python +# LoRA capacity +capacity = client.get_capacity_info() +# capacity.max_loras, capacity.used_loras, capacity.free_loras + +# Weights metadata +info = client.get_weights_info('twinkle://run_id/weights/checkpoint') +# info.base_model, info.is_lora, info.lora_rank +``` + +## Cleanup + +```python +client.close() # Stops heartbeat thread (also registered via atexit) +``` diff --git a/docs/source_en/Components/Training Middleware/index.rst b/docs/source_en/Components/Training Middleware/index.rst index 014dfdc66..b2dc3acee 100644 --- a/docs/source_en/Components/Training Middleware/index.rst +++ b/docs/source_en/Components/Training Middleware/index.rst @@ -4,4 +4,8 @@ Training Middleware :maxdepth: 1 DeviceMesh-and-DeviceGroup.md + Expert-Parallel.md + Sequence-Parallel.md + Padding-Free.md RemoteClass.md + TwinkleClient.md diff --git a/docs/source_en/Usage Guide/Embedding-Training.md b/docs/source_en/Usage Guide/Embedding-Training.md new file mode 100644 index 000000000..d86769c38 --- /dev/null +++ b/docs/source_en/Usage Guide/Embedding-Training.md @@ -0,0 +1,120 @@ +# Embedding Training + +Twinkle supports contrastive embedding model training with InfoNCE loss, in-batch negatives, and cross-rank gathering. This guide demonstrates how to train embedding models using Twinkle. + +--- + +## Overview + +Embedding training in Twinkle uses the following core components: + +| Component | Role | +|:----------|:-----| +| `InfonceLoss` | Contrastive loss with in-batch negatives | +| `EmbeddingMetric` | Tracks pos/neg similarity and loss | +| `TransformersModel` | Trainable embedding model (with LoRA or full) | +| `InputProcessor` | Processes anchor/positive pairs into features | + +### Data Format + +Each training sample consists of **(anchor, positive)** pairs. In the embedding feature tensor: + +``` +embeddings: [anchor_0, positive_0, anchor_1, positive_1, ...] +labels: [ 1, 0, 1, 0, ...] +``` + +- `labels=1` marks the start of a new group (anchor) +- `labels=0` marks positives/negatives within the group + +--- + +## Basic Embedding Training + +A minimal embedding training script with DDP: + +```python +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_logger +from twinkle.dataloader import DataLoader +from twinkle.loss import InfonceLoss +from twinkle.metric import EmbeddingMetric +from twinkle.model import TransformersModel +from twinkle.processor import InputProcessor +from twinkle.template import Qwen3_5Template + +logger = get_logger() + +# --- Configuration --- +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +MODEL_GPUS = 4 +BATCH_SIZE = 32 +LEARNING_RATE = 1e-5 +TEMPERATURE = 0.07 +EMB_MAX_LENGTH = 8192 + +# --- Initialize --- +device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), +] +model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) +twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups) + +# --- Model --- +model = TransformersModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ddp_config={'find_unused_parameters': True}, +) +model.set_processor(InputProcessor) +model.set_loss(InfonceLoss, temperature=TEMPERATURE, use_batch=True) +model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) +model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=200, + num_training_steps=total_steps, +) +model.add_metric(EmbeddingMetric, is_training=True) + +# --- Template --- +template = Qwen3_5Template( + model_id=MODEL_ID, + max_length=EMB_MAX_LENGTH, + enable_thinking=False, +) + +# --- Training Loop --- +for step, batch in enumerate(dataloader): + # batch: list of features with anchor/positive pairs + model.forward_backward(inputs=batch, task='embedding') + model.clip_grad_and_step(gradient_accumulation_steps=1) + + if step % 10 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Step {step}: {metric}') +``` + +### Key Parameters + +| Parameter | Recommended | Description | +|:----------|:------------|:------------| +| `temperature` | 0.05–0.1 | Lower = sharper contrast. 0.07 keeps gradients flowing until cosine > 0.75 | +| `use_batch` | True | Enables cross-sample in-batch negatives for better efficiency | +| `hard_negatives` | None or 7 | Fix negative count per sample; None uses all in-batch | +| `find_unused_parameters` | True | Required for embedding models (only last hidden state contributes gradients) | + +--- + +## Monitoring + +The `EmbeddingMetric` reports key training signals: + +| Metric | What it means | +|:-------|:--------------| +| `pos_sim` | Average anchor-positive cosine similarity (target: > 0.8) | +| `neg_sim` | Average anchor-negative similarity (target: < 0.3) | +| `loss` | InfoNCE loss value | +| `grad_norm` | Gradient magnitude | + +Healthy training shows `pos_sim` rising and `neg_sim` stable or falling. If `pos_sim` saturates near 1.0, lower the temperature. diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 1c1a70fb6..707473910 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -156,6 +156,8 @@ if __name__ == '__main__': In this training code, we constructed a dataset and loaded the Qwen/Qwen3.5-4B model, used LoRA with the all-linear approach, and completed one training run. In the logs, you can observe the process of loss gradually converging. +> **Tip — Full-Parameter Training**: The example above uses LoRA for efficiency. To switch to full-parameter training, simply remove the `add_adapter_to_model` call (and the `from peft import LoraConfig` import). Everything else stays the same. + ### torchrun Twinkle supports running training in torchrun mode. In this scenario, Ray-related dependencies do not need to be installed. @@ -471,7 +473,7 @@ python train.py A major feature of Twinkle is support for multi-tenant mixed training. Specifically, multiple users can use a single base model for LoRA training, which can greatly reduce server-side deployment costs. -Checkpoint resumption is also supported in client-server training. The recommended flow is to call `model.resume_from_checkpoint(resume_path)` to restore weights and optimizer state, then call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip consumed data. See [Twinkle-Client](./Server%20and%20Client/Twinkle-Client.md) and [self_cognition.py](../../../cookbook/client/twinkle/self_host/self_cognition.py). +Checkpoint resumption is also supported in client-server training. The recommended flow is to call `model.resume_from_checkpoint(resume_path)` to restore weights and optimizer state, then call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip consumed data. See [Twinkle-Client](./Server%20and%20Client/Twinkle-Client.md) and [self_cognition.py](../../../cookbook/server_mode/twinkle/self_host/self_cognition.py). Suppose we start a service using eight GPUs. First, we need to start the Ray cluster: @@ -493,6 +495,8 @@ Next, start the server: twinkle-server launch -c cookbook/client/server/transformer/server_config.yaml ``` +> For details on how to write `server_config.yaml`, see [Server Configuration](../Server%20and%20Client/Server.md). + The server will start three services: a sampler cluster, a model cluster, and a utility cluster. Now you can perform client-side training: diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst index ef477f7fc..6128079c8 100644 --- a/docs/source_en/index.rst +++ b/docs/source_en/index.rst @@ -15,6 +15,7 @@ Twinkle DOCUMENTATION Usage Guide/NPU-Support.md Usage Guide/Train-as-a-Service.md Usage Guide/Introduction-with-Qwen3.5.md + Usage Guide/Embedding-Training.md .. toctree:: :maxdepth: 2 @@ -30,7 +31,6 @@ Twinkle DOCUMENTATION Components/Sampler/index.rst Components/Reward/index.rst Components/Advantage/index.rst - Components/Gym/index.rst Components/Hub/index.rst Components/Checkpoint Engine/index.rst Components/Metrics/index.rst @@ -41,6 +41,10 @@ Twinkle DOCUMENTATION Components/Plugin/index.rst Components/Kernel/index.rst Components/Training Middleware/index.rst + Components/CLI/index.rst + Components/Notifier/index.rst + Components/Agentic/index.rst + Components/TUI/index.rst Indices and tables ================== diff --git a/docs/source_zh/index.rst b/docs/source_zh/index.rst index 3d07d4b2a..6a7be7b5b 100644 --- a/docs/source_zh/index.rst +++ b/docs/source_zh/index.rst @@ -15,6 +15,7 @@ Twinkle DOCUMENTATION 使用指引/NPU的支持.md 使用指引/训练服务.md 使用指引/Qwen3.5最佳实践.md + 使用指引/Embedding训练.md .. toctree:: :maxdepth: 2 @@ -30,7 +31,6 @@ Twinkle DOCUMENTATION 组件/采样器/index.rst 组件/奖励/index.rst 组件/优势/index.rst - 组件/Gym/index.rst 组件/Hub/index.rst 组件/检查点引擎/index.rst 组件/指标/index.rst @@ -41,6 +41,10 @@ Twinkle DOCUMENTATION 组件/组件化/index.rst 组件/Kernel/index.rst 组件/训练中间件/index.rst + 组件/CLI/index.rst + 组件/通知器/index.rst + 组件/Agentic/index.rst + 组件/TUI/index.rst Indices and tables ================== diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Embedding\350\256\255\347\273\203.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Embedding\350\256\255\347\273\203.md" new file mode 100644 index 000000000..94ea86ebe --- /dev/null +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Embedding\350\256\255\347\273\203.md" @@ -0,0 +1,120 @@ +# Embedding 模型训练 + +Twinkle 支持基于 InfoNCE 损失的对比学习 Embedding 模型训练,内置 in-batch negatives 和跨 rank 聚合。本文介绍如何使用 Twinkle 训练 Embedding 模型。 + +--- + +## 概述 + +Embedding 训练使用以下核心组件: + +| 组件 | 职责 | +|:-----|:-----| +| `InfonceLoss` | 对比损失,支持 in-batch negatives | +| `EmbeddingMetric` | 追踪正/负对相似度和损失 | +| `TransformersModel` | 可训练的 Embedding 模型(LoRA 或全参) | +| `InputProcessor` | 将 anchor/positive 对处理为特征 | + +### 数据格式 + +每个训练样本由 **(anchor, positive)** 对组成。在 Embedding 特征张量中: + +``` +embeddings: [anchor_0, positive_0, anchor_1, positive_1, ...] +labels: [ 1, 0, 1, 0, ...] +``` + +- `labels=1` 标记新分组的起始位置(anchor) +- `labels=0` 标记组内的 positive/negative + +--- + +## 基础 Embedding 训练 + +使用 DDP 的最小化 Embedding 训练脚本: + +```python +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_logger +from twinkle.dataloader import DataLoader +from twinkle.loss import InfonceLoss +from twinkle.metric import EmbeddingMetric +from twinkle.model import TransformersModel +from twinkle.processor import InputProcessor +from twinkle.template import Qwen3_5Template + +logger = get_logger() + +# --- 配置 --- +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +MODEL_GPUS = 4 +BATCH_SIZE = 32 +LEARNING_RATE = 1e-5 +TEMPERATURE = 0.07 +EMB_MAX_LENGTH = 8192 + +# --- 初始化 --- +device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), +] +model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) +twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups) + +# --- 模型 --- +model = TransformersModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ddp_config={'find_unused_parameters': True}, +) +model.set_processor(InputProcessor) +model.set_loss(InfonceLoss, temperature=TEMPERATURE, use_batch=True) +model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) +model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=200, + num_training_steps=total_steps, +) +model.add_metric(EmbeddingMetric, is_training=True) + +# --- 模板 --- +template = Qwen3_5Template( + model_id=MODEL_ID, + max_length=EMB_MAX_LENGTH, + enable_thinking=False, +) + +# --- 训练循环 --- +for step, batch in enumerate(dataloader): + # batch: 包含 anchor/positive 对的特征列表 + model.forward_backward(inputs=batch, task='embedding') + model.clip_grad_and_step(gradient_accumulation_steps=1) + + if step % 10 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Step {step}: {metric}') +``` + +### 关键参数 + +| 参数 | 推荐值 | 说明 | +|:----|:------|:-----| +| `temperature` | 0.05–0.1 | 越低对比越尖锐;0.07 保持梯度流动直至 cosine > 0.75 | +| `use_batch` | True | 启用跨样本 in-batch negatives 提升效率 | +| `hard_negatives` | None 或 7 | 固定每样本负例数量;None 使用全部 in-batch | +| `find_unused_parameters` | True | Embedding 模型必需(仅最后隐藏状态产生梯度) | + +--- + +## 监控指标 + +`EmbeddingMetric` 报告关键训练信号: + +| 指标 | 含义 | +|:----|:-----| +| `pos_sim` | anchor-positive 平均余弦相似度(目标 > 0.8) | +| `neg_sim` | anchor-negative 平均相似度(目标 < 0.3) | +| `loss` | InfoNCE 损失值 | +| `grad_norm` | 梯度范数 | + +健康的训练表现为 `pos_sim` 持续上升、`neg_sim` 稳定或下降。如果 `pos_sim` 过早饱和至 1.0 附近,应降低 temperature。 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" index 39f6fe182..feaf7f4e8 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" @@ -257,6 +257,19 @@ pip install torch_npu-2.7.1-cp311-cp311-linux_aarch64.whl 2. “待验证”功能可以尝试,但可能遇到兼容性问题 3. 遇到问题时,参考对应的示例代码进行配置 +## 示例代码 + +Twinkle 在 NPU 上已验证的示例目前聚焦 Megatron smoke 路径;SFT 和 GRPO cookbook 示例暂无对应文件。 + +### 远程训练(Tinker 协议) +- **服务端配置**:[cookbook/remote/tinker/ascend/](https://github.com/modelscope/twinkle/tree/main/cookbook/remote/tinker/ascend) + - 提供 HTTP API 接口 + - 支持远程训练和推理 + - 适用于生产环境部署 + +**运行示例**: +暂无对应命令示例。 + ## 参考资源 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 19189948f..3bc5c4ba4 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -157,6 +157,8 @@ if __name__ == '__main__': 在这个训练代码中,我们构造了一个数据集并拉起了Qwen/Qwen3.5-4B模型,使用all-linear方式加载了lora,并完成了一次训练。在日志中,可以看到loss逐步收敛的过程。 +> **提示 — 全参数训练**:上面的示例使用 LoRA 以提高效率。若要切换为全参数训练,只需移除 `add_adapter_to_model` 调用(以及 `from peft import LoraConfig` 导入),其余代码完全不变。 + ### torchrun Twinkle 支持以 torchrun 模式运行训练。在这种场景下,不需要安装 Ray 相关的依赖。 @@ -470,7 +472,7 @@ python train.py ``` ### 远程训练 -client-server 训练场景同样支持断点续训。推荐流程是调用 `model.resume_from_checkpoint(resume_path)` 恢复权重和优化器状态,再调用 `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` 跳过已消费数据。详细示例可参考 [Twinkle客户端](./服务端和客户端/Twinkle客户端.md) 和 [self_cognition.py](../../../cookbook/client/twinkle/self_host/self_cognition.py)。 +client-server 训练场景同样支持断点续训。推荐流程是调用 `model.resume_from_checkpoint(resume_path)` 恢复权重和优化器状态,再调用 `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` 跳过已消费数据。详细示例可参考 [Twinkle客户端](./服务端和客户端/Twinkle客户端.md) 和 [self_cognition.py](../../../cookbook/server_mode/twinkle/self_host/self_cognition.py)。 Twinkle 的一大特色是支持多租户用户混合训练。具体来说,多个用户可以使用一个基模进行 LoRA 训练,这样可以极大减小服务端部署成本。 @@ -494,6 +496,8 @@ CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0 twinkle-server launch -c cookbook/client/server/transformer/server_config.yaml ``` +> `server_config.yaml` 的编写方式详见 [服务端配置](../服务端和客户端/服务端.md)。 + 服务端会启动一个包含 Sampler 集群、模型集群、工具集群的三个服务。 下面可以进行client端训练: diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Envs.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Envs.md" new file mode 100644 index 000000000..675f05baf --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Envs.md" @@ -0,0 +1,183 @@ +# 执行环境(Envs) + +Envs 模块提供了用于 Agentic 训练的 RL 执行环境抽象。环境可以在多轮 rollout 中交互式参与,也可以批量评估已完成的轨迹。 + +## Env 基类 + +```python +from twinkle_agentic.envs.base import Env, StepResult + +class Env(ABC): + + def reset(self, trajectory=None) -> StepResult: + """重置环境,开始新一轮。""" + + @abstractmethod + def step(self, tool_name: str, arguments: dict) -> StepResult: + """执行单个动作,返回观测 + 奖励 + 完成标志。""" + + def tools(self) -> List[ToolInfo]: + """返回此环境中可用的工具定义。""" + + def evaluate(self, trajectories, **kwargs) -> List[float]: + """批量评估已完成的轨迹,返回奖励列表。""" + + def close(self) -> None: + """释放资源。""" +``` + +### StepResult + +```python +@dataclass +class StepResult: + observation: str = '' # 动作执行后的环境观测 + reward: float = 0.0 # 此步骤的标量奖励 + done: bool = False # 是否终止 + info: Dict[str, Any] = field(default_factory=dict) # 额外元数据 +``` + +### 两种使用模式 + +1. **交互模式**(多轮 rollout)—— 逐步执行: + +```python +env = MyEnv() +env.reset(trajectory) +result = env.step('search', {'query': 'Python'}) +# ... 重复直到 result.done +``` + +2. **批量评估模式** —— 评估已完成的轨迹: + +```python +rewards = env.evaluate(completed_trajectories) +``` + +## EnvTool + +`EnvTool` 将 `Env` 包装为 `Tool`,连接环境与 `ToolManager` 和 `MultiTurnRollout`。 + +```python +from twinkle_agentic.envs.env_tool import EnvTool +from twinkle_agentic.tools.tool_manager import ToolManager + +env = MyEnv() + +# 为环境中定义的每个工具创建一个 EnvTool +env_tools = EnvTool.from_env(env) + +# 注册到 ToolManager +manager = ToolManager(env_tools) +``` + +### 核心特性 + +| 特性 | 说明 | +|------|------| +| `from_env(env)` | 工厂方法:为 `env.tools()` 中的每个工具创建一个 `EnvTool`。 | +| `last_result` | 存储最近一次 `StepResult` 供调用方检查。 | +| `done` | 属性:最后一步是否终止了回合。 | +| `episode_reward` | 属性:来自 `info['episode_reward']` 的累计奖励。 | + +### 手动构造 + +```python +env_tool = EnvTool( + env=my_env, + tool_name='execute_code', + description='在沙箱中执行 Python 代码。', + parameters={ + 'type': 'object', + 'properties': { + 'code': {'type': 'string', 'description': '要执行的 Python 代码。'}, + }, + 'required': ['code'], + }, +) +``` + +## OpenEnv + +`OpenEnv` 将基于 WebSocket 的 [OpenEnv](https://github.com/OpenEnv) 环境服务器适配为同步的 Twinkle `Env`。 + +```python +from twinkle_agentic.envs.openenv import OpenEnv + +env = OpenEnv( + base_url='http://localhost:8000', + env_cls='coding_env.CodingEnv', # 可选的类型化客户端 + env_kwargs={'message_timeout_s': 30}, + tool_schema=[...], # 可选的工具定义 +) +``` + +### 参数 + +| 参数 | 类型 | 说明 | +|------|------|------| +| `base_url` | `str` | 运行中的 OpenEnv 服务器 URL。 | +| `env_cls` | `str` 或 class | 类型化客户端的点分导入路径或类。`None` 使用 `GenericEnvClient`。 | +| `env_kwargs` | `Dict` | 传递给客户端构造函数的额外参数。 | +| `tool_schema` | `List[ToolInfo]` | 通过 `tools()` 暴露的工具定义。 | +| `action_mapper` | `Callable` | 自定义函数,将 `(tool_name, args)` 映射为发送给服务器的动作字典。 | + +### 与 Rollout 集成使用 + +```python +from twinkle_agentic.envs.openenv import OpenEnv +from twinkle_agentic.envs.env_tool import EnvTool +from twinkle_agentic.tools.tool_manager import ToolManager +from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout + +# 设置环境 +env = OpenEnv(base_url='http://localhost:8000', tool_schema=[...]) +env.reset() + +# 桥接到 ToolManager +env_tools = EnvTool.from_env(env) +manager = ToolManager(env_tools) + +# 在 rollout 中使用 +rollout = APIMultiTurnRollout(api=api, tool_manager=manager, max_turns=10) +results = rollout(trajectories) +``` + +### 实现自定义环境 + +```python +from twinkle_agentic.envs.base import Env, StepResult + +class CodeExecutionEnv(Env): + + def reset(self, trajectory=None): + self._sandbox = create_sandbox() + return StepResult(observation='沙箱已就绪。') + + def step(self, tool_name, arguments): + code = arguments.get('code', '') + output = self._sandbox.run(code) + return StepResult( + observation=output, + reward=1.0 if 'error' not in output.lower() else 0.0, + done=False, + ) + + def tools(self): + return [{ + 'type': 'function', + 'function': { + 'name': 'execute_code', + 'description': '运行 Python 代码。', + 'parameters': { + 'type': 'object', + 'properties': { + 'code': {'type': 'string'}, + }, + }, + }, + }] + + def close(self): + self._sandbox.cleanup() +``` diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Multi-Turn-Tool-Usage.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Multi-Turn-Tool-Usage.md" new file mode 100644 index 000000000..1c13bfbe0 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Multi-Turn-Tool-Usage.md" @@ -0,0 +1,205 @@ +# 多轮工具使用指南 + +本指南介绍如何在 Twinkle 中设置和运行带工具调用的多轮 Agentic rollout。 + +## 架构概览 + +Agentic rollout 管线由四个核心组件组成: + +- **Tool** —— 实现特定能力(搜索、代码执行等) +- **ToolManager** —— 注册工具并分发 LLM 工具调用 +- **Env**(可选)—— RL 环境,通过 `EnvTool` 暴露工具 +- **Rollout** —— 驱动多轮对话循环 + +## 快速开始:基于 API 的 Rollout + +使用 OpenAI 兼容 API 运行多轮工具使用 rollout 的最简方式: + +```python +from twinkle_agentic.protocol.openai import OpenAI +from twinkle_agentic.tools.base import Tool +from twinkle_agentic.tools.tool_manager import ToolManager +from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout +from twinkle.data_format.sampling import SamplingParams + +# 1. 定义工具 +class WeatherTool(Tool): + def __call__(self, tool_name, arguments): + city = arguments.get('city', '未知') + return f'{city}的天气:晴,25°C。' + + def tool_info(self): + return { + 'type': 'function', + 'function': { + 'name': 'get_weather', + 'description': '获取城市的当前天气。', + 'parameters': { + 'type': 'object', + 'properties': { + 'city': {'type': 'string', 'description': '城市名称。'}, + }, + 'required': ['city'], + }, + }, + } + +# 2. 设置 ToolManager +manager = ToolManager([WeatherTool()]) + +# 3. 创建 API 客户端 +api = OpenAI(model='qwen3.5-32b', base_url='http://localhost:8000/v1') + +# 4. 创建 rollout +rollout = APIMultiTurnRollout( + api=api, + tool_manager=manager, + sampling_params=SamplingParams(temperature=0.7, max_tokens=2048), + max_turns=6, + concurrency=8, +) + +# 5. 准备轨迹 +trajectories = [ + { + 'messages': [ + {'role': 'user', 'content': '北京今天天气怎么样?'}, + ], + }, +] + +# 6. 运行 rollout +results = rollout(trajectories) +for r in results: + print(f"轮次: {r['turns']}, 停止原因: {r['stop_reason']}") + for msg in r['messages']: + print(f" [{msg['role']}] {msg.get('content', '')[:100]}") +``` + +## 训练集成:基于 vLLM 的 Rollout + +用于 RLHF 训练时,使用 `MultiTurnRollout`,它会生成 `input_ids` 和 `labels`: + +```python +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout +from twinkle.data_format.sampling import SamplingParams + +rollout = MultiTurnRollout( + sampler=vllm_sampler, # vLLMSampler 实例 + template=template, # 聊天模板 + tool_manager=manager, + sampling_params=SamplingParams(temperature=0.7, max_tokens=4096), + max_turns=6, + max_trajectory_tokens=8192, + trace_dir='rollout_traces/', +) + +# 在 GRPO 训练循环中 +results = rollout(batch_trajectories) +# results 包含 input_ids、labels、logprobs 用于训练 +``` + +## 将环境用作工具 + +将 RL 环境桥接到工具管线中: + +```python +from twinkle_agentic.envs.base import Env, StepResult +from twinkle_agentic.envs.env_tool import EnvTool +from twinkle_agentic.tools.tool_manager import ToolManager + +# 定义环境 +class CodeEnv(Env): + def step(self, tool_name, arguments): + code = arguments.get('code', '') + # 在沙箱中执行代码 + result = execute_in_sandbox(code) + return StepResult(observation=result, reward=1.0, done=False) + + def tools(self): + return [{ + 'type': 'function', + 'function': { + 'name': 'run_python', + 'description': '执行 Python 代码。', + 'parameters': { + 'type': 'object', + 'properties': { + 'code': {'type': 'string'}, + }, + 'required': ['code'], + }, + }, + }] + +# 桥接 Env -> Tool -> ToolManager +env = CodeEnv() +env_tools = EnvTool.from_env(env) +manager = ToolManager(env_tools) + +# 照常在 rollout 中使用 manager +rollout = APIMultiTurnRollout(api=api, tool_manager=manager, max_turns=10) +``` + +## 使用 OpenEnv 环境 + +连接远程 OpenEnv WebSocket 服务器: + +```python +from twinkle_agentic.envs.openenv import OpenEnv +from twinkle_agentic.envs.env_tool import EnvTool + +env = OpenEnv( + base_url='http://localhost:8000', + env_cls='coding_env.CodingEnv', + tool_schema=[{ + 'type': 'function', + 'function': { + 'name': 'submit', + 'description': '提交代码解决方案。', + 'parameters': { + 'type': 'object', + 'properties': { + 'code': {'type': 'string'}, + }, + }, + }, + }], +) + +env.reset() +env_tools = EnvTool.from_env(env) +manager = ToolManager(env_tools) +``` + +## 每轨迹独立 ToolManager + +当每个轨迹需要独立工具集时(例如,轨迹绑定的状态): + +```python +# 创建每轨迹的 manager +managers = [] +for traj in trajectories: + env = create_env_for(traj) + env_tools = EnvTool.from_env(env) + managers.append(ToolManager(env_tools)) + +# 传入列表(与轨迹 1:1 对齐) +results = rollout(trajectories, tool_manager=managers) +``` + +## 跟踪调试 + +两种 rollout 实现都支持跟踪文件输出用于调试: + +```python +rollout = APIMultiTurnRollout( + api=api, + tool_manager=manager, + trace_dir='traces/', + trace_callback=lambda t: t['turns'] > 1, # 仅存储多轮对话 + success_callback=lambda t: t.get('stop_reason') == 'stop', +) +``` + +跟踪文件以 `{step}-{ok|fail}-{id}.json` 格式保存,包含完整对话和元数据。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Preprocessor.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Preprocessor.md" new file mode 100644 index 000000000..a5730abc8 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Preprocessor.md" @@ -0,0 +1,189 @@ +# Agentic 预处理器 + +Agentic 预处理器模块提供了基于流水线的多轮对话数据质量过滤框架,用于 RLHF / Agentic 微调之前的训练数据清洗和过滤。 + +## QualityPreprocessor + +`QualityPreprocessor` 是一个轻量级流水线运行器,接受过滤器列表并按顺序执行。每个步骤接收行列表,返回 `(kept, dropped)`,流水线会记录每步统计信息。 + +```python +from twinkle_agentic.preprocessor import QualityPreprocessor, HardFilter, DeadLoopFilter + +pipeline = [ + HardFilter(min_user_chars=10), + DeadLoopFilter(), +] +preprocessor = QualityPreprocessor(pipeline, dropped_log_path='dropped.jsonl') + +# rows 是列格式的字典(Dataset.map 格式) +cleaned = preprocessor(rows) +``` + +### 参数 + +| 参数 | 类型 | 说明 | +|------|------|------| +| `pipeline` | `List[Callable]` | 有序的过滤步骤列表。每个步骤接收 `List[Dict]`,返回 `(kept, dropped)`。 | +| `dropped_log_path` | `str` | 可选的 JSONL 文件路径,用于记录被丢弃的行及步骤名称和原因。 | + +## 内置过滤器 + +### HardFilter + +基于硬规则的过滤器,使用确定性规则移除质量差的行。支持多语言检测(EN/ZH/JA/KO)。 + +```python +from twinkle_agentic.preprocessor import HardFilter + +f = HardFilter( + min_user_chars=10, # 非 CJK 用户查询最小字符数 + min_user_chars_cjk=6, # CJK 用户查询最小字符数 + min_assistant_chars_2turn=80, # 两轮对话中助手回复最小长度 + min_thinking_chars=200, # 思考链最小长度(可豁免过滤) + system_deny_keywords=['hack', 'exploit'], + max_chars_per_round=50000, + max_total_chars=200000, + max_rounds=50, +) +``` + +**丢弃原因:** `trivial_single_turn`(平凡单轮)、`shallow_reply`(浅回复)、`all_empty_assistant`(全空助手)、`system_deny_keyword`(系统拒绝关键词)、`round_too_long`(单轮过长)、`total_too_long`(总长过长)、`too_many_rounds`(轮次过多) + +### DeadLoopFilter + +检测助手消息中的犹豫/死循环模式——重复自我纠正、级联纠正和高 n-gram 重复。 + +```python +from twinkle_agentic.preprocessor import DeadLoopFilter + +f = DeadLoopFilter( + hesitation_density_threshold=7.0, # 每 1000 字符犹豫标记数(响应) + cascade_threshold=5, # 窗口内级联标记数 + cascade_window=800, # 窗口大小(字符) + repetition_threshold=0.45, # N-gram 重复率 + think_hesitation_density_threshold=15.0, # 块更宽松 + think_repetition_threshold=0.65, +) +``` + +对 `` 推理块使用更宽松的阈值(允许自由发散),对可见响应使用更严格的阈值。 + +### DedupFilter + +全局最长优先去重。签名由第一个真实用户轮次(首尾)和第一个助手回复推导。 + +```python +from twinkle_agentic.preprocessor import DedupFilter + +f = DedupFilter(prefix_chars=100, asst_chars=100) +kept, dropped = f(all_rows) # 必须在一次调用中传入整个数据集 +``` + +> **注意:** `DedupFilter` 需要在单次调用中接收完整数据集。**不要**将它放入 `QualityPreprocessor` 中(后者按批处理)。请在流水线之前或之后单独运行。 + +### RefuseFilter + +检测第一条助手回复中的自我引用式拒绝(如"我无法帮助您")。多语言模式匹配(EN/ZH/JA/KO)。 + +```python +from twinkle_agentic.preprocessor import RefuseFilter + +f = RefuseFilter(check_window=600) # 仅检查前 N 个字符 +``` + +### TokenSoupFilter + +检测乱码/token-soup 输出,检查替换字符、控制字符、私用区 Unicode、泄漏的特殊 token、单字符重复和脚本混乱。 + +```python +from twinkle_agentic.preprocessor import TokenSoupFilter + +f = TokenSoupFilter( + replacement_char_ratio=0.02, + special_token_count=20, + script_chaos_threshold=0.55, +) +``` + +### PIIPresidioFilter + +基于 Microsoft Presidio + spaCy NER + Faker 的多语言 PII 检测和重写。检测并替换个人身份信息(姓名、邮箱、电话号码、地址等)。 + +```python +from twinkle_agentic.preprocessor import PIIPresidioFilter + +f = PIIPresidioFilter(languages=['en', 'zh']) +``` + +### IntentClassifier + +启发式意图分类器,为每行标注检测到的意图。可插拔的检测器管线。 + +```python +from twinkle_agentic.preprocessor import IntentClassifier + +classifier = IntentClassifier() +``` + +**意图类别:** `tool_call`(工具调用)、`code`(代码)、`math`(数学)、`complex_logic`(复杂逻辑)、`reasoning`(推理)、`user_dissatisfaction`(用户不满)、`other`(其他) + +### ScoreFilter + +可插拔评分器过滤器,内置字符级指标、语义相似度和代码执行评分器。 + +```python +from twinkle_agentic.preprocessor import ScoreFilter + +f = ScoreFilter() +``` + +**内置评分器:** `ChrMinScorer`、`SIFDScorer`、`PassNScorer`、`ParaphraseScorer` + +### ModelFilter + +按模型 ID 白名单过滤行。 + +```python +from twinkle_agentic.preprocessor import ModelFilter + +f = ModelFilter(allowed_models=['qwen3.5-4b', 'qwen3.5-32b']) +``` + +### MessageNormalizer + +三遍消息规范化:心跳剥离、工具调用重写、连续同角色消息合并。 + +```python +from twinkle_agentic.preprocessor import MessageNormalizer + +normalizer = MessageNormalizer() +``` + +## 完整流水线示例 + +```python +from twinkle_agentic.preprocessor import ( + QualityPreprocessor, + HardFilter, + DeadLoopFilter, + RefuseFilter, + TokenSoupFilter, + MessageNormalizer, + DedupFilter, +) + +# 第一步:全局去重(必须在完整数据集上运行) +dedup = DedupFilter() +rows, _ = dedup(all_rows) + +# 第二步:按批流水线 +pipeline = [ + HardFilter(min_user_chars=10, max_rounds=30), + DeadLoopFilter(), + RefuseFilter(), + TokenSoupFilter(), + MessageNormalizer(), +] +preprocessor = QualityPreprocessor(pipeline, dropped_log_path='dropped.jsonl') +cleaned = preprocessor(rows) +``` diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Protocol.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Protocol.md" new file mode 100644 index 000000000..1e03092d6 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Protocol.md" @@ -0,0 +1,91 @@ +# 协议(Protocol) + +Protocol 模块提供了抽象的 LLM API 客户端接口及其 OpenAI 兼容实现。它将 Twinkle 的 `Trajectory` / `SamplingParams` 数据类型与外部 LLM 推理服务连接起来。 + +## API 基类 + +```python +from abc import ABC, abstractmethod +from twinkle.data_format import Trajectory +from twinkle.data_format.message import Message +from twinkle.data_format.sampling import SamplingParams + +class API(ABC): + """抽象 LLM API 客户端:Trajectory + SamplingParams -> 助手 Message""" + + @abstractmethod + def __call__( + self, + trajectory: Trajectory, + sampling_params: SamplingParams, + **kwargs, + ) -> Union[Message, List[Message]]: + raise NotImplementedError() +``` + +`API` 类定义了一个简单的契约:给定对话轨迹和采样参数,返回一条或多条助手消息。 + +## OpenAI + +`OpenAI` 是内置实现,兼容任何支持 `/v1/chat/completions` 协议的端点(OpenAI、Azure OpenAI、vLLM、SGLang、Ollama 等)。 + +```python +from twinkle_agentic.protocol.openai import OpenAI + +api = OpenAI( + model='qwen3.5-32b', + base_url='http://localhost:8000/v1', + api_key='EMPTY', +) +``` + +### 参数 + +| 参数 | 类型 | 说明 | +|------|------|------| +| `model` | `str` | API 请求中传递的模型名称。 | +| `api_key` | `str` | API 密钥。默认使用 `OPENAI_API_KEY` 环境变量。 | +| `base_url` | `str` | API 端点的基础 URL(如 `http://localhost:8000/v1`)。 | +| `client_kwargs` | `Dict` | 转发给 `openai.OpenAI` 客户端构造函数的额外关键字参数。 | + +### 使用方法 + +```python +from twinkle.data_format import Trajectory +from twinkle.data_format.sampling import SamplingParams + +trajectory = { + 'messages': [ + {'role': 'user', 'content': '法国的首都是什么?'}, + ] +} + +sp = SamplingParams(temperature=0.7, max_tokens=512) +reply = api(trajectory, sp) +# reply 是一个 Message 字典:{'role': 'assistant', 'content': '...'} +``` + +### 特性 + +- **工具调用**:自动将 `trajectory['tools']` 映射到 API 请求,并解析响应中的结构化 `tool_calls`。 +- **推理内容**:保留支持推理的模型返回的 `reasoning_content`(如 o1 风格推理)。 +- **完成原因**:在返回消息中暴露 `finish_reason`,供多轮驱动器检测长度截断。 +- **多样本**:当 `sampling_params.num_samples > 1` 时,返回消息列表(每个 choice 一条)。 + +### 自定义 API 客户端 + +要集成非 OpenAI API,请继承 `API`: + +```python +from twinkle_agentic.protocol.base import API + +class MyCustomAPI(API): + + def __call__(self, trajectory, sampling_params, **kwargs): + # 调用自定义端点 + response = my_llm_client.chat( + messages=trajectory['messages'], + temperature=sampling_params.temperature, + ) + return {'role': 'assistant', 'content': response.text} +``` diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Rollout.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Rollout.md" new file mode 100644 index 000000000..b74c1e791 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Rollout.md" @@ -0,0 +1,140 @@ +# 多轮 Rollout + +Rollout 模块提供了用于 Agentic RLHF 训练的多轮对话 rollout 引擎。包含两种实现:用于批量 vLLM 采样的 `MultiTurnRollout` 和用于 OpenAI 兼容 API 端点的 `APIMultiTurnRollout`。 + +## Rollout 基类 + +```python +from abc import ABC, abstractmethod +from twinkle.data_format import Trajectory + +class Rollout(ABC): + + @abstractmethod + def __call__(self, trajectories: List[Trajectory], **kwargs) -> List[Trajectory]: + raise NotImplementedError() +``` + +所有 rollout 接受轨迹列表并返回相同数量的轨迹,附带额外字段(`messages`、`turns`、`stop_reason`、`truncated`)。 + +## MultiTurnRollout + +批量多轮 rollout 引擎,使用 vLLM 采样器进行生成。每轮中所有活跃轨迹通过单次批量采样调用并行处理,最大化吞吐量。 + +### 每轮循环 + +1. 将每个轨迹编码为带生成提示的 `InputFeature` +2. 批量调用 `sampler.sample(active_pifs)` —— 所有活跃轨迹并行 +3. 检查终止条件:`stop_reason == 'length'`、无工具调用、或达到最大轮次 +4. 通过 `ToolManager` 分发工具调用,追加工具响应 +5. 计算桥接 token(工具轮次 + 生成提示),设置 `labels = -100` +6. 重复直到所有轨迹完成 + +```python +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout +from twinkle_agentic.tools.tool_manager import ToolManager +from twinkle.data_format.sampling import SamplingParams + +rollout = MultiTurnRollout( + sampler=vllm_sampler, + template=template, + tool_manager=tool_manager, + sampling_params=SamplingParams(temperature=0.7, max_tokens=4096), + max_turns=6, + max_trajectory_tokens=8192, + trace_dir='rollout_traces/', +) + +# 运行 rollout +results = rollout(trajectories) +``` + +### 参数 + +| 参数 | 类型 | 说明 | +|------|------|------| +| `sampler` | Sampler | 用于批量生成的 vLLM 采样器实例。 | +| `template` | `Template` | 用于编码/解码的聊天模板。 | +| `tool_manager` | `ToolManager` | 工具分发器。也可以按调用传入。 | +| `sampling_params` | `SamplingParams` | 默认采样参数。 | +| `max_turns` | `int` | 每个轨迹的最大轮次(默认:6)。 | +| `max_trajectory_tokens` | `int` | 最大总 token 长度;超出则截断轨迹。 | +| `trace_dir` | `str` | 每轨迹 JSON 跟踪文件的目录。 | +| `trace_callback` | `Callable` | 决定是否存储轨迹跟踪。 | +| `success_callback` | `Callable` | 决定文件名前缀(`ok-` 或 `fail-`)。 | + +### 输出字段 + +每个输出轨迹字典包含: + +| 字段 | 类型 | 说明 | +|------|------|------| +| `messages` | `List[Dict]` | 包含工具轮次的完整对话。 | +| `input_ids` | `List[int]` | 完整序列的 token ID。 | +| `labels` | `List[int]` | 训练标签(非可训练 token 为 `-100`)。 | +| `turns` | `int` | 执行的轮次数。 | +| `stop_reason` | `str` | `'stop'` / `'length'` | +| `truncated` | `bool` | 轨迹是否被截断。 | +| `logprobs` | `List` | 每 token 的对数概率(如有)。 | + +### Ray 远程支持 + +`MultiTurnRollout` 使用 `@remote_class()` 装饰器,支持作为 Ray actor 透明部署: + +```python +# rollout 可以作为 Ray 远程 actor 运行 +rollout_actor = MultiTurnRollout.remote(sampler=sampler, template=template, ...) +results = ray.get(rollout_actor.__call__.remote(trajectories)) +``` + +## APIMultiTurnRollout + +通过 OpenAI 兼容 chat-completions API 进行多轮 rollout。每个轨迹在线程池中独立运行,实现网络并发。 + +```python +from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout +from twinkle_agentic.protocol.openai import OpenAI + +api = OpenAI(model='qwen3.5-32b', base_url='http://localhost:8000/v1') + +rollout = APIMultiTurnRollout( + api=api, + tool_manager=tool_manager, + sampling_params=SamplingParams(temperature=0.7), + max_turns=6, + concurrency=8, + trace_dir='api_traces/', +) + +results = rollout(trajectories) +``` + +### 参数 + +| 参数 | 类型 | 说明 | +|------|------|------| +| `api` | `OpenAI` | OpenAI 兼容 API 客户端。 | +| `tool_manager` | `ToolManager` | 工具分发器(单个或按轨迹的列表)。 | +| `sampling_params` | `SamplingParams` | 默认采样参数。 | +| `max_turns` | `int` | 每轨迹最大轮次(默认:6)。 | +| `concurrency` | `int` | 并行 API 调用的线程池大小(默认:8)。 | +| `extra_body` | `Dict` | API 请求中附加的额外字段。 | +| `trace_dir` | `str` | 跟踪文件目录。 | + +### 停止原因 + +| 原因 | 说明 | +|------|------| +| `stop` | 助手回复未包含工具调用(自然结束)。 | +| `length` | API 返回 `finish_reason='length'`(token 限制)。 | +| `max_turns` | 达到 `max_turns` 限制。 | +| `api_error` | API 调用或工具执行抛出异常。 | + +## 选择建议 + +| 特性 | MultiTurnRollout | APIMultiTurnRollout | +|------|-----------------|---------------------| +| **后端** | vLLM 采样器(本地 GPU) | OpenAI 兼容 API | +| **训练集成** | 生成 `input_ids` / `labels` 用于 GRPO | 仅消息(用于数据收集) | +| **批处理** | GPU 级别批量并行 | 网络级别线程并发 | +| **用例** | 在线 RLHF 训练循环 | 离线数据生成 / 评估 | diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Tools.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Tools.md" new file mode 100644 index 000000000..122b75a14 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Tools.md" @@ -0,0 +1,119 @@ +# 工具与 ToolManager + +Tools 模块提供了抽象工具接口和中央工具分发器(`ToolManager`),用于 Agentic 多轮 rollout。工具遵循 OpenAI function-calling schema,与 LLM 工具调用能力无缝集成。 + +## Tool 基类 + +```python +from abc import ABC, abstractmethod +from twinkle.data_format import Tool as ToolInfo + +class Tool(ABC): + + @abstractmethod + def __call__(self, tool_name: str, arguments: Dict[str, Any]) -> str: + """执行工具并返回字符串结果。""" + raise NotImplementedError + + @abstractmethod + def tool_info(self) -> ToolInfo: + """返回 OpenAI 兼容的工具 schema。""" + raise NotImplementedError +``` + +### 实现自定义工具 + +```python +from twinkle_agentic.tools.base import Tool + +class SearchTool(Tool): + + def __call__(self, tool_name: str, arguments: dict) -> str: + query = arguments.get('query', '') + # 执行搜索逻辑 + return f'搜索结果:{query}' + + def tool_info(self): + return { + 'type': 'function', + 'function': { + 'name': 'search', + 'description': '搜索网络信息。', + 'parameters': { + 'type': 'object', + 'properties': { + 'query': { + 'type': 'string', + 'description': '搜索查询。', + }, + }, + 'required': ['query'], + }, + }, + } +``` + +## ToolManager + +`ToolManager` 是工具的注册中心和分发器。它解析 LLM 结构化输出中的工具调用,并路由到正确的工具实现。 + +```python +from twinkle_agentic.tools.tool_manager import ToolManager + +# 通过 Tool 实例列表初始化 +manager = ToolManager([search_tool, calculator_tool]) + +# 或通过字典初始化 +manager = ToolManager({'search': search_tool, 'calc': calculator_tool}) + +# 或动态注册 +manager = ToolManager() +manager.register(search_tool) +manager.register(calculator_tool) +``` + +### 核心方法 + +| 方法 | 说明 | +|------|------| +| `register(tool)` | 注册工具(名称从 `tool_info()` 提取)。 | +| `unregister(name)` | 按名称移除工具。 | +| `names()` | 列出所有已注册的工具名称。 | +| `copy()` | 创建管理器的浅拷贝。 | +| `tool_infos()` | 返回所有工具 schema 列表(用于 API 请求)。 | +| `__call__(tool_call)` | 分发工具调用并返回结果字符串。 | + +### 分发工具调用 + +`ToolManager` 接受 OpenAI 格式的工具调用字典: + +```python +tool_call = { + 'id': 'call_1', + 'type': 'function', + 'function': { + 'name': 'search', + 'arguments': '{"query": "Python 教程"}', + }, +} + +result = manager(tool_call) +# result: '搜索结果:Python 教程' +``` + +**错误处理:** 如果工具名未知、参数是无效 JSON 或工具抛出异常,`ToolManager` 返回描述性错误字符串而不是抛出异常——这保证了 rollout 循环的持续运行。 + +### 与 Rollout 集成 + +```python +from twinkle_agentic.rollout.multi_turn import MultiTurnRollout + +rollout = MultiTurnRollout( + sampler=sampler, + template=template, + tool_manager=manager, # 传入工具管理器 + max_turns=6, +) +``` + +Rollout 引擎对模型生成的每个工具调用执行 `manager(tool_call)`,并将结果作为 `{'role': 'tool', 'content': result}` 消息追加。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/index.rst" new file mode 100644 index 000000000..802034366 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/index.rst" @@ -0,0 +1,11 @@ +Agentic +=============== +.. toctree:: + :maxdepth: 1 + + Preprocessor.md + Protocol.md + Rollout.md + Tools.md + Envs.md + Multi-Turn-Tool-Usage.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/CLI/CLI.md" "b/docs/source_zh/\347\273\204\344\273\266/CLI/CLI.md" new file mode 100644 index 000000000..e21c5ea44 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/CLI/CLI.md" @@ -0,0 +1,134 @@ +# CLI 命令行配置 + +CLI 模块为 Twinkle 训练脚本提供统一的配置系统。它将多种配置来源(环境变量、`.env` 文件、YAML 配置、命令行参数)合并到一个带类型的 `Args` 数据类中。 + +## 配置优先级 + +配置按以下顺序应用(后者覆盖前者): + +1. **数据类默认值** — 开箱即用 +2. **`.env` 文件** — 项目本地配置 +3. **环境变量** — `TWINKLE_` 前缀或裸键名 +4. **YAML 配置文件** — `--config path/to/config.yaml` +5. **命令行参数** — `--key value`(最高优先级) + +所有键名不区分大小写,横杠和下划线等价。 + +## 快速开始 + +```python +from twinkle.cli import CLI + +args = CLI.from_args() + +# 访问类型化的参数组 +print(args.model.model_id) +print(args.training.max_steps) +print(args.optimizer.learning_rate) + +# 或获取字典用于组件构造 +model_kwargs = args.get_model_args() +optimizer_kwargs = args.get_optimizer_args() +``` + +## 参数组 + +| 分组 | 类名 | 关键参数 | +|:-----|:-----|:---------| +| model | `ModelArgs` | `model_id`, `mixed_precision`, `strategy`, `gradient_checkpointing` | +| lora | `LoraArgs` | `use_lora`, `lora_r`, `lora_alpha`, `lora_target_modules` | +| dataset | `DatasetArgs` | `dataset_id`, `subset_name`, `split`, `streaming` | +| template | `TemplateArgs` | `template_cls`, `max_length`, `truncation_strategy`, `enable_thinking` | +| training | `TrainingArgs` | `max_steps`, `batch_size`, `micro_batch_size`, `output_dir`, `save_steps` | +| optimizer | `OptimizerArgs` | `optimizer_cls`, `learning_rate`, `weight_decay`, `max_grad_norm` | +| scheduler | `SchedulerArgs` | `scheduler_cls`, `num_warmup_steps`, `t_max` | +| loss | `LossArgs` | `loss_cls`, `epsilon`, `beta`, `sft_weight` | +| sampler | `SamplerArgs` | `sampler_type`, `gpu_memory_utilization`, `tensor_parallel_size` | +| sampling | `SamplingArgs` | `max_tokens`, `temperature`, `top_k`, `top_p`, `num_samples` | +| infra | `InfraArgs` | `mode`, `nproc_per_node`, `model_gpus`, `sampler_gpus`, `dp_size` | +| server | `ServerArgs` | `config`, `host`, `port`, `ray_namespace` | +| rl | `RLArgs` | `num_generations`, `advantage_type`, `reward_fns` | +| checkpoint | `CheckpointArgs` | `save_optimizer`, `merge_and_sync`, `platform` | + +## YAML 配置示例 + +```yaml +# config.yaml +model_id: ms://Qwen/Qwen3.5-4B +mixed_precision: bf16 +strategy: accelerate + +use_lora: true +lora_r: 16 +lora_alpha: 32 + +dataset_id: ms://swift/self-cognition +max_length: 4096 + +batch_size: 8 +micro_batch_size: 2 +max_steps: 200 +learning_rate: 1e-5 + +mode: ray +nproc_per_node: 8 +model_gpus: 4 +sampler_gpus: 4 +``` + +## 命令行用法 + +```bash +# 使用 YAML 配置 +python train.py --config config.yaml + +# 覆盖特定值 +python train.py --config config.yaml --learning_rate 5e-6 --max_steps 500 + +# 布尔标志 +python train.py --use_lora --no_gradient_checkpointing + +# 无配置文件(全部从命令行指定) +python train.py --model_id ms://Qwen/Qwen3.5-4B --batch_size 4 +``` + +## 环境变量 + +```bash +# TWINKLE_ 前缀 +export TWINKLE_MODEL_ID=ms://Qwen/Qwen3.5-4B +export TWINKLE_LEARNING_RATE=1e-5 + +# 或裸键名(当能识别时) +export MODEL_ID=ms://Qwen/Qwen3.5-4B +``` + +## 字段别名 + +部分字段支持别名: + +- `learning_rate` ↔ `lr` +- `nproc_per_node` ↔ `num_gpus` +- `max_tokens` ↔ `max_new_tokens` +- `use_megatron=true` → `strategy=native_fsdp` + +## 自定义配置源 + +你可以通过自定义配置源扩展 CLI: + +```python +from twinkle.cli.cli import ConfigSource, Args, ConfigResolver + +class RemoteConfigSource(ConfigSource): + def __init__(self, url: str): + self.url = url + + def load(self) -> dict: + import requests + return requests.get(self.url).json() + +# 应用自定义配置源 +args = Args() +resolver = ConfigResolver(args) +resolver.apply(RemoteConfigSource('http://config-server/my-config').load()) +``` diff --git a/docs/source_en/Components/Gym/index.rst "b/docs/source_zh/\347\273\204\344\273\266/CLI/index.rst" similarity index 76% rename from docs/source_en/Components/Gym/index.rst rename to "docs/source_zh/\347\273\204\344\273\266/CLI/index.rst" index 85d941b97..cf59fa766 100644 --- a/docs/source_en/Components/Gym/index.rst +++ "b/docs/source_zh/\347\273\204\344\273\266/CLI/index.rst" @@ -1,6 +1,6 @@ -Gym +CLI =============== .. toctree:: :maxdepth: 1 - Gym.md + CLI.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" "b/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" deleted file mode 100644 index 63dc87aa7..000000000 --- "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" +++ /dev/null @@ -1,26 +0,0 @@ -# Gym - -Gym 组件为 Twinkle 中的强化学习环境提供接口。 - -```python -from twinkle.gym import Gym - -class CustomGym(Gym): - - def step(self, trajectories, **kwargs): - """ - 执行一个 RL 步骤:评估轨迹并返回奖励。 - - Args: - trajectories: 模型生成的待评估轨迹 - **kwargs: 额外参数 - - Returns: - 每个轨迹的奖励值 - """ - ... -``` - -Gym 抽象允许你插入自定义 RL 环境与训练循环交互。它将奖励计算和环境交互与核心训练逻辑解耦。 - -> Gym 通常用于在线策略 RL 训练中,环境需要对模型生成的输出提供反馈。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/TUI/Auto-Research.md" "b/docs/source_zh/\347\273\204\344\273\266/TUI/Auto-Research.md" new file mode 100644 index 000000000..9624f5ae7 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/TUI/Auto-Research.md" @@ -0,0 +1,313 @@ +# Auto-Research (TUI) + +Twinkle TUI 是一个基于终端的智能训练助手,支持通过**自然语言控制、监控和调试 ML 训练**。它将聊天驱动的 AI 代理与实时指标可视化、日志流、以及自动化健康监控器相结合,能够自主检测并修复训练故障。 + +## 架构概览 + +``` +┌──────────────────────────────────────────────────────────┐ +│ TwinkleTUI (Textual 应用) │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ StatusBar: 状态 / run_id / 模型 / step / 进度条 │ │ +│ ├──────────────────────┬───────────────────────────────┤ │ +│ │ MetricsPanel │ LogPanel │ │ +│ │ (ASCII 图表) │ (滚动日志) │ │ +│ ├──────────────────────┤ │ │ +│ │ ChatPanel │ │ │ +│ │ (用户 <-> 代理) │ │ │ +│ └──────────────────────┴───────────────────────────────┘ │ +│ │ +│ 后台服务: │ +│ AgentLoop ─── LLM 工具调用循环 │ +│ TrainingMonitor ─── 定期健康检查与自动修复 │ +│ MetricsPoller ─── 增量指标读取 │ +│ LogsPoller ─── 增量日志尾读 │ +│ SkillsLoader ─── 异步插件加载 │ +└──────────────────────────────────────────────────────────┘ +``` + +## 安装与启动 + +TUI 是 `twinkle-client` 包的一部分: + +```bash +pip install twinkle-client +``` + +### 命令行用法 + +```bash +# 基本启动(使用默认本地 Ollama 端点) +twinkle-tui + +# 指定 LLM 后端 +twinkle-tui --llm-base-url http://localhost:11434/v1 --llm-model qwen3.5 + +# 连接到已有训练运行 +twinkle-tui --run-id my-grpo-run + +# 使用远程 API(如 OpenAI 兼容接口) +twinkle-tui --llm-base-url https://api.example.com/v1 --llm-api-key sk-xxx --llm-model gpt-4o + +# 启用调试日志 +twinkle-tui --verbose +``` + +也可作为 Python 模块运行: + +```bash +python -m twinkle_client.tui +``` + +### CLI 参数 + +| 参数 | 环境变量 | 默认值 | 说明 | +|------|---------|--------|------| +| `--run-id`, `-r` | `TWINKLE_TUI_RUN_ID` | None | 连接到已有训练运行 | +| `--llm-base-url` | `TWINKLE_LLM_BASE_URL` | `http://localhost:11434/v1` | LLM API 基础 URL | +| `--llm-model` | `TWINKLE_LLM_MODEL` | `qwen3.5` | LLM 模型名称 | +| `--llm-api-key` | `TWINKLE_LLM_API_KEY` | `not-needed` | LLM API 密钥 | +| `--verbose`, `-v` | `TWINKLE_TUI_VERBOSE` | `False` | 启用 DEBUG 日志 | +| `--version`, `-V` | — | — | 显示版本并退出 | + +### 快捷键 + +| 按键 | 操作 | +|------|------| +| `q` | 退出 | +| `Ctrl+P` | 切换指标面板 | +| `Ctrl+L` | 清空日志 | + +## 聊天代理 + +TUI 的核心是一个 **LLM 驱动的工具调用代理**(`AgentLoop`),通过 OpenAI 兼容 API 处理自然语言命令。代理维护对话历史并自动修剪(保留最近 50 条消息),每次交互最多支持 10 轮工具调用。 + +### 你可以这样说 + +**训练生命周期:** +- *"列出我的训练运行"* +- *"用 Qwen3.5-4B 在 gsm8k 上启动一个新的 GRPO 训练"* +- *"暂停当前运行"* +- *"恢复训练"* +- *"停止训练"* + +**服务器管理:** +- *"启动服务器,使用 Qwen3.5-4B 和一个 2 卡的 Qwen3.5-72B 采样器"* +- *"关闭服务器"* +- *"有多少 GPU 可用?"* + +**监控与分析:** +- *"训练进展如何?"* +- *"显示 reward 相关的指标"* +- *"放大到 step 100-200"* +- *"重置图表视图"* + +**搜索:** +- *"搜索数学数据集"* +- *"在 ModelScope 上查找 Qwen 模型"* + +### 可用工具 + +代理内置 13 个工具: + +| 工具 | 说明 | +|------|------| +| `list_training_runs` | 列出所有训练运行 | +| `get_training_status` | 获取详细状态和最近指标 | +| `start_server` | 启动 Ray 集群 + Twinkle Server(幂等) | +| `shutdown_server` | 关闭服务器并释放 GPU 资源 | +| `start_training` | 创建并启动新的训练运行 | +| `select_run` | 切换监控到另一个运行 | +| `pause_training` | 暂停训练(SIGKILL,服务器保留状态) | +| `resume_training` | 通过重新启动客户端脚本恢复训练 | +| `stop_training` | 停止训练(SIGTERM,保存检查点) | +| `update_script` | 更新训练脚本(带版本归档) | +| `list_supported_models` | 查询服务器支持的模型 | +| `search_datasets` | 在 ModelScope 搜索数据集 | +| `search_models` | 在 ModelScope 搜索模型 | +| `zoom_metrics` | 调整指标图表视图范围 | +| `select_metrics` | 选择显示哪些指标(最多 4 个) | +| `get_cluster_info` | 获取 GPU/集群资源信息 | + +### 服务器启动 + +`start_server` 工具自动化一个多步骤流程: + +1. **GPU 检测** — `nvidia-smi` 硬件扫描 +2. **GPU 分配** — 在训练模型和采样器之间分配 GPU +3. **配置生成** — 自动创建 `server_config.yaml` +4. **Ray 集群启动** — 多节点 GPU 分区,隔离 `CUDA_VISIBLE_DEVICES` +5. **服务器启动** — 作为后台进程启动 Twinkle Server +6. **健康检查** — 轮询 `/api/v1/healthz` 直到就绪 + +支持多模型拓扑:1 个训练模型 + N 个采样器/教师模型。 + +### Skills 系统 + +TUI 支持从三个来源加载可扩展的技能插件: + +1. **内置技能** — 包含在 `twinkle_client/skills/bundled/` 中 +2. **用户本地技能** — `~/.cache/twinkle/tui/skills/local/` +3. **社区技能** — 从 ModelScope 获取(尽力而为,10 秒超时) + +技能在启动后异步加载并注入代理的系统提示词中。代理在技能加载完成前即可使用。 + +## 训练监控器(自动修复) + +`TrainingMonitor` 是一个后台服务,每 **30 秒**运行一次,收集当前训练运行的所有可用信号,并提交给 LLM 进行分析。 + +### 收集的信号 + +- **进程状态**:alive / dead / unknown +- **output.log 尾部**:最后 1500 个字符(优先提取 traceback) +- **指标**:最近条目 + 前半段 vs 后半段趋势分析 +- **停滞时长**:自最后一次产生指标以来的秒数 +- **当前 train.py**:完整脚本源码(用于精确修复) + +### 决策框架 + +LLM 将每次检查分类为三种操作之一: + +| 决策 | 触发条件 | 执行动作 | +|------|---------|---------| +| **LGTM** | 训练正常推进 | 无操作 | +| **WARNING** | Loss 平台期、reward hacking、KL 爆炸等 | 向用户报告观察结果 | +| **FIX** | 脚本崩溃、进程死亡并有 traceback | 自动修复并重启 | + +### 自动修复流程 + +当需要 FIX 时: + +1. LLM 输出诊断 + 完整修复脚本 +2. 监控器将旧 `train.py` 归档为 `train_v{N}.py` +3. 将修复脚本写为新的 `train.py` +4. 通过 `resume_training` 重新启动训练 +5. 重置停滞追踪 + +安全保障: +- 每个运行最多 **3 次自动修复尝试**(防止无限重试循环) +- 修复尝试按 `run_id` 追踪 +- 快照去重避免对未变化状态的重复分析 + +## 基于文件的连接层 + +TUI 通过本地文件系统与训练进程通信: + +``` +~/.cache/twinkle/{run_id}/ +├── meta.json — 运行元数据(model_id、config、status、pid) +├── metrics.jsonl — 每步一个 JSON 对象(增量) +├── output.log — 训练的 stdout+stderr 合并输出 +├── train.py — 当前活动训练脚本 +└── train_v{N}.py — 归档的历史脚本版本 +``` + +### 训练控制模型 + +在 Server 模式下,Twinkle Server 将所有模型/优化器状态保留在 GPU 内存中: + +- **暂停** = 杀死客户端进程 (SIGKILL) — 服务器状态保留 +- **恢复** = 重新启动客户端脚本 — 无缝继续训练 +- **停止** = SIGTERM — 触发检查点保存后退出 +- **关闭服务器** = 释放 GPU 资源,**销毁**模型状态 + +## TrainingRuntime(脚本集成) + +训练脚本使用 `TrainingRuntime` 与 TUI 集成: + +```python +from twinkle_client.tui.runtime import TrainingRuntime + +rt = TrainingRuntime(run_id='my-grpo-run') +rt.start(model_id='Qwen/Qwen3.5-4B', config={'lr': 1e-5}) +rt.register_graceful_shutdown(model, dataloader) + +for step, batch in enumerate(dataloader): + # ... 训练逻辑 ... + rt.log_metrics(step=step, loss=loss, reward=reward, grad_norm=gn, lr=lr) + rt.log(f'Completed step {step}, loss={loss:.4f}') + +rt.finish() +``` + +### 核心方法 + +| 方法 | 说明 | +|------|------| +| `start(model_id, config, script_path)` | 初始化运行目录和元数据 | +| `log_metrics(**kwargs)` | 向 `metrics.jsonl` 写入指标条目 | +| `log(message)` | 打印日志消息(被捕获为 `output.log`) | +| `get_resume_info()` | 获取 `last_step` 用于从检查点恢复 | +| `finish(status)` | 标记训练完成,关闭文件 | +| `register_graceful_shutdown(model, dataloader)` | 注册 SIGTERM 处理器以保存检查点 | + +### 断点续训支持 + +`TrainingRuntime` 自动将训练进度保存到 `meta.json`(每 5 秒节流写入一次)。脚本可以使用 `get_resume_info()` 从上次保存的步数恢复: + +```python +rt = TrainingRuntime(run_id='my-run') +resume = rt.get_resume_info() +global_step = resume['last_step'] + +if global_step > 0: + dataloader.skip_consumed_samples(global_step * BATCH_SIZE) + print(f'从 step {global_step} 恢复训练') +``` + +### 优雅关停 + +调用 `register_graceful_shutdown()` 后,会安装一个 SIGTERM 处理器: + +1. 保存模型检查点(LoRA 权重 + 优化器状态) +2. 保存数据加载器位置(`consumed_train_samples`) +3. 记录检查点路径 +4. 标记训练为 `stopped` 并退出 + +## UI 面板 + +### StatusBar(状态栏) + +显示在屏幕顶部的当前训练状态: + +- 训练状态图标(🚀 训练中 / ⏸ 已暂停 / ✅ 已完成 / ❌ 错误) +- Run ID +- 模型名称 +- 当前步数 +- 百分比进度条 + +### MetricsPanel(指标面板) + +使用 `plotext` 渲染的实时 ASCII 图表: + +- 同时绘制最多 4 个指标 +- 支持缩放(按步数范围和 y 轴范围) +- 未选择时自动显示前 3 个可用指标 +- 提示栏显示可通过代理切换的隐藏指标 +- 保留最多 2000 个数据点 + +### LogPanel(日志面板) + +滚动日志查看器: + +- 自动剥离 ANSI 转义序列 +- 硬换行长行以防止溢出 +- 处理进度条的 `\r` 回车符 +- 保留最后 500 行 + +### ChatPanel(聊天面板) + +交互式聊天界面: + +- 用户输入,流式代理响应 +- 节流令牌刷新(80ms)确保平滑显示 +- 工具调用检测时流重置 +- 支持 Rich 标记格式 + +## 日志记录 + +所有 TUI 日志写入 `./tui.log`(当前工作目录): + +- 5MB 时轮转,保留 3 个备份 +- **无控制台输出** — 避免破坏 Textual 的 alt-screen 缓冲区 +- 使用 `--verbose` 启用 DEBUG 级别日志 diff --git "a/docs/source_zh/\347\273\204\344\273\266/TUI/SkillProvider\346\212\200\350\203\275\347\263\273\347\273\237.md" "b/docs/source_zh/\347\273\204\344\273\266/TUI/SkillProvider\346\212\200\350\203\275\347\263\273\347\273\237.md" new file mode 100644 index 000000000..11637331e --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/TUI/SkillProvider\346\212\200\350\203\275\347\263\273\347\273\237.md" @@ -0,0 +1,71 @@ +# SkillProvider 技能系统 + +技能系统允许 Twinkle 的 TUI 智能体从外部来源(Git 仓库、API、本地文件)动态加载专业知识,并注入到 LLM 的系统提示词中。 + +## 架构 + +| 类 | 角色 | +|----|------| +| **Skill** | 持有单个技能名称、内容和来源的数据类 | +| **SkillProvider** | 从数据源获取技能的抽象基类 | +| **SkillManager** | 编排多个 Provider,聚合技能用于提示词注入 | + +## Skill 数据类 + +```python +@dataclasses.dataclass +class Skill: + name: str # 简短标识符(通常为文件名去除扩展名) + content: str # 完整的 Markdown 内容 + source: str # Provider 名称 + 相对路径,用于可追溯性 +``` + +## 创建自定义 Provider + +继承 `SkillProvider` 并实现 `name` 和 `fetch()`: + +```python +from twinkle_client.skills.base import SkillProvider + +class MySkillProvider(SkillProvider): + + @property + def name(self) -> str: + return 'my-skills' + + async def fetch(self) -> None: + # 将技能文件下载/克隆到 self.cache_dir + # 例如:git clone、API 下载、文件拷贝 + ... +``` + +默认的 `load_skills()` 会扫描 `self.cache_dir` 中的 `.md` 文件(跳过 README、LICENSE 等),返回 `Skill` 对象。 + +## SkillManager + +```python +from twinkle_client.skills.manager import SkillManager + +manager = SkillManager() +manager.register(my_provider) +manager.register(another_provider) + +# 拉取并加载所有技能 +skills = await manager.load_all() + +# 格式化为 LLM 系统提示词注入内容 +prompt_section = manager.format_for_prompt() +``` + +### 关键方法 + +| 方法 | 说明 | +|------|------| +| `register(provider)` | 添加技能 Provider | +| `load_all()` | 从所有 Provider 拉取并加载 | +| `format_for_prompt()` | 将技能渲染为系统提示词格式 | +| `get_skill_names()` | 列出已加载技能名称 | + +## 缓存目录 + +默认缓存在 `~/.cache/twinkle/tui/skills//`。可通过向 Provider 构造函数传入 `cache_dir` 参数覆盖。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/TUI/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/TUI/index.rst" new file mode 100644 index 000000000..32ec8dc40 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/TUI/index.rst" @@ -0,0 +1,7 @@ +TUI +=============== +.. toctree:: + :maxdepth: 1 + + Auto-Research.md + SkillProvider技能系统.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/GRPOProcessor.md" "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/GRPOProcessor.md" deleted file mode 100644 index afb8f0948..000000000 --- "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/GRPOProcessor.md" +++ /dev/null @@ -1,19 +0,0 @@ -# GRPOLossProcessor - -GRPOLossProcessor 是专为 GRPO 强化学习训练设计的任务处理器包装器。它在 InputProcessor 基础上扩展了 GRPO 特有的数据准备功能。 - -```python -from twinkle.processor import GRPOLossProcessor - -processor = GRPOLossProcessor( - device_mesh=..., - padding_free=False, - framework='transformers', -) - -model.set_processor(processor) -``` - -GRPOLossProcessor 包装了基础 `InputProcessor`,并添加了 GRPO 特有字段的处理,如优势值、旧对数概率和参考对数概率,这些是 GRPO 损失函数所需要的。 - -> 对于标准 SFT 任务,直接使用 `InputProcessor`。当训练循环涉及 GRPO 或其变体时,使用 `GRPOLossProcessor`。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst" index 1eb839f0e..a2c88eaf4 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst" +++ "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst" @@ -4,4 +4,3 @@ :maxdepth: 1 InputProcessor.md - GRPOProcessor.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/EmbeddingMetric.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/EmbeddingMetric.md" new file mode 100644 index 000000000..ab770498c --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/EmbeddingMetric.md" @@ -0,0 +1,31 @@ +# EmbeddingMetric + +`EmbeddingMetric` 跟踪对比学习(InfoNCE)训练中的嵌入质量,报告锚点-正样本余弦相似度和批内负样本相似度。 + +## 使用方法 + +```python +from twinkle.metric import EmbeddingMetric + +metric = EmbeddingMetric(device_mesh=device_mesh, process_group=process_group) + +# 训练中 +metric.accumulate(inputs, outputs) + +# 日志间隔时 +results = metric.calculate() +# results: {'pos_sim': '0.8523', 'neg_sim': '0.2134', 'loss': '0.3412', ...} +``` + +## 输出指标 + +| 指标 | 说明 | +|:-----|:-----| +| `pos_sim` | 锚点与正样本的平均余弦相似度 | +| `pos_sim_min` | 批内最小正样本相似度 | +| `pos_sim_max` | 批内最大正样本相似度 | +| `neg_sim` | 锚点与其他正样本(批内负样本)的平均相似度 | +| `loss` | 平均对比损失值 | +| `grad_norm` | 梯度范数 | + +> 此指标与 `InfonceLoss` 配合使用,适用于嵌入/检索模型训练。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/GRPOMetric.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/GRPOMetric.md" new file mode 100644 index 000000000..434dc17c2 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/GRPOMetric.md" @@ -0,0 +1,39 @@ +# GRPOMetric + +`GRPOMetric` 跟踪 GRPO 训练中的策略优化诊断指标,包括 KL 散度、裁剪率、熵和对数概率统计。 + +## 使用方法 + +```python +from twinkle.metric import GRPOMetric + +metric = GRPOMetric( + device_mesh=device_mesh, + process_group=process_group, + epsilon=0.2, # PPO 裁剪范围 + temperature=1.0, # 用于 logp 重缩放的采样温度 + top_k_kl=10, # 每步记录 top-K 高 KL token +) + +# 训练循环中 +metric.accumulate(inputs, outputs, old_logps=old_logps, advantages=advantages) + +# 日志间隔时 +results = metric.calculate() +``` + +## 输出指标 + +| 指标 | 说明 | +|:-----|:-----| +| `train/policy_confidence` | exp(mean_new_logp) — 越高表示模型越自信 | +| `train/mean_new_logp` | 当前策略下生成 token 的平均对数概率 | +| `train/mean_old_logp` | 参考策略下的平均对数概率 | +| `train/approx_kl` | Schulman K3 KL 估计器 | +| `train/entropy` | 平均 token 级熵 | +| `train/clip_ratio` | 被裁剪的 token 比例 | + +## 变体 + +- **`GSPOMetric`** — 序列级裁剪率(几何平均比率) +- **`CISPOMetric`** — 无条件裁剪率(不按优势符号门控) diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst" index 6e03f97cf..d5ba804be 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst" @@ -3,6 +3,19 @@ .. toctree:: :maxdepth: 1 + TrainMetric.md + LossMetric.md + Accuracy.md + CompletionRewardMetric.md + DPOMetric.md + GRPOMetric.md + EmbeddingMetric.md + 构建指标.md +指标 +=============== +.. toctree:: + :maxdepth: 1 + TrainMetric.md LossMetric.md Accuracy.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/InfoNCELoss.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/InfoNCELoss.md" new file mode 100644 index 000000000..f8fbaa1be --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/InfoNCELoss.md" @@ -0,0 +1,68 @@ +# InfoNCE 损失 + +`InfonceLoss` 实现带批内负样本和可选跨 rank 聚合的对比学习损失,用于嵌入/检索模型训练。 + +## 使用方法 + +```python +from twinkle.loss import InfonceLoss + +loss_fn = InfonceLoss( + temperature=0.1, + use_batch=True, # 启用批内负样本 + hard_negatives=7, # 固定每样本负样本数 + mask_fake_negative=True, # 遮蔽假负样本 + fake_neg_margin=0.1, # 假负样本检测阈值 +) + +model.set_loss(loss_fn) +``` + +## 输入格式 + +每个样本按 `锚点(1) + 正样本(1) + 负样本(n)` 排列。`inputs['labels']` 是一维掩码,`1` 标记每组的起始位置。 + +``` +embeddings: [a0, p0, n0_1, n0_2, a1, p1, n1_1, n1_2, ...] +labels: [ 1, 0, 0, 0, 1, 0, 0, 0, ...] +``` + +## 参数 + +| 参数 | 类型 | 默认值 | 说明 | +|:-----|:-----|:-------|:-----| +| `temperature` | float | 0.1 | 相似度缩放因子 | +| `use_batch` | bool | True | 使用跨样本批内负样本 | +| `hard_negatives` | int | None | 固定每样本负样本数(截断/上采样)| +| `mask_fake_negative` | bool | False | 遮蔽高于 positive + margin 的 logit | +| `fake_neg_margin` | float | 0.1 | 假负样本遮蔽阈值 | +| `include_qq` | bool | False | 添加 query-query 相似度块 | +| `include_dd` | bool | False | 添加 doc-doc 相似度块 | + +## 跨 Rank 聚合 + +当 `use_batch=True` 且分布式训练激活时,嵌入会从所有 DP rank 聚合以最大化批内负样本多样性。仅本地分片保留梯度。 + +## 相似度块 + +该损失支持三种相似度块,提供全面的对比学习信号: + +- **Q→D(默认)**:Query 到所有 Document — 主要对比信号 +- **Q→Q**(`include_qq=True`):Query 到其他所有 Query — 防止 query 坍缩 +- **D→D**(`include_dd=True`):Document 到其他所有 Document — Qwen3-Embedding 风格 + +## 示例:Embedding 训练 + +```python +from twinkle.loss import InfonceLoss +from twinkle.metric import EmbeddingMetric + +# 配置 Embedding 模型 +model.set_loss(InfonceLoss(temperature=0.05, use_batch=True, include_qq=True)) +model.set_metric(EmbeddingMetric(device_mesh=mesh, process_group=pg)) + +# 训练循环 +for batch in dataloader: + model.forward_backward(batch) + model.clip_grad_and_step() +``` diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst" index ea813f56f..0a2a890cf 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst" @@ -3,6 +3,19 @@ .. toctree:: :maxdepth: 1 + CrossEntropy.md + ChunkedCrossEntropy.md + DPOLoss.md + GKDLoss.md + GRPOLoss.md + InfoNCELoss.md + MSELoss.md + 构建损失.md +损失 +=============== +.. toctree:: + :maxdepth: 1 + CrossEntropy.md ChunkedCrossEntropy.md DPOLoss.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md" index 4017aea7e..5ae13f739 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md" @@ -30,3 +30,48 @@ class MultiLoraTransformersModel: 正因如此,用户的r必须要小于等于max_r的配置,在实际训练时仅会使用lora的部分rank参与计算。 MultiLoraTransformersModel支持`@remote_class`注解,并且支持device_mesh,这意味着它可以运行在ray的worker中。 + +## 租户生命周期 + +底层使用 `MultiLora` 管理器来处理租户 LoRA 槽位。关键 API: + +### acquire_lora + +为租户获取一个可用的 LoRA 槽位: + +```python +adapter_name = model.multi_lora.acquire_lora('tenant_a', LoraConfig(r=16, lora_alpha=32)) +``` + +- 如果所有槽位已被占用或 `config.r > max_r`,则抛出 `RuntimeError` + +### release_lora + +释放租户的 LoRA 槽位,权重重置为初始状态: + +```python +model.multi_lora.release_lora('tenant_a') +``` + +### 上下文管理器 + +使用 `adapter()` 进行作用域激活: + +```python +with model.multi_lora.adapter('tenant_a') as name: + output = model.forward(inputs) +``` + +### LoraTenant + +每个槽位以 `LoraTenant` 数据类追踪: + +```python +@dataclass +class LoraTenant: + index: int # 槽位索引 (0..max_loras-1) + adapter_name: str # 内部名称(如 "lora_0") + config: LoraConfig # 预分配配置(max_r) + tenant_adapter_name: str # 面向用户的租户名(空闲时为 None) + tenant_config: LoraConfig # 租户实际配置(空闲时为 None) +``` diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/SupportedModels.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/SupportedModels.md" new file mode 100644 index 000000000..bfbb03ea0 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/SupportedModels.md" @@ -0,0 +1,77 @@ +# 支持的模型 + +Twinkle 支持任何兼容 HuggingFace Transformers 或 Megatron-LM 的模型。以下是经过测试的模型列表。 + +## 语言模型 + +| 模型系列 | 模型 ID | 参数量 | 特性 | +|:---------|:--------|:-------|:-----| +| Qwen 3.5 | `Qwen/Qwen3.5-0.6B` ~ `Qwen/Qwen3.5-235B-A22B` | 0.6B–235B | MoE、思考模式 | +| Qwen 2.5 | `Qwen/Qwen2.5-0.5B` ~ `Qwen/Qwen2.5-72B` | 0.5B–72B | Dense | +| DeepSeek V4 | `deepseek-ai/DeepSeek-V4` | 685B MoE | 自定义 DSML 编码 | +| DeepSeek R1 | `deepseek-ai/DeepSeek-R1` | 685B MoE | 推理 | +| LLaMA 3 | `meta-llama/Llama-3.3-70B-Instruct` | 8B–70B | Dense | +| Mistral | `mistralai/Mistral-7B-v0.3` | 7B | Dense | +| Yi | `01-ai/Yi-1.5-34B` | 6B–34B | Dense | +| GLM-4 | `THUDM/glm-4-9b-chat` | 9B | Dense | +| InternLM 2.5 | `internlm/internlm2_5-7b-chat` | 7B–20B | Dense | + +## 视觉语言模型 + +| 模型系列 | 模型 ID | 特性 | +|:---------|:--------|:-----| +| Qwen 3.5 VL | `Qwen/Qwen3.5-VL-3B` ~ `Qwen/Qwen3.5-VL-72B` | 图片、视频 | +| Qwen 2.5 VL | `Qwen/Qwen2.5-VL-7B-Instruct` | 图片、视频 | +| InternVL 2.5 | `OpenGVLab/InternVL2_5-8B` | 图片 | + +## 嵌入模型 + +| 模型系列 | 模型 ID | 训练方法 | +|:---------|:--------|:---------| +| Qwen3 Embedding | `Qwen/Qwen3-Embedding-0.6B` | InfoNCE 对比学习 | +| GTE | `thenlper/gte-large-zh` | InfoNCE 对比学习 | + +## 模型加载 + +```python +from twinkle.model import TransformersModel + +# 从 ModelScope 加载(ms:// 前缀) +model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + +# 从 HuggingFace 加载(hf:// 前缀) +model = TransformersModel(model_id='hf://meta-llama/Llama-3.3-70B-Instruct') + +# 本地路径 +model = TransformersModel(model_id='/path/to/model') +``` + +## 框架支持 + +| 框架 | 类名 | 适用场景 | +|:-----|:-----|:---------| +| Transformers | `TransformersModel` | 通用训练(SFT、RLHF、DPO)| +| Transformers + Multi-LoRA | `MultiLoraTransformersModel` | 多租户训练 | +| Megatron-LM | `MegatronModel` | 大规模分布式预训练 | +| Megatron + Multi-LoRA | `MultiLoraMegatronModel` | 大规模多租户 | + +## 精度支持 + +| 模式 | 说明 | +|:-----|:-----| +| `bf16` | BFloat16 混合精度(推荐 A100/H100)| +| `fp16` | Float16 混合精度(适用于旧 GPU)| +| `fp8` | FP8 精度(H100 + Transformer Engine)| +| `no` | 全精度(仅用于调试)| + +## 并行策略 + +| 策略 | 配置键 | 说明 | +|:-----|:-------|:-----| +| FSDP | `strategy=accelerate` | Accelerate 管理的 FSDP(默认)| +| 原生 FSDP | `strategy=native_fsdp` | PyTorch 原生 FSDP | +| 张量并行 | `tp_size` | 跨 GPU 切分层 | +| 流水线并行 | `pp_size` | 切分模型阶段 | +| 数据并行 | `dp_size` | 复制模型,切分数据 | +| 序列并行 | `sequence_parallel` | 切分长序列 | +| 专家并行 | `ep_size` | MoE 专家分布 | diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst" index 713ea35c6..d20155bd7 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst" @@ -8,3 +8,14 @@ MultiLoraTransformersModel.md MegatronModel.md MultiLoraMegatronModel.md + SupportedModels.md +模型 +=============== +.. toctree:: + :maxdepth: 1 + + TwinkleModel.md + TransformersModel.md + MultiLoraTransformersModel.md + MegatronModel.md + MultiLoraMegatronModel.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/DeepSeekV4Template.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/DeepSeekV4Template.md" new file mode 100644 index 000000000..053b51051 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/DeepSeekV4Template.md" @@ -0,0 +1,30 @@ +# DeepSeek-V4 模板 + +`DeepseekV4Template` 为 DeepSeek V4 提供原生支持,包括其独特的思考模式、工具调用协议和多 token 特殊标记。 + +## 使用方法 + +```python +from twinkle.template import DeepseekV4Template + +template = DeepseekV4Template( + model_id='deepseek-ai/DeepSeek-V4', + enable_thinking=True, +) +``` + +## 特性 + +- **自定义 tokenizer 包装**:用 DeepSeek V4 的编码协议覆盖 `apply_chat_template` +- **思考模式**:支持 `thinking` / `chat` 模式切换 +- **工具调用**:原生 DSML 工具调用编码 +- **多 token EOS**:处理 DeepSeek V4 的多字符特殊标记 + +## 与基础模板的区别 + +| 特性 | 基础模板 | DeepseekV4Template | +|:-----|:---------|:-------------------| +| Chat 模板 | HuggingFace 原生 | 自定义 DSML 编码 | +| 思考模式 | `` 标签 | 原生思考模式开关 | +| 工具调用 | Hermes/Qwen 格式 | DSML 工具块 | +| EOS 处理 | 单 token | 多 token 特殊标记 | diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" index 364275b64..c3f5918e1 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" @@ -9,7 +9,7 @@ class Template: model_id: str, use_chat_template: bool = True, max_length: Optional[int] = 8192, - truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise', + truncation_strategy: Literal['raise', 'left', 'right', 'split', 'delete'] = 'raise', default_system: Optional[str] = None): ... @@ -42,7 +42,9 @@ class Template: - raise: 抛出异常。一般用于非常精确的数据集场景 - left: 移除左边的 token,使其符合 max_length - right: 移除右边的 token,使其符合 max_length - - default_system: 如果数据集没有 system,则使用默认 system + - split: 将超长样本切分为多个 max_length 的片段(不支持多模态、LazyDataset、IterablePackingDataset) + - delete: 直接丢弃超长样本 +- default_system: 如果数据集没有 system,则使用默认 system > Template 不支持使用函数来代替,因为其内部要支持的功能较多。如果需要编写新的 Template,请继承 `Template` 类。 > 一般来说,纯文本模型使用 Template 基类就足够了,在基类中我们使用了 tokenizer.apply_chat_template 来编码模型,对一般的纯文本模型是通用的。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/ToolCallParsers.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/ToolCallParsers.md" new file mode 100644 index 000000000..9d52be1c7 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/ToolCallParsers.md" @@ -0,0 +1,53 @@ +# 工具调用解析器 + +Twinkle 的模板系统包含模块化的工具调用解析框架,用于训练具有函数调用能力的模型。 + +## 架构 + +``` +ToolCallRegistry +├── HermesQwenParser — Hermes/Qwen 风格 ... +├── ReActParser — ReAct Thought/Action/Observation +├── ClineParser — Cline XML 工具调用 +└── VCPParser — VCP 协议 +``` + +## ToolCallParser 接口 + +```python +from twinkle.template.tools import ToolCallParser + +class ToolCallParser(ABC): + name: str = '' + + def detect(self, text: str) -> bool: + """检查文本是否包含此格式的标记""" + + def parse(self, text: str) -> List[Dict[str, Any]]: + """提取 OpenAI 格式的工具调用""" + + def clean(self, text: str) -> str: + """去除标记,返回纯内容""" +``` + +## ToolCallRegistry + +注册表自动发现解析器并路由检测: + +```python +from twinkle.template.tools import ToolCallRegistry + +# 检测补全使用了哪种格式 +parser = ToolCallRegistry.detect_first(completion_text) +if parser: + tool_calls = parser.parse(completion_text) +``` + +## 内置解析器 + +| 解析器 | 格式说明 | +|:-------|:---------| +| HermesQwenParser | `{"name": "...", "arguments": {...}}` | +| ReActParser | Thought/Action/Action Input/Observation | +| ClineParser | Cline XML 结构化参数 | +| VCPParser | Visual Code Protocol | diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst" index 9ab4c887b..840adf497 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst" @@ -4,3 +4,11 @@ :maxdepth: 1 Template.md + DeepSeekV4Template.md + ToolCallParsers.md +模板 +=============== +.. toctree:: + :maxdepth: 1 + + Template.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" index 00ec1f308..5842d51dd 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" @@ -40,6 +40,96 @@ class DeviceMesh: 推荐使用 `from_sizes` 来构造它。 +### 参数参考 + +| 参数 | 说明 | 默认值 | +|------|------|--------| +| `world_size` | 总进程数 | 1 | +| `dp_size` | 数据并行度 | 1 | +| `fsdp_size` | 全分片数据并行度 | None | +| `tp_size` | 张量并行度 | None | +| `pp_size` | 流水线并行度 | None | +| `ulysses_size` | Ulysses 序列并行度 | None | +| `cp_size` | 上下文并行度 | None | +| `ep_size` | 专家并行度(MoE 模型)| None | +| `etp_size` | 专家张量并行度 | None | +| `ep_fsdp_size` | 每个 EP 组内的 FSDP 度 | None | +| `vpp_size` | 虚拟流水线并行度 | None | +| `device_type` | 设备类型(`cuda`、`npu` 等)| `cuda` | +| `sequence_parallel` | 启用 Megatron 风格序列并行 | False | + +我们举一个例子: + +```python +sampler_device_mesh = DeviceMesh.from_sizes(dp_size=4) +actor_device_mesh = DeviceMesh.from_sizes(dp_size=2, pp_size=2, tp_size=2) + +dataloader = DataLoader(...) +sampler = vLLMSampler(..., device_mesh=sampler_device_mesh, remote_group=...) +actor = MegatronModel(..., device_mesh=actor_device_mesh, remote_group=...) + +for data in dataloader: + sampler_output = sampler.sample(data) + input_data = [seq.new_input_feature for response in sampler_output for seq in response.sequences] + ... + model_output = actor.forward(input_data) +``` + +我们以上面的伪代码来分析数据传递情况。 + +dataloader 取出数据 -> 按照 dp_size=4 分发给 sampler -> 按照 dp_size=4 收集数据 -> 按照 dp_size=2 分发给模型 -> 按照 dp_size=2 收集输出 + +通过 DeviceMesh,可以将数据流平顺地在各个 group 和组件之间流转起来。 + +数据的分发判断由 DeviceMesh 的 `get_slice` 方法执行: + +```python +batch[device_mesh.get_slice(len(batch))] +``` + +get_slice 会根据当前 rank,计算出当前 worker 属于哪个 dp 组,并获取对应的数据。该过程发生在 DataLoader 的 DeviceMeshSampler 中,同样发生在 remote_class 的 dispatch 和 collect 中。 +# DeviceMesh/DeviceGroup + +这两个类用于表达硬件资源分配和网络拓扑,Twinkle 的数据分发和收集也依赖它们。 + +## DeviceGroup + +```python +@dataclass +class DeviceGroup: + name: str + ranks: Union[List[int], int] + device_type: str + visible_devices: Optional[str] = None # Optional: explicitly set visible devices (e.g., "8,9") + gpus_per_worker: int = 1 +``` + +- name: 资源组名 +- ranks: 占用硬件列表,如果是CPU资源仅支持int类型 +- device_type: 硬件类型,例如 GPU/CPU/NPU 等 +- visible_devices: 可见资源列表,用于希望仅使用部分 rank 的硬件的情况 +- gpus_per_worker: 每个 worker 占用多少硬件 + +如果训练 RL,开发者可以构造多个这样的组,并将对应的模型、采样器分配进入其中。 + +## DeviceMesh + +DeviceMesh 承载了组件拓扑、分布式并行信息,这个类会在组件内传递,用于数据分发和数据收集。 + +```python +@dataclass +class DeviceMesh: + ... + + @staticmethod + def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, tp_size: int = None, + pp_size: int = None, ulysses_size: int = None, cp_size: int = None, ep_size: int = None, + etp_size: int = None,vpp_size: int = None, device_type: str = 'cuda', sequence_parallel: bool = False) -> "DeviceMesh": + ... +``` + +推荐使用 `from_sizes` 来构造它。 + 我们举一个例子: ```python diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/Padding-Free\350\256\255\347\273\203.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/Padding-Free\350\256\255\347\273\203.md" new file mode 100644 index 000000000..8bd783cff --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/Padding-Free\350\256\255\347\273\203.md" @@ -0,0 +1,52 @@ +# Padding-Free 训练 + +Padding-free(也称为"打包")训练通过将多个序列拼接到一个打包批次中,消除了对 padding token 的无效计算。Twinkle 支持标准注意力和 Qwen3.5 GatedDeltaNet 线性注意力的 padding-free 训练。 + +## 工作原理 + +不同于将所有序列填充到 `max_length`,padding-free 将多个序列打包到一行中,并使用 `position_ids` 跟踪序列边界,从而避免在 padding token 上浪费算力。 + +``` +标准方式: [tok tok tok PAD PAD PAD] [tok tok PAD PAD PAD PAD] +打包方式: [tok tok tok tok tok ...] ← 无 padding 浪费 +``` + +## 使用方式 + +通过 `PackingDataset` 或 `IterablePackingDataset` 启用: + +```python +from twinkle.dataset import PackingDataset + +dataset = PackingDataset( + dataset=base_dataset, + max_length=8192, +) +``` + +数据集会自动打包序列并生成正确的 `position_ids`,在序列边界处重置。 + +## GatedDeltaNet 补丁(Qwen3.5) + +Qwen3.5 使用混合架构,融合了标准注意力和 GatedDeltaNet 线性注意力。原生 GatedDeltaNet 实现不会在打包序列边界处重置线性注意力状态。 + +`GatedDeltaNetPaddingFreePatch` 通过以下方式修复: + +1. Patch `Qwen3_5DecoderLayer.forward`,将 `cu_seq_lens_q`(累积序列长度)传递给线性注意力层 +2. Patch `Qwen3_5GatedDeltaNet.forward`,使用支持 `cu_seqlens` 的 flash-linear-attention 内核(`causal_conv1d`、`chunk_gated_delta_rule`) + +在 Qwen3.5 模型上检测到 padding-free 时,补丁会自动应用。 + +### 要求 + +- 需安装 `flash-linear-attention` 包 +- 仅适用于含 GatedDeltaNet 层的 Qwen3.5 模型 +- 启用序列并行时,会使用 `Qwen3_5GatedDeltaNetUlyssesPatch` 替代 + +## 注意力后端要求 + +| 注意力后端 | Padding-Free 支持 | +|-----------|-------------------| +| FlashAttention2 | 完全支持 | +| SDPA | 支持(不兼容序列并行) | +| Eager | 不支持 | diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/TwinkleClient\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/TwinkleClient\345\256\242\346\210\267\347\253\257.md" new file mode 100644 index 000000000..327e2e40c --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/TwinkleClient\345\256\242\346\210\267\347\253\257.md" @@ -0,0 +1,81 @@ +# TwinkleClient 客户端 + +`TwinkleClient` 是与 Twinkle REST API 交互的 Python 客户端,管理会话、训练任务和检查点。 + +## 初始化 + +```python +from twinkle_client.manager import TwinkleClient + +client = TwinkleClient( + base_url='http://localhost:8000', # 或 TWINKLE_SERVER_URL 环境变量 + api_key='your-api-key', # 或 TWINKLE_SERVER_TOKEN 环境变量 + route_prefix='/twinkle', # API 路由前缀 + session_heartbeat_interval=10, # 心跳间隔(秒) + session_metadata={'user': 'alice'}, # 可选的会话元数据 +) +``` + +初始化时客户端会: +1. 将 `base_url` 和 `api_key` 设置到共享上下文(所有客户端对象自动使用) +2. 创建服务端会话 +3. 启动后台心跳线程保持会话活跃 + +## 健康检查 + +```python +is_healthy = client.health_check() # 返回 True/False +capabilities = client.get_server_capabilities() # 支持的模型 +``` + +## 训练任务 + +```python +# 列出训练任务 +runs = client.list_training_runs(limit=20, offset=0) + +# 带分页游标列出 +runs, cursor = client.list_training_runs_with_cursor(limit=20) + +# 获取特定任务 +run = client.get_training_run(run_id='run_abc123') + +# 按基础模型查找 +qwen_runs = client.find_training_run_by_model('Qwen/Qwen3.5-4B') +``` + +## 检查点 + +```python +# 列出训练任务的检查点 +checkpoints = client.list_checkpoints(run_id='run_abc123') + +# 获取检查点路径 +parsed = client.get_checkpoint_path(run_id, checkpoint_id) +# parsed.path → 文件系统路径 +# parsed.twinkle_path → twinkle:// URI + +# 获取最新检查点(用于恢复训练) +latest_path = client.get_latest_checkpoint_path(run_id) + +# 删除检查点 +client.delete_checkpoint(run_id, checkpoint_id) +``` + +## 容量与权重信息 + +```python +# LoRA 容量 +capacity = client.get_capacity_info() +# capacity.max_loras, capacity.used_loras, capacity.free_loras + +# 权重元数据 +info = client.get_weights_info('twinkle://run_id/weights/checkpoint') +# info.base_model, info.is_lora, info.lora_rank +``` + +## 清理 + +```python +client.close() # 停止心跳线程(也通过 atexit 自动注册) +``` diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst" index 7174ce690..377098988 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst" +++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst" @@ -4,4 +4,8 @@ :maxdepth: 1 DeviceMesh和DeviceGroup.md + 专家并行.md + 序列并行.md + Padding-Free训练.md RemoteClass.md + TwinkleClient客户端.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\344\270\223\345\256\266\345\271\266\350\241\214.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\344\270\223\345\256\266\345\271\266\350\241\214.md" new file mode 100644 index 000000000..a0112249b --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\344\270\223\345\256\266\345\271\266\350\241\214.md" @@ -0,0 +1,73 @@ +# 专家并行 (EP) + +专家并行将混合专家模型(MoE)的专家分布到多个 GPU 上,每个 rank 只持有部分专家。这降低了单卡显存占用,使大规模 MoE 模型的训练成为可能。 + +## 概览 + +| 概念 | 说明 | +|------|------| +| **ExpertParallelConfig** | 控制 EP 行为的配置数据类 | +| **apply_expert_parallel()** | 入口函数,负责分片专家并替换前向传播 | +| **shard_experts()** | 将专家均匀分配到各 EP rank | +| **patch_forward()** | 将 MoE block 的 forward 替换为带 all-to-all 通信的 EP 版本 | + +## 配置 + +```python +from twinkle.model.transformers.moe.expert_parallel import ExpertParallelConfig + +config = ExpertParallelConfig( + enabled=True, # 启用专家并行 + router_dtype='fp32', # 路由计算精度:'fp32', 'bf16', 'fp16' + keep_router_logits=True, # 在输出中保留路由 logits + ignore_shared_experts=False,# 跳过共享专家计算(如 DeepSeek) + ep_size=None, # EP 并行度(由 TransformersModel 使用) +) +``` + +## 配合 DeviceMesh 使用 + +在 `DeviceMesh.from_sizes()` 中设置 `ep_size` 即可激活 EP。框架会在模型初始化时自动调用 `apply_expert_parallel()`。 + +```python +from twinkle.utils import DeviceMesh + +# 8 卡:2 路 EP × 4 路数据并行 +device_mesh = DeviceMesh.from_sizes( + world_size=8, + dp_size=4, + ep_size=2, +) +``` + +EP + FSDP 组合分片: + +```python +# 8 卡:2 路 EP,每个 EP 组内 2 路 FSDP +device_mesh = DeviceMesh.from_sizes( + world_size=8, + dp_size=2, + ep_size=2, + ep_fsdp_size=2, +) +``` + +## 通信模式 + +EP 前向传播遵循 4 阶段流水线: + +1. **预处理** — 计算每个专家的 token 数量和分割大小 +2. **Token Pre-All2All** — 按专家分配排列 token,然后在 EP rank 间执行 all-to-all 交换 +3. **专家计算** — 每个 rank 在接收到的 token 上运行本地专家 +4. **Token Post-All2All** — all-to-all 交换结果,反排列并应用路由权重 + +``` +输入 token → 路由器 → [预处理] → [pre_all2all] → [本地专家] → [post_all2all] → 输出 +``` + +## 要求 + +- `num_experts` 必须能被 `ep_size` 整除 +- `torch.distributed` 必须已初始化 +- MoE block 必须定义 `gate`/`router` 模块和 `experts`(支持 `nn.ModuleList` 或张量形式的 `gate_up_proj`/`down_proj`) +- 共享专家(如 DeepSeek MoE)会自动处理,除非设置 `ignore_shared_experts=True` diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\345\272\217\345\210\227\345\271\266\350\241\214.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\345\272\217\345\210\227\345\271\266\350\241\214.md" new file mode 100644 index 000000000..e1d997188 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\345\272\217\345\210\227\345\271\266\350\241\214.md" @@ -0,0 +1,68 @@ +# 序列并行 (SP) + +序列并行沿序列维度将长序列分割到多个 GPU 上,使训练能处理超出单卡显存的序列长度。Twinkle 实现了 Ulysses 风格的序列并行,并可选地支持派生环形注意力。 + +## 概览 + +| 概念 | 说明 | +|------|------| +| **SequenceParallelConfig** | SP 配置数据类 | +| **SequenceParallelStrategy** | 封装 SP 生命周期的策略类 | +| **SequenceParallel** | 核心实现,处理填充/分割/聚合 | + +## 配置 + +```python +from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelConfig + +config = SequenceParallelConfig( + enabled=True, # 启用序列并行 + ulysses_size=None, # Ulysses SP 并行度(若为 None 则从 DeviceMesh 自动推导) + gather_logits=True, # 前向后聚合 logits 用于损失计算 +) +``` + +## 配合 DeviceMesh 使用 + +在 `DeviceMesh.from_sizes()` 中设置 `ulysses_size` 即可激活 SP: + +```python +from twinkle.utils import DeviceMesh + +# 8 卡:4 路 Ulysses SP × 2 路数据并行 +device_mesh = DeviceMesh.from_sizes( + world_size=8, + dp_size=2, + ulysses_size=4, +) +``` + +## 工作原理 + +1. **填充** — 输入序列被填充到可被 SP 并行度整除的长度 +2. **分割** — 填充后的输入沿序列维度均匀分配到各 SP rank +3. **分布式注意力** — FlashAttention2 被 patch 为在注意力计算前后执行 Ulysses all-to-all 通信 +4. **聚合** — 前向传播后,logits 被聚合回完整序列长度用于损失计算 + +## 支持的注意力后端 + +| 后端 | 状态 | +|------|------| +| FlashAttention2 | 完全支持(包括打包/padding-free 序列)| +| SDPA | 支持(仅非打包批次)| +| 派生环形注意力 | 仅支持 FlashAttention2(`rp_world_size > 1`)| + +## Qwen3.5 线性注意力 + +SP 自动检测 Qwen3.5 GatedDeltaNet 线性注意力层,并应用 `Qwen3_5GatedDeltaNetUlyssesPatch`,确保混合注意力架构下序列并行的正确性。 + +## MoE 辅助损失 + +对于 MoE 模型,SP 自动安装前向 hook,在计算辅助损失前跨 SP rank 聚合路由 logits,确保负载均衡信号的正确性。 + +## 关键约束 + +- `num_key_value_heads` 必须能被 `ulysses_size` 整除(Ulysses 模式),否则回退到环形注意力 +- 打包/padding-free 批次需要 FlashAttention2 +- 派生环形注意力要求 `batch_size == 1`(打包格式) +- `torch.distributed` 必须已初始化 diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/index.rst" new file mode 100644 index 000000000..8b3692d51 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/index.rst" @@ -0,0 +1,6 @@ +通知器 +=============== +.. toctree:: + :maxdepth: 1 + + 通知器.md diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/\351\200\232\347\237\245\345\231\250.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/\351\200\232\347\237\245\345\231\250.md" new file mode 100644 index 000000000..823e0d459 --- /dev/null +++ "b/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/\351\200\232\347\237\245\345\231\250.md" @@ -0,0 +1,93 @@ +# 通知器 + +通知器组件提供可插拔的通知系统,用于在训练过程中发送告警。当异常发生或训练事件需要关注时,通知器将消息投递到外部渠道(如钉钉 Webhook)。 + +## 基础接口 + +```python +from twinkle.notifier import Notifier + +class Notifier: + def __call__(self, message: str): + """发送通知消息""" + ... + + def to_dict(self) -> dict: + """序列化(用于 checkpoint 保存/恢复)""" + ... + + @classmethod + def from_dict(cls, data: dict) -> Notifier: + """从序列化数据恢复""" + ... +``` + +## DingNotifier(钉钉通知) + +向钉钉自定义机器人 Webhook 发送通知。 + +```python +from twinkle.notifier import DingNotifier + +notifier = DingNotifier( + ding_url='https://oapi.dingtalk.com/robot/send?access_token=xxx', + secret='SECxxxxxxx', # 可选:签名模式 + timeout=5.0, +) + +# 发送消息 +notifier("### 训练完成\n\n- Steps: 1000\n- Loss: 0.25") +``` + +**参数:** +- `ding_url`:完整的钉钉 Webhook URL(含 access_token) +- `secret`:可选签名密钥(签名模式机器人) +- `timeout`:HTTP 请求超时时间,单位秒(默认 5.0) + +消息以钉钉 **Markdown** 格式发送。第一个标题行自动提取为聊天预览标题。 + +## 异常通知 + +Twinkle 提供带去重的自动异常通知: + +```python +from twinkle.notifier.base import notify_exception + +# 自动发送格式化的异常信息 +# 每个唯一异常只有一个 rank 发送(防止消息洪泛) +try: + model.forward_backward(batch) +except Exception as e: + notify_exception(notifier, context='forward_backward', exc=e, name='sft_train') +``` + +通知包含: +- 异常类型和消息 +- 完整堆栈跟踪 +- 运行时元数据(rank、PID、主机名) +- 去重:所有 rank 中每个唯一异常只发一条通知 + +## 自定义通知器 + +继承 `Notifier` 创建自定义通知器: + +```python +from twinkle.notifier import Notifier + +class SlackNotifier(Notifier): + def __init__(self, webhook_url: str): + self.webhook_url = webhook_url + + def __call__(self, message: str): + import requests + requests.post(self.webhook_url, json={'text': message}) + + def to_dict(self): + return {'class': 'SlackNotifier', 'webhook_url': self.webhook_url} + + @classmethod + def _from_dict_impl(cls, data): + return cls(webhook_url=data['webhook_url']) +``` + +> 通知器通过 `__init_subclass__` 自动注册,因此 `Notifier.from_dict()` 可以按类名恢复任何子类。 diff --git a/new_feature.txt b/new_feature.txt new file mode 100644 index 000000000..39ce0d763 --- /dev/null +++ b/new_feature.txt @@ -0,0 +1,204 @@ +# Native ML LLM Control + +基于 Twinkle Server Mode 架构,实现 LLM Agent 驱动的 TUI 训练控制系统。面向零基础开发者,通过自然语言对话完成模型训练的全生命周期管理。 + + +## 一、整体架构 + +采用"无状态客户端 + 有状态服务端"的 Server Mode 架构: + +- **Server 端(Ray 集群)**:模型权重、LoRA adapter、optimizer 状态、LR scheduler 全部驻留在 GPU 内存 +- **Client 端(TUI + 训练脚本)**:完全无状态,仅负责数据加载、训练循环逻辑、指标上报 +- **核心特性**:杀死 client = 暂停(server 保留全部状态),重启 client = 恢复(零成本继续训练) + +支持两种运行模式: +1. **本地自建**:启动本地 Ray 集群 + Twinkle Server,需评估 GPU 资源和 DeviceMesh +2. **线上云服务**:连接 ModelScope 托管服务(`http://www.modelscope.cn/twinkle`),无需本地 GPU + + +## 二、TUI 界面(基于 Textual 框架) + +TUI 采用 Grid 布局,包含四个核心面板: + +1. **状态栏(StatusBar)**:顶部横跨,显示当前训练 run_id、状态、步数进度 +2. **指标面板(MetricsPanel)**:左上区域,绘制 loss / reward / grad_norm 等指标曲线,支持自然语言控制放大缩小和还原 +3. **对话面板(ChatPanel)**:左下区域,用户与 LLM agent 的对话界面,支持 UTF-8 中文输入输出 +4. **日志面板(LogPanel)**:右侧纵向,滚动显示训练运行日志 + +快捷键:`q` 退出、`Ctrl+P` 切换指标面板、`Ctrl+L` 清空日志。 + +关闭 TUI 后训练不中断(Server 端状态不受影响),重新打开 TUI 可继续监控。 + + +## 三、LLM Agent 系统 + +### 3.1 对话式 Agent(AgentLoop) + +用户通过 Chat 面板与 Agent 对话,Agent 通过 tool_call 执行训练管理操作: + +| 工具 | 功能 | +|------|------| +| `list_training_runs` | 列出所有活跃和历史训练任务 | +| `get_training_status` | 获取指定 run 的状态和近期指标 | +| `pause_training` | 暂停训练(SIGKILL client,server 保留状态) | +| `resume_training` | 恢复训练(重启 client 脚本) | +| `stop_training` | 优雅停止(SIGTERM,脚本自动保存 checkpoint + dataloader 位置后退出) | +| `list_supported_models` | 查询 Server 支持的模型列表(本地/云端) | +| `search_datasets` | 在 ModelScope 搜索数据集 | +| `search_models` | 在 ModelScope 搜索模型 | +| `zoom_metrics` | 自然语言控制指标图表缩放 | + +### 3.2 自动监控(TrainingMonitor) + +后台 LLM 定期(默认 30 秒)读取 metrics 和 logs,进行趋势分析: + +- 检测异常:loss 突增/NaN、reward 停滞、gradient 爆炸/消失、KL 散度过大、entropy 坍塌 +- 主动建议:调整学习率、更换 reward 组合、增加 num_generations +- 通过 Chat 面板推送诊断报告(`[Monitor] ...`) +- 无硬编码规则,所有分析由 LLM 推理完成 + +### 3.3 Skills 可扩展框架 + +通过 `SkillManager` + `ModelScopeSkillProvider` 加载技能文档,为 Agent 提供领域知识: + +- `skills/twinkle-training.md`:训练脚本编写指导(1260+ 行) +- `skills/autoresearch.md`:自动化研究实验设计(256 行) +- 支持从 ModelScope 远程加载社区共享 Skills + + +## 四、训练控制机制 + +### 4.1 训练进程管理 + +| 操作 | 信号 | 行为 | 恢复方式 | +|------|------|------|----------| +| 暂停 | SIGKILL | 立即杀死 client | 重启同一脚本(adapter_name 相同即可继续) | +| 停止 | SIGTERM | 脚本保存 checkpoint + dataloader 位置后退出 | `model.resume_from_checkpoint()` + `dataloader.resume_from_checkpoint()` | +| 修改超参 | SIGKILL → 编辑 → 重启 | 新配置生效,optimizer 状态保留 | 使用相同 adapter_name | +| 重置训练 | 使用新 adapter_name | 全新开始 | 旧 adapter 按 `adapter_timeout` 自动清理 | + +### 4.2 优雅退出(SIGTERM) + +每个训练脚本必须注册 graceful shutdown handler: + +```python +rt = TrainingRuntime(run_id='my-exp') +rt.register_graceful_shutdown(model, dataloader) +``` + +收到 SIGTERM 后自动执行: +1. 保存模型 checkpoint(含 optimizer 状态) +2. 记录 dataloader 已消费的样本数(`consumed_train_samples`) +3. 写入 `rt.finish(status='stopped')` +4. 安全退出 + + +## 五、数据通信(TUI ↔ 训练脚本) + +训练脚本通过 `TrainingRuntime` 写入本地 JSONL 文件,TUI 通过 `LocalConnection` 读取: + +``` +~/.cache/twinkle/{run_id}/ +├── meta.json # 运行元信息(model_id、config、status、pid、script_path、script_version) +├── train.py # 当前活跃版本(始终是最新的) +├── train_v1.py # 归档:第1版脚本(出错的原始版本) +├── train_v2.py # 归档:第2版脚本(如果也有问题) +├── metrics.jsonl # 每步一行 JSON(step, loss, reward, grad_norm, lr, ...) +└── logs.jsonl # 事件日志(timestamp + message) +``` + +**脚本命名与版本管理规则:** +- `train.py` 始终是当前活跃版本,`resume_training` 只执行它 +- 当 Agent 修复脚本时,旧版自动归档为 `train_v{N}.py`(保留完整修改历史) +- `meta.json` 中 `script_version` 字段记录当前版本号 +- `run_id` 由用户定义(如 `'grpo-gsm8k'`、`'sft-self-cognition'`) +- 同一 `run_id` 下可多次修改脚本:server 端 adapter 状态不变,只有 client 逻辑更新 + +**脚本更新流程(Agent 自动执行):** +1. 脚本出错停止 → Agent 读取 logs/metrics 诊断问题 +2. Agent 调用 `update_script(run_id, new_code)` → 旧 `train.py` 归档为 `train_v{N}.py`,新代码写入 `train.py` +3. Agent 调用 `resume_training(run_id)` → 重新执行最新的 `train.py` + +**TUI 通过 PID 进行进程控制:** +- `meta.json` 记录 `pid`(进程 ID) +- 暂停 = `os.kill(pid, SIGKILL)` +- 停止 = `os.kill(pid, SIGTERM)` → 脚本优雅保存 checkpoint +- 恢复 = `subprocess.Popen(['python', script_path])` → 新 PID 写回 meta + +TUI 支持增量读取(tail 模式),避免大文件全量加载。 + + +## 六、训练前规划(Pre-Training Planning) + +Agent 在编写训练脚本前,必须完成以下评估(本地模式): + +1. **集群资源评估**:GPU 数量、型号、显存(`nvidia-smi` / `ray status`) +2. **模型显存估算**:LoRA training ≈ model weights + 20% overhead +3. **DeviceMesh 设计**:根据 GPU 数决定 model vs sampler 分配(决策树) +4. **训练时间预估**:`total_steps × time_per_step` +5. **数据集搜索**:通过 ModelScope API 或 `search_datasets` 工具 +6. **模型选择**:根据任务类型和资源约束推荐 + +云服务模式下可跳过 1-4,直接进入数据集和模型选择。 + +`list_supported_models` 工具用于查询 Server 实际支持的模型列表,避免选择不可用模型。 + + +## 七、Skills 文档内容 + +### twinkle-training.md(训练脚本编写指导) + +覆盖以下内容: +- Pre-Training Planning 完整流程 +- Ray 集群配置(DeviceGroup / DeviceMesh / initialize) +- 模型后端(Transformers / Megatron)初始化 +- Dataset 加载、Template 编码、Preprocessor 使用 +- 所有训练方式示例:SFT、GRPO、DPO、GKD、PT +- Server Mode 完整说明(本地自建 + 云服务两种模式) +- Cloud Service Mode(ModelScope 托管,两种客户端 API 对比) +- Sampler 配置与权重同步 +- MultiTurnRollout 多轮对话采样 +- TUI 集成:TrainingRuntime、指标上报、优雅退出 +- 实验管理文件夹规范 + +### autoresearch.md(自动化研究实验设计) + +指导 Agent 如何: +- 分析用户需求,选取合适的训练方法 +- 根据资源约束选择模型规模 +- 配置超参数(SFT / GRPO / DPO 默认值和调优建议) +- 设计多阶段 Pipeline(数据清洗 → SFT → GRPO/DPO) +- 编写数据清洗和转换流程 +- 组织实验输出文件夹 + + +## 八、训练脚本规范 + +所有训练脚本必须: + +1. 使用 **Server Mode 语法**:`twinkle_client`(模型操作)+ `twinkle`(数据处理) +2. 连接到运行中的 Twinkle Server(本地或云端) +3. 注册 SIGTERM graceful shutdown handler +4. 通过 `TrainingRuntime` 上报所有可用指标(loss, reward, grad_norm, lr, ...) +5. 每个实验独立文件夹,包含 `plan.md`、`config.yaml`、`train.py`、`train.sh` + +两种客户端 API: +- **Twinkle 原生**:`init_twinkle_client()` → `MultiLoraTransformersModel` → `forward_backward()` → `clip_grad_and_step()` +- **Tinker 兼容**:`init_tinker_client()` → `ServiceClient` → `create_lora_training_client()` → `forward_backward()` → `optim_step()` + + +## 九、插件化架构与组件约束 + +Twinkle 是一个插件化框架,所有核心组件(loss、preprocessor、metric、sampler 等)均以插件形式注册,本地模式下支持用户编写新组件。 + +### 本地自建模式 +- 完全可扩展:可编写自定义 loss、preprocessor、metric、reward function +- 通过装饰器注册(如 `@register_loss('MyLoss')`),然后在训练脚本中以字符串名引用 + +### 线上云服务模式(ModelScope 托管) +- **安全限制**:不支持传入类、函数对象或 pickle 序列化 +- **只能使用已注册的内置组件**,以字符串名引用 +- 内置 Loss:`CrossEntropyLoss`、`GRPOLoss`、`DPOLoss`、`GKDLoss` +- 内置 Preprocessor:`SFTPreprocessor`、`RLPreprocessor`、`DPOPreprocessor` + +Agent 编写脚本时必须判断目标环境:本地可用自定义组件,云端只能用内置名称。 diff --git a/pyproject.toml b/pyproject.toml index 5daa14fe9..484dbc9e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,24 +16,28 @@ dependencies = [ "transformers", "typer>=0.9.0", "pyzmq", + "accelerate", + "torch>=2.6.0,<3.0.0", ] [project.scripts] twinkle-server = "twinkle.server.cli:main" +twinkle-tui = "twinkle_client.tui:main" [project.optional-dependencies] -transformers = [ - "accelerate", - "torch>=2.6.0,<3.0.0", - "torchvision", -] -kernels = ["kernels"] megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]", "mcore_bridge"] -vllm = ["vllm>=0.11"] -ray = ["ray[serve]"] -datajuicer = ["py-data-juicer"] -tinker = ["tinker==0.14.0"] -test = ["hypothesis>=6.0", "pytest", "pytest-asyncio"] +data = ["py-data-juicer"] +rl = [ + "vllm>=0.11", + "ray[serve]" +] +client = [ + "textual>=1.0.0", + "plotext>=5.2.0", + "openai>=1.0.0", + "httpx>=0.25.0", + "tinker==0.16.1", +] server = [ "redis>=5.0", "psutil>=5.9.0", @@ -43,6 +47,11 @@ server = [ "opentelemetry-exporter-otlp", "opentelemetry-instrumentation-logging", ] +test = [ + "hypothesis>=6.0", + "pytest", + "pytest-asyncio" +] docs = [ "sphinx>=5.3.0,<6.0.0", "docutils>=0.16.0,<0.17.0", @@ -68,3 +77,6 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] where = ["src"] + +[tool.setuptools.package-data] +"twinkle_client.skills.bundled" = ["*.md"] diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index cde5c519d..3860d2840 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py -import time from typing import List, Optional from twinkle import Platform, get_logger diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py index e2e5d94d5..8dc15c926 100644 --- a/src/twinkle/checkpoint_engine/mixin.py +++ b/src/twinkle/checkpoint_engine/mixin.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import os from twinkle import Platform, remote_function from twinkle.checkpoint_engine.base import CheckpointEngine diff --git a/src/twinkle/cli/cli.py b/src/twinkle/cli/cli.py index 1730887f2..085ad7ec6 100644 --- a/src/twinkle/cli/cli.py +++ b/src/twinkle/cli/cli.py @@ -1,12 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - import os import sys from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from pathlib import Path -from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union +from typing import Any, Iterator, Literal + # ──────────────────────────────────────────────────────────────────────────────── # Arg group dataclasses @@ -46,6 +45,7 @@ class LoraArgs: lora_dropout: float = 0.05 lora_target_modules: list[str] | None = None adapter_name: str = 'default' + lora_path: str | None = None @dataclass @@ -84,6 +84,7 @@ class TrainingArgs: log_interval: int = 10 eval_interval: int | None = None eval_samples: int | None = None + train_samples: int | None = None resume_from_checkpoint: str | None = None resume_only_model: bool = False ignore_data_skip: bool = False @@ -117,9 +118,11 @@ class SchedulerArgs: @dataclass class LossArgs: loss_cls: str = 'CrossEntropyLoss' + loss_type: str = 'sigmoid' epsilon: float = 0.2 epsilon_high: float | None = None - beta: float = 0.0 + beta: float = 0.1 + sft_weight: float = 1.0 entropy_coef: float = 0.0 ignore_index: int = -100 @@ -155,6 +158,7 @@ class InfraArgs: ncpu_proc_per_node: int = 8 model_gpus: int | None = None sampler_gpus: int | None = None + ref_model_gpus: int | None = None world_size: int | None = None dp_size: int | None = None fsdp_size: int | None = None @@ -185,6 +189,11 @@ class RLArgs: advantage_type: str = 'GRPOAdvantage' advantage_scale: Literal['group', 'batch', 'none'] = 'group' reward_fns: list[str] | None = None + student_model_id: str | None = None + teacher_model_id: str | None = None + gkd_beta: float = 0.5 + gkd_temperature: float = 1.0 + gkd_topk: int = 64 @dataclass @@ -192,6 +201,7 @@ class CheckpointArgs: save_optimizer: bool = True merge_and_sync: bool = True platform: str = 'GPU' + lora_sync_dir: str | None = None # ──────────────────────────────────────────────────────────────────────────────── @@ -243,7 +253,7 @@ def _resolve_path(self) -> Path | None: class EnvVarSource(ConfigSource): """Reads os.environ; recognizes TWINKLE_ prefix and any key known to the registry.""" - def __init__(self, registry: ConfigRegistry): + def __init__(self, registry: 'ConfigRegistry'): self._registry = registry def load(self) -> dict[str, str]: diff --git a/src/twinkle/data_format/output.py b/src/twinkle/data_format/output.py index 763ef246f..596252fb6 100644 --- a/src/twinkle/data_format/output.py +++ b/src/twinkle/data_format/output.py @@ -20,11 +20,13 @@ class ModelOutput(TypedDict, total=False): loss: The loss calculated by the model. logps: The log-probabilities of correct tokens by the model. num_tokens: The token denominator associated with ``loss``. + embeddings: The embeddings output by the model, used be embedding task. """ logits: Optional[OutputType] loss: Optional[OutputType] logps: Optional[OutputType] num_tokens: Optional[OutputType] + embeddings: Optional[OutputType] class LossOutput(TypedDict, total=False): diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index 1d5fe07c6..01ff0377d 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index c392d56cf..408c8d4b4 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -146,7 +146,7 @@ def _tracking_iter(self, inner): def skip_consumed_samples(self, consumed_train_samples: int) -> None: from torch.utils.data import IterableDataset - if isinstance(self.dataset, IterableDataset): + if isinstance(self.dataset, IterableDataset) or consumed_train_samples is None or consumed_train_samples <= 0: warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.') self._skip_samples = 0 return @@ -164,6 +164,7 @@ def resume_from_checkpoint(self, consumed_train_samples, **kwargs): @remote_function() def get_state(self) -> dict: + """The dataloader state for saving.""" return {'consumed_train_samples': self._consumed_train_samples} def _rebuild_sampler_stack(self): diff --git a/src/twinkle/dataset/iterable_packing_dataset.py b/src/twinkle/dataset/iterable_packing_dataset.py index ca7c6fbd8..ab1d3a982 100644 --- a/src/twinkle/dataset/iterable_packing_dataset.py +++ b/src/twinkle/dataset/iterable_packing_dataset.py @@ -88,10 +88,27 @@ def _fetch_data_out_queue(self, last_res, num_samples): last_res += res return last_res - @staticmethod - def _cyclic_iter(iterable): - while True: - yield from iterable + def _write_through_iter(self, iterable): + """Yields from iterable, meanwhile, save it to disk if needed. + Saving is needed when you are using several datasets at a time. + """ + if not self.cyclic: + for row in iterable: + self._write_through(row) + yield row + return + else: + first_pass = True + while True: + empty = True + for row in iterable: + empty = False + if first_pass: + self._write_through(row) + yield row + if empty: + return + first_pass = False @remote_function() def __iter__(self): @@ -102,10 +119,7 @@ def __iter__(self): except StopIteration: return - if self.cyclic: - iterator = self._cyclic_iter(self.dataset) - else: - iterator = iter(self.dataset) + iterator = self._write_through_iter(self.dataset) data = [] max_length = self.template.max_length or 2048 while True: diff --git a/src/twinkle/dataset/packing_dataset.py b/src/twinkle/dataset/packing_dataset.py index fa4acbd57..ada9498b8 100644 --- a/src/twinkle/dataset/packing_dataset.py +++ b/src/twinkle/dataset/packing_dataset.py @@ -114,6 +114,8 @@ def __getitem__(self, index): assert self._packed_called, 'Call `pack_dataset()` first before index the sample.' sequence = self.packed_idx[index] rows = [self.dataset[i] for i in sequence] + for row in rows: + self._write_through(row) output = {} for key in rows[0]: output[key] = [r[key] for r in rows] diff --git a/src/twinkle/gym/__init__.py b/src/twinkle/gym/__init__.py deleted file mode 100644 index 44b0771bb..000000000 --- a/src/twinkle/gym/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from .base import Gym diff --git a/src/twinkle/gym/base.py b/src/twinkle/gym/base.py deleted file mode 100644 index aca798093..000000000 --- a/src/twinkle/gym/base.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. - - -class Gym: - - def __init__(self): - pass - - def step(self): - pass diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 83e10d132..a2760c900 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -1,11 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import functools import inspect -import itertools import json import numpy as np import os -import random import sys from typing import Any, Callable, List, Literal, Optional, TypeVar, Union @@ -59,7 +57,7 @@ def _tag_exc(exc: BaseException, caller: Optional[str]) -> None: prefix = f'[twinkle driver caller: {caller}] ' exc.args = (prefix + str(exc.args[0]), *exc.args[1:]) if exc.args else (prefix.rstrip(), ) exc._twinkle_caller_augmented = True - except Exception: # noqa: BLE001 + except Exception: # noqa pass @@ -404,6 +402,7 @@ def dispatch_func(arg, n): return result elif dispatch == 'slice_dp': + assert device_mesh is not None # split by dp. each worker in one ep will receive the same argument result = [] # if device_mesh is not None: @@ -420,14 +419,6 @@ def dispatch_func(arg, n): import torch if isinstance(arg, list) or isinstance(arg, torch.Tensor): _args = [] - if device_mesh is None: - total = len(arg) - chunk = max(1, (total + n - 1) // n) - for i in range(n): - start = i * chunk - end = min(total, start + chunk) - _args.append(arg[start:end]) - return _args for i in range(n): _args.append(arg[device_mesh.get_slice( len(arg), device_mesh.get_data_rank_from_global_rank(i * _rank_stride))]) @@ -696,11 +687,12 @@ def __next__(_self): return decorator -def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callable] = 'slice', +def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp', 'last_pp_first'], Callable] = 'slice', execute: Literal['first', 'peer', 'all'] = 'all', collect: Union[Literal['none', 'flatten', 'mean', 'sum', 'first', 'last_pp'], Callable] = 'none', sync: bool = False, - lazy_collect: Optional[bool] = None): + lazy_collect: Optional[bool] = None, + timeout: Optional[float] = None): """Patch each method called from remote(which class should be decorated with `remote_class`) with this decorator. Args: @@ -726,6 +718,7 @@ def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callabl sync: If True, use synchronous execution (execute_all_sync) instead of async. Required for methods with NCCL collective operations (e.g., Megatron forward_backward). lazy_collect: Do lazy collect, this boolean value decides whether this function needs lazy collect. If setting to None, it will follow the global setting. + timeout: Timeout in seconds for ray.get() when collecting results. Instance attribute ``_ray_get_timeout`` overrides this. """ # noqa def decorator(func: Callable[..., T1]) -> Callable[..., T1]: @@ -773,7 +766,9 @@ def wrapper(self, *args, **kwargs) -> T1: result = execute_method(func.__name__, _workers_and_args) # This is a result future, call it to get the actual result - result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result, device_mesh) + _rgt = getattr(self, '_ray_get_timeout', None) or timeout + result_func = RayHelper.do_get_and_collect_func( + _collect_func, collect, result, device_mesh, timeout=_rgt) _local_lazy_collect = _lazy_collect if func.__name__ == '__iter__': # return self @@ -803,18 +798,13 @@ def wrapper(self, *args, **kwargs) -> T1: # And this is user independent, only decided by the code. _local_lazy_collect = self._lazy_collect if _local_lazy_collect: - # Wrap the deferred collector so that exceptions - # raised when the caller later materializes the - # result also trigger the notifier. Attributes - # (``_futures`` etc.) on the original collector - # are preserved for downstream code paths. _orig_result_func = result_func @functools.wraps(_orig_result_func) def _notifying_result_func(*rargs, **rkwargs): try: return _orig_result_func(*rargs, **rkwargs) - except Exception as _e: # noqa: BLE001 + except Exception as _e: # noqa _tag_exc(_e, _caller) notify_exception(_notifier, _ctx, _e, _name) raise diff --git a/src/twinkle/infra/_ray/ray_helper.py b/src/twinkle/infra/_ray/ray_helper.py index 0d8908a35..5cd792c3a 100644 --- a/src/twinkle/infra/_ray/ray_helper.py +++ b/src/twinkle/infra/_ray/ray_helper.py @@ -161,18 +161,20 @@ def get_node_address(): return ip, port @staticmethod - def do_get_and_collect_func(collect_func: Callable, method: Union[str, Callable], futures, device_mesh): + def do_get_and_collect_func(collect_func: Callable, method: Union[str, Callable], futures, device_mesh, + timeout=None): """Return a callable to collect results in the workers.""" class LazyCollect: - def __init__(self, futures, method, collect_func, device_mesh): + def __init__(self, futures, method, collect_func, device_mesh, timeout=None): self._futures = futures self._method = method self._collect_func = collect_func self._is_lazy_collect = True self.device_mesh = device_mesh self._result = None # Cache collected results + self._timeout = timeout def _get_result(self): """Internal method to lazily collect and cache results""" @@ -181,7 +183,7 @@ def _get_result(self): result = [] for future in self._futures: if isinstance(future, ray.ObjectRef): - result.append(ray.get(future)) + result.append(ray.get(future, timeout=self._timeout)) else: result.append(future) self._result = self._collect_func(self._method, result, device_mesh=self.device_mesh) @@ -199,7 +201,7 @@ def __len__(self): """Support len() function""" return len(self._get_result()) - return LazyCollect(futures, method, collect_func, device_mesh) + return LazyCollect(futures, method, collect_func, device_mesh, timeout=timeout) @staticmethod def do_get_and_collect(args, kwargs): diff --git a/src/twinkle/loss/chunked_cross_entropy.py b/src/twinkle/loss/chunked_cross_entropy.py index 22d3d4077..061ca2168 100644 --- a/src/twinkle/loss/chunked_cross_entropy.py +++ b/src/twinkle/loss/chunked_cross_entropy.py @@ -1,63 +1,159 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import math -from typing import Any - -from ..data_format import LossOutput +from twinkle.data_format import LossOutput from .base import Loss +# Lazily-built singleton autograd.Function, so we neither pay the +# class-construction cost on every forward nor force a top-level torch import. +_CHUNKED_CE_FUNC = None + + +def _get_chunked_ce_func(): + global _CHUNKED_CE_FUNC + if _CHUNKED_CE_FUNC is not None: + return _CHUNKED_CE_FUNC + + import torch + import torch.nn.functional as F + + class _ChunkedCrossEntropyFunc(torch.autograd.Function): + """Chunked CE that materialises log_softmax(B, V) only one chunk at a time. + + Forward returns a scalar loss; backward writes per-token gradients into + a freshly allocated `grad_logits` tensor (the input `logits` is never + mutated). Mathematically equivalent to ``CrossEntropyLoss`` in the same + package; ``chunk_size`` only controls the memory/throughput trade-off. + """ + + @staticmethod + def forward(ctx, logits, labels, chunk_size, ignore_index, reduction, dft): + ctx.save_for_backward(logits, labels) + ctx.chunk_size = chunk_size + ctx.ignore_index = ignore_index + ctx.reduction = reduction + ctx.dft = dft + + n = logits.shape[0] + # Use fp32 accumulators so we don't lose precision when summing + # over many tokens under fp16/bf16 autocast (matches cross_entropy.py). + total_loss = logits.new_zeros((), dtype=torch.float32) + total_count = logits.new_zeros((), dtype=torch.float32) + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + logits_chunk = logits[start:end] + labels_chunk = labels[start:end] + mask = (labels_chunk != ignore_index).float() + + logps = F.log_softmax(logits_chunk, dim=-1).gather( + -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1) + per_token = -logps * logps.exp() if dft else -logps + + total_loss = total_loss + (per_token * mask).sum() + total_count = total_count + mask.sum() + + ctx.num_tokens = total_count.detach() + if reduction == 'mean': + return total_loss / total_count.clamp(min=1) + return total_loss + + @staticmethod + def backward(ctx, grad_output): + logits, labels = ctx.saved_tensors + chunk_size = ctx.chunk_size + ignore_index = ctx.ignore_index + reduction = ctx.reduction + dft = ctx.dft + + if reduction == 'mean': + scale = grad_output / ctx.num_tokens.clamp(min=1) + else: + scale = grad_output + + grad_logits = torch.empty_like(logits) + n = logits.shape[0] + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + logits_chunk = logits[start:end].detach().requires_grad_(True) + labels_chunk = labels[start:end] + mask = (labels_chunk != ignore_index).float() + + with torch.enable_grad(): + logps = F.log_softmax(logits_chunk, dim=-1).gather( + -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1) + per_token = -logps * logps.exp() if dft else -logps + loss_chunk = (per_token * mask).sum() + + grad_chunk = torch.autograd.grad(loss_chunk, logits_chunk, retain_graph=False)[0] + grad_logits[start:end] = grad_chunk * scale + + # logits, labels, chunk_size, ignore_index, reduction, dft + return grad_logits, None, None, None, None, None + + _CHUNKED_CE_FUNC = _ChunkedCrossEntropyFunc + return _CHUNKED_CE_FUNC + class ChunkedCrossEntropyLoss(Loss): - """TODO untested code""" + """CE loss that chunks the (B, V) softmax to bound peak memory. + + Drop-in replacement for :class:`CrossEntropyLoss` when ``outputs['logits']`` + is large (e.g. long sequence x big vocab). Behaviour matches that loss + bit-for-bit; ``chunk_size`` only affects memory/throughput. + + Args: + chunk_size: How many rows of ``logits`` to process per chunk. + ignore_index: Label id treated as padding (excluded from loss). + reduction: ``'mean'`` or ``'sum'``; matches ``CrossEntropyLoss``. + dft: If True, use DFT weighting ``-p*log(p)`` (arxiv 2508.05629). + """ - def __init__(self, chunk_size): + require_logits = True + # We chunk the (B, V) softmax ourselves; tell upstream not to materialise + # `logps` (which would already pay the full memory cost we're trying to + # avoid). The `_loss_from_logps` fast path is kept only for the rare case + # where someone explicitly hands us pre-computed logps. + require_logps = False + + def __init__(self, + chunk_size: int, + ignore_index: int = -100, + reduction: str = 'mean', + dft: bool = False, + **kwargs): + super().__init__() + assert chunk_size > 0, 'chunk_size must be positive' + assert reduction in ('mean', 'sum'), f"reduction must be 'mean' or 'sum', got {reduction!r}" self.chunk_size = chunk_size + self.ignore_index = ignore_index + self.reduction = reduction + self.dft = dft def __call__(self, inputs, outputs, **kwargs): - import torch - - class ChunkedCrossEntropyLossFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, logits, labels, chunk_size): - import torch - ctx.save_for_backward(logits, labels) - ctx.chunk_size = chunk_size - - losses = [] - for i in range(math.ceil(logits.shape[0] / chunk_size)): - l_start = i * chunk_size - l_end = min((i + 1) * chunk_size, logits.shape[0]) - logits_chunk = logits[l_start:l_end] - labels_chunk = labels[l_start:l_end] - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - loss_chunk = loss_fct(logits_chunk, labels_chunk) - losses.append(loss_chunk) - del logits_chunk - del labels_chunk - all_losses = torch.cat(losses) - return all_losses - - @staticmethod - def backward(ctx: Any, *grad_outputs: Any): - import torch - logits, labels = ctx.saved_tensors - chunk_size = ctx.chunk_size - - for i in range(math.ceil(logits.shape[0] / chunk_size)): - l_start = i * chunk_size - l_end = min((i + 1) * chunk_size, logits.shape[0]) - logits_chunk = logits[l_start:l_end].detach().requires_grad_(True) - labels_chunk = labels[l_start:l_end] - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - with torch.enable_grad(): - loss_chunk = loss_fct(logits_chunk, labels_chunk) - grad_output_chunk = grad_outputs[0][l_start:l_end] - _loss_chunk = (loss_chunk * grad_output_chunk).sum() - grad_chunk = torch.autograd.grad(_loss_chunk, logits_chunk, retain_graph=False)[0] - logits[l_start:l_end] = grad_chunk - - return logits, None, None + labels = inputs['labels'] + logps = outputs.get('logps') + + # Fast path: if logps is already gathered upstream, chunking the + # softmax is moot — fall back to the same scalar formula as + # CrossEntropyLoss to keep behaviour identical. + if logps is not None: + return self._loss_from_logps(labels, logps) logits = outputs['logits'] - labels = inputs['labels'] - return LossOutput(loss=ChunkedCrossEntropyLossFunc.apply(logits, labels, self.chunk_size), num_tokens=0) + labels = labels.view(-1) + logits = logits.view(-1, logits.shape[-1]) + + func = _get_chunked_ce_func() + loss = func.apply(logits, labels, self.chunk_size, self.ignore_index, self.reduction, self.dft) + + if self.reduction == 'mean': + return LossOutput(loss=loss, num_tokens=0) + num_tokens = (labels != self.ignore_index).float().sum().clamp(min=1) + return LossOutput(loss=loss, num_tokens=num_tokens) + + def _loss_from_logps(self, labels, logps): + mask = (labels != self.ignore_index).float() + per_token = -logps * logps.exp() if self.dft else -logps + if self.reduction == 'mean': + return LossOutput(loss=(per_token * mask).sum() / mask.sum().clamp(min=1), num_tokens=0) + return LossOutput(loss=(per_token * mask).sum(), num_tokens=mask.sum().clamp(min=1)) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index fe526ab46..d53019513 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -7,7 +7,7 @@ (https://arxiv.org/abs/2305.18290) """ from typing import TYPE_CHECKING, Dict, List, Optional, Union - +import math from twinkle.data_format import LossOutput from twinkle.loss.base import Loss from twinkle.utils.torch_utils import selective_log_softmax @@ -132,6 +132,12 @@ def __init__( **kwargs, ): super().__init__(ignore_index=ignore_index) + if loss_type not in ('sigmoid', 'hinge', 'ipo', 'kto_pair'): + raise ValueError(f'Unknown loss_type: {loss_type}') + if label_smoothing > 0 and loss_type != 'sigmoid': + raise ValueError( + f'label_smoothing > 0 is only defined for loss_type="sigmoid", ' + f'got loss_type="{loss_type}". Set label_smoothing=0.0 or switch to sigmoid.') self.beta = beta self.label_smoothing = label_smoothing self.loss_type = loss_type @@ -217,6 +223,11 @@ def _compute_dpo_loss( if self.loss_type == 'sigmoid': # Standard DPO loss: -log(sigmoid(beta * margin)) losses = -F.logsigmoid(logits) + # Apply label smoothing (only meaningful here: Bradley-Terry soft labels). + if self.label_smoothing > 0: + # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected + smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference + losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses elif self.loss_type == 'hinge': # Hinge loss variant losses = torch.relu(1 - logits) @@ -234,12 +245,6 @@ def _compute_dpo_loss( else: raise ValueError(f'Unknown loss_type: {self.loss_type}') - # Apply label smoothing if specified - if self.label_smoothing > 0: - # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected - smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference - losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses - return losses.mean() def __call__( @@ -321,7 +326,8 @@ def __call__( reference_chosen_logps = torch.zeros_like(policy_chosen_logps) reference_rejected_logps = torch.zeros_like(policy_rejected_logps) else: - return LossOutput(loss=torch.tensor(0.0, device=chosen_logps.device), num_tokens=0) + zero = (policy_chosen_logps.sum() + policy_rejected_logps.sum()) * 0.0 + return LossOutput(loss=zero, num_tokens=0) # Compute DPO loss dpo_loss = self._compute_dpo_loss( @@ -535,11 +541,23 @@ def __call__( # Odds ratio: log(odds_chosen / odds_rejected) # log_odds = log(p/(1-p)) = log(p) - log(1-p) - # Compute entirely in log-space to avoid exp() underflow: - # log(p) = avg_logps (already in log-space) - # log(1-p) = log1p(-exp(avg_logps)) (numerically stable via log1p) - log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps)) - log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps)) + # Compute log(1-p) = log(1 - exp(avg_logp)) numerically stably: + # - For x > -log(2): log(-expm1(x)) (avoids log(0) when p → 1) + # - For x ≤ -log(2): log1p(-exp(x)) (avoids cancellation when p → 0) + # ``avg_logp ∈ (-∞, 0]`` so the threshold partitions the safe regime. + log_two = math.log(2.0) + + def _log1mexp(x: 'torch.Tensor') -> 'torch.Tensor': + # Clamp at a tiny negative to keep both branches well-defined when p≈1. + x_safe = torch.clamp(x, max=-1e-7) + return torch.where( + x_safe > -log_two, + torch.log(-torch.expm1(x_safe)), + torch.log1p(-torch.exp(x_safe)), + ) + + log_odds_chosen = chosen_avg_logps - _log1mexp(chosen_avg_logps) + log_odds_rejected = rejected_avg_logps - _log1mexp(rejected_avg_logps) # ORPO odds ratio loss odds_ratio = log_odds_chosen - log_odds_rejected diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 3f7db4bfb..7c198ad02 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -41,6 +41,10 @@ def __init__( chunk_size: int = 512, **kwargs, ): + if not (0.0 <= beta <= 1.0): + raise ValueError(f'beta must be in [0, 1], got {beta}') + if temperature <= 0: + raise ValueError(f'temperature must be > 0, got {temperature}') self.beta = beta self.temperature = temperature self.ignore_index = ignore_index @@ -94,6 +98,7 @@ def __call__( labels=labels, beta=self.beta, temperature=self.temperature, + ignore_index=self.ignore_index, chunk_size=self.chunk_size, topk=topk, teacher_topk_logprobs=teacher_topk_logprobs, @@ -108,6 +113,7 @@ def _generalized_jsd_loss( labels=None, beta: float = 0.5, temperature: float = 1.0, + ignore_index: int = -100, chunk_size: int = 512, topk: Optional[int] = None, teacher_topk_logprobs=None, @@ -164,7 +170,7 @@ def _generalized_jsd_loss( # ── Mask valid (response) tokens ────────────────────────────────────── if labels is not None: - mask = labels != -100 # ignore_index is always -100 per convention + mask = labels != ignore_index # Vocab-size mismatch (e.g. Qwen2.5-VL-3B vs 7B): pad the smaller side # so both distributions are defined over the same token set. stu_dim = student_logits.shape[-1] @@ -178,12 +184,15 @@ def _generalized_jsd_loss( student_logits = student_logits[mask] # [num_valid, vocab/topk] teacher_logits = teacher_logits[mask] num_valid = mask.sum() + # ``[mask]`` already created fresh storage, so in-place divide is safe + # and avoids an extra [num_valid, V] allocation. + student_logits.div_(temperature) + teacher_logits.div_(temperature) else: - student_logits = student_logits.view(-1, student_logits.size(-1)) - teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) + # Keep logits, may be an infer scenario + student_logits = student_logits.reshape(-1, student_logits.size(-1)) / temperature + teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) / temperature num_valid = student_logits.size(0) - student_logits.div_(temperature) - teacher_logits.div_(temperature) if num_valid == 0: return student_logits.new_zeros(()) diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index 4bb71216c..781b22060 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -42,18 +42,6 @@ def __init__( self.require_entropy = entropy_coef > 0.0 self.ignore_index = ignore_index - def _compute_loss_mask(self, labels: 'torch.Tensor') -> 'torch.Tensor': - """ - Compute loss mask from labels. - - Args: - labels: [batch, seq_len] target token ids, -100 for ignored positions - - Returns: - mask: [batch, seq_len] float tensor, 1.0 for valid positions, 0.0 for ignored - """ - return (labels != self.ignore_index).float() - def _compute_log_importance_weights( self, per_token_logps: 'torch.Tensor', @@ -165,10 +153,13 @@ def _pad_and_align_to_batch( return data # Already aligned if data.dim() == 1: data = data.unsqueeze(1) - if data.shape[1] == 1: # Scalars - result = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device) - result[mask] = data[mask.any(dim=1).nonzero(as_tuple=True)[0].repeat_interleave(mask.sum(dim=1)), 0] - return result + if data.shape[1] == 1: + assert data.shape[0] == batch_size, ( + f'scalar broadcast expects data.shape[0]==batch_size, ' + f'got data.shape={tuple(data.shape)} mask.shape={(batch_size, seq_len)}') + fill = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device) + expanded = data.expand(batch_size, seq_len) + return torch.where(mask, expanded, fill) data = [data[i] for i in range(batch_size)] # To list # Handle list (scalars or sequences) @@ -276,10 +267,12 @@ def __call__( ) # GRPO loss is ill-defined without advantages (e.g. ref-logps-only forward, - # or eval/validation forwards). Return a zero loss so the forward still - # flows through cleanly and callers can harvest outputs['logps'] freely. + # or eval/validation forwards). Return a zero loss that still flows through + # autograd so DDP/FSDP do not see unused params, and callers can harvest + # outputs['logps'] freely. if advantages is None: - return LossOutput(loss=torch.zeros((), device=device, dtype=logps.dtype), num_tokens=0) + zero = logps.sum() * 0.0 + return LossOutput(loss=zero, num_tokens=0) advantages = self._pad_and_align_to_batch( advantages, diff --git a/src/twinkle/loss/infonce.py b/src/twinkle/loss/infonce.py index 68d14840c..c356bd64c 100644 --- a/src/twinkle/loss/infonce.py +++ b/src/twinkle/loss/infonce.py @@ -13,8 +13,6 @@ import numpy as np import torch import torch.distributed as dist -import torch.nn.functional as F -from enum import Enum from torch import nn from typing import Optional @@ -22,15 +20,6 @@ from .base import Loss -# Borrowed from sentence_transformers. -class SiameseDistanceMetric(Enum): - """Distance metrics available to the pairwise contrastive losses.""" - - EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa - MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa - COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa - - def _extract_sentences(outputs) -> torch.Tensor: """Return [B, D] sentence embeddings from postprocess_tensor_sp output. @@ -119,6 +108,11 @@ def __init__( process_group=None, **kwargs, ): + if mask_fake_negative and fake_neg_margin <= 0: + raise ValueError( + f'fake_neg_margin must be > 0 when mask_fake_negative=True, got {fake_neg_margin}. ' + 'A non-positive margin would mask out the positive itself or every above-positive ' + 'logit indiscriminately, collapsing the contrastive signal.') self.temperature = temperature self.use_batch = use_batch self.hard_negatives = hard_negatives @@ -129,7 +123,13 @@ def __init__( self.process_group = process_group def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor): - """All-gather embeddings & labels across DP ranks; only local shard keeps grad.""" + """All-gather embeddings & labels across DP ranks; only local shard keeps grad. + + NCCL ``all_gather`` requires every rank to send the *same* tensor size. Under + ``slice_dp`` dispatch the per-rank batch is uneven (``divmod`` splits), so we + pad each rank to the global max along dim-0, do an equal-sized all_gather, + then strip padding back. Only the local shard retains gradients. + """ if not (dist.is_available() and dist.is_initialized()): return sentences, labels world_size = dist.get_world_size(group=self.process_group) @@ -137,24 +137,40 @@ def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor): return sentences, labels rank = dist.get_rank(group=self.process_group) - # variable per-rank shapes require communicating shape first - local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) - shapes = [torch.empty_like(local_shape) for _ in range(world_size)] - dist.all_gather(shapes, local_shape, group=self.process_group) - all_sentences = [sentences.new_empty(shape.tolist()) for shape in shapes] - dist.all_gather(all_sentences, sentences.contiguous(), group=self.process_group) - - local_label_shape = labels.new_tensor(labels.shape, dtype=torch.long) - label_shapes = [torch.empty_like(local_label_shape) for _ in range(world_size)] - dist.all_gather(label_shapes, local_label_shape, group=self.process_group) - all_labels = [labels.new_empty(shape.tolist()) for shape in label_shapes] - dist.all_gather(all_labels, labels.contiguous(), group=self.process_group) - - # keep the local shard differentiable; detach others - all_sentences[rank] = sentences + # ``labels`` is a 1-D mask aligned to ``sentences`` along dim-0, so they + # share the same per-rank size. Gather sizes once and reuse for both. + assert sentences.shape[0] == labels.shape[0], ( + f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}') + local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long) + sizes = [torch.empty_like(local_n) for _ in range(world_size)] + dist.all_gather(sizes, local_n, group=self.process_group) + sizes_int = [int(s.item()) for s in sizes] + max_n = max(sizes_int) + + def _pad_gather(tensor: torch.Tensor): + if tensor.shape[0] < max_n: + pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:]) + padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0) + else: + padded = tensor + buffers = [torch.empty_like(padded) for _ in range(world_size)] + dist.all_gather(buffers, padded.contiguous(), group=self.process_group) + return buffers + + sent_buffers = _pad_gather(sentences) + label_buffers = _pad_gather(labels) + + # Strip padding; keep local shard differentiable, detach others. + all_sentences = [] + all_labels = [] for idx in range(world_size): - if idx != rank: - all_sentences[idx] = all_sentences[idx].detach() + n = sizes_int[idx] + if idx == rank: + all_sentences.append(sentences) + all_labels.append(labels) + else: + all_sentences.append(sent_buffers[idx][:n].detach()) + all_labels.append(label_buffers[idx][:n]) return torch.cat(all_sentences, dim=0), torch.cat(all_labels, dim=0) def __call__(self, inputs, outputs, **kwargs) -> LossOutput: diff --git a/src/twinkle/metric/accuracy.py b/src/twinkle/metric/accuracy.py index b3034c57a..4dfb01198 100644 --- a/src/twinkle/metric/accuracy.py +++ b/src/twinkle/metric/accuracy.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np from typing import List, Union from ..data_format import InputFeature, ModelOutput diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index b203d255e..024cb0473 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -131,6 +131,12 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M ref_outputs = kwargs.get('ref_outputs') if ref_outputs is not None: ref_logps = ref_outputs.get('logps') + if ref_logps is not None: + if isinstance(ref_logps, list): + if len(ref_logps) == 0: + ref_logps = None + else: + ref_logps = pad_and_stack_tensors(ref_logps) if ref_logps is not None: # Align ref_logps to match labels shape (handles different seq lengths) ref_logps = self._align_logps(ref_logps, labels.shape, labels.device, logps.dtype) diff --git a/src/twinkle/metric/embedding.py b/src/twinkle/metric/embedding.py index 9fb3aed8c..8b3681031 100644 --- a/src/twinkle/metric/embedding.py +++ b/src/twinkle/metric/embedding.py @@ -1,7 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import torch -import torch.distributed as dist -import torch.nn.functional as F from typing import List, Union from twinkle.data_format import InputFeature, ModelOutput @@ -32,6 +29,9 @@ def reset(self): self.grad_norm = 0.0 def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): + import torch + import torch.distributed as dist + import torch.nn.functional as F sentences = outputs.get('embeddings') if sentences is None: sentences = outputs.get('logits') @@ -44,22 +44,34 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M inputs = [inputs] labels = torch.cat([inp['labels'].view(-1) for inp in inputs], dim=0) - # Gather embeddings and labels across DP for in-batch stats + # Gather embeddings and labels across DP for in-batch stats. + # NCCL ``all_gather`` requires every rank to send the same tensor size, + # but ``slice_dp`` dispatch (``divmod`` split) can leave per-rank dim-0 + # uneven. Pad to the global max along dim-0, gather, then strip padding. if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: world_size = dist.get_world_size() - local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) - shapes = [torch.empty_like(local_shape) for _ in range(world_size)] - dist.all_gather(shapes, local_shape) - all_sentences = [sentences.new_empty(s.tolist()) for s in shapes] - dist.all_gather(all_sentences, sentences.contiguous()) - sentences = torch.cat(all_sentences, dim=0) - - local_lshape = labels.new_tensor(labels.shape, dtype=torch.long) - lshapes = [torch.empty_like(local_lshape) for _ in range(world_size)] - dist.all_gather(lshapes, local_lshape) - all_labels = [labels.new_empty(s.tolist()) for s in lshapes] - dist.all_gather(all_labels, labels.contiguous()) - labels = torch.cat(all_labels, dim=0) + assert sentences.shape[0] == labels.shape[0], ( + f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}') + local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long) + sizes = [torch.empty_like(local_n) for _ in range(world_size)] + dist.all_gather(sizes, local_n) + sizes_int = [int(s.item()) for s in sizes] + max_n = max(sizes_int) + + def _pad_gather(tensor: 'torch.Tensor') -> 'List[torch.Tensor]': + if tensor.shape[0] < max_n: + pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:]) + padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0) + else: + padded = tensor + buffers = [torch.empty_like(padded) for _ in range(world_size)] + dist.all_gather(buffers, padded.contiguous()) + return buffers + + sent_buffers = _pad_gather(sentences) + label_buffers = _pad_gather(labels) + sentences = torch.cat([sent_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0) + labels = torch.cat([label_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0) anchor_idx = torch.nonzero(labels, as_tuple=False).squeeze(-1) if anchor_idx.numel() == 0: diff --git a/src/twinkle/metric/grpo.py b/src/twinkle/metric/grpo.py index 06e082eeb..e2797b1ec 100644 --- a/src/twinkle/metric/grpo.py +++ b/src/twinkle/metric/grpo.py @@ -3,9 +3,12 @@ from typing import Any, Dict, List, Optional, Union from twinkle.data_format import InputFeature, ModelOutput +from twinkle.utils import get_logger from twinkle.utils.transformers_utils import align_logps_to_mask from .base import Metric +logger = get_logger() + class GRPOMetric(Metric): @@ -254,6 +257,11 @@ def accumulate( if len(seq_lens) == 1: merged = torch.cat(label_tensors, dim=0) inputs_list = [{'labels': merged}] + else: + logger.warning( + f'GRPOMetric: logps is a single tensor but inputs_list has ' + f'{len(inputs_list)} mb with mismatched seq_lens={sorted(seq_lens)}. ' + f'Only mb[0] will be accumulated; check the model forward path.') flat_old: Optional[List] = None if old_logps is not None and isinstance(old_logps, (list, tuple)): @@ -284,7 +292,17 @@ def accumulate( # Uncommon: aligned global tensor. Only honour when it # exactly matches the single-mb shape; otherwise drop. import torch as _torch # noqa: F811 - old_slice = old_logps if (_torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape) else None + if _torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape: + old_slice = old_logps + else: + if mb_idx == 0: + # Warn once per accumulate call (not per mb) to avoid log spam. + old_shape = tuple(old_logps.shape) if _torch.is_tensor(old_logps) else 'unknown' + logger.warning( + f'GRPOMetric: old_logps shape {old_shape} does not match ' + f'logps_mb shape {tuple(logps_mb.shape)}; ratio/kl metrics will ' + f'be skipped for this step.') + old_slice = None else: old_slice = None diff --git a/src/twinkle/metric/train_metric.py b/src/twinkle/metric/train_metric.py index da82a8783..8d785c38b 100644 --- a/src/twinkle/metric/train_metric.py +++ b/src/twinkle/metric/train_metric.py @@ -2,7 +2,7 @@ import time from typing import List, Union -from ..data_format import InputFeature, ModelOutput +from twinkle.data_format import InputFeature, ModelOutput from .base import Metric diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index a4d4ea064..8ea00d696 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -1,8 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import ABC, abstractmethod -from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union from twinkle import Platform, torch_util from twinkle.data_format import InputFeature, ModelOutput diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index a5ea3fc56..60b45f774 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -420,7 +420,7 @@ def forward_step_func(data_iterator, model): embeddings = output_tensor elif labels is not None and is_last_pp: _loss_require_logps = getattr(_loss_instance, 'require_logps', True) - _loss_require_entropy = (hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy) + _loss_require_entropy = getattr(_loss_instance, 'require_entropy', True) _packed = batch.get('packed_seq_params') cu_seqlens_q = getattr(_packed, 'cu_seqlens_q', None) if _packed is not None else None if _loss_require_logps: @@ -446,7 +446,7 @@ def forward_step_func(data_iterator, model): _outputs = {'logps': logps} if entropies is not None: _outputs['entropies'] = entropies - if hasattr(_loss_instance, 'require_logits') and _loss_instance.require_logits: + if getattr(_loss_instance, 'require_logits', False): _outputs['logits'] = output_tensor batch, _outputs = processor.unpack_packed_sequences(batch, _outputs) logps = _outputs['logps'] @@ -990,7 +990,9 @@ def _get_rng_state() -> 'ShardedObject': 'random_rng_state': random.getstate(), 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state(), + # Backend-agnostic device RNG (CUDA / NPU / MPS); key kept as + # 'cuda_rng_state' for backward compatibility with existing checkpoints. + 'cuda_rng_state': Platform.get_device_rng_state(), 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states(), } rng_state_list = [rng_state] @@ -1112,7 +1114,7 @@ def _save_mcore_optimizer( with open(tracker_path, 'w') as f: f.write(str(iteration)) - logging.getLogger(__name__).info(f'Saved mcore optimizer state at iteration {iteration} ' + logger.info(f'Saved mcore optimizer state at iteration {iteration} ' f'to {checkpoint_dir}') def _load_mcore_optimizer( @@ -1139,7 +1141,7 @@ def _load_mcore_optimizer( ) iteration = self._read_iteration(tracker_path) if iteration == 0: - logging.getLogger(__name__).warning(f'No checkpoint found in {checkpoint_dir}') + logger.warning(f'No checkpoint found in {checkpoint_dir}') return iter_dir = os.path.join(checkpoint_dir, f'iter_{iteration:07d}') @@ -1201,7 +1203,9 @@ def _load_mcore_optimizer( random.setstate(rng['random_rng_state']) np.random.set_state(rng['np_rng_state']) torch.set_rng_state(rng['torch_rng_state']) - torch.cuda.set_rng_state(rng['cuda_rng_state']) + # Backend-agnostic restore: tolerates ckpt produced on different backend + # (returns None) and avoids hard-coded torch.cuda which crashes on NPU. + Platform.set_device_rng_state(rng.get('cuda_rng_state')) tensor_parallel.get_cuda_rng_tracker().set_states(rng['rng_tracker_states'], ) # Restore iteration counter. @@ -1211,26 +1215,26 @@ def _load_mcore_optimizer( if dist.is_initialized(): dist.barrier() - logging.getLogger(__name__).info(f'Resumed from mcore checkpoint at iteration {iteration} ' + logger.info(f'Resumed from mcore checkpoint at iteration {iteration} ' f'from {checkpoint_dir}') @staticmethod def _read_iteration(tracker_path: str) -> int: - if not os.path.exists(tracker_path): - return 0 - with open(tracker_path) as f: - iteration = int(f.read().strip()) + # All ranks must enter the all_reduce together; missing tracker on some + # ranks (e.g. NFS lag, partial mount) must NOT short-circuit, otherwise + # the remaining ranks hang at the collective. Treat missing as 0 and + # let MAX reduction recover the canonical iteration from any rank that + # successfully read the file. + iteration = 0 + if os.path.exists(tracker_path): + with open(tracker_path) as f: + iteration = int(f.read().strip()) if torch.distributed.is_initialized(): - iters_cuda = torch.tensor( - [iteration], - dtype=torch.long, - device='cuda', - ) - torch.distributed.all_reduce( - iters_cuda, - op=torch.distributed.ReduceOp.MAX, - ) - iteration = iters_cuda[0].item() + # Use Platform.get_local_device() to stay backend-agnostic + # (CUDA / NPU / MPS); 'cuda' would crash on NPU. + iters_dev = torch.tensor([iteration], dtype=torch.long, device=Platform.get_local_device()) + torch.distributed.all_reduce(iters_dev, op=torch.distributed.ReduceOp.MAX) + iteration = int(iters_dev[0].item()) return iteration def _merge_lora_adapters(self, adapter_name: str = 'default'): @@ -1256,7 +1260,7 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non For distributed training: - All PP ranks participate in export (each has different layers) - - Only DP rank 0 actually writes to disk + - Only global rank 0 actually writes shared config files - Uses barrier for synchronization For LoRA training: @@ -1264,12 +1268,9 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non """ # Check if this is LoRA training is_peft_format = (adapter_name != _default_adapter_name) + is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0 - # Create output directory on rank 0 only - from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 - - if dp_rank == 0: + if is_global_zero: os.makedirs(output_dir, exist_ok=True) # Synchronize before saving @@ -1281,8 +1282,8 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non self.strategy.bridge.save_weights( model, output_dir, peft_format=is_peft_format, adapter_name=adapter_name, converter=lora_converter) - # Save config on rank 0 only - if dp_rank == 0: + # Save config on global rank 0 only (avoid concurrent writers). + if is_global_zero: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): config = model[0].peft_config[adapter_name] @@ -1291,11 +1292,13 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non model[0].peft_config[adapter_name].save_pretrained(output_dir) config.target_modules = target_modules + if dist.is_initialized(): + dist.barrier() + def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None): """Save in Megatron checkpoint format.""" + is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0 os.makedirs(output_dir, exist_ok=True) - from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 state_dict = self._get_trainable_parameters(adapter_name) cpu_state_dict = {} for k, v in state_dict.items(): @@ -1311,13 +1314,18 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert rank = dist.get_rank() if dist.is_initialized() else 0 checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt') torch.save(cpu_state_dict, checkpoint_path) - # Save config on rank 0 only + # Save shared config on global rank 0 only (avoid concurrent writers). model = self.strategy.unwrap_model(self.model) - if dp_rank == 0: + if is_global_zero: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): model[0].peft_config[adapter_name].save_pretrained(output_dir) + # Finalize barrier: ensure all ranks finish writing model_rank*.pt + # before the caller proceeds (e.g. uploading / loading the ckpt). + if dist.is_initialized(): + dist.barrier() + def _save_tokenizer(self, output_dir: str, **kwargs): from twinkle.utils import is_last_rank if not is_last_rank(): diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 2dd6b7a53..78981b888 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -15,7 +15,7 @@ from transformers import AutoConfig, PretrainedConfig from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union -from twinkle import DeviceMesh, remote_class, remote_function, requires, template, torch_util +from twinkle import DeviceMesh, Platform, remote_class, remote_function, requires, template, torch_util from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation from twinkle.infra import collect_tensor_dict @@ -26,6 +26,9 @@ from ._mindspeed_runtime import ensure_mindspeed_adaptor_patched from .megatron import MegatronModel from .strategy import MegatronStrategy +from twinkle.utils import get_logger + +logger = get_logger() @remote_class(execute='all') @@ -221,8 +224,11 @@ def _save_local_training_rng_state(): 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), } - if torch.cuda.is_available(): - rng_state['cuda_rng_state'] = torch.cuda.get_rng_state() + # Backend-agnostic device RNG capture (CUDA / NPU / MPS). Key is kept as + # 'cuda_rng_state' for backward compatibility with existing checkpoints. + device_rng = Platform.get_device_rng_state() + if device_rng is not None: + rng_state['cuda_rng_state'] = device_rng rng_state['rng_tracker_states'] = tensor_parallel.get_cuda_rng_tracker().get_states() return rng_state @@ -233,8 +239,10 @@ def _load_local_training_rng_state(rng_state): random.setstate(rng_state['random_rng_state']) np.random.set_state(rng_state['np_rng_state']) torch.set_rng_state(rng_state['torch_rng_state']) - if 'cuda_rng_state' in rng_state and torch.cuda.is_available(): - torch.cuda.set_rng_state(rng_state['cuda_rng_state']) + # Backend-agnostic device RNG restore: tolerates ckpt produced on different + # backend (key absent or None) and avoids hard-coded torch.cuda on NPU. + if 'cuda_rng_state' in rng_state: + Platform.set_device_rng_state(rng_state['cuda_rng_state']) tensor_parallel.get_cuda_rng_tracker().set_states(rng_state['rng_tracker_states']) def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kwargs): @@ -251,19 +259,35 @@ def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kw torch.save(state_dict, self._rank_local_optimizer_path(checkpoint_dir)) + if dist.is_initialized(): + dist.barrier() + def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '', **kwargs): no_load_optim = kwargs.pop('no_load_optim', False) - no_load_rng = kwargs.pop('no_load_rng', False) + no_load_rng = kwargs.pop('no_load_rng', True) optimizer_config = self.optimizer_group.get(adapter_name) state_dict = torch.load(self._rank_local_optimizer_path(checkpoint_dir), map_location='cpu', weights_only=False) if not no_load_optim and optimizer_config is not None: if optimizer_config.optimizer is not None and 'optimizer' in state_dict: optimizer_config.optimizer.load_state_dict(state_dict['optimizer']) + device = Platform.get_local_device() + for group_state in optimizer_config.optimizer.state.values(): + if not isinstance(group_state, dict): + continue + for k, v in group_state.items(): + if isinstance(v, torch.Tensor): + group_state[k] = v.to(device) if optimizer_config.lr_scheduler is not None and 'opt_param_scheduler' in state_dict: optimizer_config.lr_scheduler.load_state_dict(state_dict['opt_param_scheduler']) + # RNG state is intentionally not restored in multi-tenant mode: + # restoring the global RNG would silently affect other active tenants' + # dropout / initialization behaviour. if not no_load_rng and 'rng_state' in state_dict: - self._load_local_training_rng_state(state_dict['rng_state']) + logger.warning( + 'Skipping RNG state restoration in multi-tenant mode. ' + 'Global RNG is shared across tenants; restoring it would ' + 'affect other active adapters.') if optimizer_config is not None and 'iteration' in state_dict: optimizer_config.cur_step = state_dict['iteration'] @@ -354,6 +378,11 @@ def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **k self._check_adapter_valid(adapter_name) trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + if not os.path.isfile(trainer_state_path): + raise FileNotFoundError( + f'trainer_state.json not found in {checkpoint_dir}. ' + f'Ensure the checkpoint was saved with save_optimizer=True.') + with open(trainer_state_path) as f: trainer_state = json.load(f) diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 819014eb8..1bd809025 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -48,7 +48,6 @@ def __init__( ddp_config: Dict[str, Any] = None, **kwargs, ): - import torch.distributed as dist from megatron.core import mpu self.device_mesh = device_mesh self.use_distributed_optimizer = use_distributed_optimizer diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index d0434991f..018f4c494 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -34,10 +34,8 @@ def __init__( parallelism_config = self._parallelism_config_from_device_mesh(device_mesh) fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init) - kwargs_handlers = [] - kwargs_handlers.append( - InitProcessGroupKwargs( - timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))) + kwargs_handlers = [InitProcessGroupKwargs( + timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))] if ddp_config is not None: from accelerate import DistributedDataParallelKwargs ddp_config = DistributedDataParallelKwargs(**ddp_config) @@ -131,8 +129,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di return fsdp_plugin def wrap_model(self, model, *args): - result = self.accelerator.prepare(model, *args) - return result + return self.accelerator.prepare(model, *args) def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 61733d7dc..cac375263 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -414,8 +414,8 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec inputs = optimizer_config.template.batch_encode(inputs) # noqa processor: InputProcessor = optimizer_config.processor loss_instance = optimizer_config.loss_instance - loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) - loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logits = getattr(loss_instance, 'require_logits', False) + loss_require_entropy = getattr(loss_instance, 'require_entropy', False) loss_require_logps = getattr(loss_instance, 'require_logps', True) assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding' inputs: Dict[str, Any] = processor( @@ -490,8 +490,8 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T processor: InputProcessor = optimizer_config.processor assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' loss_instance = optimizer_config.loss_instance - loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) - loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logits = getattr(loss_instance, 'require_logits', False) + loss_require_entropy = getattr(loss_instance, 'require_entropy', False) loss_require_logps = getattr(loss_instance, 'require_logps', True) inputs: Dict[str, Any] = processor( inputs, @@ -929,7 +929,6 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int if optimizer_config.cur_step % interval != 0: return model = self.strategy.unwrap_model(self.model) - processed_state_dict = {} save_kwargs = {} if adapter_name == _default_adapter_name: # Full model save diff --git a/src/twinkle/notifier/__init__.py b/src/twinkle/notifier/__init__.py index 329cb6f1d..067db71a1 100644 --- a/src/twinkle/notifier/__init__.py +++ b/src/twinkle/notifier/__init__.py @@ -1,2 +1,3 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. from .base import Notifier, notify_exception from .ding_notifier import DingNotifier diff --git a/src/twinkle/notifier/base.py b/src/twinkle/notifier/base.py index a83903b53..6f50ca659 100644 --- a/src/twinkle/notifier/base.py +++ b/src/twinkle/notifier/base.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import os from typing import Dict, Optional @@ -66,7 +67,7 @@ def notify_exception(notifier: Notifier, context: str, exc: BaseException, name: if not _try_claim_notify_slot(exc, context, name): try: setattr(exc, '_twinkle_notified', True) - except Exception: # noqa: BLE001 + except Exception: # noqa pass return diff --git a/src/twinkle/notifier/ding_notifier.py b/src/twinkle/notifier/ding_notifier.py index fe102d8a5..fc535edd7 100644 --- a/src/twinkle/notifier/ding_notifier.py +++ b/src/twinkle/notifier/ding_notifier.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import base64 import hashlib import hmac diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index 97065fba1..39d3257be 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -48,9 +48,9 @@ def preprocess(self, row) -> Trajectory: class SelfCognitionProcessor(Preprocessor): - def __init__(self, model_name, model_author): - self.model_name = model_name - self.model_author = model_author + def __init__(self, model_name=None, model_author=None): + self.model_name = model_name or 'twinkle robot' + self.model_author = model_author or 'twinkle lab' def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 8709d98ab..14a0e206b 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -42,7 +42,6 @@ class InputProcessor: 'video_grid_thw': 0, 'input_features': 0.0, 'feature_attention_mask': 0, - 'mm_token_type_ids': 0, } # VLM fields to concatenate (not pad) in batch @@ -108,8 +107,12 @@ def to_tensor(_input): # so tensor ops like labels != ignore_index or .to(device) would fail without this. if isinstance(value, np.ndarray): value = torch.from_numpy(value) - elif (isinstance(value, list) and isinstance(value[0], - (int, float, np.number))) or key == 'position_ids': + elif isinstance(value, list) and len(value) > 0 and isinstance( + value[0], (int, float, np.number)): + value = torch.tensor(value) + elif key == 'position_ids' and not isinstance(value, torch.Tensor): + if value is None: + continue value = torch.tensor(value) elif (isinstance(value, list)) and key in ('completion_mask', 'mm_token_type_ids'): value = torch.tensor(value) @@ -284,7 +287,9 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso return input_tensor if cp_size > 1: - position_ids_f = position_ids.flatten() + pos_for_cu = position_ids[:1] if position_ids.dim() >= 2 and position_ids.shape[0] > 1 \ + else position_ids + position_ids_f = pos_for_cu.flatten() indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) cu_seqlens = torch.cat([ indices_q[position_ids_f == 0], @@ -354,8 +359,11 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) val = val.view(view_shape) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', - pin_memory=True).cuda(non_blocking=True) + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], + device=inputs.device, + dtype=torch.long, + ) val = val.index_select(dim, index) view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:]) new_inputs.append(val.view(view_shape)) @@ -402,17 +410,18 @@ def prepare_transformers_padding_free_patch(self, inputs: List[InputFeature], ** if not padding_free or bool(kwargs.get('enable_sp', False)): return inputs - from twinkle.patch import apply_patch - from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch - - apply_patch( - model, - GatedDeltaNetPaddingFreePatch, - hf_config=kwargs.get('hf_config'), - enable_sp=False, - ) if not getattr(model, '_twinkle_gdn_padding_free_patched', False): - return inputs + from twinkle.patch import apply_patch + from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch + + apply_patch( + model, + GatedDeltaNetPaddingFreePatch, + hf_config=kwargs.get('hf_config'), + enable_sp=False, + ) + if not getattr(model, '_twinkle_gdn_padding_free_patched', False): + return inputs for _inp in inputs: position_ids = _inp.get('position_ids') @@ -631,15 +640,27 @@ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeat output = {} _keys = [ 'input_ids', - 'input_embeddings', + 'inputs_embeds', 'attention_mask', 'position_ids', 'labels', 'completion_mask', + 'cu_seq_lens_q', + 'cu_seq_lens_k', + 'cu_seqlens_q', + 'cu_seqlens_kv', + 'max_length_q', + 'max_length_k', + 'packed_seq_params', ] + list(InputProcessor.VLM_CONCAT_FIELDS) for key in list(_input.keys()): - if key in _keys: - output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key] + if key not in _keys: + continue + value = _input[key] + if isinstance(value, torch.Tensor) or not isinstance(value, (list, np.ndarray)): + output[key] = value + else: + output[key] = np.array(value) results.append(InputFeature(**output)) return results @@ -694,7 +715,8 @@ def is_mm_position_ids(position_ids): result[key] = self._create_4d_attention_mask(values) elif key == 'position_ids' and is_mm_position_ids(values[0]): result[key] = InputProcessor._pad_sequence(values, self.padding_map[key], self.padding_side) - result[key] = result[key].reshape(values[0].shape[0], len(values), -1) + num_axes = values[0].shape[0] + result[key] = result[key].reshape(len(values), num_axes, -1).permute(1, 0, 2).contiguous() elif isinstance(values[0], torch.Tensor): result[key] = InputProcessor._pad_sequence(values, self.padding_map[key], self.padding_side) if result[key].dim() == 1: @@ -776,6 +798,5 @@ def postprocess_tensor_cp(self, tensor, cu_seqlens=None): if self.device_mesh.cp_world_size <= 1: return tensor from megatron.core import parallel_state as mpu - from twinkle.utils.torch_utils import gather_cp_load_balanced return gather_cp_load_balanced(tensor, mpu.get_context_parallel_group(), seq_dim=1, cu_seqlens=cu_seqlens) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 32dc1ca50..0433ef5a8 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -1,24 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""vLLM-based sampler using VLLMEngine (AsyncLLM). - -Device Configuration: - vLLMSampler automatically detects the number of available GPUs from - CUDA_VISIBLE_DEVICES environment variable (set by twinkle's ResourceManager) - and configures vLLM's tensor_parallel_size accordingly. - - To use tensor parallelism, configure DeviceGroup with gpus_per_worker > 1: - - # DP2 with TP2 (4 GPUs total, 2 workers, each with 2 GPUs) - DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=2) - - # TP4 (4 GPUs, 1 worker with all 4 GPUs) - DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=4) - -Data Flow: - When multiple vLLMSampler workers exist (DP > 1): - - Data is dispatched via dispatch='slice_dp' (each worker gets a slice) - - Results are collected via collect='flatten' (merged into single list) -""" import asyncio import atexit import numpy as np diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py index 326a6916f..be24c7824 100644 --- a/src/twinkle/server/state/backend/factory.py +++ b/src/twinkle/server/state/backend/factory.py @@ -1,12 +1,11 @@ """Backend factory for creating StateBackend instances based on configuration.""" from __future__ import annotations -import logging - from twinkle.server.config.persistence import PersistenceConfig +from twinkle.utils import get_logger from .base import StateBackend -logger = logging.getLogger(__name__) +logger = get_logger() def create_backend(config: PersistenceConfig | None = None) -> StateBackend: diff --git a/src/twinkle/server/state/base.py b/src/twinkle/server/state/base.py index d931ce336..8cb055ae4 100644 --- a/src/twinkle/server/state/base.py +++ b/src/twinkle/server/state/base.py @@ -1,7 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations -import logging import time from abc import ABC, abstractmethod from datetime import datetime, timezone @@ -9,9 +8,10 @@ from typing import Generic, TypeVar from twinkle.server.state.backend.base import StateBackend +from twinkle.utils import get_logger T = TypeVar('T', bound=BaseModel) -logger = logging.getLogger(__name__) +logger = get_logger() class BaseManager(ABC, Generic[T]): diff --git a/src/twinkle/server/state/session_manager.py b/src/twinkle/server/state/session_manager.py index 442019ea4..1bc901b0c 100644 --- a/src/twinkle/server/state/session_manager.py +++ b/src/twinkle/server/state/session_manager.py @@ -2,14 +2,14 @@ from __future__ import annotations import functools -import logging import time +from twinkle.utils import get_logger from .backend.base import ConcurrencyError, StateBackend from .base import BaseManager from .models import SessionRecord -logger = logging.getLogger(__name__) +logger = get_logger() def _session_touch_transform(existing: dict | None, *, now: float) -> dict | None: diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py index 77212c757..059f301fb 100644 --- a/src/twinkle/server/telemetry/provider.py +++ b/src/twinkle/server/telemetry/provider.py @@ -16,8 +16,9 @@ from typing import Any from twinkle.server.config.telemetry import TelemetryConfig +from twinkle.utils import get_logger -logger = logging.getLogger(__name__) +logger = get_logger() # Loggers belonging to the OTLP transport stack. Their own records must never # be routed back through the OTLP LoggingHandler: an exporter error logged diff --git a/src/twinkle/server/telemetry/worker_init.py b/src/twinkle/server/telemetry/worker_init.py index 997f2e140..40edc6628 100644 --- a/src/twinkle/server/telemetry/worker_init.py +++ b/src/twinkle/server/telemetry/worker_init.py @@ -7,10 +7,11 @@ """ from __future__ import annotations -import logging import os -logger = logging.getLogger(__name__) +from twinkle.utils import get_logger + +logger = get_logger() _worker_initialized = False diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 6c4bdddd2..168456eab 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -2,3 +2,4 @@ from .base import Template from .deepseek_v4 import DeepseekV4Template from .qwen3_5_vl import Qwen3_5Template +from .tools import ToolCallParser, ToolCallRegistry, ClineParser, HermesQwenParser, ReActParser, VCPParser diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 26c2e4f26..c1e8f069f 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -32,6 +32,19 @@ class Template: video_placeholder: str = '