diff --git a/INSTALL_MEGATRON.sh b/INSTALL_MEGATRON.sh
index 276598c50..951b0171d 100644
--- a/INSTALL_MEGATRON.sh
+++ b/INSTALL_MEGATRON.sh
@@ -5,6 +5,7 @@
set -e # Exit immediately on error
export SETUPTOOLS_USE_DISTUTILS=local
+export UV_INDEX_URL=${UV_INDEX_URL:-https://mirrors.aliyun.com/pypi/simple/}
echo "=========================================="
echo "Starting deep learning dependencies installation..."
echo "=========================================="
@@ -53,15 +54,15 @@ TORCH_CUDA_ARCH_LIST=$(get_cuda_arch "$GPU_NAME")
export TORCH_CUDA_ARCH_LIST
echo "Using CUDA architecture: $TORCH_CUDA_ARCH_LIST"
-# Install latest base packages
+# Install vllm 0.21.x (latest 0.2x uses CUDA 12 toolchain, avoids CUDA 13 CUTLASS conflicts)
echo ""
-echo "Installing peft, accelerate, transformers, modelscope..."
-pip install --upgrade peft accelerate transformers "modelscope[framework]" --no-cache-dir
+echo "Installing vllm 0.21..."
+uv pip install "vllm>=0.21,<0.22"
-# Install latest vllm
+# Install latest base packages
echo ""
-echo "Installing latest vllm..."
-pip install --upgrade vllm --no-cache-dir
+echo "Installing peft, accelerate, transformers, modelscope..."
+uv pip install --upgrade peft accelerate transformers "modelscope[framework]"
# Get site-packages path and install transformer_engine and megatron_core
echo ""
@@ -69,26 +70,30 @@ echo "Installing transformer_engine and megatron_core..."
SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
echo "Site-packages path: $SITE_PACKAGES"
-CUDNN_PATH=$SITE_PACKAGES/nvidia/cudnn \
-CPLUS_INCLUDE_PATH=$SITE_PACKAGES/nvidia/cudnn/include \
-pip install --no-build-isolation "transformer_engine[pytorch]" --no-cache-dir
+export CUDA_HOME=${SITE_PACKAGES}/nvidia/cu13
+export PATH=$CUDA_HOME/bin:$PATH
+export CPATH=$CUDA_HOME/include:$CPATH
+export LIBRARY_PATH=$CUDA_HOME/lib:$LIBRARY_PATH
+export LD_LIBRARY_PATH=$CUDA_HOME/lib:$LD_LIBRARY_PATH
+uv pip install transformer_engine_torch --no-build-isolation
-pip install megatron_core mcore_bridge --no-cache-dir
+uv pip install megatron_core mcore_bridge
-# Install flash-attention (force local build)
+# Install flash-attention
+# Prefer prebuilt wheel; fall back to source build only if needed.
echo ""
-echo "Installing flash-attention (local build for $GPU_NAME)..."
-TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" \
-MAX_JOBS=8 \
-FLASH_ATTENTION_FORCE_BUILD=TRUE \
-pip install flash-attn --no-build-isolation --no-cache-dir
+echo "Installing flash-attention..."
+export TORCH_CUDA_ARCH_LIST
+export MAX_JOBS=8
+pip install flash-attn --no-cache-dir || \
+ FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn --no-build-isolation --no-cache-dir
-pip install flash-linear-attention -U --no-cache-dir
+uv pip install flash-linear-attention --upgrade
# Install numpy
echo ""
echo "Installing numpy==2.2 and deep_gemm..."
-pip install numpy==2.2 --no-cache-dir
+uv pip install numpy==2.2
# Verify installation
echo ""
diff --git a/README.md b/README.md
index 4dd203cfb..4867358a3 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@
- English Documentation   |   中文文档   |   Twinkle Web  
+ English Documentation   |   中文文档   |   Twinkle Web  
## ✨ What is Twinkle?
@@ -100,11 +100,11 @@ sh INSTALL_MEGATRON.sh
| DPO multi-LoRA training | transformers | [Script](cookbook/rl/dpo_multi_lora.py) |
| GKD on-policy distillation | megatron | [Script](cookbook/rl/gkd_on_policy.py) |
| GKD off-policy distillation | megatron | [Script](cookbook/rl/gkd_off_policy.py) |
-| Tinker client finetuning (self-host) | transformers | [Script](cookbook/client/tinker/self_host) |
-| Tinker client finetuning (ModelScope) | transformers | [Script](cookbook/client/tinker/modelscope) |
-| Twinkle client finetuning (self-host) | transformers | [Script](cookbook/client/twinkle/self_host) |
-| Twinkle client finetuning (ModelScope) | transformers | [Script](cookbook/client/twinkle/modelscope) |
-| Server startup scripts | transformers/megatron | [Script](cookbook/client/server) |
+| Tinker client finetuning (self-host) | transformers | [Script](cookbook/server_mode/tinker/self_host) |
+| Tinker client finetuning (ModelScope) | transformers | [Script](cookbook/server_mode/tinker/modelscope) |
+| Twinkle client finetuning (self-host) | transformers | [Script](cookbook/server_mode/twinkle/self_host) |
+| Twinkle client finetuning (ModelScope) | transformers | [Script](cookbook/server_mode/twinkle/modelscope) |
+| Server startup scripts | transformers/megatron | [Script](cookbook/server_mode/server) |
## Changelog
- 🎉2026-05-20 Support DeepSeek-V4-Flash and DeepSeek-V4-Pro models.
@@ -122,7 +122,7 @@ sh INSTALL_MEGATRON.sh
We are rolling out training service built atop Twinkle✨ on ModelScope. You may
train via API endpoint `base_url=https://www.modelscope.cn/twinkle`. For more details, please refer to
-our [documentation](docs/source_en/Usage%20Guide/Train-as-a-Service.md).
+our [documentation](https://modelscope.github.io/twinkle-web/docs/usage-guide/train-as-a-service/).
## Supported Hardware
@@ -177,7 +177,7 @@ supported on Twinkle✨ framework.
## Sample Code
Below are some of the capabilities demonstrated in the example code. For a complete introduction to training capabilities,
-please refer to [Quick Start](docs/source_en/Usage%20Guide/Quick-Start.md) and [cookbook](cookbook).
+please refer to [Quick Start](https://modelscope.github.io/twinkle-web/docs/usage-guide/quick-start/) and [cookbook](cookbook).
### Train with Ray
diff --git a/README_ZH.md b/README_ZH.md
index 5d588b393..92923209c 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -19,7 +19,7 @@ by ModelScope &
- 英文文档   |   中文文档   |   Twinkle 站点  
+ 英文文档   |   中文文档   |   Twinkle 站点  
## ✨ Twinkle 是什么?
@@ -94,13 +94,13 @@ sh INSTALL_MEGATRON.sh
| DPO 多 LoRA 训练 | transformers | [脚本](cookbook/rl/dpo_multi_lora.py) |
| GKD 在线蒸馏 | megatron | [脚本](cookbook/rl/gkd_on_policy.py) |
| GKD 离线蒸馏 | megatron | [脚本](cookbook/rl/gkd_off_policy.py) |
-| Tinker 客户端微调(自部署) | transformers | [脚本](cookbook/client/tinker/self_host) |
-| Tinker 客户端微调(ModelScope) | transformers | [脚本](cookbook/client/tinker/modelscope) |
-| Twinkle 客户端微调(自部署) | transformers | [脚本](cookbook/client/twinkle/self_host) |
-| Twinkle 客户端微调(ModelScope) | transformers | [脚本](cookbook/client/twinkle/modelscope) |
-| 服务端启动脚本 | transformers/megatron | [脚本](cookbook/client/server) |
+| Tinker 客户端微调(自部署) | transformers | [脚本](cookbook/server_mode/tinker/self_host) |
+| Tinker 客户端微调(ModelScope) | transformers | [脚本](cookbook/server_mode/tinker/modelscope) |
+| Twinkle 客户端微调(自部署) | transformers | [脚本](cookbook/server_mode/twinkle/self_host) |
+| Twinkle 客户端微调(ModelScope) | transformers | [脚本](cookbook/server_mode/twinkle/modelscope) |
+| 服务端启动脚本 | transformers/megatron | [脚本](cookbook/server_mode/server) |
-Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Client等各场景下。其算法过程是外露的,非常便于修改和调试。完整的框架介绍请查看[快速开始](docs/source_zh/使用指引/快速开始.md)
+Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Client等各场景下。其算法过程是外露的,非常便于修改和调试。完整的框架介绍请查看[快速开始](https://modelscope.github.io/twinkle-web/zh/docs/usage-guide/quick-start/)
## 更新日志
- 🎉2026-05-20 支持DeepSeek-V4-Flash and DeepSeek-V4-Pro系列模型。
@@ -116,7 +116,7 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl
## ModelScope 的训练服务
-我们正在 ModelScope 上推出基于 Twinkle✨ 构建的训练服务。你可以通过 API 端点 `base_url=https://www.modelscope.cn/twinkle` 进行训练。更多详情请参阅我们的[文档](docs/source_zh/使用指引/训练服务.md)。
+我们正在 ModelScope 上推出基于 Twinkle✨ 构建的训练服务。你可以通过 API 端点 `base_url=https://www.modelscope.cn/twinkle` 进行训练。更多详情请参阅我们的[文档](https://modelscope.github.io/twinkle-web/zh/docs/usage-guide/train-as-a-service/)。
## 支持的硬件
@@ -166,7 +166,7 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl
## 示例代码
-下面列出了示例代码的一部分能力。完整的训练能力介绍请参考[快速开始](docs/source_zh/使用指引/快速开始.md)以及[cookbook](cookbook)。
+下面列出了示例代码的一部分能力。完整的训练能力介绍请参考[快速开始](https://modelscope.github.io/twinkle-web/zh/docs/usage-guide/quick-start/)以及[cookbook](cookbook)。
### 使用 Ray 训练
diff --git a/cookbook/exp/condenser/untested/eval_condensed_compressed.sh b/cookbook/exp/condenser/untested/eval_condensed_compressed.sh
index 5567a1a3b..833b446fa 100755
--- a/cookbook/exp/condenser/untested/eval_condensed_compressed.sh
+++ b/cookbook/exp/condenser/untested/eval_condensed_compressed.sh
@@ -3,14 +3,14 @@
# Identical --dataset / --limit / --model_id as eval_condensed_native.sh for an A/B comparison.
set -euo pipefail
-DATASET="${DATASET:-/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl}"
-MODEL_ID="${MODEL_ID:-ms://Qwen/Qwen3.5-4B}"
-CONDENSER_LORA="${CONDENSER_LORA:-ms://twinkle-kit/Qwen3.5-4B-Condenser}"
-LIMIT="${LIMIT:-500}"
-NUM_GPUS="${NUM_GPUS:-4}"
-OUT_DIR="${OUT_DIR:-eval_out}"
+DATASET="/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl"
+MODEL_ID="ms://Qwen/Qwen3.5-4B"
+CONDENSER_LORA="ms://twinkle-kit/Qwen3.5-4B-Condenser"
+LIMIT="500"
+NUM_GPUS="4"
+OUT_DIR="eval_out"
-CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} \
+CUDA_VISIBLE_DEVICES=0,1,2,3 \
python cookbook/exp/eval_condensed.py \
--mode condensed \
--dataset_format musique \
diff --git a/cookbook/exp/condenser/untested/eval_condensed_native.sh b/cookbook/exp/condenser/untested/eval_condensed_native.sh
index 0849e9378..176c767b6 100755
--- a/cookbook/exp/condenser/untested/eval_condensed_native.sh
+++ b/cookbook/exp/condenser/untested/eval_condensed_native.sh
@@ -3,13 +3,13 @@
# Compare against eval_condensed_compressed.sh on identical --dataset / --limit / --model_id.
set -euo pipefail
-DATASET="${DATASET:-/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl}"
-MODEL_ID="${MODEL_ID:-ms://Qwen/Qwen3.5-4B}"
-LIMIT="${LIMIT:-500}"
-NUM_GPUS="${NUM_GPUS:-4}"
-OUT_DIR="${OUT_DIR:-eval_out}"
+DATASET="/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl"
+MODEL_ID="ms://Qwen/Qwen3.5-4B"
+LIMIT="500"
+NUM_GPUS="4"
+OUT_DIR="eval_out"
-CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} \
+CUDA_VISIBLE_DEVICES=0,1,2,3 \
python cookbook/exp/eval_condensed.py \
--mode native \
--dataset_format musique \
diff --git a/cookbook/exp/embedding/build_thinking_rag_index.py b/cookbook/exp/embedding/build_thinking_rag_index.py
new file mode 100644
index 000000000..d228a597a
--- /dev/null
+++ b/cookbook/exp/embedding/build_thinking_rag_index.py
@@ -0,0 +1,935 @@
+"""Build a thinking-trace RAG index from condensed (query, cot) pairs.
+
+Pipeline (per row, batched):
+ 1. Load (user_query, reasoning_content) pairs from ``dataset_think.get_dataset``.
+ 2. Compress query with ``RAG_QUERY_HINT`` and cot with ``RAG_THINKING_HINT``
+ (a symmetric Problem/Skill/Knowledge schema defined in this file) using a
+ Twinkle ``vLLMSampler`` (TP=4 across GPUs 0-3). Reuses the system/user
+ wrappers from ``cookbook/exp/condenser/make_condenser_dataset.py``.
+ 3. On condenser truncation (``stop_reason='length'`` or skeleton-incomplete
+ output), fall back to an external OpenAI-compatible API.
+ 4. Encode the condensed pair via the trained embedding model — Twinkle
+ ``TransformersModel`` on the ``emb_model`` device group (DP=4 across GPUs
+ 4-7) using ``forward_only(task='embedding')``, the same code path as
+ training.
+ 5. Compute cosine similarity for each (query, thinking) pair, drop pairs with
+ ``sim < SIM_THRESHOLD``, and insert kept rows into LanceDB. The vector
+ column carries the **positive (compressed-skill)** embedding so a search
+ keyed by an anchor-encoded query retrieves the matching thinking trace.
+ 6. Each row stores the **raw thinking** alongside its embedding, so a hit
+ in the index can directly surface the original CoT.
+
+Eval mode (``--mode eval`` or ``--mode both``):
+ * Self-recall test — encode a sample of dataset queries (whose corresponding
+ rows are already in the index) as anchors and report recall@1/5/10 plus
+ a per-source breakdown.
+
+Architecture (8 GPUs):
+ * GPU 0-3: vLLM condenser (tensor-parallel, ``DeviceGroup name='sampler'``)
+ * GPU 4-7: TransformersModel embedding (data-parallel, ``DeviceGroup name='emb_model'``)
+ * Single ``twinkle.initialize(mode='ray', ...)`` call wires both groups.
+
+Launch examples:
+ python build_thinking_rag_index.py --mode build --total 500000
+ python build_thinking_rag_index.py --mode eval --eval-size 1000
+ python build_thinking_rag_index.py --mode both --total 200000 --eval-size 500
+"""
+import argparse
+import json
+import os
+import re
+import sys
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+
+# ---------------------------------------------------------------------------
+# Compress prompts — MUST match train_embedding_full_ddp.py exactly.
+# ---------------------------------------------------------------------------
+_HERE = Path(__file__).resolve().parent
+sys.path.insert(0, str(_HERE))
+
+COMPRESS_SYSTEM = """\
+You are a compression and summary assistant. For the (query, source) pair, emit a Markdown \
+answer with TWO sections, designed to pair with the `extract_compressed` tool: \
+the reader absorbs `## Summary` directly, then calls `extract_compressed` \
+on any topic-key listed under `## More` to recover its \
+fuller content.
+
+ `## Summary` \u2014 extreme-density text the reader reads directly.
+ `## More` \u2014 a topic index whose keys are valid arguments \
+to `extract_compressed` for recovering material not captured inline.
+
+Together the two sections must form a COMPLETE, NON-DISTORTING inventory of the \
+source for the query \u2014 nothing essential lost, nothing implied that the source \
+does not support. NO preamble, NO meta-commentary, NO code fences wrapping the \
+whole output.
+
+Output skeleton:
+
+## Summary
+Topic:
+
+
+## More
+- :
+- ...
+
+Format selection for the inline body (pick the MOST COMPACT form per query, mix \
+when helpful):
+- Interface / signature \u2192 code notation directly: `func(a:int)->str`
+- Factual / entity \u2192 telegraphic prose; drop function words; \":\" for \"is\", \",\" \
+for \"has\"
+- Skill / how-to / usage \u2192 lead with `Use when: `; numbered telegraphic \
+steps `1.do X 2.then Y`; close with `Output: ` when relevant
+- Procedural \u2192 numbered short steps
+- Analytical / design \u2192 hierarchical bullets with abbreviations
+
+`## Summary` rules:
+1. TOPIC LINE \u2014 line 1 is ALWAYS `Topic: `, even when the \
+query is narrow. Anchors both the reader and the tool.
+2. DENSITY \u2014 every token in the body carries query-relevant signal; cut filler.
+3. PRIMARY-COMPLETE \u2014 never silently drop a fact essential to answering the \
+query. Anything cut for length MUST appear as a key under \
+`## More`.
+4. NON-MISLEADING \u2014 phrasing must not let the reader infer anything the source \
+does not support; partial truths that mislead are worse than honest omissions \
+flagged in the index.
+5. SELF-CONTAINED \u2014 the reader can act on the answer without re-opening the source.
+6. FAITHFUL \u2014 only content the source supports; no fabrication, no extrapolation.
+7. LANGUAGE \u2014 match the source language.
+8. NO outer code fences around the whole answer; no meta-commentary.
+
+`## More` rules (MANDATORY \u2014 this section is never omitted):
+1. FORMAT \u2014 each bullet is `- : `:
+ \u2022 topic-key \u2014 short, unambiguous, grounded in source vocabulary so the \
+`extract_compressed` tool can locate the aspect (e.g. `decorators`, \
+`error handling`, `pitfalls`).
+ \u2022 hint \u2014 tells WHAT the reader gains by expanding (concrete numbers, code \
+listings, secondary cases, edge details, related context, \u2026); do NOT restate \
+the inline answer.
+2. CRITERION \u2014 each bullet names an aspect that EXISTS in the source but is \
+NOT fully captured inline. Material that genuinely fits inline without \
+distortion MUST NOT be duplicated here.
+3. FAITHFUL \u2014 hints must be grounded in the source; never speculate or invent.
+4. ORDER \u2014 by relevance to the query, then by importance.
+5. EMPTY CASE \u2014 if the source is so short / single-purpose that everything \
+fits inline, write a single line `- (none)`.
+
+Now begin.\
+"""
+
+COMPRESS_USER = (
+ 'Downstream model will read your compressed block to decide whether to '
+ 'expand it. Compress faithfully: preserve the passage topic + core facts. '
+ 'Do NOT invent facts. Do NOT drop major facts. Do NOT write meta-commentary '
+ 'about the Query (never write "Query info: absent", "no X mention", etc.); '
+ 'if the passage does not address the Query, still summarize the passage. '
+ 'CRITICAL LANGUAGE RULE: detect the dominant language of the Passage '
+ '(NOT the Query, NOT this instruction) and write the ENTIRE output in that '
+ 'same language; English passage \u2192 English output, Chinese passage \u2192 '
+ 'Chinese output, Japanese passage \u2192 Japanese output. NEVER translate, '
+ 'NEVER mix languages, NEVER copy these instructions into the output.\n\n'
+ '## Query (ordering hint only \u2014 still summarize the whole passage)\n{query}\n\n'
+ '## Passage\n{text}')
+
+# Default dataset loader is the index-time corpus (broader retrieval profile);
+# pass --dataset-module dataset_think to fall back to the training mix.
+from dataset_index import get_dataset as _default_get_dataset # noqa: E402
+
+_GET_DATASET = _default_get_dataset
+
+import twinkle # noqa: E402
+from twinkle import DeviceGroup, DeviceMesh, get_logger # noqa: E402
+from twinkle.data_format import SamplingParams as TwinkleSamplingParams # noqa: E402
+from twinkle.loss import InfonceLoss # noqa: E402
+from twinkle.model import TransformersModel # noqa: E402
+from twinkle.processor import InputProcessor # noqa: E402
+from twinkle.sampler import vLLMSampler # noqa: E402
+from twinkle.template import Qwen3_5Template # noqa: E402
+from twinkle.utils.parallel import PosixFileLock # noqa: E402
+from twinkle_agentic.protocol.openai import OpenAI as OpenAIClient # noqa: E402
+
+logger = get_logger()
+
+
+# ===========================================================================
+# Config (most fields overridable via CLI / env)
+# ===========================================================================
+
+EMBED_MODEL_ID = os.environ.get(
+ 'EMBED_MODEL_ID',
+ 'output/embedding_lora_transformers/step_4000',
+)
+CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2')
+
+# Twinkle device topology: TP=4 sampler on 0-3, DP=4 embedding on 4-7.
+SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
+EMB_GPUS = int(os.environ.get('EMB_GPUS', 4))
+NUM_GPUS = SAMPLER_GPUS + EMB_GPUS
+
+# vLLM engine sizing.
+CONDENSE_GPU_MEM = float(os.environ.get('CONDENSE_GPU_MEM', 0.85))
+CONDENSE_MAX_MODEL_LEN = int(os.environ.get('CONDENSE_MAX_MODEL_LEN', 32768))
+CONDENSE_MAX_TOKENS = int(os.environ.get('CONDENSE_MAX_TOKENS', 8192))
+COMPRESS_TEMPERATURE = float(os.environ.get('COMPRESS_TEMPERATURE', 0.2))
+COMPRESS_TOP_P = float(os.environ.get('COMPRESS_TOP_P', 0.5))
+
+# Embedding sizing.
+EMBED_MAX_LENGTH = int(os.environ.get('EMBED_MAX_LENGTH', 8192))
+
+SIM_THRESHOLD = float(os.environ.get('SIM_THRESHOLD', 0.65))
+MIN_TEXT_CHARS = int(os.environ.get('MIN_TEXT_CHARS', 256))
+
+# Hard-templated hints: the condenser SFT prior maps `Skill` to the legacy
+# `Use when: / numbered steps / Output:` skeleton on long inputs; embedding the
+# exact 4-line body template + explicit negative constraints is the only way to
+# override it deterministically across query and cot sides.
+RAG_QUERY_HINT = (
+ 'Summarize this query for retrieval. '
+ 'The body of ## Summary MUST follow this EXACT 4-line template — '
+ 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n'
+ 'Topic: \n'
+ 'Problem: \n'
+ 'Skill: \n'
+ 'Knowledge: \n'
+ 'Then emit the mandatory ## More section as usual. '
+ 'Topic must name the specific pattern, never generic labels.')
+RAG_THINKING_HINT = (
+ 'Summarize this reasoning trace for retrieval. '
+ 'The body of ## Summary MUST follow this EXACT 4-line template — '
+ 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n'
+ 'Topic: \n'
+ 'Problem: \n'
+ 'Skill: \n'
+ 'Knowledge: \n'
+ 'Then emit the mandatory ## More section as usual. '
+ 'Topic must name the specific pattern, never generic labels.')
+
+# OpenAI API fallback (used when vLLM truncates).
+COMPRESS_API_KEY = os.environ.get('COMPRESS_API_KEY', '')
+COMPRESS_BASE_URL = os.environ.get(
+ 'COMPRESS_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1')
+COMPRESS_API_MODEL = os.environ.get('COMPRESS_API_MODEL', 'qwen3.7-max')
+
+# Source → coarse domain (for filtered eval).
+DOMAIN_MAP = {
+ 'CodeX-2M-Thinking': 'code',
+ 'OpenThoughts3-1.2M': 'reasoning',
+ 'LIMO-v2': 'math',
+ 'Chinese-DeepSeek-R1-Distill-data-110k': 'reasoning_zh',
+ 'Opus-4.6-Reasoning-3000x-filtered': 'reasoning',
+ 'claude-opus-4.6-10000x': 'mixed',
+ 'angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k': 'mixed',
+}
+
+
+# ===========================================================================
+# Small helpers
+# ===========================================================================
+
+_LEGACY_USE_WHEN_RE = re.compile(r'(?im)^\s*Use when\s*:')
+_SCHEMA_MARKERS = ('Problem:', 'Skill:', 'Knowledge:')
+
+
+def _is_truncated_compression(text: str) -> bool:
+ """Reject structurally incomplete OR schema-regressed condenser output.
+
+ Triggers API fallback when the vLLM output:
+ * lacks ``## Summary`` / ``## More``,
+ * has an empty or unterminated ``## More`` bullet list, or
+ * regresses to the legacy ``Use when: / numbered-steps / Output:`` skeleton
+ instead of the mandated Problem/Skill/Knowledge 4-line body — the
+ dominant cot-side failure mode that drives sim < 0.45 drops.
+ """
+ if not text or not text.strip():
+ return True
+ if '## More' not in text or '## Summary' not in text:
+ return True
+ after_more = text.split('## More', 1)[1].strip()
+ if not after_more:
+ return True
+ last_line = after_more.splitlines()[-1].strip()
+ if not (last_line.startswith('-') or last_line.endswith(')')):
+ return True
+ summary_body = text.split('## Summary', 1)[1].split('## More', 1)[0]
+ if _LEGACY_USE_WHEN_RE.search(summary_body):
+ return True
+ if not all(marker in summary_body for marker in _SCHEMA_MARKERS):
+ return True
+ return False
+
+
+def _strip_outer_codefence(text: str) -> str:
+ m = re.match(r'^```[a-zA-Z]*\n(.*?)\n```\s*$', text, re.DOTALL)
+ if m:
+ return m.group(1).strip()
+ return text.strip()
+
+
+def _wrap_anchor(text: str) -> List[Dict[str, str]]:
+ """Anchor-side message wrapping (must match training)."""
+ return [
+ {'role': 'user', 'content': text},
+ {'role': 'assistant', 'content': 'Match the correct response here.'},
+ ]
+
+
+def _wrap_positive(text: str) -> List[Dict[str, str]]:
+ """Positive-side message wrapping (must match training)."""
+ return [
+ {'role': 'user', 'content': 'Match the correct query here.'},
+ {'role': 'assistant', 'content': text},
+ ]
+
+
+def _short(text: str, n: int = 96) -> str:
+ text = (text or '').replace('\n', ' ').strip()
+ return text[:n] + ('…' if len(text) > n else '')
+
+
+def _detect_lang(text: str) -> str:
+ if not text:
+ return 'unknown'
+ cjk = sum(1 for ch in text[:512] if '\u4e00' <= ch <= '\u9fff')
+ return 'zh' if cjk >= 8 else 'en'
+
+
+def _build_compress_messages(text: str, query: str) -> List[Dict[str, str]]:
+ return [
+ {'role': 'system', 'content': COMPRESS_SYSTEM},
+ {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)},
+ ]
+
+
+# ===========================================================================
+# Twinkle component wrappers
+# ===========================================================================
+
+def initialize_twinkle() -> Tuple[DeviceMesh, DeviceMesh]:
+ """Wire two device groups (sampler / emb_model) and return their meshes."""
+ device_groups = [
+ DeviceGroup(
+ name='sampler',
+ ranks=list(range(SAMPLER_GPUS)),
+ device_type='GPU',
+ gpus_per_worker=SAMPLER_GPUS, # TP=4 → one worker spans all 4 GPUs
+ ),
+ DeviceGroup(
+ name='emb_model',
+ ranks=list(range(SAMPLER_GPUS, NUM_GPUS)),
+ device_type='GPU',
+ ),
+ ]
+ sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, tp_size=SAMPLER_GPUS)
+ emb_mesh = DeviceMesh.from_sizes(world_size=EMB_GPUS, dp_size=EMB_GPUS)
+ twinkle.initialize(
+ mode='ray',
+ nproc_per_node=NUM_GPUS,
+ groups=device_groups,
+ lazy_collect=False,
+ )
+ return sampler_mesh, emb_mesh
+
+
+def build_sampler(sampler_mesh: DeviceMesh) -> vLLMSampler:
+ sampler = vLLMSampler(
+ model_id=CONDENSE_MODEL_ID,
+ engine_args={
+ 'gpu_memory_utilization': CONDENSE_GPU_MEM,
+ 'max_model_len': CONDENSE_MAX_MODEL_LEN,
+ },
+ device_mesh=sampler_mesh,
+ remote_group='sampler',
+ )
+ sampler.set_template(
+ 'Qwen3_5Template',
+ model_id=CONDENSE_MODEL_ID,
+ enable_thinking=False,
+ max_length=CONDENSE_MAX_MODEL_LEN,
+ )
+ return sampler
+
+
+def build_emb_model(emb_mesh: DeviceMesh) -> Tuple[TransformersModel, Qwen3_5Template]:
+ model = TransformersModel(
+ model_id=EMBED_MODEL_ID,
+ device_mesh=emb_mesh,
+ remote_group='emb_model',
+ )
+ model.set_processor(InputProcessor)
+ # InfonceLoss is required by the framework even though forward_only does
+ # not actually invoke it; matches the training-time configuration.
+ model.set_loss(InfonceLoss, temperature=0.03, use_batch=True)
+ # Qwen3.5-specific subclass applies orphan- chat-template patches.
+ template = Qwen3_5Template(
+ model_id=EMBED_MODEL_ID,
+ max_length=EMBED_MAX_LENGTH,
+ truncation_strategy='delete',
+ enable_thinking=False,
+ )
+ return model, template
+
+
+# ===========================================================================
+# Compression helpers (vLLMSampler) + API fallback
+# ===========================================================================
+
+def _vllm_compress(sampler: vLLMSampler, texts: List[str], query_hint: str
+ ) -> List[Tuple[str, str]]:
+ """Compress ``texts`` via the sampler; return ``(decoded, stop_reason)``."""
+ if not texts:
+ return []
+ prompts = [{'messages': _build_compress_messages(t, query_hint)} for t in texts]
+ params = TwinkleSamplingParams(
+ max_tokens=CONDENSE_MAX_TOKENS,
+ temperature=COMPRESS_TEMPERATURE,
+ top_p=COMPRESS_TOP_P,
+ num_samples=1,
+ )
+ responses = sampler.sample(prompts, params)
+ results: List[Tuple[str, str]] = []
+ for resp in responses:
+ seq = resp.sequences[0] if resp and resp.sequences else None
+ if seq is None:
+ results.append(('', 'error'))
+ continue
+ text = seq.decoded or ''
+ # Strip any leaked chat-template special tokens like ``<|im_end|>``.
+ text = re.sub(r'<\|[^|]+\|>', '', text).rstrip()
+ text = _strip_outer_codefence(text)
+ results.append((text, seq.stop_reason or 'stop'))
+ return results
+
+
+def _api_compress(api: OpenAIClient, messages: List[Dict[str, str]]) -> Optional[str]:
+ sp = TwinkleSamplingParams(temperature=COMPRESS_TEMPERATURE, max_tokens=CONDENSE_MAX_TOKENS)
+ try:
+ reply = api({'messages': messages}, sp, extra_body={'enable_thinking': False})
+ except Exception as exc: # noqa: BLE001 — broad catch is intentional
+ sys.stderr.write(f'[api_fallback] error: {exc}\n')
+ return None
+ content = (reply.get('content') or '').strip()
+ if not content:
+ return None
+ return _strip_outer_codefence(content)
+
+
+def _resolve_compressed(sampler: vLLMSampler, api: Optional[OpenAIClient],
+ texts: List[str], query_hint: str) -> List[Optional[str]]:
+ """Run vLLM batch; replace truncations / skeleton-incomplete with API output."""
+ pairs = _vllm_compress(sampler, texts, query_hint)
+ results: List[Optional[str]] = []
+ for (text, stop), src_text in zip(pairs, texts):
+ if stop != 'length' and not _is_truncated_compression(text):
+ results.append(text)
+ continue
+ if api is None:
+ results.append(None)
+ continue
+ api_text = _api_compress(api, _build_compress_messages(src_text, query_hint))
+ if api_text is None or _is_truncated_compression(api_text):
+ results.append(None)
+ else:
+ results.append(api_text)
+ return results
+
+
+# ===========================================================================
+# Embedding helpers (TransformersModel.forward_only(task='embedding'))
+# ===========================================================================
+
+def _build_features(template: Qwen3_5Template, texts: List[str], role: str
+ ) -> List[Dict[str, Any]]:
+ """Wrap each text into the role-specific anchor / positive feature dict."""
+ features: List[Dict[str, Any]] = []
+ for text in texts:
+ if not text or not text.strip():
+ # Pad with a single space so positional alignment holds against
+ # the input list — the caller filters out empty-text rows upstream.
+ text = ' '
+ if role == 'anchor':
+ feat = template.encode({'messages': _wrap_anchor(text)})
+ feat['labels'] = [1]
+ else:
+ feat = template.encode({'messages': _wrap_positive(text)})
+ feat['labels'] = [0]
+ features.append(feat)
+ return features
+
+
+def get_embeddings(model: TransformersModel, template: Qwen3_5Template,
+ texts: List[str], role: str) -> np.ndarray:
+ """Return ``[N, H]`` float32 L2-normalised embeddings for ``texts``.
+
+ Inputs are padded up to a multiple of ``EMB_GPUS`` and sliced back to the
+ original ``N``: the dispatch layer (``_dispatch_args``) starves any rank
+ whose chunk lands beyond ``len(texts)``, so a single forward of fewer than
+ ``EMB_GPUS`` items (e.g. the probe) would otherwise raise
+ ``Batch too small for {EMB_GPUS} workers``.
+ """
+ if not texts:
+ return np.zeros((0,), dtype=np.float32)
+ n = len(texts)
+ pad_n = (-n) % EMB_GPUS
+ padded = list(texts) + [' '] * pad_n if pad_n else list(texts)
+ features = _build_features(template, padded, role)
+ out = model.forward_only(inputs=features, task='embedding', return_logits=True)
+ emb = out['embeddings']
+ if isinstance(emb, torch.Tensor):
+ emb = emb.detach().to(torch.float32).cpu().numpy()
+ emb = np.asarray(emb, dtype=np.float32)
+ return emb[:n] if pad_n else emb
+
+
+def _probe_hidden_size(model: TransformersModel, template: Qwen3_5Template) -> int:
+ """One-shot warmup forward to read out the embedding dimension."""
+ emb = get_embeddings(model, template, ['probe'], role='anchor')
+ if emb.ndim != 2 or emb.shape[0] == 0:
+ raise RuntimeError(f'unexpected embedding shape from probe: {emb.shape}')
+ return int(emb.shape[1])
+
+
+# ===========================================================================
+# LanceDB I/O
+# ===========================================================================
+
+def _make_arrow_schema(hidden_size: int):
+ import pyarrow as pa
+ return pa.schema([
+ pa.field('id', pa.string()),
+ pa.field('vector', pa.list_(pa.float32(), hidden_size)),
+ pa.field('thinking_raw', pa.string()),
+ pa.field('query_raw', pa.string()),
+ pa.field('cot_compressed', pa.string()),
+ pa.field('query_compressed', pa.string()),
+ pa.field('source', pa.string()),
+ pa.field('domain', pa.string()),
+ pa.field('language', pa.string()),
+ pa.field('sim', pa.float32()),
+ ])
+
+
+def _open_or_create_table(db_path: str, table_name: str, hidden_size: int,
+ mode: str):
+ """Open an existing table for append/eval, or create a fresh one."""
+ import lancedb
+ db = lancedb.connect(db_path)
+ schema = _make_arrow_schema(hidden_size)
+ if table_name in db.table_names():
+ if mode == 'overwrite':
+ db.drop_table(table_name)
+ tbl = db.create_table(table_name, schema=schema, mode='overwrite')
+ else:
+ tbl = db.open_table(table_name)
+ else:
+ tbl = db.create_table(table_name, schema=schema, mode='create')
+ return db, tbl
+
+
+def _existing_ids(table) -> set:
+ try:
+ col = table.to_pandas(columns=['id'])
+ return set(col['id'].astype(str).tolist())
+ except Exception: # noqa: BLE001
+ return set()
+
+
+# ===========================================================================
+# Build pipeline
+# ===========================================================================
+
+def _stream_corpus(total: Optional[int], load_from_cache_file: bool,
+ max_rows: int = 0) -> Iterator[Dict[str, Any]]:
+ ds = _GET_DATASET(total=total, load_from_cache_file=load_from_cache_file)
+ n_full = len(ds)
+ cap = max_rows if (max_rows and max_rows < n_full) else n_full
+ sys.stderr.write(f'[corpus] get_dataset: {n_full} rows'
+ + (f' → yielding first {cap}\n' if cap < n_full else '\n'))
+ for i, row in enumerate(ds):
+ if i >= cap:
+ break
+ yield row
+
+
+def _extract_query_cot(row: Dict[str, Any]) -> Tuple[str, str]:
+ user_query, cot = '', ''
+ for m in row.get('messages') or []:
+ if not isinstance(m, dict):
+ continue
+ role = m.get('role') or ''
+ if role == 'user' and not user_query:
+ user_query = (m.get('content') or '').strip()
+ elif role == 'assistant':
+ cot = (m.get('reasoning_content') or '').strip()
+ break
+ return user_query, cot
+
+
+def _log_miss(misses_path: str, lock: PosixFileLock, record: Dict[str, Any]) -> None:
+ line = json.dumps(record, ensure_ascii=False, default=str) + '\n'
+ with lock:
+ with open(misses_path, 'a', encoding='utf-8') as fh:
+ fh.write(line)
+
+
+def build_index(args: argparse.Namespace,
+ sampler: vLLMSampler,
+ emb_model: TransformersModel,
+ emb_template: Qwen3_5Template,
+ api: Optional[OpenAIClient]) -> None:
+ # ---- Probe embedding dimension -----------------------------------------
+ sys.stderr.write('[build] probing embedding hidden size...\n')
+ hidden_size = _probe_hidden_size(emb_model, emb_template)
+ sys.stderr.write(f'[build] hidden_size={hidden_size}\n')
+
+ # ---- LanceDB ------------------------------------------------------------
+ db, tbl = _open_or_create_table(
+ args.db_path, args.table, hidden_size,
+ mode='overwrite' if args.overwrite else 'append',
+ )
+ indexed = _existing_ids(tbl) if not args.overwrite else set()
+ sys.stderr.write(f'[build] table "{args.table}" — {len(indexed)} existing rows.\n')
+
+ misses_path = args.misses_log or (str(Path(args.db_path) / f'{args.table}.misses.jsonl'))
+ Path(misses_path).parent.mkdir(parents=True, exist_ok=True)
+ misses_lock = PosixFileLock(misses_path + '.lock')
+
+ # ---- Streaming loop -----------------------------------------------------
+ n_seen = n_kept = n_dropped_short = n_dropped_compress = n_dropped_sim = 0
+ n_dropped_dup = 0
+ pbar = tqdm(desc='index', unit='row', dynamic_ncols=True)
+
+ batch: List[Dict[str, Any]] = []
+
+ def _flush(rows: List[Dict[str, Any]]) -> None:
+ nonlocal n_kept, n_dropped_compress, n_dropped_sim
+ if not rows:
+ return
+ # Phase 1 — compress query (RAG_QUERY_HINT) and cot (RAG_THINKING_HINT).
+ # Short queries bypass condenser (passthrough) — matches training behaviour.
+ long_q_indices = [i for i, r in enumerate(rows) if len(r['query_raw']) >= MIN_TEXT_CHARS]
+ q_compressed: List[Optional[str]] = [None] * len(rows)
+ for i, r in enumerate(rows):
+ if len(r['query_raw']) < MIN_TEXT_CHARS:
+ q_compressed[i] = r['query_raw']
+ if long_q_indices:
+ long_results = _resolve_compressed(
+ sampler, api, [rows[i]['query_raw'] for i in long_q_indices], RAG_QUERY_HINT)
+ for idx, res in zip(long_q_indices, long_results):
+ q_compressed[idx] = res
+ c_compressed = _resolve_compressed(
+ sampler, api, [r['cot_raw'] for r in rows], RAG_THINKING_HINT)
+ kept_rows: List[Dict[str, Any]] = []
+ for r, q_cmp, c_cmp in zip(rows, q_compressed, c_compressed):
+ if not q_cmp or not c_cmp:
+ n_dropped_compress += 1
+ _log_miss(misses_path, misses_lock, {
+ 'id': r['id'], 'source': r['source'], 'reason': 'compress_fail',
+ 'query_raw_head': _short(r['query_raw'], 200),
+ 'cot_raw_head': _short(r['cot_raw'], 200),
+ })
+ continue
+ r['query_compressed'] = q_cmp
+ r['cot_compressed'] = c_cmp
+ kept_rows.append(r)
+ if not kept_rows:
+ return
+ # Phase 2 — encode anchor (compressed query) + positive (compressed cot).
+ anchor_emb = get_embeddings(
+ emb_model, emb_template, [r['query_compressed'] for r in kept_rows], role='anchor')
+ positive_emb = get_embeddings(
+ emb_model, emb_template, [r['cot_compressed'] for r in kept_rows], role='positive')
+ sims = (anchor_emb * positive_emb).sum(axis=1).astype(np.float32)
+ # Phase 3 — sim filter + LanceDB insert.
+ to_insert: List[Dict[str, Any]] = []
+ for idx, (r, sim_val) in enumerate(zip(kept_rows, sims)):
+ tag = 'KEEP' if sim_val >= SIM_THRESHOLD else 'DROP'
+ print(f'[{tag} sim={sim_val:.4f}] {r["source"][:24]} '
+ f'q={_short(r["query_raw"], 60)!r} '
+ f'cot={_short(r["cot_raw"], 60)!r}', flush=True)
+ if sim_val < SIM_THRESHOLD:
+ n_dropped_sim += 1
+ _log_miss(misses_path, misses_lock, {
+ 'id': r['id'], 'source': r['source'], 'reason': 'sim_low',
+ 'sim': float(sim_val),
+ 'query_raw': r['query_raw'],
+ 'cot_raw': r['cot_raw'],
+ 'query_compressed': r['query_compressed'],
+ 'cot_compressed': r['cot_compressed'],
+ })
+ continue
+ to_insert.append({
+ 'id': r['id'],
+ 'vector': positive_emb[idx].tolist(),
+ 'thinking_raw': r['cot_raw'],
+ 'query_raw': r['query_raw'],
+ 'cot_compressed': r['cot_compressed'],
+ 'query_compressed': r['query_compressed'],
+ 'source': r['source'],
+ 'domain': DOMAIN_MAP.get(r['source'], 'mixed'),
+ 'language': _detect_lang(r['cot_raw']),
+ 'sim': float(sim_val),
+ })
+ if to_insert:
+ tbl.add(to_insert)
+ n_kept += len(to_insert)
+ indexed.update(r['id'] for r in to_insert)
+
+ try:
+ for row in _stream_corpus(total=args.total, load_from_cache_file=not args.no_cache,
+ max_rows=args.max_rows):
+ n_seen += 1
+ if args.limit and n_kept >= args.limit:
+ break
+ rid = row.get('id') or ''
+ if not rid:
+ continue
+ if rid in indexed:
+ n_dropped_dup += 1
+ continue
+ user_query, cot = _extract_query_cot(row)
+ if not user_query or len(cot) < MIN_TEXT_CHARS:
+ n_dropped_short += 1
+ continue
+ batch.append({
+ 'id': rid,
+ 'source': row.get('source') or 'unknown',
+ 'query_raw': user_query,
+ 'cot_raw': cot,
+ })
+ if len(batch) >= args.batch_size:
+ _flush(batch)
+ batch.clear()
+ pbar.set_postfix(kept=n_kept, sim_drop=n_dropped_sim,
+ cmp_drop=n_dropped_compress, refresh=False)
+ pbar.update(1)
+ if batch:
+ _flush(batch)
+ batch.clear()
+ finally:
+ pbar.close()
+
+ sys.stderr.write(
+ f'[build] seen={n_seen} kept={n_kept} sim_drop={n_dropped_sim} '
+ f'cmp_drop={n_dropped_compress} short_drop={n_dropped_short} '
+ f'dup_skip={n_dropped_dup}\n')
+
+ # ---- Build vector index for fast retrieval ------------------------------
+ if n_kept >= 64 and not args.skip_index:
+ sys.stderr.write('[build] creating IVF_PQ index (metric=dot)...\n')
+ n_partitions = max(8, min(256, n_kept // 1000 + 1))
+ try:
+ tbl.create_index(
+ metric='dot',
+ vector_column_name='vector',
+ num_partitions=n_partitions,
+ num_sub_vectors=16,
+ index_type='IVF_PQ',
+ replace=True,
+ )
+ except Exception as exc: # noqa: BLE001
+ sys.stderr.write(f'[build] index build failed: {exc} '
+ '(table is still queryable via brute-force scan)\n')
+ sys.stderr.write(f'[build] done. table rows={tbl.count_rows()}\n')
+
+
+# ===========================================================================
+# Eval pipeline (self-recall on indexed rows)
+# ===========================================================================
+
+def eval_recall(args: argparse.Namespace,
+ sampler: vLLMSampler,
+ emb_model: TransformersModel,
+ emb_template: Qwen3_5Template,
+ api: Optional[OpenAIClient]) -> None:
+ """Probe each gold query against the index; report recall@k.
+
+ Self-recall semantics: only rows whose ``id`` is already present in the
+ index are probed. The corresponding ``cot``-keyed vector must be retrieved
+ by encoding the **raw user query** through the condenser → embedder
+ pipeline (anchor side). The match is correct iff the retrieved row's
+ ``id`` equals the probe row's ``id``.
+ """
+ import lancedb
+ db = lancedb.connect(args.db_path)
+ if args.table not in db.table_names():
+ raise SystemExit(f'[eval] table "{args.table}" does not exist in {args.db_path}')
+ tbl = db.open_table(args.table)
+ indexed_ids = _existing_ids(tbl)
+ sys.stderr.write(f'[eval] table rows={tbl.count_rows()} indexed_ids={len(indexed_ids)}\n')
+ if not indexed_ids:
+ sys.stderr.write('[eval] empty index — nothing to evaluate.\n')
+ return
+
+ ks = sorted({1, 5, 10, args.top_k})
+ hits = {k: 0 for k in ks}
+ per_source_hits: Dict[str, Dict[int, int]] = {}
+ per_source_total: Dict[str, int] = {}
+ probed = 0
+
+ pbar = tqdm(desc='eval', unit='probe', dynamic_ncols=True)
+ batch_rows: List[Dict[str, Any]] = []
+
+ def _flush(rows: List[Dict[str, Any]]) -> None:
+ nonlocal probed
+ if not rows:
+ return
+ compressed = _resolve_compressed(
+ sampler, api, [r['query_raw'] for r in rows], RAG_QUERY_HINT)
+ useful = [(r, c) for r, c in zip(rows, compressed) if c]
+ if not useful:
+ return
+ anchor_emb = get_embeddings(
+ emb_model, emb_template, [c for _, c in useful], role='anchor')
+ for (r, _), vec in zip(useful, anchor_emb):
+ res = (
+ tbl.search(vec.astype(np.float32).tolist())
+ .metric('dot')
+ .limit(max(ks))
+ .select(['id', 'source'])
+ .to_list()
+ )
+ hit_ids = [item['id'] for item in res]
+ try:
+ rank = hit_ids.index(r['id'])
+ except ValueError:
+ rank = -1
+ for k in ks:
+ if 0 <= rank < k:
+ hits[k] += 1
+ per_source_hits.setdefault(r['source'], {kk: 0 for kk in ks})[k] += 1
+ per_source_total[r['source']] = per_source_total.get(r['source'], 0) + 1
+ per_source_hits.setdefault(r['source'], {kk: 0 for kk in ks})
+ probed += 1
+ pbar.update(len(useful))
+
+ try:
+ for row in _stream_corpus(total=args.total, load_from_cache_file=not args.no_cache,
+ max_rows=args.max_rows):
+ if probed + len(batch_rows) >= args.eval_size:
+ break
+ rid = row.get('id') or ''
+ if not rid or rid not in indexed_ids:
+ continue
+ user_query, _ = _extract_query_cot(row)
+ if not user_query or len(user_query) < MIN_TEXT_CHARS:
+ continue
+ batch_rows.append({
+ 'id': rid,
+ 'source': row.get('source') or 'unknown',
+ 'query_raw': user_query,
+ })
+ if len(batch_rows) >= args.batch_size:
+ _flush(batch_rows)
+ batch_rows.clear()
+ if batch_rows:
+ _flush(batch_rows)
+ finally:
+ pbar.close()
+
+ if probed == 0:
+ sys.stderr.write(
+ '[eval] no probed rows — index empty, queries too short, or '
+ 'corpus exhausted before eval-size?\n')
+ return
+
+ print('\n=== Recall @ k (self-recall, gold present in index) ===')
+ print(f'probed = {probed}')
+ for k in ks:
+ print(f' recall@{k:<3} = {hits[k]/probed:.4f} ({hits[k]}/{probed})')
+
+ print('\n=== Per-source recall@10 ===')
+ for src in sorted(per_source_total):
+ tot = per_source_total[src]
+ h10 = per_source_hits.get(src, {}).get(10, 0)
+ print(f' {src:<48s} {h10/tot:.4f} ({h10}/{tot})')
+
+
+# ===========================================================================
+# CLI
+# ===========================================================================
+
+def parse_args() -> argparse.Namespace:
+ p = argparse.ArgumentParser(description=__doc__,
+ formatter_class=argparse.RawDescriptionHelpFormatter)
+ p.add_argument('--mode', choices=['build', 'eval', 'both'], default='build')
+ p.add_argument('--db-path', default='./output/thinking_rag/lance.db',
+ help='LanceDB on-disk directory (persisted across runs).')
+ p.add_argument('--table', default='thinking_traces',
+ help='LanceDB table name within --db-path.')
+ p.add_argument('--total', type=int, default=0,
+ help='Total dataset rows to scale corpus to (0 = base sizes from the loader module).')
+ p.add_argument('--dataset-module', default='dataset_index',
+ choices=['dataset_index', 'dataset_think'],
+ help='Which loader to use: dataset_index (RAG profile) or '
+ 'dataset_think (training mix).')
+ p.add_argument('--limit', type=int, default=0,
+ help='Stop building once this many rows are kept (0 = no cap).')
+ p.add_argument('--max-rows', type=int, default=0,
+ help='Truncate corpus to this many rows AFTER get_dataset (0 = no cap). '
+ 'Use this instead of --total to avoid invalidating the dataset cache.')
+ p.add_argument('--batch-size', type=int, default=64,
+ help='Rows per condense+encode batch.')
+ p.add_argument('--no-cache', action='store_true',
+ help='Disable load_from_cache_file in dataset_think.get_dataset.')
+ p.add_argument('--overwrite', action='store_true',
+ help='Drop the table before build and start fresh.')
+ p.add_argument('--skip-index', action='store_true',
+ help='Skip IVF_PQ index build at the end (debug).')
+ p.add_argument('--misses-log', default='',
+ help='Path for filtered-row JSONL log (defaults to /.misses.jsonl).')
+
+ # eval-only
+ p.add_argument('--eval-size', type=int, default=500,
+ help='Number of probes for self-recall evaluation.')
+ p.add_argument('--top-k', type=int, default=10,
+ help='Largest k to report. Smaller ks (1, 5) are always reported.')
+
+ return p.parse_args()
+
+
+def main() -> None:
+ args = parse_args()
+ Path(args.db_path).mkdir(parents=True, exist_ok=True)
+
+ global _GET_DATASET
+ if args.dataset_module == 'dataset_think':
+ from dataset_think import get_dataset as _swap
+ _GET_DATASET = _swap
+ sys.stderr.write(f'[main] dataset loader: {args.dataset_module}\n')
+
+ # Build/eval both depend on the same Twinkle stack — initialize once.
+ sampler_mesh, emb_mesh = initialize_twinkle()
+ sys.stderr.write(f'[main] twinkle initialized: '
+ f'sampler ranks 0-{SAMPLER_GPUS - 1} (TP={SAMPLER_GPUS}), '
+ f'emb_model ranks {SAMPLER_GPUS}-{NUM_GPUS - 1} (DP={EMB_GPUS}).\n')
+
+ sys.stderr.write('[main] starting vLLM condenser sampler...\n')
+ sampler = build_sampler(sampler_mesh)
+ sys.stderr.write('[main] starting embedding TransformersModel...\n')
+ emb_model, emb_template = build_emb_model(emb_mesh)
+
+ api: Optional[OpenAIClient] = None
+ if COMPRESS_API_KEY:
+ api = OpenAIClient(
+ model=COMPRESS_API_MODEL,
+ api_key=COMPRESS_API_KEY,
+ base_url=COMPRESS_BASE_URL,
+ )
+ else:
+ sys.stderr.write(
+ '[main] WARNING: COMPRESS_API_KEY unset — truncated rows will be dropped.\n')
+
+ if args.mode in ('build', 'both'):
+ build_index(args, sampler, emb_model, emb_template, api)
+ if args.mode in ('eval', 'both'):
+ eval_recall(args, sampler, emb_model, emb_template, api)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/exp/embedding/dataset_index.py b/cookbook/exp/embedding/dataset_index.py
new file mode 100644
index 000000000..7d2905a59
--- /dev/null
+++ b/cookbook/exp/embedding/dataset_index.py
@@ -0,0 +1,718 @@
+"""RAG-index corpus loader — abstract reasoning skills + textbook-style methods.
+
+Distinct from training-time ``dataset_think.py``. Optimizes for **abstraction
+density**, not raw coverage: every row should encode a transferable method,
+theorem, or solution pattern that downstream queries can retrieve as a
+"use-when-X-do-Y" recipe.
+
+Single-table design (``thinking_traces``); EMBED_QUERY_COT condense step in
+``build_thinking_rag_index`` homogenizes thinking-style and textbook-style
+content into the same retrieval form, so dual-table is unnecessary. The
+``source`` field carries the original dataset name for eval-time
+domain-bucket diagnostics.
+
+Output schema matches ``dataset_think.get_dataset()``: ``{id, source, messages}``
+with ``messages[1].reasoning_content`` carrying the CoT.
+
+Mix (≈3.6M rows base, 10 datasets):
+ Math thinking 23% — OpenMathReasoning + OpenR1-Math-220k + s1K-1.1
+ Code thinking 19% — OpenCodeReasoning-2 + codeforces-cots
+ Cross-domain R1 39% — Bespoke-Stratos + dolphin-r1 + reasoning-v1-20m
+ + natural_reasoning
+ Textbook synth 17% — cosmopedia v1 (auto_math_text, chunked by H2)
+ Olympiad solutions <1% — Omni-MATH
+
+Dropped: camel-ai/{physics,chemistry,biology} (zip-only, no parquet/jsonl) and
+swift/stack-exchange-paired (dataset_infos.json/data layout mismatch); the
+textbook-density gap is covered by a larger cosmopedia slice.
+
+Textbook processors synthesize a question from the chapter heading and place
+the explanatory body into the ``cot`` field — embedding+condense reads
+``query | cot`` so the textbook prose becomes a retrievable method.
+
+Field extraction is defensive: each processor tries multiple plausible column
+names and silently drops rows that miss a usable signal. Inspect
+``dropped_index.jsonl`` after the first run to verify field-name guesses.
+"""
+import re
+from typing import Any, Dict, List, Optional
+
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import Preprocessor
+
+from dataset_think import _THINK_RE, _hash_id, _register, ToMessagesProcessor
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+# Sky-T1 / Bespoke-Stratos custom markers (used in place of ).
+_BOT_RE = re.compile(
+ r'<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>', re.DOTALL)
+_BOS_RE = re.compile(
+ r'<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>', re.DOTALL)
+
+# H2 heading split for cosmopedia-style markdown chunks.
+_H2_RE = re.compile(r'^##\s+(.+?)\s*$', re.MULTILINE)
+
+
+def _split_think(text: str) -> tuple:
+ """Return ``(cot, response)``; cot empty if no ```` block found."""
+ if not text:
+ return '', ''
+ m = _THINK_RE.search(text)
+ if not m:
+ return '', text.strip()
+ return m.group(1).strip(), text[m.end():].strip()
+
+
+def _split_sky_t1(text: str) -> tuple:
+ """Return ``(cot, response)`` for Sky-T1 / Bespoke-Stratos marker format."""
+ if not text:
+ return '', ''
+ bot = _BOT_RE.search(text)
+ bos = _BOS_RE.search(text)
+ cot = bot.group(1).strip() if bot else ''
+ sol = bos.group(1).strip() if bos else ''
+ return cot, sol
+
+
+def _from_messages(messages: Any) -> tuple:
+ """Pull (first_user, first_assistant) from OpenAI/ShareGPT-style list."""
+ if not isinstance(messages, list):
+ return '', ''
+ query, assistant = '', ''
+ for msg in messages:
+ if not isinstance(msg, dict):
+ continue
+ role = msg.get('role') or msg.get('from') or ''
+ content = msg.get('content') or msg.get('value') or ''
+ if not isinstance(content, str):
+ continue
+ if role in ('user', 'human') and not query:
+ query = content.strip()
+ elif role in ('assistant', 'gpt') and not assistant:
+ assistant = content.strip()
+ break
+ return query, assistant
+
+
+def _chunk_by_h2(text: str, min_chars: int = 200, max_chars: int = 6000):
+ """Split markdown text on ``## `` headings; yield ``(title, body)`` pairs."""
+ if not text:
+ return
+ matches = list(_H2_RE.finditer(text))
+ if not matches:
+ head = text.strip()[:80].splitlines()[0] if text.strip() else ''
+ body = text.strip()
+ if head and min_chars <= len(body) <= max_chars:
+ yield head, body
+ return
+ for i, m in enumerate(matches):
+ title = m.group(1).strip()
+ start = m.end()
+ end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
+ body = text[start:end].strip()
+ if min_chars <= len(body) <= max_chars and title:
+ yield title, body
+
+
+# ===========================================================================
+# Math thinking
+# ===========================================================================
+
+OPEN_MATH_REASONING_REPO = 'ms://AI-ModelScope/OpenMathReasoning'
+
+
+class OpenMathReasoningProcessor(Preprocessor):
+ """OpenMathReasoning → ``{id, source, query, cot, response}``.
+
+ Schema: ``problem``, ``generated_solution`` (R1 trace with ````),
+ ``expected_answer``. The ``cot`` *split* (not column) is the long-CoT
+ portion — TIR/genselect/additional_problems sit in sibling splits and
+ are filtered at load time, not row-level.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('problem') or row.get('question') or '').strip()
+ assistant = (row.get('generated_solution') or row.get('solution')
+ or row.get('output') or '').strip()
+ if not query or not assistant:
+ continue
+ cot, response = _split_think(assistant)
+ if not cot:
+ continue
+ if not response:
+ response = (row.get('expected_answer') or row.get('answer') or '').strip()
+ if not response:
+ continue
+ out.append({
+ 'id': _hash_id('open_math_reasoning', f'{query}\n{response}'),
+ 'source': 'OpenMathReasoning',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+OPEN_R1_MATH_REPO = 'ms://open-r1/OpenR1-Math-220k'
+
+
+class OpenR1MathProcessor(Preprocessor):
+ """OpenR1-Math-220k → ``{id, source, query, cot, response}``.
+
+ Schema: ``problem``, ``solution``, ``answer``, ``generations`` (list of
+ R1 traces), ``correctness_math_verify`` (parallel bool list). Pick the
+ first generation whose math-verify passed; fall back to ``solution``.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('problem') or row.get('question') or '').strip()
+ if not query:
+ continue
+ assistant = ''
+ gens = row.get('generations')
+ verifies = row.get('correctness_math_verify')
+ if isinstance(gens, list):
+ if isinstance(verifies, list) and len(verifies) == len(gens):
+ for g, v in zip(gens, verifies):
+ if v and isinstance(g, str) and g.strip():
+ assistant = g.strip()
+ break
+ if not assistant:
+ for g in gens:
+ if isinstance(g, str) and g.strip():
+ assistant = g.strip()
+ break
+ if not assistant:
+ assistant = (row.get('solution') or '').strip()
+ if not assistant:
+ continue
+ cot, response = _split_think(assistant)
+ if not cot:
+ continue
+ if not response:
+ response = (row.get('answer') or '').strip()
+ if not response:
+ continue
+ out.append({
+ 'id': _hash_id('open_r1_math', f'{query}\n{response}'),
+ 'source': 'OpenR1-Math-220k',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+S1K_REPO = 'ms://simplescaling/s1K-1.1'
+
+
+class S1KProcessor(Preprocessor):
+ """s1K-1.1 → ``{id, source, query, cot, response}``.
+
+ Schema: ``question`` + ``deepseek_thinking_trajectory`` (or
+ ``thinking_trajectories`` legacy) + ``deepseek_attempt`` (final answer).
+ Hand-curated peak-abstraction set, kept whole.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('question') or row.get('problem') or '').strip()
+ thinking = (row.get('deepseek_thinking_trajectory')
+ or row.get('thinking_trajectories')
+ or row.get('thinking') or '')
+ if isinstance(thinking, list):
+ thinking = '\n\n'.join(t for t in thinking if isinstance(t, str))
+ cot = (thinking or '').strip()
+ response = (row.get('deepseek_attempt') or row.get('attempt')
+ or row.get('answer') or row.get('solution') or '').strip()
+ if not query or not cot or not response:
+ continue
+ out.append({
+ 'id': _hash_id('s1k', f'{query}\n{response}'),
+ 'source': 's1K-1.1',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+# ===========================================================================
+# Code thinking
+# ===========================================================================
+
+OPEN_CODE_REASONING_REPO = 'ms://nv-community/OpenCodeReasoning-2'
+
+
+class OpenCodeReasoning2Processor(Preprocessor):
+ """OpenCodeReasoning-2 → ``{id, source, query, cot, response}``.
+
+ Schema: ``input``/``problem``, plus per-model R1-style trace columns
+ (``r1_generation``, ``qwq_generation``, etc.). Prefer the ``r1`` trace;
+ fall back to ``solution``.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('input') or row.get('problem')
+ or row.get('question') or '').strip()
+ # OCR-2 'python' split ships dirty rows where question is literally '-';
+ # the real prompt is buried in r1_generation and not recoverable here.
+ if not query or query == '-':
+ continue
+ assistant = (row.get('r1_generation') or row.get('reasoning_content')
+ or row.get('solution') or row.get('output') or '').strip()
+ if not assistant:
+ continue
+ cot, response = _split_think(assistant)
+ if not cot:
+ continue
+ if not response:
+ response = (row.get('expected_solution') or row.get('answer') or '').strip()
+ if not response:
+ continue
+ out.append({
+ 'id': _hash_id('opencode_reasoning2', f'{query}\n{response}'),
+ 'source': 'OpenCodeReasoning-2',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+CODEFORCES_COTS_REPO = 'ms://open-r1/codeforces-cots'
+
+
+class CodeforcesCotsProcessor(Preprocessor):
+ """codeforces-cots → ``{id, source, query, cot, response}``.
+
+ Schema: ``description``/``problem``, ``generation``/``solution`` (R1
+ trace with ```` + final code). Algorithmic patterns at high
+ abstraction density.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('description') or row.get('problem')
+ or row.get('input') or row.get('question') or '').strip()
+ assistant = (row.get('generation') or row.get('solution')
+ or row.get('output') or '').strip()
+ if not query or not assistant:
+ continue
+ cot, response = _split_think(assistant)
+ if not cot or not response:
+ continue
+ out.append({
+ 'id': _hash_id('codeforces_cots', f'{query}\n{response}'),
+ 'source': 'codeforces-cots',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+# ===========================================================================
+# Cross-domain R1
+# ===========================================================================
+
+BESPOKE_STRATOS_REPO = 'ms://bespokelabs/Bespoke-Stratos-17k'
+
+
+class BespokeStratosProcessor(Preprocessor):
+ """Bespoke-Stratos-17k → ``{id, source, query, cot, response}``.
+
+ Schema: ``conversations`` (ShareGPT). Assistant content uses Sky-T1
+ markers ``<|begin_of_thought|>...<|end_of_thought|>`` then
+ ``<|begin_of_solution|>...<|end_of_solution|>``.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query, assistant = _from_messages(
+ row.get('conversations') or row.get('messages'))
+ if not query or not assistant:
+ continue
+ cot, response = _split_sky_t1(assistant)
+ if not cot:
+ cot, response = _split_think(assistant)
+ if not cot or not response:
+ continue
+ out.append({
+ 'id': _hash_id('bespoke_stratos', f'{query}\n{response}'),
+ 'source': 'Bespoke-Stratos-17k',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+DOLPHIN_R1_REPO = 'ms://AI-ModelScope/dolphin-r1'
+
+
+class DolphinR1Processor(Preprocessor):
+ """dolphin-r1 → ``{id, source, query, cot, response}``.
+
+ Schema (reasoning-deepseek subset): ``messages=[system, user]`` (no
+ assistant turn) + flat ``reasoning`` (CoT) + ``answer`` (final response)
+ + ``model``. Pull the user turn as query, ``reasoning``/``answer`` as
+ cot/response. Fallback to embedded ```` for legacy rows.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ msgs = row.get('messages') or row.get('conversations')
+ query = ''
+ if isinstance(msgs, list):
+ for msg in msgs:
+ if not isinstance(msg, dict):
+ continue
+ role = msg.get('role') or msg.get('from') or ''
+ content = msg.get('content') or msg.get('value') or ''
+ if role in ('user', 'human') and isinstance(content, str):
+ query = content.strip()
+ cot = (row.get('reasoning') or row.get('reasoning_content') or '').strip()
+ response = (row.get('answer') or '').strip()
+ if (not cot or not response) and isinstance(msgs, list):
+ _, assistant = _from_messages(msgs)
+ if assistant:
+ c2, r2 = _split_think(assistant)
+ if c2:
+ cot = cot or c2
+ response = response or r2 or assistant
+ if not query or not cot or not response:
+ continue
+ out.append({
+ 'id': _hash_id('dolphin_r1', f'{query}\n{response}'),
+ 'source': 'dolphin-r1',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+GLAIVE_REASONING_REPO = 'ms://glaiveai/reasoning-v1-20m'
+
+
+class GlaiveReasoningProcessor(Preprocessor):
+ """reasoning-v1-20m → ``{id, source, query, cot, response}``.
+
+ Schema: ``prompt``, ``response`` (R1 trace with ```` + answer).
+ Largest cross-domain corpus in the mix; downsample aggressively.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('prompt') or row.get('question')
+ or row.get('input') or '').strip()
+ assistant = (row.get('response') or row.get('output')
+ or row.get('answer') or '').strip()
+ if not query or not assistant:
+ continue
+ cot, response = _split_think(assistant)
+ if not cot or not response:
+ continue
+ out.append({
+ 'id': _hash_id('glaive_reasoning', f'{query}\n{response}'),
+ 'source': 'reasoning-v1-20m',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+NATURAL_REASONING_REPO = 'ms://facebook/natural_reasoning'
+
+
+class NaturalReasoningProcessor(Preprocessor):
+ """natural_reasoning → ``{id, source, query, cot, response}``.
+
+ Schema: ``question`` + ``reference_answer`` + ``responses=[{response_model,
+ response}]``. The ``response`` field itself is the step-by-step CoT
+ (``## Step 1...## Step 2...``); there is no separate ``reasoning`` key.
+ Map ``responses[i].response`` → cot, ``reference_answer`` → response.
+ Rows with empty ``reference_answer`` (~18% per README) are dropped.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('question') or '').strip()
+ if not query:
+ continue
+ cot = ''
+ responses = row.get('responses')
+ if isinstance(responses, list):
+ for r in responses:
+ if not isinstance(r, dict):
+ continue
+ txt = (r.get('response') or r.get('reasoning')
+ or r.get('thinking') or r.get('answer') or '').strip()
+ if txt:
+ cot = txt
+ break
+ if not cot:
+ cot = (row.get('reasoning') or row.get('thinking')
+ or row.get('response') or '').strip()
+ response = (row.get('reference_answer') or row.get('answer') or '').strip()
+ if not cot or not response:
+ continue
+ out.append({
+ 'id': _hash_id('natural_reasoning', f'{query}\n{response}'),
+ 'source': 'natural_reasoning',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+# ===========================================================================
+# Textbook-style — synthesize query from chapter heading; body → cot
+# ===========================================================================
+
+COSMOPEDIA_REPO = 'ms://HuggingFaceTB/cosmopedia'
+
+class CosmopediaProcessor(Preprocessor):
+ """cosmopedia v1 → ``{id, source, query, cot, response}``.
+
+ Schema: ``prompt`` (writing instruction), ``text`` (full chapter body),
+ ``format``/``audience``/``seed_data``. The subset is selected at load
+ time (``subset_name='auto_math_text'`` — densest math-textbook slice);
+ H2 chunking inside each row yields synthetic queries
+ (``Explain {heading}``) with the body placed into ``cot``.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ text = (row.get('text') or row.get('content') or '').strip()
+ if not text:
+ continue
+ for title, body in _chunk_by_h2(text):
+ # Heading-only "Explain: X" was 1-2 tokens and impossible to align
+ # with full-section cot. Promote the section's lead paragraph into
+ # the query so anchor carries real semantic content.
+ parts = body.split('\n\n', 1)
+ first_para = parts[0].strip()
+ rest = parts[1].strip() if len(parts) > 1 else ''
+ if len(first_para) < 256 or len(rest) < 256:
+ continue
+ query = f'{title}\n\n{first_para}' if title else first_para
+ out.append({
+ 'id': _hash_id('cosmopedia', f'{title}\n{first_para[:200]}'),
+ 'source': 'cosmopedia-v1',
+ 'query': query,
+ 'cot': rest,
+ 'response': '',
+ })
+ return self.map_row_to_col(out)
+
+
+OMNI_MATH_REPO = 'ms://AI-ModelScope/Omni-MATH'
+
+
+class OmniMathProcessor(Preprocessor):
+ """Omni-MATH → ``{id, source, query, cot, response}``.
+
+ Schema: ``problem``, ``solution`` (full proof), ``answer``, ``domain``,
+ ``difficulty``. Olympiad-grade derivations — solution body → cot,
+ answer → response.
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('problem') or row.get('question') or '').strip()
+ solution = (row.get('solution') or '').strip()
+ answer = (row.get('answer') or row.get('expected_answer') or '').strip()
+ if not query or not solution:
+ continue
+ out.append({
+ 'id': _hash_id('omni_math', f'{query}\n{solution[:200]}'),
+ 'source': 'Omni-MATH',
+ 'query': query,
+ 'cot': solution,
+ 'response': answer,
+ })
+ return self.map_row_to_col(out)
+
+
+# ===========================================================================
+# Mix configuration — base sizes target ≈3.6M total rows
+# ===========================================================================
+
+_BASE_SIZES = {
+ 'open_math_reasoning': 600_000,
+ 'open_r1_math': 220_000,
+ 's1k': 1_000,
+ 'opencode_reasoning2': 500_000,
+ 'codeforces_cots': 200_000,
+ 'bespoke_stratos': 17_000,
+ 'dolphin_r1': 400_000,
+ 'glaive_reasoning': 800_000,
+ 'natural_reasoning': 200_000,
+ 'cosmopedia': 700_000,
+ 'omni_math': 4_000,
+}
+
+
+def _scaled_sizes(total: Optional[int]) -> Dict[str, int]:
+ if total is None or total <= 0:
+ return dict(_BASE_SIZES)
+ scale = total / sum(_BASE_SIZES.values())
+ return {k: max(1, int(round(v * scale))) for k, v in _BASE_SIZES.items()}
+
+
+def _build_dataset(total: Optional[int] = None,
+ load_from_cache_file: bool = True) -> Dataset:
+ sizes = _scaled_sizes(total)
+ dataset = Dataset()
+
+ _register(dataset, OpenMathReasoningProcessor,
+ DatasetMeta(dataset_id=OPEN_MATH_REASONING_REPO, split='cot',
+ data_slice=range(sizes['open_math_reasoning'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, OpenR1MathProcessor,
+ DatasetMeta(dataset_id=OPEN_R1_MATH_REPO, split='train',
+ data_slice=range(sizes['open_r1_math'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, S1KProcessor,
+ DatasetMeta(dataset_id=S1K_REPO, split='train'),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, OpenCodeReasoning2Processor,
+ DatasetMeta(dataset_id=OPEN_CODE_REASONING_REPO,
+ subset_name='train', split='python',
+ data_slice=range(sizes['opencode_reasoning2'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, CodeforcesCotsProcessor,
+ DatasetMeta(dataset_id=CODEFORCES_COTS_REPO,
+ subset_name='solutions_w_editorials_decontaminated',
+ split='train',
+ data_slice=range(sizes['codeforces_cots'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, BespokeStratosProcessor,
+ DatasetMeta(dataset_id=BESPOKE_STRATOS_REPO, split='train'),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, DolphinR1Processor,
+ DatasetMeta(dataset_id=DOLPHIN_R1_REPO,
+ subset_name='reasoning-deepseek', split='train',
+ data_slice=range(sizes['dolphin_r1'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, GlaiveReasoningProcessor,
+ DatasetMeta(dataset_id=GLAIVE_REASONING_REPO, split='train',
+ data_slice=range(sizes['glaive_reasoning'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, NaturalReasoningProcessor,
+ DatasetMeta(dataset_id=NATURAL_REASONING_REPO, split='train',
+ data_slice=range(sizes['natural_reasoning'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, CosmopediaProcessor,
+ DatasetMeta(dataset_id=COSMOPEDIA_REPO,
+ subset_name='auto_math_text', split='train',
+ data_slice=range(sizes['cosmopedia'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, OmniMathProcessor,
+ DatasetMeta(dataset_id=OMNI_MATH_REPO, split='test'),
+ load_from_cache_file=load_from_cache_file)
+
+ dataset.mix_dataset(False)
+ # Mix is concatenated in registration order; shuffle so the streaming
+ # consumer sees all sources interleaved instead of 600k OpenMathReasoning
+ # rows before it ever reaches code/textbook splits.
+ dataset.dataset = dataset.dataset.shuffle(seed=42)
+ return dataset
+
+
+def get_dataset(total: Optional[int] = None,
+ dropped_log: Optional[str] = None,
+ load_from_cache_file: bool = True) -> Dataset:
+ """Build, convert to messages, and quality-filter the RAG-index corpus.
+
+ Mirrors ``dataset_think.get_dataset``: identical signature + output
+ schema so ``build_thinking_rag_index`` consumes both modules unchanged.
+ """
+ from twinkle_agentic.preprocessor import (
+ DeadLoopFilter,
+ FixUnicodeFilter,
+ HardFilter,
+ MessageSanityFilter,
+ QualityPreprocessor,
+ RefuseFilter,
+ RemoveRepeatSentencesFilter,
+ TokenNumFilter,
+ TokenSoupFilter,
+ )
+
+ dataset = _build_dataset(total=total, load_from_cache_file=load_from_cache_file)
+ # Drop trivially-short queries (e.g. one-line math problems, OmniMath stubs)
+ # before message conversion — anchor side needs enough tokens to embed meaningfully.
+ dataset.dataset = dataset.dataset.filter(
+ lambda x: len((x.get('query') or '').strip()) >= 100,
+ num_proc=32, load_from_cache_file=load_from_cache_file)
+ dataset.map(ToMessagesProcessor(), remove_columns=['query', 'cot', 'response'],
+ load_from_cache_file=load_from_cache_file)
+ qp = QualityPreprocessor(
+ pipeline=[
+ HardFilter(),
+ RefuseFilter(),
+ DeadLoopFilter(),
+ TokenSoupFilter(),
+ MessageSanityFilter(min_turns=1, max_msg_chars=200000),
+ FixUnicodeFilter(),
+ RemoveRepeatSentencesFilter(),
+ TokenNumFilter(max_num=32768),
+ ],
+ dropped_log_path=dropped_log or '',
+ )
+ dataset.map(qp, num_proc=32, load_from_cache_file=load_from_cache_file)
+ return dataset
+
+
+if __name__ == '__main__':
+ import os
+ dropped_log = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+ 'dropped_index.jsonl')
+ if os.path.exists(dropped_log):
+ os.remove(dropped_log)
+ dataset = get_dataset(load_from_cache_file=False)
+ print(len(dataset))
diff --git a/cookbook/exp/embedding/train_embedding_full_ddp.py b/cookbook/exp/embedding/train_embedding_full_ddp.py
index 492e29aae..5db9f786c 100644
--- a/cookbook/exp/embedding/train_embedding_full_ddp.py
+++ b/cookbook/exp/embedding/train_embedding_full_ddp.py
@@ -1,14 +1,12 @@
-"""LoRA embedding training with online condenser self-improvement.
+"""LoRA embedding training with online compression via frozen vLLM condenser.
Architecture (8 GPUs total):
- Ranks 0-3 (``model``): Trainable embedding model with LoRA, InfoNCE loss.
- - Ranks 4-5 (``condenser_sampler``): Frozen vLLM condenser for online compression.
- - Ranks 6-7 (``condenser_model``): Trainable condenser with LoRA for self-improvement.
+ - Ranks 4-7 (``condenser_sampler``): Frozen vLLM condenser for online compression.
-When the condenser sampler truncates (stop_reason='length'), an external OpenAI-
-compatible API produces the correct compression. The failure is logged as SFT
-training data. A background thread retrains the condenser on accumulated failures
-mixed with condense_300K, then syncs weights back to the sampler.
+When the condenser sampler truncates or regresses to the legacy schema, an
+external OpenAI-compatible API produces the correct compression. The failure is
+logged to failures.jsonl for offline SFT data regeneration.
Launch:
python cookbook/exp/train_embedding_lora_ddp.py
@@ -19,6 +17,7 @@
import re
import sys
import threading
+import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
@@ -27,7 +26,6 @@
import twinkle
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
-from twinkle.checkpoint_engine import CheckpointEngineManager
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.loss import InfonceLoss
@@ -35,12 +33,13 @@
from twinkle.model import TransformersModel
from twinkle.processor import InputProcessor
from twinkle.sampler import vLLMSampler
-from twinkle.template import Template
+from twinkle.template import Qwen3_5Template, Template
from twinkle.utils.parallel import PosixFileLock
from twinkle_agentic.protocol.openai import OpenAI as OpenAIClient
sys.path.insert(0, str(Path(__file__).resolve().parent))
-from dataset_think import get_dataset # noqa: E402
+from dataset_think import get_dataset as get_dataset_think # noqa: E402
+from dataset_index import get_dataset as get_dataset_index # noqa: E402
logger = get_logger()
@@ -54,29 +53,33 @@
# -- GPU placement (8 total) --------------------------------------------------
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
-CONDENSER_SAMPLER_GPUS = int(os.environ.get('CONDENSER_SAMPLER_GPUS', 2))
-CONDENSER_MODEL_GPUS = int(os.environ.get('CONDENSER_MODEL_GPUS', 2))
-NUM_GPUS = MODEL_GPUS + CONDENSER_SAMPLER_GPUS + CONDENSER_MODEL_GPUS
+CONDENSER_SAMPLER_GPUS = int(os.environ.get('CONDENSER_SAMPLER_GPUS', 4))
+NUM_GPUS = MODEL_GPUS + CONDENSER_SAMPLER_GPUS
# -- Embedding training hyper-params ------------------------------------------
EMB_MAX_LENGTH = 8192
HARD_NEGATIVES = None
-TEMPERATURE = 0.03
+# 0.07 keeps gradient on diag pairs until cosine clears ~0.75; 0.03 saturated near 0.40.
+TEMPERATURE = 0.07
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32))
-LEARNING_RATE = 1.5e-6
+LEARNING_RATE = 1e-5
GRADIENT_ACCUMULATION_STEPS = 1
LOG_INTERVAL = 2
-SAVE_INTERVAL = 4000
-NUM_EPOCHS = 2
+SAVE_INTERVAL = 2000
+NUM_EPOCHS = 1
TOTAL_SAMPLES: Optional[int] = None
+# Post-build caps on each loader (None = no cap). Applied via .select() before mix.
+THINK_CAP: Optional[int] = 400_000
+INDEX_CAP: Optional[int] = 400_000
+MIX_SHUFFLE_SEED = 42
# -- Resume from checkpoint ---------------------------------------------------
-RESUME_CHECKPOINT = os.environ.get(
- 'RESUME_CHECKPOINT',
- './output/embedding_lora_transformers/step_16000')
-RESUME_STEP = int(os.environ.get('RESUME_STEP', 16000))
+# Empty by default — build_model falls back to MODEL_ID (the published emb model).
+# Set both to point at a local in-progress run only when resuming the *same* schedule.
+RESUME_CHECKPOINT = os.environ.get('RESUME_CHECKPOINT', '')
+RESUME_STEP = int(os.environ.get('RESUME_STEP', 0))
# -- Online-compression knobs -------------------------------------------------
# Below this length, condenser fabricates content for open-ended short prompts;
@@ -87,16 +90,18 @@
COMPRESS_TOP_P = 0.5
COMPRESS_MAX_MODEL_LEN = 32768
+# How many BATCH_SIZE chunks to fetch and compress in one vLLM call.
+PREFETCH_BATCH_MULTIPLIER = int(os.environ.get('PREFETCH_BATCH_MULTIPLIER', 8))
+
# -- OpenAI API fallback for truncated compressions ---------------------------
COMPRESS_API_KEY = os.environ.get('COMPRESS_API_KEY', '')
COMPRESS_BASE_URL = os.environ.get('COMPRESS_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1')
COMPRESS_MODEL = os.environ.get('COMPRESS_MODEL', 'qwen3.7-max')
-
-# -- Condenser retraining knobs -----------------------------------------------
-CONDENSER_DATASET_ID = 'ms://twinkle-kit/condense_300K'
-CONDENSER_RETRAIN_SAMPLES = 128
-CONDENSER_RETRAIN_EPOCHS = 3
-CONDENSER_RETRAIN_LR = 1e-5
+# Minimum gap between API calls (seconds); bounds dashscope qps under provider limits.
+API_MIN_INTERVAL = float(os.environ.get('API_MIN_INTERVAL', 0.1))
+API_CONCURRENCY = int(os.environ.get('API_CONCURRENCY', 8))
+# vLLM sampler timeout (seconds); if a sample() call exceeds this, fall back to API.
+SAMPLER_TIMEOUT = float(os.environ.get('SAMPLER_TIMEOUT', 300))
# -- Output paths -------------------------------------------------------------
OUTPUT_DIR = f'./output/embedding_lora_{BACKEND}'
@@ -204,6 +209,17 @@
_sample_counter = 0
_sample_counter_lock = threading.Lock()
+_api_throttle_lock = threading.Lock()
+_api_last_call = [0.0]
+
+
+def _api_throttle():
+ with _api_throttle_lock:
+ gap = time.monotonic() - _api_last_call[0]
+ if gap < API_MIN_INTERVAL:
+ time.sleep(API_MIN_INTERVAL - gap)
+ _api_last_call[0] = time.monotonic()
+
def _next_sample_id() -> int:
global _sample_counter
@@ -314,11 +330,40 @@ def save_checkpoint(model, name: str):
# Compression prompt building
# =============================================================================
+# Hard-templated hints: the condenser SFT prior maps `Skill` to the legacy
+# `Use when: / numbered steps / Output:` skeleton on long inputs; embedding the
+# exact 4-line body template + explicit negative constraints is the only way to
+# override it deterministically across query and cot sides.
EMBED_QUERY_Q = (
+ 'Summarize this query for retrieval. '
+ 'The body of ## Summary MUST follow this EXACT 4-line template — '
+ 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n'
+ 'Topic: \n'
+ 'Problem: \n'
+ 'Skill: \n'
+ 'Knowledge: \n'
+ 'Then emit the mandatory ## More section as usual. '
+ 'Topic must name the specific pattern, never generic labels.')
+EMBED_QUERY_COT = (
+ 'Summarize this reasoning trace for retrieval. '
+ 'The body of ## Summary MUST follow this EXACT 4-line template — '
+ 'do NOT emit "Use when:", numbered procedure steps, or "Output:":\n'
+ 'Topic: \n'
+ 'Problem: \n'
+ 'Skill: \n'
+ 'Knowledge: \n'
+ 'Then emit the mandatory ## More section as usual. '
+ 'Topic must name the specific pattern, never generic labels.')
+
+# Legacy schema (Use when: / numbered steps / Output:) — mixed in 50/50 with the
+# new schema to expose the embedder to schema-invariant semantic alignment.
+# Both query and cot of the SAME pair always use the SAME schema; cross-schema
+# anchors and positives would re-introduce the schema asymmetry we just fixed.
+EMBED_QUERY_Q_LEGACY = (
'What problem does this passage address, and what skill or method is needed? '
'Topic must name the specific pattern, never generic labels. '
'Compress into a retrieval-friendly need description.')
-EMBED_QUERY_COT = (
+EMBED_QUERY_COT_LEGACY = (
'Extract the reusable skill: trigger conditions, key steps, and expected output. '
'Topic names the method/pattern; format as "Use when: ...", numbered steps, '
'"Output: ...". Compress into a standardized procedure for retrieval.')
@@ -342,45 +387,59 @@ def _extract_query_cot(row: Dict[str, Any]):
def _build_compress_prompts(rows: List[Dict[str, Any]]) -> tuple:
"""Build prompts for compressing both query and cot per row.
- Returns (prompts, valid_indices, raw_pairs, prompt_queries, passthrough) where:
+ Returns (prompts, valid_indices, raw_pairs, prompt_queries, passthrough, schemas)
+ where:
- prompts: flat-interleaved [query_0, cot_0, query_1, cot_1, ...]; ``None`` means
passthrough (use raw text directly, do not call sampler)
- valid_indices: which rows passed the min-length filter
- raw_pairs: [(query, cot), ...]
- prompt_queries: the query string used for each prompt (for failure logging)
- passthrough: parallel to prompts; non-None text means "use this verbatim as qc"
+ - schemas: parallel to prompts; 'new' or 'legacy', drives validator branch
"""
prompts: List[Optional[Dict[str, Any]]] = []
valid_indices: List[int] = []
raw_pairs: List[tuple] = []
prompt_queries: List[str] = []
passthrough: List[Optional[str]] = []
+ schemas: List[str] = []
+ # Conservative char budget: 32768 max_length - 8192 gen - ~2k prompt overhead = ~22k tokens.
+ # 30k cap bounds vLLM batch latency (vLLM batches by max prompt length).
+ _MAX_COT_CHARS = 30_000
for i, row in enumerate(rows):
query, cot = _extract_query_cot(row)
if not query or len(cot) < MIN_TEXT_CHARS:
continue
+ if len(cot) > _MAX_COT_CHARS:
+ continue
valid_indices.append(i)
raw_pairs.append((query, cot))
+ # 50/50 schema mix; same schema for query+cot of one pair to keep alignment.
+ schema = 'legacy' if (i % 2 == 0) else 'new'
+ q_hint = EMBED_QUERY_Q_LEGACY if schema == 'legacy' else EMBED_QUERY_Q
+ c_hint = EMBED_QUERY_COT_LEGACY if schema == 'legacy' else EMBED_QUERY_COT
# Short query bypasses condenser to avoid skeleton-induced hallucination.
if len(query) < MIN_TEXT_CHARS:
prompts.append(None)
passthrough.append(query)
else:
- user = COMPRESS_USER.format(query=EMBED_QUERY_Q, text=query)
+ user = COMPRESS_USER.format(query=q_hint, text=query)
prompts.append({'messages': [
{'role': 'system', 'content': COMPRESS_SYSTEM},
{'role': 'user', 'content': user},
]})
passthrough.append(None)
- prompt_queries.append(EMBED_QUERY_Q)
- user = COMPRESS_USER.format(query=EMBED_QUERY_COT, text=cot)
+ prompt_queries.append(q_hint)
+ schemas.append(schema)
+ user = COMPRESS_USER.format(query=c_hint, text=cot)
prompts.append({'messages': [
{'role': 'system', 'content': COMPRESS_SYSTEM},
{'role': 'user', 'content': user},
]})
- prompt_queries.append(EMBED_QUERY_COT)
+ prompt_queries.append(c_hint)
passthrough.append(None)
- return prompts, valid_indices, raw_pairs, prompt_queries, passthrough
+ schemas.append(schema)
+ return prompts, valid_indices, raw_pairs, prompt_queries, passthrough, schemas
def _get_first_feature(decoded_text: str, template: Template, role: str) -> Optional[Dict[str, Any]]:
@@ -405,13 +464,24 @@ def _get_first_feature(decoded_text: str, template: Template, role: str) -> Opti
# OpenAI API fallback
# =============================================================================
-def _is_truncated_compression(text: str) -> bool:
- """Detect structurally incomplete output that vLLM may report as stop_reason='stop'.
+_LEGACY_USE_WHEN_RE = re.compile(r'(?im)^\s*Use when\s*:')
+_SCHEMA_MARKERS = ('Problem:', 'Skill:', 'Knowledge:')
+
+
+def _is_truncated_compression(text: str, schema: str = 'new') -> bool:
+ """Reject structurally incomplete OR schema-regressed condenser output.
- The condenser sometimes emits a chat-template token mid-skeleton (which we then
- strip), so the visible text ends mid-sentence even though stop_reason!='length'.
- The COMPRESS_SYSTEM skeleton mandates a `## More` section ending in a bullet list;
- its absence is an unambiguous truncation signal.
+ Triggers API fallback when the vLLM output:
+ * lacks ``## Summary`` / ``## More``,
+ * has an empty or unterminated ``## More`` bullet list, or
+ * (schema='new' only) regresses to the legacy ``Use when: / numbered-steps /
+ Output:`` skeleton instead of the mandated Problem/Skill/Knowledge 4-line
+ body — the dominant cot-side failure mode that drives sim < 0.45 drops on
+ the RAG index.
+
+ For schema='legacy', body markers are intentionally NOT enforced: the legacy
+ template legitimately emits ``Use when:`` and the SFT prior already produces
+ that shape natively, so only structural completeness is checked.
"""
if not text or not text.strip():
return True
@@ -423,11 +493,18 @@ def _is_truncated_compression(text: str) -> bool:
last_line = after_more.splitlines()[-1].strip()
if not (last_line.startswith('-') or last_line.endswith(')')):
return True
+ if schema == 'new':
+ summary_body = text.split('## Summary', 1)[1].split('## More', 1)[0]
+ if _LEGACY_USE_WHEN_RE.search(summary_body):
+ return True
+ if not all(marker in summary_body for marker in _SCHEMA_MARKERS):
+ return True
return False
def _api_compress(api_client: OpenAIClient, prompt: Dict[str, Any]) -> Optional[str]:
"""Call external API to compress when vLLM truncates."""
+ _api_throttle()
trajectory = {'messages': prompt['messages']}
# Cap max_tokens to leave ample prompt headroom inside the API model context.
sp = SamplingParams(temperature=0.2, max_tokens=8192)
@@ -446,61 +523,12 @@ def _api_compress(api_client: OpenAIClient, prompt: Dict[str, Any]) -> Optional[
return content
-# =============================================================================
-# Condenser Retrainer (background thread)
-# =============================================================================
-
-class CondenserRetrainer:
- """Async condenser self-improvement: retrains from failures, syncs to sampler."""
-
- def __init__(self, condenser_model, ckpt_manager: CheckpointEngineManager,
- condenser_sampler):
- self._model = condenser_model
- self._ckpt_manager = ckpt_manager
- self._sampler = condenser_sampler
- self._signal = threading.Event()
- self._stop = threading.Event()
- self._thread = threading.Thread(target=self._loop, daemon=True)
- self._condense_300k_cache = None
- self._retrain_count = 0
- # Prevents sample() and sync_weights() from running concurrently
- self.sampler_lock = threading.Lock()
-
- def start(self):
- self._thread.start()
-
- def stop(self):
- self._stop.set()
- self._signal.set()
- self._thread.join(timeout=10)
-
- def notify_failure(self):
- self._signal.set()
-
- def _loop(self):
- while not self._stop.is_set():
- self._signal.wait(timeout=60)
- if self._stop.is_set():
- break
- if not self._signal.is_set():
- continue
- self._signal.clear()
- try:
- self._retrain_and_sync()
- except Exception as exc:
- logger.error(f'[condenser_retrain] crashed: {exc}')
-
- def _retrain_and_sync(self):
- # Retrain + sync temporarily disabled; failures.jsonl is written directly by _log_failure.
- pass
-
-
# =============================================================================
# Main training
# =============================================================================
def train():
- # -------- Device groups (3 groups) ----------------------------------------
+ # -------- Device groups (2 groups) ----------------------------------------
device_groups = [
DeviceGroup(name='model',
ranks=list(range(MODEL_GPUS)),
@@ -508,22 +536,31 @@ def train():
DeviceGroup(name='condenser_sampler',
ranks=list(range(MODEL_GPUS, MODEL_GPUS + CONDENSER_SAMPLER_GPUS)),
device_type='GPU'),
- DeviceGroup(name='condenser_model',
- ranks=list(range(MODEL_GPUS + CONDENSER_SAMPLER_GPUS, NUM_GPUS)),
- device_type='GPU'),
]
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
condenser_sampler_mesh = DeviceMesh.from_sizes(
world_size=CONDENSER_SAMPLER_GPUS, dp_size=CONDENSER_SAMPLER_GPUS)
- condenser_model_mesh = DeviceMesh.from_sizes(
- world_size=CONDENSER_MODEL_GPUS, dp_size=1, fsdp_size=CONDENSER_MODEL_GPUS)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups)
# -------- Data -----------------------------------------------------------
- dataset = get_dataset(total=TOTAL_SAMPLES, load_from_cache_file=True)
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
- total_forward_steps = len(dataloader) * NUM_EPOCHS
+ dataset = get_dataset_think(total=TOTAL_SAMPLES, load_from_cache_file=True)
+ if THINK_CAP and len(dataset.dataset) > THINK_CAP:
+ dataset.dataset = dataset.dataset.select(range(THINK_CAP))
+ if INDEX_CAP != 0:
+ from datasets import concatenate_datasets
+ ds_index = get_dataset_index(total=None, load_from_cache_file=True)
+ if INDEX_CAP and len(ds_index.dataset) > INDEX_CAP:
+ ds_index.dataset = ds_index.dataset.select(range(INDEX_CAP))
+ n_think = len(dataset.dataset)
+ n_index = len(ds_index.dataset)
+ # Both loaders emit identical {id, source, messages} schema post-QP.
+ dataset.dataset = concatenate_datasets(
+ [dataset.dataset, ds_index.dataset]).shuffle(seed=MIX_SHUFFLE_SEED)
+ logger.info(f'[mix] think={n_think} + index={n_index} → total={len(dataset.dataset)}')
+ _mega_batch_size = BATCH_SIZE * PREFETCH_BATCH_MULTIPLIER
+ dataloader = DataLoader(dataset=dataset, batch_size=_mega_batch_size, shuffle=True)
+ total_forward_steps = len(dataloader) * PREFETCH_BATCH_MULTIPLIER * NUM_EPOCHS
optimizer_steps = total_forward_steps // GRADIENT_ACCUMULATION_STEPS
# -------- Embedding model (4 GPU) ----------------------------------------
@@ -534,10 +571,10 @@ def train():
setup_optimizer(model, optimizer_steps)
model.add_metric(EmbeddingMetric, is_training=True)
- # -------- Condenser sampler (2 GPU, vLLM) --------------------------------
- emb_template = Template(model_id=MODEL_ID, max_length=EMB_MAX_LENGTH, enable_thinking=False)
+ # -------- Condenser sampler (4 GPU, vLLM) --------------------------------
+ emb_template = Qwen3_5Template(model_id=MODEL_ID, max_length=EMB_MAX_LENGTH, enable_thinking=False)
# Special tokens come from the condenser tokenizer because the leak we strip is in its decoded output.
- condenser_template = Template(model_id=CONDENSE_MODEL_ID, max_length=DATASET_MAX_TOKENS,
+ condenser_template = Qwen3_5Template(model_id=CONDENSE_MODEL_ID, max_length=DATASET_MAX_TOKENS,
enable_thinking=False)
_special_tokens = set(condenser_template.tokenizer.all_special_tokens)
condenser_sampler = vLLMSampler(
@@ -559,23 +596,32 @@ def train():
num_samples=1,
)
- # -------- Condenser model (2 GPU, trainable full-param) -------------------
- condenser_model = TransformersModel(
- model_id=CONDENSE_MODEL_ID,
- device_mesh=condenser_model_mesh,
- remote_group='condenser_model',
- )
- condenser_model.set_optimizer(optimizer_cls='AdamW', lr=CONDENSER_RETRAIN_LR)
-
- # -------- CheckpointEngineManager: condenser_model → condenser_sampler ---
- condenser_ckpt_manager = CheckpointEngineManager(
- model=condenser_model, sampler=condenser_sampler)
- condenser_ckpt_manager.sync_weights()
-
- # -------- Background retrainer -------------------------------------------
- retrainer = CondenserRetrainer(condenser_model, condenser_ckpt_manager,
- condenser_sampler)
- retrainer.start()
+ condenser_sampler._ray_get_timeout = SAMPLER_TIMEOUT
+ _sampler_epoch = 0
+
+ def _rebuild_sampler():
+ """Kill stuck actors and recreate the vLLM sampler from scratch."""
+ nonlocal condenser_sampler, _sampler_epoch
+ import ray
+ for actor in getattr(condenser_sampler, '_actors', []):
+ try:
+ ray.kill(actor, no_restart=True)
+ except Exception:
+ pass
+ logger.warning('[sampler] killed stuck actors, recreating sampler \u2026')
+ new = vLLMSampler(
+ model_id=CONDENSE_MODEL_ID,
+ engine_args={'gpu_memory_utilization': 0.8, 'max_model_len': COMPRESS_MAX_MODEL_LEN},
+ device_mesh=condenser_sampler_mesh,
+ remote_group='condenser_sampler',
+ )
+ new.set_template(
+ TEMPLATE_NAME, model_id=CONDENSE_MODEL_ID, enable_thinking=False,
+ truncation_strategy='delete', max_length=DATASET_MAX_TOKENS)
+ new._ray_get_timeout = SAMPLER_TIMEOUT
+ condenser_sampler = new
+ _sampler_epoch += 1
+ logger.warning('[sampler] sampler rebuilt successfully')
# -------- OpenAI API client for fallback ---------------------------------
api_client = OpenAIClient(
@@ -606,28 +652,41 @@ def train():
# -------- Train loop -----------------------------------------------------
def _sample_batch(raw_batch):
"""Compress via vLLM sampler; fall back to API on truncation."""
- compress_prompts, valid_indices, raw_pairs, prompt_queries, passthrough = \
+ _t_enter = time.monotonic()
+ compress_prompts, valid_indices, raw_pairs, prompt_queries, passthrough, schemas = \
_build_compress_prompts(raw_batch)
- if not compress_prompts:
+ _t_build = time.monotonic()
+ if len(compress_prompts) < 4:
return None
# Only submit non-passthrough prompts to the sampler.
sampler_input = [p for p in compress_prompts if p is not None]
sampler_pos = [ri for ri, p in enumerate(compress_prompts) if p is not None]
if sampler_input:
- with retrainer.sampler_lock:
+ try:
sampler_responses = condenser_sampler.sample(sampler_input, compress_params)
+ except Exception as exc:
+ logger.warning(f'[sampler] error \u2192 API fallback: {exc}')
+ sampler_responses = [None] * len(sampler_input)
+ if 'Timeout' in type(exc).__name__:
+ try:
+ _rebuild_sampler()
+ except Exception as re_exc:
+ logger.error(f'[sampler] rebuild failed: {re_exc}')
else:
sampler_responses = []
+ _t_sample = time.monotonic()
+
responses = [None] * len(compress_prompts)
for resp, pos in zip(sampler_responses, sampler_pos):
responses[pos] = resp
# Extract decoded texts; detect truncations and fall back to API
- decoded_texts: List[str] = []
+ decoded_texts: List[Optional[str]] = [None] * len(compress_prompts)
+ fallback_indices: List[int] = []
for ri in range(len(compress_prompts)):
if passthrough[ri] is not None:
- decoded_texts.append(passthrough[ri])
+ decoded_texts[ri] = passthrough[ri]
continue
resp = responses[ri]
seq = resp.sequences[0] if resp and resp.sequences else None
@@ -638,27 +697,33 @@ def _sample_batch(raw_batch):
text = text.replace(tok, '')
text = text.rstrip()
- # Premature-EOS: model emits chat-template token mid-skeleton, vLLM reports
- # stop_reason='stop' but the stripped text is structurally incomplete.
needs_fallback = (not seq or seq.stop_reason == 'length'
- or _is_truncated_compression(text))
+ or _is_truncated_compression(text, schemas[ri]))
if not needs_fallback:
- decoded_texts.append(text)
- continue
-
- api_result = _api_compress(api_client, compress_prompts[ri])
- # Skip logging when the API itself produced truncated output: an incomplete
- # gold answer would teach the condenser to imitate broken outputs.
- if api_result and not _is_truncated_compression(api_result):
- decoded_texts.append(api_result)
- pair_idx = ri // 2
- q_raw, c_raw = raw_pairs[pair_idx]
- source_text = q_raw if ri % 2 == 0 else c_raw
- _log_failure(source_text, prompt_queries[ri], api_result,
- valid_indices[pair_idx])
- retrainer.notify_failure()
+ decoded_texts[ri] = text
else:
- decoded_texts.append('')
+ fallback_indices.append(ri)
+
+ _api_calls = len(fallback_indices)
+ if fallback_indices:
+ from concurrent.futures import as_completed
+ api_futures = {}
+ with ThreadPoolExecutor(max_workers=API_CONCURRENCY) as api_pool:
+ for ri in fallback_indices:
+ api_futures[api_pool.submit(_api_compress, api_client, compress_prompts[ri])] = ri
+ for fut in as_completed(api_futures):
+ ri = api_futures[fut]
+ api_result = fut.result()
+ if api_result and not _is_truncated_compression(api_result, schemas[ri]):
+ decoded_texts[ri] = api_result
+ pair_idx = ri // 2
+ q_raw, c_raw = raw_pairs[pair_idx]
+ source_text = q_raw if ri % 2 == 0 else c_raw
+ _log_failure(source_text, prompt_queries[ri], api_result,
+ valid_indices[pair_idx])
+ else:
+ decoded_texts[ri] = ''
+ _t_api = time.monotonic()
# Build embedding features from decoded texts
emb_features: List[Dict[str, Any]] = []
@@ -673,69 +738,96 @@ def _sample_batch(raw_batch):
if feat_q and feat_c:
emb_features.append(feat_q)
emb_features.append(feat_c)
+ _t_feat = time.monotonic()
- if len(emb_features) < 4:
- return None
- return emb_features
+ logger.info(
+ f'[prefetch] prompts={len(sampler_input)} api={_api_calls} feats={len(emb_features)} | '
+ f'build={_t_build - _t_enter:.1f}s '
+ f'vllm={_t_sample - _t_build:.1f}s '
+ f'api={_t_api - _t_sample:.1f}s feat={_t_feat - _t_api:.1f}s '
+ f'total={_t_feat - _t_enter:.1f}s')
+
+ _target = BATCH_SIZE * 2
+ minibatches = [emb_features[i:i + _target] for i in range(0, len(emb_features), _target)]
+ minibatches = [mb for mb in minibatches if len(mb) >= 4]
+ return minibatches if minibatches else None
cur_step = RESUME_STEP
- # Compute which epoch and how many batches to skip within that epoch
_batches_per_epoch = len(dataloader)
- _start_epoch = cur_step // _batches_per_epoch if cur_step > 0 else 0
- _skip_batches_in_epoch = cur_step - _start_epoch * _batches_per_epoch if cur_step > 0 else 0
+ _steps_per_mega = PREFETCH_BATCH_MULTIPLIER
+ _start_epoch = cur_step // (_batches_per_epoch * _steps_per_mega) if cur_step > 0 else 0
+ _skip_batches_in_epoch = max(0, cur_step // _steps_per_mega - _start_epoch * _batches_per_epoch)
+
+ _ema_prefetch = 0.0
+ _ema_train = 0.0
+ _ema_alpha = 0.1
prefetch_executor = ThreadPoolExecutor(max_workers=1)
for epoch in range(_start_epoch, NUM_EPOCHS):
- # Skip consumed samples for the resume epoch (shuffle order won't match
- # exactly, but the correct number of samples is skipped).
if _skip_batches_in_epoch > 0:
- dataloader.skip_consumed_samples(_skip_batches_in_epoch * BATCH_SIZE)
+ dataloader.skip_consumed_samples(_skip_batches_in_epoch * _mega_batch_size)
batch_iter = iter(dataloader)
- # Reset skip after first resumed epoch
_skip_batches_in_epoch = 0
- prefetch_future = None
- first_batch = next(batch_iter, None)
- if first_batch is not None:
- prefetch_future = prefetch_executor.submit(_sample_batch, first_batch)
- for raw_batch in batch_iter:
- emb_features = prefetch_future.result() if prefetch_future else None
- prefetch_future = prefetch_executor.submit(_sample_batch, raw_batch)
+ first = next(batch_iter, None)
+ future = prefetch_executor.submit(_sample_batch, first) if first else None
- if emb_features is None:
+ for raw_mega_batch in batch_iter:
+ t0 = time.monotonic()
+ minibatches = future.result() if future else None
+ t_prefetch = time.monotonic() - t0
+ future = prefetch_executor.submit(_sample_batch, raw_mega_batch)
+
+ if not minibatches:
continue
- model.forward_backward(inputs=emb_features, task='embedding')
- model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
- cur_step += 1
-
- if cur_step % LOG_INTERVAL == 0:
- metric = model.calculate_metric(is_training=True)
- logger.info(
- f'Epoch {epoch} Step {cur_step}/{total_forward_steps}, metric: {metric}')
- log_dict = {}
- for k, v in metric.items():
- if not v:
- continue
- try:
- log_dict[k] = float(v)
- except (ValueError, TypeError):
- pass
- log_dict['epoch'] = epoch
- swanlab.log(log_dict, step=cur_step)
- if cur_step % SAVE_INTERVAL == 0:
- save_checkpoint(model, f'step_{cur_step}')
-
- # # Drain last prefetched batch
- # if prefetch_future is not None:
- # emb_features = prefetch_future.result()
- # if emb_features is not None:
- # model.forward_backward(inputs=emb_features, task='embedding')
- # model.clip_grad_and_step()
- # cur_step += 1
+ for mb in minibatches:
+ t1 = time.monotonic()
+ model.forward_backward(inputs=mb, task='embedding')
+ model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+ t_train = time.monotonic() - t1
+ cur_step += 1
+
+ _ema_prefetch = _ema_alpha * t_prefetch + (1 - _ema_alpha) * _ema_prefetch if cur_step > RESUME_STEP + 1 else t_prefetch
+ _ema_train = _ema_alpha * t_train + (1 - _ema_alpha) * _ema_train if cur_step > RESUME_STEP + 1 else t_train
+
+ if cur_step % LOG_INTERVAL == 0:
+ metric = model.calculate_metric(is_training=True)
+ _bottleneck = 'PREFETCH' if _ema_prefetch > _ema_train else 'TRAIN'
+ logger.info(
+ f'Epoch {epoch} Step {cur_step}/{total_forward_steps}, metric: {metric} | '
+ f'prefetch={t_prefetch:.1f}s(ema {_ema_prefetch:.1f}) '
+ f'train={t_train:.1f}s(ema {_ema_train:.1f}) '
+ f'bottleneck={_bottleneck}')
+ log_dict = {}
+ for k, v in metric.items():
+ if not v:
+ continue
+ try:
+ log_dict[k] = float(v)
+ except (ValueError, TypeError):
+ pass
+ log_dict['epoch'] = epoch
+ log_dict['prefetch_sec'] = round(t_prefetch, 2)
+ log_dict['train_sec'] = round(t_train, 2)
+ swanlab.log(log_dict, step=cur_step)
+ if cur_step % SAVE_INTERVAL == 0:
+ save_checkpoint(model, f'step_{cur_step}')
+ t_prefetch = 0.0
+
+ # Drain final mega-batch
+ if future:
+ minibatches = future.result()
+ future = None
+ if minibatches:
+ for mb in minibatches:
+ model.forward_backward(inputs=mb, task='embedding')
+ model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+ cur_step += 1
+ if cur_step % SAVE_INTERVAL == 0:
+ save_checkpoint(model, f'step_{cur_step}')
prefetch_executor.shutdown(wait=False)
- retrainer.stop()
save_checkpoint(model, 'last-checkpoint')
diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py
index 650cf67b6..13c5ccfb1 100644
--- a/cookbook/megatron/tp.py
+++ b/cookbook/megatron/tp.py
@@ -5,42 +5,26 @@
import twinkle
from twinkle import DeviceMesh, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import MegatronModel
from twinkle.preprocessor import SelfCognitionProcessor
logger = get_logger()
+args = CLI.from_args()
-MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
-DATASET_ID = 'ms://swift/self-cognition'
-TEMPLATE_NAME = 'Qwen3_5Template'
-MODEL_NAME = 'twinkle大模型'
-MODEL_AUTHOR = 'ModelScope社区'
-DP_SIZE = 2
-TP_SIZE = 2
-PP_SIZE = 2
-BATCH_SIZE = 16
-LEARNING_RATE = 1e-4
-LOG_INTERVAL = 5
-EVAL_INTERVAL = 20
-EVAL_SAMPLES = 100
-TRAIN_SAMPLES = 1000
-
-OUTPUT_DIR = './output/megatron_tp'
-RESUME_FROM_CHECKPOINT = None
-RESUME_ONLY_MODEL = False
-IGNORE_DATA_SKIP = False
-ADAPTER_NAME = 'default'
-
-device_mesh = DeviceMesh.from_sizes(dp_size=DP_SIZE, tp_size=TP_SIZE, pp_size=PP_SIZE)
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+device_mesh = DeviceMesh.from_sizes(dp_size=args.infra.dp_size, tp_size=args.infra.tp_size, pp_size=args.infra.pp_size)
+twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh)
def build_dataset(num_samples: int) -> Dataset:
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples)))
- dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID)
- dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR))
+ dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(num_samples)))
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
+ dataset.map(SelfCognitionProcessor(
+ args.extra.get('model_name', 'twinkle大模型'),
+ args.extra.get('model_author', 'ModelScope社区'),
+ ))
dataset.encode()
return dataset
@@ -48,42 +32,45 @@ def build_dataset(num_samples: int) -> Dataset:
def save_checkpoint(model: MegatronModel, checkpoint_name: str, dataloader: DataLoader):
model.save(
checkpoint_name,
- output_dir=OUTPUT_DIR,
- adapter_name=ADAPTER_NAME,
- save_optimizer=True,
+ output_dir=args.training.output_dir,
+ adapter_name=args.lora.adapter_name,
+ save_optimizer=args.checkpoint.save_optimizer,
consumed_train_samples=dataloader.get_state()['consumed_train_samples'],
)
def evaluate(model):
- dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE)
+ eval_samples = args.training.eval_samples or 100
+ dataloader = DataLoader(dataset=build_dataset(eval_samples), batch_size=args.training.batch_size)
for batch in tqdm(dataloader):
model.forward_only(inputs=batch)
return model.calculate_metric(is_training=False)
def train():
- dataset = build_dataset(TRAIN_SAMPLES)
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
+ train_samples = args.training.train_samples or 1000
+ dataset = build_dataset(train_samples)
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size)
- model = MegatronModel(model_id=MODEL_ID)
+ model = MegatronModel(model_id=args.model.model_id)
- lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
+ lora_config = LoraConfig(**args.get_lora_args())
# Comment this to use full-parameter training
- model.add_adapter_to_model(ADAPTER_NAME, lora_config)
- model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE)
- model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader))
+ model.add_adapter_to_model(args.lora.adapter_name, lora_config)
+ model.set_optimizer(optimizer_cls='default', lr=args.optimizer.learning_rate)
+ model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=args.scheduler.num_warmup_steps,
+ lr_decay_steps=len(dataloader))
start_step = 0
- if RESUME_FROM_CHECKPOINT:
- checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()
+ if args.training.resume_from_checkpoint:
+ checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve()
kwargs = {}
- if ADAPTER_NAME:
- kwargs['adapter_name'] = ADAPTER_NAME
+ if args.lora.adapter_name:
+ kwargs['adapter_name'] = args.lora.adapter_name
progress = model.resume_from_checkpoint(
- str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs)
- if not IGNORE_DATA_SKIP:
+ str(checkpoint_path), resume_only_model=args.training.resume_only_model, **kwargs)
+ if not args.training.ignore_data_skip:
dataloader.resume_from_checkpoint(progress['consumed_train_samples'])
start_step = progress['cur_step']
@@ -92,14 +79,15 @@ def train():
logger.info(f'Total steps: {len(dataloader)}')
best_loss = float('inf')
+ eval_interval = args.training.eval_interval or 20
for step, batch in enumerate(dataloader, start=start_step):
model.forward_backward(inputs=batch)
model.clip_grad_and_step()
- if step % LOG_INTERVAL == 0:
+ if step % args.training.log_interval == 0:
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
- if step > 0 and step % EVAL_INTERVAL == 0:
+ if step > 0 and step % eval_interval == 0:
metrics = evaluate(model)
logger.info(f'Eval metric: {metrics}')
metrics['step'] = step
diff --git a/cookbook/megatron/tp.sh b/cookbook/megatron/tp.sh
index 5516130e3..789c54379 100644
--- a/cookbook/megatron/tp.sh
+++ b/cookbook/megatron/tp.sh
@@ -1 +1,23 @@
-CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp.py
+#!/usr/bin/env bash
+set -euo pipefail
+
+# Megatron TP + LoRA training.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash tp.sh --model-id ms://Qwen/Qwen3.5-4B --tp-size 4
+
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ torchrun --nproc_per_node=8 tp.py \
+ --model-id ms://Qwen/Qwen3.5-4B \
+ --dataset-id ms://swift/self-cognition \
+ --template-cls Qwen3_5Template \
+ --dp-size 4 \
+ --tp-size 2 \
+ --batch-size 8 \
+ --lr 1e-4 \
+ --train-samples 1000 \
+ --log-interval 10 \
+ --eval-interval 20 \
+ --output-dir ./output/megatron_tp \
+ --model-name twinkle大模型 \
+ --model-author ModelScope社区 \
+ "$@"
diff --git a/cookbook/megatron/tp_moe.py b/cookbook/megatron/tp_moe.py
index a13b0e58a..11e2c7d84 100644
--- a/cookbook/megatron/tp_moe.py
+++ b/cookbook/megatron/tp_moe.py
@@ -1,29 +1,38 @@
-import os
from peft import LoraConfig
from tqdm import tqdm
import twinkle
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import MegatronModel
from twinkle.preprocessor import SelfCognitionProcessor
+logger = get_logger()
+args = CLI.from_args()
+
# Construct a device_mesh, tp=pp=ep=dp=2
-device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2, ep_size=2, sequence_parallel=True)
+device_mesh = DeviceMesh.from_sizes(
+ dp_size=args.infra.dp_size, tp_size=args.infra.tp_size,
+ pp_size=args.infra.pp_size, ep_size=args.infra.ep_size,
+ sequence_parallel=args.infra.sequence_parallel,
+)
# use torchrun mode
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
-logger = get_logger()
+twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh)
def eval(model):
- # 100 Samples
- dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
- dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-35B-A3B')
- dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ # Eval samples
+ eval_samples = args.training.eval_samples or 100
+ dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(eval_samples)))
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
+ dataset.map(SelfCognitionProcessor(
+ args.extra.get('model_name', 'twinkle大模型'),
+ args.extra.get('model_author', 'ModelScope社区'),
+ ))
dataset.encode()
- dataloader = DataLoader(dataset=dataset, batch_size=16)
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size)
for step, batch in tqdm(enumerate(dataloader)):
model.forward_only(inputs=batch)
metrics = model.calculate_metric(is_training=False)
@@ -31,44 +40,49 @@ def eval(model):
def train():
- # 1000 samples
- dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Training samples
+ train_samples = args.training.train_samples or 1000
+ dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(train_samples)))
# Set template to prepare encoding
- dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-35B-A3B')
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
# Preprocess the dataset to standard format
- dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ dataset.map(SelfCognitionProcessor(
+ args.extra.get('model_name', 'twinkle大模型'),
+ args.extra.get('model_author', 'ModelScope社区'),
+ ))
# Encode dataset
dataset.encode()
- # Global batch size = 1, dp_size = 1
- dataloader = DataLoader(dataset=dataset, batch_size=16)
+ # Global batch size
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size)
# Use a MegatronModel
- model = MegatronModel(model_id='ms://Qwen/Qwen3.5-35B-A3B')
+ model = MegatronModel(model_id=args.model.model_id)
- lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
+ lora_config = LoraConfig(**args.get_lora_args())
- # Add a lora to model, with name `default`
+ # Add a lora to model, with name from args
# Comment this to use full-parameter training
- model.add_adapter_to_model('default', lora_config)
- # Add Optimizer for lora `default`
- model.set_optimizer(optimizer_cls='default', lr=1e-4)
- # Add LRScheduler for lora `default`
- model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader))
+ model.add_adapter_to_model(args.lora.adapter_name, lora_config)
+ # Add Optimizer
+ model.set_optimizer(optimizer_cls='default', lr=args.optimizer.learning_rate)
+ # Add LRScheduler
+ model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=args.scheduler.num_warmup_steps,
+ lr_decay_steps=len(dataloader))
logger.info(get_device_placement())
# Print the training config
logger.info(model.get_train_configs())
logger.info(f'Total steps: {len(dataloader)}')
loss_metric = 99.0
- # lora: 23G * 8
+ eval_interval = args.training.eval_interval or 20
for step, batch in enumerate(dataloader):
# Do forward and backward
model.forward_backward(inputs=batch)
# Step
model.clip_grad_and_step()
- if step % 5 == 0:
+ if step % args.training.log_interval == 0:
# Print metric
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
- if step > 0 and step % 20 == 0:
+ if step > 0 and step % eval_interval == 0:
metrics = eval(model)
logger.info(f'Eval metric: {metrics}')
metrics['step'] = step
diff --git a/cookbook/megatron/tp_moe.sh b/cookbook/megatron/tp_moe.sh
index 58e586464..7f6a2d06b 100644
--- a/cookbook/megatron/tp_moe.sh
+++ b/cookbook/megatron/tp_moe.sh
@@ -1 +1,25 @@
-CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp_moe.py
+#!/usr/bin/env bash
+set -euo pipefail
+
+# Megatron TP + MoE + LoRA training.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash tp_moe.sh --model-id ms://Qwen/Qwen3.5-30B-A3B --tp-size 4
+
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ torchrun --nproc_per_node=8 tp_moe.py \
+ --model-id ms://Qwen/Qwen3.5-30B-A3B \
+ --dataset-id ms://swift/self-cognition \
+ --template-cls Qwen3_5Template \
+ --dp-size 2 \
+ --tp-size 2 \
+ --pp-size 2 \
+ --ep-size 2 \
+ --sequence-parallel \
+ --batch-size 8 \
+ --lr 1e-4 \
+ --train-samples 1000 \
+ --log-interval 10 \
+ --eval-interval 20 \
+ --model-name twinkle大模型 \
+ --model-author ModelScope社区 \
+ "$@"
diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py
index 4dc508506..2edaf54e9 100644
--- a/cookbook/mm/fsdp2.py
+++ b/cookbook/mm/fsdp2.py
@@ -3,18 +3,20 @@
import twinkle
from twinkle import DeviceMesh, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.data_format import Trajectory, Message
from twinkle.dataloader import DataLoader
from twinkle.dataset import LazyDataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import Preprocessor
-# Construct a device_mesh, fsdp=2
-device_mesh = DeviceMesh.from_sizes(fsdp_size=2)
-# use torchrun mode
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
logger = get_logger()
+args = CLI.from_args()
+
+# Construct a device_mesh
+device_mesh = DeviceMesh.from_sizes(fsdp_size=args.infra.fsdp_size, dp_size=args.infra.dp_size)
+# use torchrun mode
+twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh)
class LatexOCRProcessor(Preprocessor):
@@ -35,12 +37,13 @@ def preprocess(self, row) -> Trajectory:
def eval(model):
- # 100 Samples
- dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(100)))
- dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
+ # Eval samples
+ eval_samples = args.training.eval_samples or 100
+ dataset = LazyDataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(eval_samples)))
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
dataset.map(LatexOCRProcessor)
dataset.encode()
- dataloader = DataLoader(dataset=dataset, batch_size=8)
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size)
for step, batch in tqdm(enumerate(dataloader)):
model.forward_only(inputs=batch)
model.calculate_loss()
@@ -49,54 +52,56 @@ def eval(model):
def train():
- # 2000 samples
- dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000)))
+ # Training samples
+ train_samples = args.training.train_samples or 2000
+ dataset = LazyDataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(train_samples)))
# Set template to prepare encoding
- dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B', max_length=1024)
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id, max_length=args.template.max_length)
# Preprocess the dataset to standard format
dataset.map(LatexOCRProcessor)
# Encode dataset
dataset.encode()
- # Global batch size = 4, for GPUs, so 2 sample per GPU
- dataloader = DataLoader(dataset=dataset, batch_size=4)
+ # Global batch size
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size)
# Use a TransformersModel
- from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration
- model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', model_cls=Qwen3_5ForConditionalGeneration)
+ model = TransformersModel(model_id=args.model.model_id, model_cls=args.model.model_cls)
model.model._no_split_modules = {'Qwen3_5DecoderLayer'}
- lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
+ lora_config = LoraConfig(**args.get_lora_args())
- # Add a lora to model, with name `default`
- # Comment this to use full-parameter training
- model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
- # Add Optimizer for lora `default`
- model.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
- model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
- # Add LRScheduler for lora `default`
+ # Add a lora to model
+ model.add_adapter_to_model(args.lora.adapter_name, lora_config,
+ gradient_accumulation_steps=args.training.gradient_accumulation_steps)
+ # Add Optimizer
+ model.set_template(args.template.template_cls, model_id=args.model.model_id)
+ model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate)
+ # Add LRScheduler
model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
+ scheduler_cls=args.scheduler.scheduler_cls, num_warmup_steps=args.scheduler.num_warmup_steps,
+ num_training_steps=len(dataloader))
logger.info(get_device_placement())
# Print the training config
logger.info(model.get_train_configs())
logger.info(f'Total steps: {len(dataloader)}')
loss_metric = 99.0
+ eval_interval = args.training.eval_interval or 200
for step, batch in enumerate(dataloader):
# Do forward and backward
model.forward_backward(inputs=batch)
# Step
model.clip_grad_and_step()
- if step % 20 == 0:
+ if step % args.training.log_interval == 0:
# Print metric
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
- if step > 0 and step % 200 == 0:
+ if step > 0 and step % eval_interval == 0:
metrics = eval(model)
logger.info(f'Eval metric: {metrics}')
metrics['step'] = step
if loss_metric > float(metrics['loss']):
model.save(f'checkpoint-{step}')
loss_metric = float(metrics['loss'])
- model.save(f'last-checkpoint')
+ model.save('last-checkpoint')
if __name__ == '__main__':
diff --git a/cookbook/mm/fsdp2.sh b/cookbook/mm/fsdp2.sh
index 46e9f27f6..2e0bed3d5 100644
--- a/cookbook/mm/fsdp2.sh
+++ b/cookbook/mm/fsdp2.sh
@@ -1 +1,21 @@
-CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 fsdp2.py
+#!/usr/bin/env bash
+set -euo pipefail
+
+# Multi-modal FSDP2 + LoRA training (LaTeX OCR).
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash fsdp2.sh --model-id ms://Qwen/Qwen2.5-VL-3B-Instruct --batch-size 4
+
+CUDA_VISIBLE_DEVICES=0,1 \
+ torchrun --nproc_per_node=2 fsdp2.py \
+ --model-id ms://Qwen/Qwen2.5-VL-3B-Instruct \
+ --dataset-id ms://AI-ModelScope/LaTeX_OCR \
+ --template-cls Qwen2_5VLTemplate \
+ --dp-size 2 \
+ --batch-size 2 \
+ --lr 1e-4 \
+ --gradient-accumulation-steps 4 \
+ --train-samples 2000 \
+ --eval-samples 100 \
+ --eval-interval 200 \
+ --log-interval 10 \
+ "$@"
diff --git a/cookbook/mm/fsdp2_gemma4_12b_mm.py b/cookbook/mm/fsdp2_gemma4_12b_mm.py
index c21932b33..62e26d776 100644
--- a/cookbook/mm/fsdp2_gemma4_12b_mm.py
+++ b/cookbook/mm/fsdp2_gemma4_12b_mm.py
@@ -1,4 +1,3 @@
-import os
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoConfig
@@ -8,35 +7,29 @@
import twinkle
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
-# from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor
logger = get_logger()
+args = CLI.from_args()
########## Construct a device_mesh ##########
device_mesh = DeviceMesh.from_sizes(
- # fsdp_size=2,
- # dp_size=1,
- # ep_size=2,
+ fsdp_size=args.infra.fsdp_size,
+ dp_size=args.infra.dp_size,
+ ep_size=args.infra.ep_size,
device_type=Platform.get_platform().device_prefix(),
)
# use torchrun mode
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh)
########## hyperparameters ##########
-IGNORE_MISMATCHED_SIZES = True
-# MODEL_PATH = 'ms://google/gemma-4-26b-a4b'
-MODEL_PATH = 'ms://google/gemma-4-12b'
-DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR'
-TRAIN_LEN = 2000
-BATCH_SIZE = 4
-METRIC_STEP = 10
-SAVE_STEP = 10
+IGNORE_MISMATCHED_SIZES = args.extra.get('ignore_mismatched_sizes', True)
### reduce model layers for debug
-TEXT_NUM_LAYERS = 8 # gemma-4-12b text_config.num_hidden_layers=48
+TEXT_NUM_LAYERS = args.extra.get('text_num_layers', None)
from twinkle.preprocessor import Preprocessor
from twinkle.data_format import Message, Trajectory
@@ -79,24 +72,21 @@ def train():
'messages': List(sub_msg_feat)
})
### prepare dataset and dataloader
- dataset = Dataset(features=writer_features, dataset_meta=DatasetMeta(DATASET_PATH, subset_name='default', data_slice=range(TRAIN_LEN)))
+ train_samples = args.training.train_samples or 2000
+ dataset = Dataset(features=writer_features, dataset_meta=DatasetMeta(
+ args.dataset.dataset_id, subset_name=args.dataset.subset_name, data_slice=range(train_samples)))
# Set template to prepare encoding
- dataset.set_template('Template', model_id=MODEL_PATH)
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
# Preprocess the dataset to standard format
- # dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.map(preprocess_func=LatexOCRProcessor)
# Encode dataset
dataset.encode()
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size)
config, kwargs = AutoConfig.from_pretrained(
- MODEL_PATH,
+ args.model.model_id,
trust_remote_code=True,
return_unused_kwargs=True,
- # code_revision=code_revision,
- # _commit_hash=commit_hash,
- # **hub_kwargs,
- # **kwargs,
)
if isinstance(config, Gemma4UnifiedConfig): # 减层
@@ -111,10 +101,10 @@ def train():
from transformers import AutoModelForMultimodalLM
model = TransformersModel(
model_cls=AutoModelForMultimodalLM,
- model_id=MODEL_PATH,
+ model_id=args.model.model_id,
config=config,
device_mesh=device_mesh,
- strategy='accelerate', # native_fsdp、 accelerate
+ strategy=args.model.strategy,
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
fsdp_config={
'reshard_after_forward': True,
@@ -126,46 +116,46 @@ def train():
},
)
- lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
+ lora_config = LoraConfig(**args.get_lora_args())
- # Add a lora to model, with name `default`
- # Comment this to use full-parameter training
- model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
- # Add Optimizer for lora `default`
- model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
+ # Add a lora to model
+ model.add_adapter_to_model(args.lora.adapter_name, lora_config,
+ gradient_accumulation_steps=args.training.gradient_accumulation_steps)
+ # Add Optimizer
+ model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate)
- # Add LRScheduler for lora `default`
+ # Add LRScheduler
model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
+ scheduler_cls=args.scheduler.scheduler_cls, num_warmup_steps=args.scheduler.num_warmup_steps,
+ num_training_steps=len(dataloader))
logger.info(get_device_placement())
# Print the training config
logger.info(model.get_train_configs())
logger.info(f'Total steps: {len(dataloader)}')
best_eval_loss = float('inf')
- # lora: 8G * 8
- # full: 18G * 8
### eval dataset and dataloader
- EVAL_LENGTH = 100
- eval_dataset = Dataset(features=writer_features, dataset_meta=DatasetMeta(DATASET_PATH, subset_name='default', data_slice=range(EVAL_LENGTH)))
- eval_dataset.set_template('Template', model_id=MODEL_PATH)
- # eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ eval_samples = args.training.eval_samples or 100
+ eval_dataset = Dataset(features=writer_features, dataset_meta=DatasetMeta(
+ args.dataset.dataset_id, subset_name=args.dataset.subset_name, data_slice=range(eval_samples)))
+ eval_dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
eval_dataset.map(preprocess_func=LatexOCRProcessor)
eval_dataset.encode()
- eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8)
+ eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=args.training.batch_size)
+ save_step = args.training.save_steps
for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
# Do forward and backward
model.forward_backward(inputs=batch)
# Step
model.clip_grad_and_step()
- if step % METRIC_STEP == 0:
+ if step % args.training.log_interval == 0:
# Print metric
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}')
- if step % SAVE_STEP == 0:
+ if step % save_step == 0:
metrics = evaluate(model, eval_dataloader)
metrics['step'] = step
if float(metrics['loss']) < best_eval_loss:
diff --git a/cookbook/mm/fsdp2_gemma4_mm.py b/cookbook/mm/fsdp2_gemma4_mm.py
index 778051874..5c756cfbe 100644
--- a/cookbook/mm/fsdp2_gemma4_mm.py
+++ b/cookbook/mm/fsdp2_gemma4_mm.py
@@ -1,4 +1,3 @@
-import os
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoConfig
@@ -8,35 +7,30 @@
import twinkle
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
-# from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor
logger = get_logger()
+args = CLI.from_args()
########## Construct a device_mesh ##########
device_mesh = DeviceMesh.from_sizes(
- # fsdp_size=2,
- # dp_size=1,
- # ep_size=2,
+ fsdp_size=args.infra.fsdp_size,
+ dp_size=args.infra.dp_size,
+ ep_size=args.infra.ep_size,
device_type=Platform.get_platform().device_prefix(),
)
# use torchrun mode
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh)
########## hyperparameters ##########
-IGNORE_MISMATCHED_SIZES = True
-MODEL_PATH = 'ms://google/gemma-4-26b-a4b'
-DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR'
-TRAIN_LEN = 2000
-BATCH_SIZE = 4
-METRIC_STEP = 10
-SAVE_STEP = 10
+IGNORE_MISMATCHED_SIZES = args.extra.get('ignore_mismatched_sizes', True)
### reduce model layers for debug
-TEXT_NUM_LAYERS = 3
-VISION_NUM_LAYERS = 3
+TEXT_NUM_LAYERS = args.extra.get('text_num_layers', None)
+VISION_NUM_LAYERS = args.extra.get('vision_num_layers', None)
from twinkle.preprocessor import Preprocessor
@@ -67,24 +61,20 @@ def eval(model, eval_dataloader):
def train():
### prepare dataset and dataloader
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_LEN)))
+ train_samples = args.training.train_samples or 2000
+ dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(train_samples)))
# Set template to prepare encoding
- dataset.set_template('Template', model_id=MODEL_PATH)
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
# Preprocess the dataset to standard format
- # dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.map(preprocess_func=LatexOCRProcessor)
# Encode dataset
dataset.encode()
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size)
config, kwargs = AutoConfig.from_pretrained(
- MODEL_PATH,
+ args.model.model_id,
trust_remote_code=True,
return_unused_kwargs=True,
- # code_revision=code_revision,
- # _commit_hash=commit_hash,
- # **hub_kwargs,
- # **kwargs,
)
if isinstance(config, Gemma4Config): # 减层
@@ -101,10 +91,10 @@ def train():
# Use a TransformersModel
model = TransformersModel(
- model_id=MODEL_PATH,
+ model_id=args.model.model_id,
config=config,
device_mesh=device_mesh,
- strategy='accelerate', # native_fsdp、 accelerate
+ strategy=args.model.strategy,
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
fsdp_config={
'reshard_after_forward': True,
@@ -116,46 +106,45 @@ def train():
},
)
- lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
+ lora_config = LoraConfig(**args.get_lora_args())
- # Add a lora to model, with name `default`
- # Comment this to use full-parameter training
- model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
- # Add Optimizer for lora `default`
- model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
+ # Add a lora to model
+ model.add_adapter_to_model(args.lora.adapter_name, lora_config,
+ gradient_accumulation_steps=args.training.gradient_accumulation_steps)
+ # Add Optimizer
+ model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate)
- # Add LRScheduler for lora `default`
+ # Add LRScheduler
model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
+ scheduler_cls=args.scheduler.scheduler_cls, num_warmup_steps=args.scheduler.num_warmup_steps,
+ num_training_steps=len(dataloader))
logger.info(get_device_placement())
# Print the training config
logger.info(model.get_train_configs())
logger.info(f'Total steps: {len(dataloader)}')
best_eval_loss = float('inf')
- # lora: 8G * 8
- # full: 18G * 8
### eval dataset and dataloader
- EVAL_LENGTH = 100
- eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH)))
- eval_dataset.set_template('Template', model_id=MODEL_PATH)
- # eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ eval_samples = args.training.eval_samples or 100
+ eval_dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(eval_samples)))
+ eval_dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
eval_dataset.map(preprocess_func=LatexOCRProcessor)
eval_dataset.encode()
- eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8)
+ eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=args.training.batch_size)
+ save_step = args.training.save_steps
for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
# Do forward and backward
model.forward_backward(inputs=batch)
# Step
model.clip_grad_and_step()
- if step % METRIC_STEP == 0:
+ if step % args.training.log_interval == 0:
# Print metric
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}')
- if step % SAVE_STEP == 0:
+ if step % save_step == 0:
metrics = eval(model, eval_dataloader)
metrics['step'] = step
if float(metrics['loss']) < best_eval_loss:
diff --git a/cookbook/mm/fsdp2_gemma4_mm.sh b/cookbook/mm/fsdp2_gemma4_mm.sh
index c67113d8f..82d9ef1d5 100644
--- a/cookbook/mm/fsdp2_gemma4_mm.sh
+++ b/cookbook/mm/fsdp2_gemma4_mm.sh
@@ -1,3 +1,21 @@
-export CUDA_VISIBLE_DEVICES=0,1
+#!/usr/bin/env bash
+set -euo pipefail
-torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py
+# Multi-modal FSDP2 + LoRA training for Gemma4 (LaTeX OCR).
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash fsdp2_gemma4_mm.sh --model-id ms://google/gemma-4-4b-it --batch-size 4
+
+CUDA_VISIBLE_DEVICES=0,1 \
+ torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py \
+ --model-id ms://google/gemma-4-12b-it \
+ --dataset-id ms://AI-ModelScope/LaTeX_OCR \
+ --template-cls Gemma4Template \
+ --dp-size 2 \
+ --batch-size 2 \
+ --lr 1e-4 \
+ --gradient-accumulation-steps 4 \
+ --train-samples 2000 \
+ --eval-samples 100 \
+ --log-interval 10 \
+ --save-steps 200 \
+ "$@"
diff --git a/cookbook/rl/dpo_full.py b/cookbook/rl/dpo/dpo_full.py
similarity index 92%
rename from cookbook/rl/dpo_full.py
rename to cookbook/rl/dpo/dpo_full.py
index 8610b986f..afb3f6155 100644
--- a/cookbook/rl/dpo_full.py
+++ b/cookbook/rl/dpo/dpo_full.py
@@ -49,6 +49,7 @@
import twinkle
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.data_format import Trajectory
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
@@ -58,25 +59,26 @@
from twinkle.processor import InputProcessor
logger = get_logger()
+args = CLI.from_args()
# ── Configuration ─────────────────────────────────────────────────────────────
-USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0))
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
+USE_MEGATRON = args.model.strategy != 'native_fsdp'
+MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3-4B'
+DATASET_ID = args.dataset.dataset_id or 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji'
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
-REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4))
+MODEL_GPUS = args.infra.model_gpus or 4
+REF_MODEL_GPUS = args.infra.ref_model_gpus or 4
NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
-GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
-LEARNING_RATE = float(os.environ.get('LR', 1e-5))
-DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
-SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization
-LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100))
-MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
-SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.')
+BATCH_SIZE = args.training.batch_size or 8
+GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 2
+LEARNING_RATE = args.optimizer.learning_rate or 1e-5
+DPO_BETA = args.loss.beta
+SFT_WEIGHT = args.loss.sft_weight
+LOSS_TYPE = args.loss.loss_type
+SAVE_STEPS = args.training.save_steps or 100
+MAX_LENGTH = args.template.max_length
+SYSTEM_PROMPT = args.template.default_system or 'You are a helpful assistant.'
def create_dpo_dataset():
diff --git a/cookbook/rl/dpo/dpo_full.sh b/cookbook/rl/dpo/dpo_full.sh
new file mode 100644
index 000000000..cffba898b
--- /dev/null
+++ b/cookbook/rl/dpo/dpo_full.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# DPO Full-Parameter Training via Ray.
+# Uses separate policy and reference model GPU groups.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash dpo_full.sh --model-id ms://Qwen/Qwen3-8B --beta 0.05
+
+python dpo_full.py \
+ --model-id ms://Qwen/Qwen3-4B \
+ --dataset-id ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
+ --model-gpus 4 \
+ --ref-model-gpus 4 \
+ --batch-size 8 \
+ --gradient-accumulation-steps 2 \
+ --lr 1e-5 \
+ --beta 0.1 \
+ --sft-weight 1.0 \
+ --loss-type sigmoid \
+ --max-length 2048 \
+ --save-steps 100 \
+ "$@"
diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo/dpo_lora.py
similarity index 91%
rename from cookbook/rl/dpo_lora.py
rename to cookbook/rl/dpo/dpo_lora.py
index c7ec3147c..868de1521 100644
--- a/cookbook/rl/dpo_lora.py
+++ b/cookbook/rl/dpo/dpo_lora.py
@@ -48,6 +48,7 @@
import twinkle
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.data_format import Trajectory
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
@@ -57,24 +58,25 @@
from twinkle.processor import InputProcessor
logger = get_logger()
+args = CLI.from_args()
# ── Configuration ─────────────────────────────────────────────────────────────
-USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0))
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
-
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8))
-
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
-GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
-LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4)
-DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
-SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization
-LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100))
-MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
-ADAPTER_NAME = 'default'
-SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.')
+USE_MEGATRON = args.model.strategy != 'native_fsdp'
+MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3-4B'
+DATASET_ID = args.dataset.dataset_id or 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji'
+
+MODEL_GPUS = args.infra.model_gpus or 8
+
+BATCH_SIZE = args.training.batch_size or 8
+GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 2
+LEARNING_RATE = args.optimizer.learning_rate or 1e-4
+DPO_BETA = args.loss.beta
+SFT_WEIGHT = args.loss.sft_weight
+LOSS_TYPE = args.loss.loss_type
+SAVE_STEPS = args.training.save_steps or 100
+MAX_LENGTH = args.template.max_length
+ADAPTER_NAME = args.lora.adapter_name or 'default'
+SYSTEM_PROMPT = args.template.default_system or 'You are a helpful assistant.'
def create_dpo_dataset():
diff --git a/cookbook/rl/dpo/dpo_lora.sh b/cookbook/rl/dpo/dpo_lora.sh
new file mode 100644
index 000000000..7af42b6dc
--- /dev/null
+++ b/cookbook/rl/dpo/dpo_lora.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# DPO LoRA Training via Ray (single GPU group).
+# Uses base model (disable_lora=True) as reference model.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash dpo_lora.sh --model-id ms://Qwen/Qwen3-8B --lr 5e-5
+
+python dpo_lora.py \
+ --model-id ms://Qwen/Qwen3-4B \
+ --dataset-id ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
+ --model-gpus 8 \
+ --batch-size 8 \
+ --gradient-accumulation-steps 2 \
+ --lr 1e-4 \
+ --beta 0.1 \
+ --sft-weight 1.0 \
+ --loss-type sigmoid \
+ --max-length 2048 \
+ --save-steps 100 \
+ --adapter-name default \
+ "$@"
diff --git a/cookbook/rl/dpo_multi_lora.py b/cookbook/rl/dpo/dpo_multi_lora.py
similarity index 91%
rename from cookbook/rl/dpo_multi_lora.py
rename to cookbook/rl/dpo/dpo_multi_lora.py
index 7c09bf61f..0a43322fc 100644
--- a/cookbook/rl/dpo_multi_lora.py
+++ b/cookbook/rl/dpo/dpo_multi_lora.py
@@ -48,6 +48,7 @@
import twinkle
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.data_format import Trajectory
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
@@ -57,23 +58,24 @@
from twinkle.processor import InputProcessor
logger = get_logger()
+args = CLI.from_args()
# ── Configuration ─────────────────────────────────────────────────────────────
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
-
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2))
-
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
-GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
-LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4)
-DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
-SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization
-LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100))
-MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
-ADAPTER_NAME = 'default_0'
-SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.')
+MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B'
+DATASET_ID = args.dataset.dataset_id or 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji'
+
+MODEL_GPUS = args.infra.model_gpus or 2
+
+BATCH_SIZE = args.training.batch_size or 8
+GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 2
+LEARNING_RATE = args.optimizer.learning_rate or 1e-4
+DPO_BETA = args.loss.beta
+SFT_WEIGHT = args.loss.sft_weight
+LOSS_TYPE = args.loss.loss_type
+SAVE_STEPS = args.training.save_steps or 100
+MAX_LENGTH = args.template.max_length
+ADAPTER_NAME = args.lora.adapter_name or 'default_0'
+SYSTEM_PROMPT = args.template.default_system or 'You are a helpful assistant.'
def create_dpo_dataset():
diff --git a/cookbook/rl/dpo/dpo_multi_lora.sh b/cookbook/rl/dpo/dpo_multi_lora.sh
new file mode 100644
index 000000000..0652b95f2
--- /dev/null
+++ b/cookbook/rl/dpo/dpo_multi_lora.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# DPO MultiLoRA Training via Ray (Megatron backend).
+# Uses base model (disable_lora=True) as reference model.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash dpo_multi_lora.sh --model-id ms://Qwen/Qwen3.5-4B --lr 5e-5
+
+python dpo_multi_lora.py \
+ --model-id ms://Qwen/Qwen3.5-4B \
+ --dataset-id ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
+ --model-gpus 2 \
+ --batch-size 8 \
+ --gradient-accumulation-steps 2 \
+ --lr 1e-4 \
+ --beta 0.1 \
+ --sft-weight 1.0 \
+ --loss-type sigmoid \
+ --max-length 2048 \
+ --save-steps 100 \
+ --adapter-name default_0 \
+ "$@"
diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd/gkd_off_policy.py
similarity index 94%
rename from cookbook/rl/gkd_off_policy.py
rename to cookbook/rl/gkd/gkd_off_policy.py
index 204e90f92..bdf992463 100644
--- a/cookbook/rl/gkd_off_policy.py
+++ b/cookbook/rl/gkd/gkd_off_policy.py
@@ -45,6 +45,7 @@
import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
@@ -55,23 +56,24 @@
from twinkle.template import Template
logger = get_logger()
+args = CLI.from_args()
# ── Configuration ─────────────────────────────────────────────────────────────
-STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B')
-TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B')
+STUDENT_MODEL_ID = args.rl.student_model_id or 'ms://Qwen/Qwen3-0.6B'
+TEACHER_MODEL_ID = args.rl.teacher_model_id or 'ms://Qwen/Qwen3-8B'
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
-SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
+MODEL_GPUS = args.infra.model_gpus or 4
+SAMPLER_GPUS = args.infra.sampler_gpus or 4
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16))
-MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
-LEARNING_RATE = float(os.environ.get('LR', 5e-5))
+BATCH_SIZE = args.training.batch_size or 16
+MAX_STEPS = args.training.max_steps or 1000
+LEARNING_RATE = args.optimizer.learning_rate or 5e-5
-GKD_BETA = float(os.environ.get('GKD_BETA', 0.5))
-GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0))
-GKD_TOPK = int(os.environ.get('GKD_TOPK', 64))
-ADAPTER_NAME = 'default'
+GKD_BETA = args.rl.gkd_beta
+GKD_TEMPERATURE = args.rl.gkd_temperature
+GKD_TOPK = args.rl.gkd_topk
+ADAPTER_NAME = args.lora.adapter_name or 'default'
SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem step by step and put '
'your final answer within #### ')
diff --git a/cookbook/rl/gkd/gkd_off_policy.sh b/cookbook/rl/gkd/gkd_off_policy.sh
new file mode 100644
index 000000000..262542062
--- /dev/null
+++ b/cookbook/rl/gkd/gkd_off_policy.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# GKD Off-Policy Distillation via Ray.
+# Teacher vLLM computes prompt logprobs on existing dataset responses.
+# Student Megatron model learns to match teacher's token distribution.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash gkd_off_policy.sh --student-model-id ms://Qwen/Qwen3-1.7B --gkd-beta 0.3
+
+python gkd_off_policy.py \
+ --student-model-id ms://Qwen/Qwen3-0.6B \
+ --teacher-model-id ms://Qwen/Qwen3-8B \
+ --model-gpus 4 \
+ --sampler-gpus 4 \
+ --batch-size 16 \
+ --max-steps 1000 \
+ --lr 5e-5 \
+ --gkd-beta 0.5 \
+ --gkd-temperature 1.0 \
+ --gkd-topk 64 \
+ --adapter-name default \
+ "$@"
diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd/gkd_on_policy.py
similarity index 94%
rename from cookbook/rl/gkd_on_policy.py
rename to cookbook/rl/gkd/gkd_on_policy.py
index 2675d0358..1ddc9d89e 100644
--- a/cookbook/rl/gkd_on_policy.py
+++ b/cookbook/rl/gkd/gkd_on_policy.py
@@ -51,6 +51,7 @@
import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.cli import CLI
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import DatasetMeta, LazyDataset
@@ -60,26 +61,27 @@
from twinkle.sampler import vLLMSampler
logger = get_logger()
+args = CLI.from_args()
# ── Configuration ─────────────────────────────────────────────────────────────
-STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3.5-9B')
-USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
+STUDENT_MODEL_ID = args.rl.student_model_id or 'ms://Qwen/Qwen3.5-4B'
+TEACHER_MODEL_ID = args.rl.teacher_model_id or 'ms://Qwen/Qwen3.5-9B'
+USE_MEGATRON = args.model.strategy != 'native_fsdp'
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
-SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2))
+MODEL_GPUS = args.infra.model_gpus or 4
+SAMPLER_GPUS = args.infra.sampler_gpus or 2
NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS
-MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048))
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
-MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
-LEARNING_RATE = float(os.environ.get('LR', 5e-5))
-N_SAMPLES = int(os.environ.get('N_SAMPLES', 1))
+MAX_NEW_TOKENS = args.sampling.max_tokens or 2048
+BATCH_SIZE = args.training.batch_size or 4
+MAX_STEPS = args.training.max_steps or 1000
+LEARNING_RATE = args.optimizer.learning_rate or 5e-5
+N_SAMPLES = args.sampling.num_samples
-GKD_BETA = float(os.environ.get('GKD_BETA', 0.5))
-GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0))
-GKD_TOPK = int(os.environ.get('GKD_TOPK', 64))
-ADAPTER_NAME = 'default'
+GKD_BETA = args.rl.gkd_beta
+GKD_TEMPERATURE = args.rl.gkd_temperature
+GKD_TOPK = args.rl.gkd_topk
+ADAPTER_NAME = args.lora.adapter_name or 'default'
# OlympiadBench subsets
SUBSETS = [
diff --git a/cookbook/rl/gkd/gkd_on_policy.sh b/cookbook/rl/gkd/gkd_on_policy.sh
new file mode 100644
index 000000000..ed37e5fb5
--- /dev/null
+++ b/cookbook/rl/gkd/gkd_on_policy.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# GKD On-Policy Multimodal Distillation via Ray.
+# Student generates on-policy, teacher provides top-k prompt logprobs,
+# student trains to match teacher's distribution.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash gkd_on_policy.sh --student-model-id ms://Qwen/Qwen3.5-4B --max-steps 500
+
+python gkd_on_policy.py \
+ --student-model-id ms://Qwen/Qwen3.5-4B \
+ --teacher-model-id ms://Qwen/Qwen3.5-9B \
+ --model-gpus 4 \
+ --sampler-gpus 2 \
+ --batch-size 4 \
+ --max-steps 1000 \
+ --max-tokens 2048 \
+ --lr 5e-5 \
+ --num-samples 1 \
+ --gkd-beta 0.5 \
+ --gkd-temperature 1.0 \
+ --gkd-topk 64 \
+ --adapter-name default \
+ "$@"
diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo/grpo.py
similarity index 88%
rename from cookbook/rl/grpo.py
rename to cookbook/rl/grpo/grpo.py
index dd16e7f07..e50402383 100644
--- a/cookbook/rl/grpo.py
+++ b/cookbook/rl/grpo/grpo.py
@@ -7,6 +7,7 @@
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.cli import CLI
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
@@ -19,24 +20,25 @@
from twinkle.preprocessor.llm import GSM8KProcessor
logger = get_logger()
+args = CLI.from_args()
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0')))
+MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B'
+USE_MEGATRON = args.model.strategy != 'native_fsdp'
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
-SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS',4))
+MODEL_GPUS = args.infra.model_gpus or 4
+SAMPLER_GPUS = args.infra.sampler_gpus or 4
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
-NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
-MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
-LEARNING_RATE = float(os.environ.get('LR', 1e-5))
-MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
-MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) # global completion-level mini-batch-size
-MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward
-GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
-ADAPTER_NAME = 'default'
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50))
+NUM_GENERATIONS = args.rl.num_generations or 8
+MAX_NEW_TOKENS = args.sampling.max_tokens or 4096
+LEARNING_RATE = args.optimizer.learning_rate or 1e-5
+MAX_STEPS = args.training.max_steps or 200
+BATCH_SIZE = args.training.batch_size or 8
+MINI_BATCH_SIZE = args.training.mini_batch_size or 8
+MICRO_BATCH_SIZE = args.training.micro_batch_size or 2
+GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1
+ADAPTER_NAME = args.lora.adapter_name or 'default'
+SAVE_STEPS = args.training.save_steps or 50
def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
diff --git a/cookbook/rl/grpo/grpo.sh b/cookbook/rl/grpo/grpo.sh
new file mode 100644
index 000000000..b6feb02fa
--- /dev/null
+++ b/cookbook/rl/grpo/grpo.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# GRPO training on GSM8K via Ray.
+# Model + vLLM sampler on separate GPU groups.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash grpo.sh --model-id ms://Qwen/Qwen3.5-4B --max-steps 500
+
+python grpo.py \
+ --model-id ms://Qwen/Qwen3.5-4B \
+ --model-gpus 4 \
+ --sampler-gpus 4 \
+ --num-generations 8 \
+ --max-tokens 4096 \
+ --batch-size 8 \
+ --mini-batch-size 8 \
+ --micro-batch-size 2 \
+ --max-steps 200 \
+ --lr 1e-5 \
+ --save-steps 50 \
+ --adapter-name default \
+ "$@"
diff --git a/cookbook/rl/grpo_mm.py b/cookbook/rl/grpo/grpo_mm.py
similarity index 92%
rename from cookbook/rl/grpo_mm.py
rename to cookbook/rl/grpo/grpo_mm.py
index 1f89c7a91..f7de43ca5 100644
--- a/cookbook/rl/grpo_mm.py
+++ b/cookbook/rl/grpo/grpo_mm.py
@@ -14,6 +14,7 @@
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.cli import CLI
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import DatasetMeta, LazyDataset
@@ -28,27 +29,28 @@
from twinkle.sampler import vLLMSampler
logger = get_logger()
+args = CLI.from_args()
# Model configuration
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
+MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B'
+USE_MEGATRON = args.model.strategy != 'native_fsdp'
# GPU configuration
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
-SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
+MODEL_GPUS = args.infra.model_gpus or 4
+SAMPLER_GPUS = args.infra.sampler_gpus or 4
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
# Training hyperparameters
-NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
-MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
-LEARNING_RATE = float(os.environ.get('LR', 1e-5))
-MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
-MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
-MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
-GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
-ADAPTER_NAME = 'default'
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50))
+NUM_GENERATIONS = args.rl.num_generations or 8
+MAX_NEW_TOKENS = args.sampling.max_tokens or 4096
+LEARNING_RATE = args.optimizer.learning_rate or 1e-5
+MAX_STEPS = args.training.max_steps or 1000
+BATCH_SIZE = args.training.batch_size or 4
+MINI_BATCH_SIZE = args.training.mini_batch_size or 4
+MICRO_BATCH_SIZE = args.training.micro_batch_size or 1
+GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1
+ADAPTER_NAME = args.lora.adapter_name or 'default'
+SAVE_STEPS = args.training.save_steps or 50
# Dataset configuration
SUBSETS = [
diff --git a/cookbook/rl/grpo/grpo_mm.sh b/cookbook/rl/grpo/grpo_mm.sh
new file mode 100644
index 000000000..b5ca2fda3
--- /dev/null
+++ b/cookbook/rl/grpo/grpo_mm.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# GRPO Multimodal training on OlympiadBench via Ray.
+# Supports multimodal math/physics problems (Chinese CEE).
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash grpo_mm.sh --model-id ms://Qwen/Qwen3.5-4B --max-steps 500
+
+python grpo_mm.py \
+ --model-id ms://Qwen/Qwen3.5-4B \
+ --model-gpus 4 \
+ --sampler-gpus 4 \
+ --num-generations 8 \
+ --max-tokens 4096 \
+ --batch-size 4 \
+ --mini-batch-size 4 \
+ --micro-batch-size 1 \
+ --max-steps 1000 \
+ --lr 1e-5 \
+ --save-steps 50 \
+ --adapter-name default \
+ "$@"
diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/grpo/short_math_grpo.py
similarity index 91%
rename from cookbook/rl/short_math_grpo.py
rename to cookbook/rl/grpo/short_math_grpo.py
index 5e107b0ae..91fcd7669 100644
--- a/cookbook/rl/short_math_grpo.py
+++ b/cookbook/rl/grpo/short_math_grpo.py
@@ -14,6 +14,7 @@
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.cli import CLI
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
@@ -26,26 +27,27 @@
from twinkle.preprocessor.llm import GSM8KProcessor
logger = get_logger()
+args = CLI.from_args()
# ========== Configuration ==========
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
+MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B'
+USE_MEGATRON = args.model.strategy != 'native_fsdp'
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
-SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
+MODEL_GPUS = args.infra.model_gpus or 4
+SAMPLER_GPUS = args.infra.sampler_gpus or 4
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
-NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
-MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
-LEARNING_RATE = float(os.environ.get('LR', 1e-5))
-MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8))
-MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8))
-MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
-GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
-ADAPTER_NAME = 'default'
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
-LORA_RANK = int(os.environ.get('LORA_RANK', 16))
+NUM_GENERATIONS = args.rl.num_generations or 8
+MAX_NEW_TOKENS = args.sampling.max_tokens or 4096
+LEARNING_RATE = args.optimizer.learning_rate or 1e-5
+MAX_STEPS = args.training.max_steps or 1000
+BATCH_SIZE = args.training.batch_size or 8
+MINI_BATCH_SIZE = args.training.mini_batch_size or 8
+MICRO_BATCH_SIZE = args.training.micro_batch_size or 2
+GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1
+ADAPTER_NAME = args.lora.adapter_name or 'default'
+SAVE_STEPS = args.training.save_steps or 1000
+LORA_RANK = args.lora.lora_r or 16
SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
'and put your final answer within \\boxed{}.')
diff --git a/cookbook/rl/grpo/short_math_grpo.sh b/cookbook/rl/grpo/short_math_grpo.sh
new file mode 100644
index 000000000..033507dc8
--- /dev/null
+++ b/cookbook/rl/grpo/short_math_grpo.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# GRPO Short Math Reasoning on GSM8K via Ray.
+# Uses short reasoning format: shorter thinking gets higher brevity reward.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash short_math_grpo.sh --model-id ms://Qwen/Qwen3.5-4B --max-steps 500
+
+python short_math_grpo.py \
+ --model-id ms://Qwen/Qwen3.5-4B \
+ --model-gpus 4 \
+ --sampler-gpus 4 \
+ --num-generations 8 \
+ --max-tokens 4096 \
+ --batch-size 8 \
+ --mini-batch-size 8 \
+ --micro-batch-size 2 \
+ --max-steps 1000 \
+ --lr 1e-5 \
+ --lora-r 16 \
+ --save-steps 1000 \
+ --adapter-name default \
+ "$@"
diff --git a/cookbook/rl/short_math_grpo_moe.py b/cookbook/rl/grpo/short_math_grpo_moe.py
similarity index 90%
rename from cookbook/rl/short_math_grpo_moe.py
rename to cookbook/rl/grpo/short_math_grpo_moe.py
index 9d870eacb..f19747282 100644
--- a/cookbook/rl/short_math_grpo_moe.py
+++ b/cookbook/rl/grpo/short_math_grpo_moe.py
@@ -14,6 +14,7 @@
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.cli import CLI
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
@@ -26,31 +27,32 @@
from twinkle.preprocessor.llm import GSM8KProcessor
logger = get_logger()
+args = CLI.from_args()
# ========== Configuration ==========
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B')
-USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
+MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.6-35B-A3B'
+USE_MEGATRON = args.model.strategy != 'native_fsdp'
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
-MODEL_EP = int(os.environ.get('MODEL_EP', 2))
-MODEL_TP = int(os.environ.get('MODEL_TP', 2))
-MODEL_PP = int(os.environ.get('MODEL_PP', 2))
+MODEL_GPUS = args.infra.model_gpus or 4
+MODEL_EP = args.infra.ep_size or 2
+MODEL_TP = args.infra.tp_size or 2
+MODEL_PP = args.infra.pp_size or 2
-SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2))
-SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2))
+SAMPLER_GPUS = args.infra.sampler_gpus or 2
+SAMPLER_TP = args.sampler.tensor_parallel_size or 2
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
-NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
-MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
-LEARNING_RATE = float(os.environ.get('LR', 5e-5))
-MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
-MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
-MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
-GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
-ADAPTER_NAME = 'default'
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
-LORA_RANK = int(os.environ.get('LORA_RANK', 16))
+NUM_GENERATIONS = args.rl.num_generations or 8
+MAX_NEW_TOKENS = args.sampling.max_tokens or 4096
+LEARNING_RATE = args.optimizer.learning_rate or 5e-5
+MAX_STEPS = args.training.max_steps or 1000
+BATCH_SIZE = args.training.batch_size or 4
+MINI_BATCH_SIZE = args.training.mini_batch_size or 4
+MICRO_BATCH_SIZE = args.training.micro_batch_size or 1
+GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1
+ADAPTER_NAME = args.lora.adapter_name or 'default'
+SAVE_STEPS = args.training.save_steps or 1000
+LORA_RANK = args.lora.lora_r or 16
SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
'and put your final answer within \\boxed{}.')
diff --git a/cookbook/rl/grpo/short_math_grpo_moe.sh b/cookbook/rl/grpo/short_math_grpo_moe.sh
new file mode 100644
index 000000000..00369610c
--- /dev/null
+++ b/cookbook/rl/grpo/short_math_grpo_moe.sh
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# GRPO Short Math MoE on GSM8K via Ray.
+# Uses Megatron MoE model with TP+EP+PP parallelism.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash short_math_grpo_moe.sh --model-id ms://Qwen/Qwen3.6-35B-A3B --max-steps 500
+
+python short_math_grpo_moe.py \
+ --model-id ms://Qwen/Qwen3.6-35B-A3B \
+ --model-gpus 4 \
+ --sampler-gpus 2 \
+ --ep-size 2 \
+ --tp-size 2 \
+ --pp-size 2 \
+ --tensor-parallel-size 2 \
+ --num-generations 8 \
+ --max-tokens 4096 \
+ --batch-size 4 \
+ --mini-batch-size 4 \
+ --micro-batch-size 1 \
+ --max-steps 1000 \
+ --lr 5e-5 \
+ --lora-r 16 \
+ --save-steps 1000 \
+ --adapter-name default \
+ "$@"
diff --git a/cookbook/rl/short_math_grpo_multi_lora.py b/cookbook/rl/grpo/short_math_grpo_multi_lora.py
similarity index 92%
rename from cookbook/rl/short_math_grpo_multi_lora.py
rename to cookbook/rl/grpo/short_math_grpo_multi_lora.py
index 9dad8df30..cff4bb4b9 100644
--- a/cookbook/rl/short_math_grpo_multi_lora.py
+++ b/cookbook/rl/grpo/short_math_grpo_multi_lora.py
@@ -21,6 +21,7 @@
import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
+from twinkle.cli import CLI
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
@@ -33,28 +34,29 @@
from twinkle.preprocessor.llm import GSM8KProcessor
logger = get_logger()
+args = CLI.from_args()
# ========== Configuration ==========
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B')
+MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.6-35B-A3B'
-MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
-SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2))
-SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2))
+MODEL_GPUS = args.infra.model_gpus or 4
+SAMPLER_GPUS = args.infra.sampler_gpus or 2
+SAMPLER_TP = args.sampler.tensor_parallel_size or 2
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
-NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
-MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
-LEARNING_RATE = float(os.environ.get('LR', 5e-5))
-MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
-MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
-MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
-GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
-ADAPTER_NAME = 'default_0'
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
-LORA_RANK = int(os.environ.get('LORA_RANK', 16))
-LORA_SYNC_DIR = os.environ.get('LORA_SYNC_DIR', 'output/lora_sync')
+NUM_GENERATIONS = args.rl.num_generations or 8
+MAX_NEW_TOKENS = args.sampling.max_tokens or 4096
+LEARNING_RATE = args.optimizer.learning_rate or 5e-5
+MAX_STEPS = args.training.max_steps or 1000
+BATCH_SIZE = args.training.batch_size or 4
+MINI_BATCH_SIZE = args.training.mini_batch_size or 4
+MICRO_BATCH_SIZE = args.training.micro_batch_size or 1
+GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1
+ADAPTER_NAME = args.lora.adapter_name or 'default_0'
+SAVE_STEPS = args.training.save_steps or 1000
+LORA_RANK = args.lora.lora_r or 16
+LORA_SYNC_DIR = args.checkpoint.lora_sync_dir or 'output/lora_sync'
SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
'and put your final answer within \\boxed{}.')
diff --git a/cookbook/rl/grpo/short_math_grpo_multi_lora.sh b/cookbook/rl/grpo/short_math_grpo_multi_lora.sh
new file mode 100644
index 000000000..a465250c8
--- /dev/null
+++ b/cookbook/rl/grpo/short_math_grpo_multi_lora.sh
@@ -0,0 +1,26 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# GRPO Short Math MultiLoRA on GSM8K via Ray.
+# Uses MultiLoraMegatronModel with filesystem-based LoRA sync to vLLM.
+# Model: Qwen3.6-35B-A3B (MoE) with tp=2, ep=2, pp=2.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash short_math_grpo_multi_lora.sh --model-id ms://Qwen/Qwen3.6-35B-A3B --max-steps 500
+
+python short_math_grpo_multi_lora.py \
+ --model-id ms://Qwen/Qwen3.6-35B-A3B \
+ --model-gpus 4 \
+ --sampler-gpus 2 \
+ --tensor-parallel-size 2 \
+ --num-generations 8 \
+ --max-tokens 4096 \
+ --batch-size 4 \
+ --mini-batch-size 4 \
+ --micro-batch-size 1 \
+ --max-steps 1000 \
+ --lr 5e-5 \
+ --lora-r 16 \
+ --save-steps 1000 \
+ --adapter-name default_0 \
+ --lora-sync-dir output/lora_sync \
+ "$@"
diff --git a/cookbook/rl/multi_turn/multi_turn_grpo.py b/cookbook/rl/multi_turn/multi_turn_grpo.py
new file mode 100644
index 000000000..6661bfcdc
--- /dev/null
+++ b/cookbook/rl/multi_turn/multi_turn_grpo.py
@@ -0,0 +1,457 @@
+"""Multi-turn GRPO training with EnvPool (integrated environment pool).
+
+Demonstrates how to train an LLM agent via GRPO on interactive environments
+(e.g. Blackjack) using EnvPool and Twinkle's MultiTurnRollout.
+
+EnvPool is deployed as a @remote_class component — either:
+ - With remote_group='env': runs on a dedicated CPU DeviceGroup (isolated)
+ - Without remote_group: runs locally in the driver (zero RPC overhead)
+
+The agent interacts with environments through tool calls:
+ 1. EnvPool manages N env instances; each trajectory maps to one slot.
+ 2. MultiTurnRollout drives the multi-turn loop: model generates tool calls,
+ EnvTool dispatches them to env.step(), observations are fed back.
+ 3. Episode reward is extracted after rollout completes.
+ 4. GRPO advantages are computed across the batch and used for policy update.
+
+Usage:
+ # No need to start a separate server — environments are instantiated
+ # directly inside the EnvPool worker:
+ # python multi_turn_grpo.py
+ #
+ # To run envs on a dedicated CPU worker (isolated):
+ # ENV_REMOTE=1 python multi_turn_grpo.py
+
+References:
+ - OpenEnv GRPO Blackjack: https://github.com/huggingface/OpenEnv/tree/main/examples/grpo_blackjack
+ - cookbook/rl/grpo/short_math_grpo.py (single-turn GRPO template)
+"""
+import os
+from typing import Any, Dict, List, Tuple
+
+from peft import LoraConfig
+
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
+from twinkle.advantage import GRPOAdvantage
+from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.cli import CLI
+from twinkle.data_format import SamplingParams
+from twinkle.metric import CompletionRewardMetric
+from twinkle.model import TransformersModel
+from twinkle.processor import InputProcessor
+from twinkle.sampler import vLLMSampler
+from twinkle.template import Qwen3_5Template
+from twinkle_agentic.envs import EnvPool, EnvPoolAdapter, EnvTool
+from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+logger = get_logger()
+args = CLI.from_args()
+
+# ========== Configuration ==========
+MODEL_ID = args.model.model_id or 'ms://Qwen/Qwen3.5-4B'
+USE_MEGATRON = False
+
+MODEL_GPUS = args.infra.model_gpus or 4
+SAMPLER_GPUS = args.infra.sampler_gpus or 4
+NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
+
+NUM_GENERATIONS = args.rl.num_generations or 8
+MAX_NEW_TOKENS = args.sampling.max_tokens or 2048
+LEARNING_RATE = args.optimizer.learning_rate or 1e-5
+MAX_STEPS = args.training.max_steps or 1000
+BATCH_SIZE = args.training.batch_size or 4
+MINI_BATCH_SIZE = args.training.mini_batch_size or 8
+MICRO_BATCH_SIZE = args.training.micro_batch_size or 2
+GRADIENT_ACCUMULATION_STEPS = args.training.gradient_accumulation_steps or 1
+ADAPTER_NAME = args.lora.adapter_name or 'default'
+SAVE_STEPS = args.training.save_steps or 500
+LORA_RANK = args.lora.lora_r or 16
+MAX_TURNS = int(os.environ.get('MAX_TURNS', '6'))
+
+# Environment configuration
+# ENV_CLS: import path to the environment class (no server needed)
+ENV_CLS = os.environ.get('ENV_CLS', 'blackjack_env:BlackjackEnv')
+# ENV_REMOTE: set to '1' to deploy envs on a dedicated CPU DeviceGroup
+ENV_REMOTE = os.environ.get('ENV_REMOTE', '0') == '1'
+# Pool size = total trajectories per batch
+ENV_POOL_SIZE = int(os.environ.get('ENV_POOL_SIZE', '0')) # 0 = auto
+
+# ========== Tool Schema (Blackjack example) ==========
+# Define tools the model can use in the environment.
+# For blackjack: a single "play" tool with hit/stand actions.
+# Override TOOL_SCHEMA for different environments.
+BLACKJACK_TOOL_SCHEMA = [
+ {
+ 'type': 'function',
+ 'function': {
+ 'name': 'play',
+ 'description': 'Take an action in the blackjack game.',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'action': {
+ 'type': 'string',
+ 'enum': ['hit', 'stand'],
+ 'description': 'The action to take: "hit" to draw a card, "stand" to keep current hand.',
+ }
+ },
+ 'required': ['action'],
+ },
+ },
+ }
+]
+
+TOOL_SCHEMA = BLACKJACK_TOOL_SCHEMA
+
+# Action name → OpenSpiel action_id mapping for blackjack.
+# OpenSpiel blackjack: 0 = HIT, 1 = STAND
+BLACKJACK_ACTION_MAP = {'hit': 0, 'stand': 1}
+
+
+def blackjack_action_mapper(tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
+ """Map tool calls to OpenSpielAction format.
+
+ Converts play(action='hit') → {action_id: 0, game_name: 'blackjack'}
+ """
+ action_str = str(arguments.get('action', 'stand')).lower().strip()
+ action_id = BLACKJACK_ACTION_MAP.get(action_str, 1) # default STAND
+ return {'action_id': action_id, 'game_name': 'blackjack'}
+
+
+SYSTEM_PROMPT = """You are a skilled blackjack player. You will be told your current hand and the dealer's visible card.
+
+Your goal is to win the game by getting as close to 21 as possible without going over.
+
+Strategy guidelines:
+- Hit if your hand total is below 12
+- Consider the dealer's visible card when deciding
+- Stand if you have 17 or higher
+- Be cautious with hard hands (no ace counted as 11)
+
+Use the `play` tool to take actions. Always reason briefly before acting."""
+
+
+# ========== Environment Setup ==========
+def prepare_trajectories(
+ env_pool: EnvPool,
+ n_trajectories: int,
+ tool_schema: List[Dict],
+ system_prompt: str,
+ action_mapper=None,
+) -> Tuple[List[Dict[str, Any]], List[ToolManager], List[List[EnvTool]]]:
+ """Reset environments via EnvPool and build initial trajectories.
+
+ For each trajectory:
+ 1. Get an EnvPoolAdapter (standard Env interface) from the pool
+ 2. Reset the env slot to get initial observation
+ 3. Build a trajectory dict with system + user messages and tools
+
+ Args:
+ env_pool: The EnvPool instance managing all environments.
+ n_trajectories: Total number of trajectories to create.
+ tool_schema: Tool definitions for the environment.
+ system_prompt: System prompt for the agent.
+ action_mapper: Optional callable to transform actions.
+
+ Returns:
+ Tuple of (trajectories, tool_managers, env_tools_list).
+ """
+ # Get per-trajectory adapters from the pool
+ adapters = env_pool.get_adapters(
+ n=n_trajectories,
+ tool_schema=tool_schema,
+ action_mapper=action_mapper,
+ )
+
+ trajectories = []
+ tool_managers = []
+ env_tools_list = []
+
+ for adapter in adapters:
+ # Reset env slot to start a new episode
+ initial_result = adapter.reset()
+ initial_obs = initial_result.observation
+
+ # Create EnvTool and ToolManager for this trajectory
+ env_tools = EnvTool.from_env(adapter)
+ tm = ToolManager(env_tools)
+
+ # Build trajectory with initial observation as user message
+ traj = {
+ 'messages': [
+ {'role': 'system', 'content': system_prompt},
+ {'role': 'user', 'content': initial_obs},
+ ],
+ 'tools': tool_schema,
+ }
+
+ trajectories.append(traj)
+ tool_managers.append(tm)
+ env_tools_list.append(env_tools)
+
+ return trajectories, tool_managers, env_tools_list
+
+
+def extract_rewards(env_tools_list: List[List[EnvTool]]) -> List[float]:
+ """Extract episode rewards from EnvTool instances after rollout.
+
+ Each EnvTool tracks the cumulative episode reward from env.step() calls.
+ """
+ rewards = []
+ for env_tools in env_tools_list:
+ if env_tools:
+ reward = env_tools[0].episode_reward
+ else:
+ reward = 0.0
+ rewards.append(reward)
+ return rewards
+
+
+# ========== Main ==========
+def main():
+ # Determine pool size
+ n_trajectories = BATCH_SIZE * NUM_GENERATIONS
+ pool_size = ENV_POOL_SIZE if ENV_POOL_SIZE > 0 else n_trajectories
+
+ # Device groups: model + sampler + (optionally) env
+ device_groups = [
+ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
+ DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'),
+ ]
+
+ if ENV_REMOTE:
+ # Add a CPU-only DeviceGroup for env pool (1 CPU process, colocated on same node)
+ device_groups.append(
+ DeviceGroup(name='env', ranks=1, device_type='CPU'),
+ )
+
+ model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
+ sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
+ twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
+
+ lora_config = LoraConfig(
+ target_modules='all-linear',
+ r=LORA_RANK,
+ lora_alpha=LORA_RANK * 2,
+ lora_dropout=0.05,
+ )
+
+ if USE_MEGATRON:
+ from twinkle.model.megatron import MegatronModel
+ model = MegatronModel(
+ model_id=MODEL_ID,
+ device_mesh=model_mesh,
+ remote_group='model',
+ mixed_precision='bf16',
+ variable_seq_lengths=True,
+ )
+ else:
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ device_mesh=model_mesh,
+ remote_group='model',
+ )
+
+ model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+ if USE_MEGATRON:
+ model.set_optimizer('default', lr=LEARNING_RATE)
+ model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE)
+ else:
+ model.set_optimizer('AdamW', lr=LEARNING_RATE)
+ model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)
+
+ model.set_loss('GRPOLoss', epsilon=0.2)
+ model.set_processor(InputProcessor, padding_free=True)
+ model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False)
+
+ sampler = vLLMSampler(
+ model_id=MODEL_ID,
+ engine_args={
+ 'gpu_memory_utilization': 0.8,
+ 'max_model_len': 8192,
+ 'max_lora_rank': 32,
+ 'enable_lora': True,
+ 'enable_tower_connector_lora': True,
+ },
+ device_mesh=sampler_mesh,
+ remote_group='sampler',
+ )
+ sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False)
+
+ # ========== EnvPool: environment instances managed by Twinkle ==========
+ env_pool_kwargs = dict(
+ env_cls=ENV_CLS,
+ pool_size=pool_size,
+ )
+ if ENV_REMOTE:
+ # Deploy on dedicated CPU DeviceGroup
+ env_mesh = DeviceMesh.from_sizes(world_size=1, dp_size=1)
+ env_pool_kwargs['remote_group'] = 'env'
+ env_pool_kwargs['device_mesh'] = env_mesh
+ # else: runs locally in driver (zero RPC overhead)
+
+ env_pool = EnvPool(**env_pool_kwargs)
+ logger.info(f'EnvPool created: env_cls={ENV_CLS}, pool_size={pool_size}, '
+ f'remote={ENV_REMOTE}')
+
+ # Local template for MultiTurnRollout bridge computation
+ rollout_template = Qwen3_5Template(MODEL_ID, max_length=8192, enable_thinking=False)
+ rollout_template.truncation_strategy = 'delete'
+
+ ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)
+
+ # MultiTurnRollout: tool_manager is optional at construction time;
+ # the actual per-trajectory ToolManagers are provided at call time.
+ sampling_params = SamplingParams(
+ max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1,
+ temperature=1.0, top_p=0.95,
+ )
+ rollout = MultiTurnRollout(
+ sampler=sampler,
+ template=rollout_template,
+ sampling_params=sampling_params,
+ max_turns=MAX_TURNS,
+ )
+
+ advantage_fn = GRPOAdvantage()
+ metrics = CompletionRewardMetric()
+
+ optim_step = 0
+ logger.info('Starting multi-turn GRPO training with EnvPool')
+ logger.info(f'ENV_CLS={ENV_CLS}, MAX_TURNS={MAX_TURNS}, NUM_GENERATIONS={NUM_GENERATIONS}')
+ logger.info(get_device_placement())
+
+ while optim_step < MAX_STEPS:
+ metrics.reset()
+
+ # Total trajectories per batch: BATCH_SIZE * NUM_GENERATIONS
+ # Each trajectory is an independent game episode.
+ n_traj = BATCH_SIZE * NUM_GENERATIONS
+
+ # 1. Prepare environments and initial trajectories
+ logger.info(f'[Step {optim_step}] Resetting {n_traj} environments...')
+ expand_prompts, tool_managers, env_tools_list = prepare_trajectories(
+ env_pool=env_pool,
+ n_trajectories=n_traj,
+ tool_schema=TOOL_SCHEMA,
+ system_prompt=SYSTEM_PROMPT,
+ action_mapper=blackjack_action_mapper,
+ )
+
+ # 2. Sync model weights to sampler
+ ckpt_manager.sync_weights(merge_and_sync=False)
+ sampler.reset_prefix_cache()
+
+ # 3. Run multi-turn rollout with per-trajectory ToolManagers
+ all_trajectories: List[Dict[str, Any]] = rollout(
+ expand_prompts,
+ tool_manager=tool_managers,
+ )
+
+ # 4. Extract rewards and logprobs
+ env_rewards = extract_rewards(env_tools_list)
+
+ all_old_logps: List[List[float]] = []
+ all_completion_lengths: List[int] = []
+ n_turns_per_rollout: List[int] = []
+
+ for traj in all_trajectories:
+ logprobs = traj.get('logprobs') or []
+ old_logps = [lp[0][1] for lp in logprobs] if logprobs else []
+ all_old_logps.append(old_logps)
+ # Completion length = number of trainable tokens (labels != -100)
+ labels = traj.get('labels') or []
+ comp_len = sum(1 for l in labels if l != -100)
+ all_completion_lengths.append(comp_len)
+ n_turns_per_rollout.append(int(traj.get('turns') or 0))
+
+ # 5. Compute advantages (group-relative within NUM_GENERATIONS)
+ total_rewards = env_rewards
+ advantages = advantage_fn(
+ total_rewards, num_generations=NUM_GENERATIONS, scale='group',
+ ).tolist()
+
+ # 6. Log metrics
+ metrics.accumulate(
+ completion_lengths=all_completion_lengths,
+ rewards={'total': total_rewards},
+ )
+
+ avg_reward = sum(total_rewards) / len(total_rewards) if total_rewards else 0.0
+ avg_turns = sum(n_turns_per_rollout) / len(n_turns_per_rollout) if n_turns_per_rollout else 0.0
+ logger.info(f'[Step {optim_step}] avg_reward={avg_reward:.3f}, avg_turns={avg_turns:.1f}')
+
+ # 7. Forward-backward with mini-batches
+ # Filter out oversized/truncated trajectories (strategy='delete'),
+ # keep only those with valid completions and ensure >= MODEL_GPUS inputs.
+ all_input_data: List[Dict[str, Any]] = []
+ filtered_old_logps: List[List[float]] = []
+ filtered_advantages: List[float] = []
+ max_len = rollout_template.max_length or float('inf')
+ for i, traj in enumerate(all_trajectories):
+ traj_len = len(traj.get('input_ids') or traj.get('labels') or [])
+ comp_len = sum(1 for l in (traj.get('labels') or []) if l != -100)
+ if traj_len > max_len or comp_len == 0:
+ continue
+ all_input_data.append(traj)
+ filtered_old_logps.append(all_old_logps[i])
+ filtered_advantages.append(advantages[i])
+
+ if len(all_input_data) < MODEL_GPUS:
+ logger.warning(f'[Step {optim_step}] Only {len(all_input_data)} valid trajectories '
+ f'after filtering (need >= {MODEL_GPUS}), skipping this batch.')
+ continue
+
+ all_old_logps = filtered_old_logps
+ advantages = filtered_advantages
+ total_completions = len(all_input_data)
+ logger.info(f'[Step {optim_step}] {total_completions}/{n_traj} trajectories '
+ f'passed length filter (max_len={max_len})')
+
+ for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
+ mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
+ mb_inputs = all_input_data[mb_start:mb_end]
+ mb_old_logps = all_old_logps[mb_start:mb_end]
+ mb_advantages = advantages[mb_start:mb_end]
+
+ # Print trajectory lengths before forward_backward
+ traj_lengths = []
+ for idx, traj in enumerate(mb_inputs):
+ labels = traj.get('labels') or traj.get('input_ids') or []
+ traj_lengths.append(len(labels))
+ logger.info(f'[Step {optim_step}] mini-batch [{mb_start}:{mb_end}] '
+ f'n_inputs={len(mb_inputs)}, dp_world={MODEL_GPUS}, '
+ f'traj_lengths={traj_lengths}')
+
+ model.forward_backward(
+ inputs=mb_inputs,
+ old_logps=mb_old_logps,
+ advantages=mb_advantages,
+ micro_batch_size=MICRO_BATCH_SIZE,
+ )
+ model.clip_grad_and_step()
+ optim_step += 1
+
+ if optim_step >= MAX_STEPS:
+ break
+ if optim_step % SAVE_STEPS == 0:
+ model.save(f'multi-turn-grpo-checkpoint-{optim_step}')
+
+ # 8. Log step summary
+ log_dict = metrics.calculate()
+ log_dict.update(model.calculate_metric(is_training=True))
+ log_dict['avg_turns'] = avg_turns
+ log_dict['avg_reward'] = avg_reward
+ metrics.reset()
+ logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')
+
+ # Cleanup
+ env_pool.close()
+ logger.info(f'Training completed. optim_steps={optim_step}')
+ model.save('multi-turn-grpo-final')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/sample/emb_sample.py b/cookbook/sample/emb_sample.py
index 6d7e4c599..da27a8155 100644
--- a/cookbook/sample/emb_sample.py
+++ b/cookbook/sample/emb_sample.py
@@ -12,15 +12,15 @@
python cookbook/sample/emb_sample.py
EMB_MODEL=./output/embedding_lora_transformers/step_16000 python cookbook/sample/emb_sample.py
"""
-import os
import re
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
import torch
import torch.nn.functional as F
import twinkle
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.data_format import SamplingParams
from twinkle.loss import InfonceLoss
from twinkle.model import TransformersModel
@@ -29,12 +29,13 @@
from twinkle.template import Template
logger = get_logger()
+args = CLI.from_args()
# -- Config -------------------------------------------------------------------
-CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2')
-EMB_MODEL_ID = os.environ.get('EMB_MODEL', 'ms://twinkle-kit/Qwen3.5-4B-QA-emb')
-SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 1))
-EMB_GPUS = int(os.environ.get('EMB_GPUS', 1))
+CONDENSE_MODEL_ID = args.extra.get('condense_model_id', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2')
+EMB_MODEL_ID = args.extra.get('emb_model_id', 'ms://twinkle-kit/Qwen3.5-4B-QA-emb')
+SAMPLER_GPUS = args.infra.sampler_gpus or 1
+EMB_GPUS = int(args.extra.get('emb_gpus', 1))
EMB_MAX_LENGTH = 8192
# -- Prompts (aligned with train_embedding_full_ddp.py) -----------------------
diff --git a/cookbook/sample/sample.py b/cookbook/sample/sample.py
index b56460ea1..8cd452b8f 100644
--- a/cookbook/sample/sample.py
+++ b/cookbook/sample/sample.py
@@ -18,19 +18,19 @@
MODEL_ID=/path/to/model LORA_PATH=/path/to/adapter SAMPLER_GPUS=1 python sample.py
"""
-import os
from typing import List, Dict, Any
import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.data_format import SamplingParams
from twinkle.sampler import vLLMSampler
logger = get_logger()
+args = CLI.from_args()
-MODEL_ID = os.environ.get('MODEL_ID', 'Qwen/Qwen3.5-4B')
-LORA_PATH = os.environ.get('LORA_PATH', '/path/to/lora')
-SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 1))
+SAMPLER_GPUS = args.infra.sampler_gpus or 1
+LORA_PATH = args.lora.lora_path or '/path/to/lora'
def build_prompts() -> List[Dict[str, Any]]:
@@ -67,7 +67,7 @@ def main():
# ── 2. Create vLLMSampler with LoRA enabled ────────────────────────
sampler = vLLMSampler(
- model_id=MODEL_ID,
+ model_id=args.model.model_id,
engine_args={
'gpu_memory_utilization': 0.7,
'max_model_len': 4096,
@@ -79,7 +79,7 @@ def main():
device_mesh=sampler_mesh,
remote_group='sampler',
)
- sampler.set_template('Qwen3_5Template', model_id=MODEL_ID)
+ sampler.set_template('Qwen3_5Template', model_id=args.model.model_id)
logger.info(get_device_placement())
# ── 3. Configure sampling parameters ────────────────────────────────
@@ -92,7 +92,7 @@ def main():
# ── 4. Run inference ────────────────────────────────────────────────
prompts = build_prompts()
- logger.info(f'Sampling {len(prompts)} prompts with model {MODEL_ID} ...')
+ logger.info(f'Sampling {len(prompts)} prompts with model {args.model.model_id} ...')
responses = sampler.sample(prompts, sampling_params, adapter_path=LORA_PATH)
diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
index af72efa10..79daafd38 100644
--- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
+++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
@@ -4,7 +4,6 @@
Run on 8 GPUs:
torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
"""
-import os
from pathlib import Path
from peft import LoraConfig
@@ -12,45 +11,31 @@
import twinkle
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor
logger = get_logger()
+args = CLI.from_args()
-MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
-TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
-GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
-LOG_INTERVAL = GRAD_ACCUM_STEPS
-LR = float(os.environ.get('LR', '1e-4'))
-MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
-LORA_R = int(os.environ.get('LORA_R', '8'))
-LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
-ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1'
-OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output_dsv4')
-RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None
-RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1'
-IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1'
-ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default')
-NUM_GPUS = int(os.environ.get('NUM_GPUS', '8'))
+ENABLE_EP = args.extra.get('enable_ep', True)
device_mesh = DeviceMesh.from_sizes(
- fsdp_size=NUM_GPUS,
- dp_size=1,
- ep_size=NUM_GPUS,
+ fsdp_size=args.infra.fsdp_size,
+ dp_size=args.infra.dp_size,
+ ep_size=args.infra.ep_size,
device_type=Platform.get_platform().device_prefix(),
)
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh)
def _build_lora_config(enable_ep: bool):
if enable_ep:
return LoraConfig(
- r=LORA_R,
- lora_alpha=LORA_ALPHA,
+ r=args.lora.lora_r,
+ lora_alpha=args.lora.lora_alpha,
target_modules='all-linear',
exclude_modules=['o_a_proj'],
target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
@@ -60,8 +45,8 @@ def _build_lora_config(enable_ep: bool):
# during forward. That is not stable with plain FSDP2, so non-EP mode uses
# regular module LoRA and does not train expert parameters.
return LoraConfig(
- r=LORA_R,
- lora_alpha=LORA_ALPHA,
+ r=args.lora.lora_r,
+ lora_alpha=args.lora.lora_alpha,
exclude_modules=['o_a_proj'],
target_modules='all-linear',
)
@@ -70,31 +55,34 @@ def _build_lora_config(enable_ep: bool):
def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader):
return model.save(
name=checkpoint_name,
- output_dir=OUTPUT_DIR,
- adapter_name=ADAPTER_NAME,
- save_optimizer=True,
+ output_dir=args.training.output_dir,
+ adapter_name=args.lora.adapter_name,
+ save_optimizer=args.checkpoint.save_optimizer,
consumed_train_samples=dataloader.get_state()['consumed_train_samples'],
)
def train():
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ config = AutoConfig.from_pretrained(args.model.model_id, trust_remote_code=True)
text_config = getattr(config, 'text_config', config)
if hasattr(text_config, 'use_cache'):
text_config.use_cache = False
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID))
- dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
- dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
+ dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id))
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
+ dataset.map(SelfCognitionProcessor(
+ args.extra.get('model_name', 'twinkle'),
+ args.extra.get('model_author', 'ModelScope'),
+ ))
dataset.encode(batched=True)
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size, device_mesh=device_mesh)
model = TransformersModel(
- model_id=MODEL_ID,
+ model_id=args.model.model_id,
config=config,
device_mesh=device_mesh,
strategy='native_fsdp',
- memory_efficient_init=True,
+ memory_efficient_init=args.model.memory_efficient_init,
fsdp_config={
'expert_parallel': {
'enabled': ENABLE_EP,
@@ -104,38 +92,41 @@ def train():
},
)
lora_cfg = _build_lora_config(ENABLE_EP)
- model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
- model.set_optimizer('AdamW', lr=LR, foreach=False)
+ model.add_adapter_to_model(args.lora.adapter_name, lora_cfg,
+ gradient_accumulation_steps=args.training.gradient_accumulation_steps)
+ model.set_optimizer(args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate, foreach=False)
model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler',
- num_warmup_steps=5,
+ scheduler_cls=args.scheduler.scheduler_cls,
+ num_warmup_steps=args.scheduler.num_warmup_steps,
num_training_steps=len(dataloader),
)
- if RESUME_FROM_CHECKPOINT:
- checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()
+ if args.training.resume_from_checkpoint:
+ checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve()
kwargs = {}
- if ADAPTER_NAME:
- kwargs['adapter_name'] = ADAPTER_NAME
+ if args.lora.adapter_name:
+ kwargs['adapter_name'] = args.lora.adapter_name
progress = model.resume_from_checkpoint(
- str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs)
- if not IGNORE_DATA_SKIP:
+ str(checkpoint_path), resume_only_model=args.training.resume_only_model, **kwargs)
+ if not args.training.ignore_data_skip:
dataloader.resume_from_checkpoint(progress['consumed_train_samples'])
logger.info(get_device_placement())
logger.info(model.get_train_configs())
logger.info(
- f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, '
- f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}')
+ f'Total steps: {len(dataloader)}, batch_size={args.training.batch_size}, '
+ f'grad_accum={args.training.gradient_accumulation_steps}, '
+ f'enable_ep={ENABLE_EP}, output_dir={args.training.output_dir}')
- optimizer_group = model.optimizer_group[ADAPTER_NAME]
+ optimizer_group = model.optimizer_group[args.lora.adapter_name]
for batch in dataloader:
if callable(batch):
batch = batch()
model.forward_backward(inputs=batch)
- model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ model.clip_grad_and_step(max_grad_norm=args.optimizer.max_grad_norm,
+ gradient_accumulation_steps=args.training.gradient_accumulation_steps)
cur_step = optimizer_group.cur_step
- if cur_step > 0 and cur_step % LOG_INTERVAL == 0:
+ if cur_step > 0 and cur_step % args.training.log_interval == 0:
metric = model.calculate_metric(is_training=True)
if callable(metric):
metric = metric()
diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh
index b4e3d9ffb..f2b01ff6e 100644
--- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh
+++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh
@@ -4,13 +4,18 @@ set -euo pipefail
# EP + FSDP2 + LoRA training for DeepSeek-V4.
# ENABLE_EP=1 trains expert LoRA with target_parameters.
# ENABLE_EP=0 runs plain FSDP2 LoRA and does not train expert parameters.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash ep_fsdp2_lora_deepseek_v4.sh --batch-size 8 --lr 5e-5
-export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
-export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
-export ENABLE_EP="${ENABLE_EP:-1}"
-export BATCH_SIZE="${BATCH_SIZE:-4}"
-export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}"
-export OUTPUT_DIR="${OUTPUT_DIR:-./output_dsv4}"
-
-torchrun --nproc-per-node="${NPROC_PER_NODE}" \
- cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ torchrun --nproc-per-node=8 \
+ cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py \
+ --model-id ms://deepseek-ai/DeepSeek-V3-0324 \
+ --dataset-id ms://swift/self-cognition \
+ --dp-size 4 \
+ --ep-size 2 \
+ --batch-size 4 \
+ --gradient-accumulation-steps 4 \
+ --output-dir ./output_dsv4 \
+ --enable-ep 1 \
+ "$@"
diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh
index 7344474e5..47a7ebfc4 100644
--- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh
+++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh
@@ -1,26 +1,39 @@
+#!/usr/bin/env bash
+set -euo pipefail
# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights.
# Convert the checkpoint before training by following:
# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87
# Install `transformers==5.8.0` before running this cookbook.
+# All training config passed as CLI flags. Override at invocation.
-export DSV4_MODEL_ID="ms://deepseek-ai/DeepSeek-V4-Flash-bf16"
-export DATASET_ID="ms://swift/self-cognition"
-# The following environment variables are required for multi-node training. Adjust the values according to your cluster setup.
-export GLOO_SOCKET_IFNAME="eth0" # Use ifconfig to check the network interface name
+# Multi-node networking config — adjust to your cluster setup.
+export GLOO_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export HCCL_EXEC_TIMEOUT=1200
export HCCL_CONNECT_TIMEOUT=1200
-export NNODES=4
-export NUM_GPUS=64
-export MASTER_ADDR="node0" # Replace with the IP address or hostname of the master node
-export MASTER_PORT=29500 # Replace with an open port on the master node
export HCCL_IF_BASE_PORT=20000
-torchrun --nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=16 \
- --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT ep_fsdp2_lora_deepseek_v4.py
+NNODES=4
+MASTER_ADDR=node0
+MASTER_PORT=29500
+NPROC_PER_NODE=16
-# NODE_RANK=0 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh
-# NODE_RANK=1 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh
-# NODE_RANK=2 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh
-# NODE_RANK=3 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh
+torchrun --nnodes=$NNODES --node_rank=${NODE_RANK:?"NODE_RANK must be set"} \
+ --nproc_per_node=$NPROC_PER_NODE \
+ --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
+ ep_fsdp2_lora_deepseek_v4.py \
+ --model-id ms://deepseek-ai/DeepSeek-V4-Flash-bf16 \
+ --dataset-id ms://swift/self-cognition \
+ --dp-size 4 \
+ --ep-size 2 \
+ --batch-size 4 \
+ --gradient-accumulation-steps 4 \
+ --output-dir ./output_dsv4_multinode \
+ --enable-ep 1 \
+ "$@"
+
+# NODE_RANK=0 bash ep_fsdp2_lora_deepseek_v4_multinode.sh
+# NODE_RANK=1 bash ep_fsdp2_lora_deepseek_v4_multinode.sh
+# NODE_RANK=2 bash ep_fsdp2_lora_deepseek_v4_multinode.sh
+# NODE_RANK=3 bash ep_fsdp2_lora_deepseek_v4_multinode.sh
diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
index a9f90111d..5d5088182 100644
--- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
+++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
@@ -4,7 +4,6 @@
Run on 8 GPUs:
torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
"""
-import os
from pathlib import Path
from peft import LoraConfig
@@ -12,6 +11,7 @@
import twinkle
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
@@ -20,38 +20,24 @@
from twinkle.kernel import kernelize_model
logger = get_logger()
+args = CLI.from_args()
-MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
-TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template')
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
-GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
-LOG_INTERVAL = GRAD_ACCUM_STEPS
-LR = float(os.environ.get('LR', '1e-4'))
-MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
-LORA_R = int(os.environ.get('LORA_R', '8'))
-LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
-ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1'
-OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
-RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None
-RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1'
-IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1'
-ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default')
+ENABLE_EP = args.extra.get('enable_ep', True)
device_mesh = DeviceMesh.from_sizes(
- fsdp_size=8,
- dp_size=1,
- ep_size=8,
+ fsdp_size=args.infra.fsdp_size,
+ dp_size=args.infra.dp_size,
+ ep_size=args.infra.ep_size,
device_type=Platform.get_platform().device_prefix(),
)
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh)
def _build_lora_config(enable_ep: bool):
if enable_ep:
return LoraConfig(
- r=LORA_R,
- lora_alpha=LORA_ALPHA,
+ r=args.lora.lora_r,
+ lora_alpha=args.lora.lora_alpha,
target_modules='all-linear',
target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
)
@@ -60,8 +46,8 @@ def _build_lora_config(enable_ep: bool):
# during forward. That is not stable with plain FSDP2, so non-EP mode uses
# regular module LoRA and does not train expert parameters.
return LoraConfig(
- r=LORA_R,
- lora_alpha=LORA_ALPHA,
+ r=args.lora.lora_r,
+ lora_alpha=args.lora.lora_alpha,
target_modules='all-linear',
)
@@ -69,30 +55,33 @@ def _build_lora_config(enable_ep: bool):
def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader):
return model.save(
name=checkpoint_name,
- output_dir=OUTPUT_DIR,
- adapter_name=ADAPTER_NAME,
- save_optimizer=True,
+ output_dir=args.training.output_dir,
+ adapter_name=args.lora.adapter_name,
+ save_optimizer=args.checkpoint.save_optimizer,
consumed_train_samples=dataloader.get_state()['consumed_train_samples'],
)
def train():
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ config = AutoConfig.from_pretrained(args.model.model_id, trust_remote_code=True)
text_config = getattr(config, 'text_config', config)
if hasattr(text_config, 'use_cache'):
text_config.use_cache = False
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID))
+ dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id))
try:
- dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
except ValueError:
- dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
- dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
+ dataset.set_template('Qwen3_5Template', model_id=args.model.model_id)
+ dataset.map(SelfCognitionProcessor(
+ args.extra.get('model_name', 'twinkle'),
+ args.extra.get('model_author', 'ModelScope'),
+ ))
dataset.encode(batched=True)
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size, device_mesh=device_mesh)
model = TransformersModel(
- model_id=MODEL_ID,
+ model_id=args.model.model_id,
config=config,
device_mesh=device_mesh,
strategy='native_fsdp',
@@ -108,38 +97,41 @@ def train():
if Torch.is_npu_available():
model = kernelize_model(model, mode='train', device='npu')
lora_cfg = _build_lora_config(ENABLE_EP)
- model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
- model.set_optimizer('AdamW', lr=LR, foreach=False)
+ model.add_adapter_to_model(args.lora.adapter_name, lora_cfg,
+ gradient_accumulation_steps=args.training.gradient_accumulation_steps)
+ model.set_optimizer(args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate, foreach=False)
model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler',
- num_warmup_steps=5,
+ scheduler_cls=args.scheduler.scheduler_cls,
+ num_warmup_steps=args.scheduler.num_warmup_steps,
num_training_steps=len(dataloader),
)
- if RESUME_FROM_CHECKPOINT:
- checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()
+ if args.training.resume_from_checkpoint:
+ checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve()
kwargs = {}
- if ADAPTER_NAME:
- kwargs['adapter_name'] = ADAPTER_NAME
+ if args.lora.adapter_name:
+ kwargs['adapter_name'] = args.lora.adapter_name
progress = model.resume_from_checkpoint(
- str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs)
- if not IGNORE_DATA_SKIP:
+ str(checkpoint_path), resume_only_model=args.training.resume_only_model, **kwargs)
+ if not args.training.ignore_data_skip:
dataloader.resume_from_checkpoint(progress['consumed_train_samples'])
logger.info(get_device_placement())
logger.info(model.get_train_configs())
logger.info(
- f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, '
- f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}')
+ f'Total steps: {len(dataloader)}, batch_size={args.training.batch_size}, '
+ f'grad_accum={args.training.gradient_accumulation_steps}, '
+ f'enable_ep={ENABLE_EP}, output_dir={args.training.output_dir}')
- optimizer_group = model.optimizer_group[ADAPTER_NAME]
+ optimizer_group = model.optimizer_group[args.lora.adapter_name]
for batch in dataloader:
if callable(batch):
batch = batch()
model.forward_backward(inputs=batch)
- model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ model.clip_grad_and_step(max_grad_norm=args.optimizer.max_grad_norm,
+ gradient_accumulation_steps=args.training.gradient_accumulation_steps)
cur_step = optimizer_group.cur_step
- if cur_step > 0 and cur_step % LOG_INTERVAL == 0:
+ if cur_step > 0 and cur_step % args.training.log_interval == 0:
metric = model.calculate_metric(is_training=True)
if callable(metric):
metric = metric()
diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh
index 8f1813e4f..5132d9d0b 100644
--- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh
+++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh
@@ -2,15 +2,19 @@
set -euo pipefail
# EP + FSDP2 + LoRA training for Qwen3.5-MoE.
-# ENABLE_EP=1 trains expert LoRA with target_parameters.
-# ENABLE_EP=0 runs plain FSDP2 LoRA and does not train expert parameters.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash ep_fsdp2_lora_qwen3_5_moe.sh --batch-size 8 --lr 5e-5
-export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
-export NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
-export ENABLE_EP="${ENABLE_EP:-1}"
-export BATCH_SIZE="${BATCH_SIZE:-4}"
-export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}"
-export OUTPUT_DIR="${OUTPUT_DIR:-./output_qwen3_5_moe}"
-
-torchrun --nproc-per-node="${NPROC_PER_NODE}" \
- cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ torchrun --nproc-per-node=8 \
+ cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py \
+ --model-id ms://Qwen/Qwen3.5-30B-A3B \
+ --dataset-id ms://swift/self-cognition \
+ --template-cls Qwen3_5Template \
+ --dp-size 4 \
+ --ep-size 2 \
+ --batch-size 4 \
+ --gradient-accumulation-steps 4 \
+ --output-dir ./output_qwen3_5_moe \
+ --enable-ep 1 \
+ "$@"
diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py
index ad4c917f9..a3b4da645 100644
--- a/cookbook/transformers/fsdp2.py
+++ b/cookbook/transformers/fsdp2.py
@@ -2,6 +2,7 @@
from peft import LoraConfig
from tqdm import tqdm
+from torch.optim import Muon # PyTorch 2.9+; matrix-orthogonalized momentum optimizer.
import twinkle
from twinkle import DeviceMesh, get_device_placement, get_logger
@@ -51,7 +52,7 @@ def evaluate(model):
def train():
- train_samples = int(args.extra.get('train_samples', 1000))
+ train_samples = args.training.train_samples or 1000
dataset = build_dataset(train_samples)
dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size)
model = TransformersModel(model_id=args.model.model_id)
@@ -64,7 +65,16 @@ def train():
model.add_adapter_to_model(
args.lora.adapter_name, lora_config,
gradient_accumulation_steps=args.training.gradient_accumulation_steps)
- model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate)
+ # Muon optimizes 2D hidden-layer weight matrices via Newton-Schulz orthogonalization.
+ # In LoRA training the trainable params are exclusively lora_A / lora_B (both 2D),
+ # so Muon applies cleanly without an AdamW fallback for 1D params.
+ # ``adjust_lr_fn='match_rms_adamw'`` rescales the orthogonalized update so the same
+ # lr / weight_decay tuned for AdamW can be reused directly (Moonshot Muon recipe).
+ model.set_optimizer(
+ optimizer_cls=Muon,
+ lr=args.optimizer.learning_rate,
+ adjust_lr_fn='match_rms_adamw',
+ )
# Add LRScheduler for lora `default`
model.set_lr_scheduler(
diff --git a/cookbook/transformers/fsdp2.sh b/cookbook/transformers/fsdp2.sh
index bbe269629..c372fbf2e 100644
--- a/cookbook/transformers/fsdp2.sh
+++ b/cookbook/transformers/fsdp2.sh
@@ -2,7 +2,7 @@
# All training config passed as CLI flags. Override at invocation, e.g.:
# bash fsdp2.sh --batch-size 16 --lr 5e-5
-CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} \
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --nproc_per_node=8 fsdp2.py \
--model-id ms://Qwen/Qwen3.5-4B \
--dataset-id ms://swift/self-cognition \
diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py
index a6fd0bdcb..56f22c801 100644
--- a/cookbook/transformers/sp_fsdp_dense.py
+++ b/cookbook/transformers/sp_fsdp_dense.py
@@ -1,9 +1,9 @@
-import numpy as np
from functools import partial
from peft import LoraConfig
import twinkle
-from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger
+from twinkle import DeviceMesh, get_logger
+from twinkle.cli import CLI
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
@@ -12,48 +12,44 @@
from twinkle.kernel import kernelize_model
logger = get_logger()
-MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
-DATASETS = 'ms://swift/self-cognition'
-
-device_group = [DeviceGroup(
- name='default',
- ranks=[0, 1, 2, 3],
- device_type=Platform.get_platform().device_prefix(),
-)]
+args = CLI.from_args()
# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2.
# In Transformers route, ulysses_size is the total sequence-parallel degree.
-device_mesh = DeviceMesh(
- device_type=Platform.get_platform().device_prefix(),
- mesh=np.arange(4).reshape(2, 2),
- mesh_dim_names=('dp', 'fsdp'),
- ulysses_size=2,
+device_mesh = DeviceMesh.from_sizes(
+ dp_size=args.infra.dp_size,
+ fsdp_size=args.infra.fsdp_size,
+ ulysses_size=args.infra.ulysses_size,
)
twinkle.initialize(
- mode='local',
- nproc_per_node=4,
+ mode=args.infra.mode,
global_device_mesh=device_mesh,
- lazy_collect=False,
+ lazy_collect=args.infra.lazy_collect,
)
def eval(model):
+ eval_samples = args.training.eval_samples or 100
dataloader = DataLoader(
- dataset=partial(create_dataset, data_slice=range(100)),
- batch_size=4,
+ dataset=partial(create_dataset, data_slice=range(eval_samples)),
+ batch_size=args.training.batch_size,
device_mesh=device_mesh,
)
for _, batch in enumerate(dataloader):
- model.forward_only(inputs=batch, adapter_name='default')
- model.calculate_loss(adapter_name='default')
- return model.calculate_metric(is_training=False, adapter_name='default')
+ model.forward_only(inputs=batch, adapter_name=args.lora.adapter_name)
+ model.calculate_loss(adapter_name=args.lora.adapter_name)
+ return model.calculate_metric(is_training=False, adapter_name=args.lora.adapter_name)
def create_dataset(data_slice=None):
- dataset = Dataset(dataset_meta=DatasetMeta(DATASETS, data_slice=range(500)))
- dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
- dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'))
+ train_samples = args.training.train_samples or 500
+ dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=data_slice or range(train_samples)))
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
+ dataset.map(SelfCognitionProcessor(
+ args.extra.get('model_name', 'twinkle模型'),
+ args.extra.get('model_author', 'twinkle团队'),
+ ))
dataset.encode(batched=True)
return dataset
@@ -61,36 +57,38 @@ def create_dataset(data_slice=None):
def train():
dataloader = DataLoader(
dataset=partial(create_dataset, data_slice=None),
- batch_size=8,
+ batch_size=args.training.batch_size,
device_mesh=device_mesh,
)
model = TransformersModel(
- model_id=MODEL_ID,
+ model_id=args.model.model_id,
device_mesh=device_mesh,
- strategy='native_fsdp',
+ strategy=args.model.strategy,
)
# npu patch
if Torch.is_npu_available():
model = kernelize_model(model, mode='train', device='npu')
- lora_config = LoraConfig(target_modules='all-linear')
- model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1)
- model.set_optimizer('AdamW', lr=1e-4, adapter_name='default')
+ lora_config = LoraConfig(**args.get_lora_args())
+ model.add_adapter_to_model(args.lora.adapter_name, lora_config,
+ gradient_accumulation_steps=args.training.gradient_accumulation_steps)
+ model.set_optimizer(args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate,
+ adapter_name=args.lora.adapter_name)
model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler',
- num_warmup_steps=5,
+ scheduler_cls=args.scheduler.scheduler_cls,
+ num_warmup_steps=args.scheduler.num_warmup_steps,
num_training_steps=len(dataloader),
- adapter_name='default',
+ adapter_name=args.lora.adapter_name,
)
- logger.info(model.get_train_configs(adapter_name='default'))
+ logger.info(model.get_train_configs(adapter_name=args.lora.adapter_name))
logger.info(f'Total steps: {len(dataloader)}')
for step, batch in enumerate(dataloader):
- model.forward_backward(inputs=batch, adapter_name='default')
- model.clip_grad_and_step(adapter_name='default')
- if step % 20 == 0:
- metric = model.calculate_metric(is_training=True, adapter_name='default')
+ model.forward_backward(inputs=batch, adapter_name=args.lora.adapter_name)
+ model.clip_grad_and_step(adapter_name=args.lora.adapter_name)
+ if step % args.training.log_interval == 0:
+ metric = model.calculate_metric(is_training=True, adapter_name=args.lora.adapter_name)
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
model.save('last-checkpoint', interval=1)
diff --git a/cookbook/transformers/sp_fsdp_dense.sh b/cookbook/transformers/sp_fsdp_dense.sh
index 2a8bcf08b..841561fe4 100644
--- a/cookbook/transformers/sp_fsdp_dense.sh
+++ b/cookbook/transformers/sp_fsdp_dense.sh
@@ -1,11 +1,24 @@
-#!/bin/bash
-# To enable Transformers sequence parallelism, please set ulysses_size > 1.
-# ulysses_size is interpreted as the total sequence-parallel degree.
-# device_mesh = DeviceMesh(
-# device_type="cuda",
-# mesh=np.arange(4).reshape(2, 2),
-# mesh_dim_names=("dp", "fsdp"),
-# ulysses_size=2,
-# )
-#
-CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 sp_fsdp_dense.py
+#!/usr/bin/env bash
+set -euo pipefail
+
+# FSDP + Sequence Parallelism training.
+# To enable Transformers sequence parallelism, set ulysses-size > 1.
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash sp_fsdp_dense.sh --model-id ms://Qwen/Qwen3.5-4B --ulysses-size 4
+
+CUDA_VISIBLE_DEVICES=0,1,2,3 \
+ torchrun --nproc_per_node=4 sp_fsdp_dense.py \
+ --model-id ms://Qwen/Qwen3.5-4B \
+ --dataset-id ms://swift/self-cognition \
+ --template-cls Qwen3_5Template \
+ --dp-size 2 \
+ --fsdp-size 2 \
+ --ulysses-size 2 \
+ --batch-size 4 \
+ --lr 1e-4 \
+ --gradient-accumulation-steps 2 \
+ --train-samples 500 \
+ --log-interval 10 \
+ --model-name twinkle模型 \
+ --model-author twinkle团队 \
+ "$@"
diff --git a/docs/source_en/Components/Agentic/Envs.md b/docs/source_en/Components/Agentic/Envs.md
new file mode 100644
index 000000000..3e2e90b3a
--- /dev/null
+++ b/docs/source_en/Components/Agentic/Envs.md
@@ -0,0 +1,183 @@
+# Environments (Envs)
+
+The Envs module provides an RL execution environment abstraction for agentic training. Environments can participate in multi-turn rollouts interactively or evaluate completed trajectories in batch.
+
+## Env Base Class
+
+```python
+from twinkle_agentic.envs.base import Env, StepResult
+
+class Env(ABC):
+
+ def reset(self, trajectory=None) -> StepResult:
+ """Reset for a new episode."""
+
+ @abstractmethod
+ def step(self, tool_name: str, arguments: dict) -> StepResult:
+ """Execute a single action, return observation + reward + done."""
+
+ def tools(self) -> List[ToolInfo]:
+ """Return tool definitions available in this environment."""
+
+ def evaluate(self, trajectories, **kwargs) -> List[float]:
+ """Batch-evaluate completed trajectories, return rewards."""
+
+ def close(self) -> None:
+ """Release resources."""
+```
+
+### StepResult
+
+```python
+@dataclass
+class StepResult:
+ observation: str = '' # Environment observation after the action
+ reward: float = 0.0 # Scalar reward for this step
+ done: bool = False # Whether the episode is terminated
+ info: Dict[str, Any] = field(default_factory=dict) # Extra metadata
+```
+
+### Two Usage Modes
+
+1. **Interactive mode** (multi-turn rollout) — step-by-step execution:
+
+```python
+env = MyEnv()
+env.reset(trajectory)
+result = env.step('search', {'query': 'Python'})
+# ... repeat until result.done
+```
+
+2. **Batch evaluation mode** — evaluate completed trajectories:
+
+```python
+rewards = env.evaluate(completed_trajectories)
+```
+
+## EnvTool
+
+`EnvTool` wraps an `Env` as a `Tool`, bridging the environment with `ToolManager` and `MultiTurnRollout`.
+
+```python
+from twinkle_agentic.envs.env_tool import EnvTool
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+env = MyEnv()
+
+# Create one EnvTool per tool defined in the environment
+env_tools = EnvTool.from_env(env)
+
+# Register into ToolManager
+manager = ToolManager(env_tools)
+```
+
+### Key Features
+
+| Feature | Description |
+|---------|-------------|
+| `from_env(env)` | Factory: creates one `EnvTool` per tool in `env.tools()`. |
+| `last_result` | Stores the most recent `StepResult` for inspection. |
+| `done` | Property: whether the last step terminated the episode. |
+| `episode_reward` | Property: cumulative reward from `info['episode_reward']`. |
+
+### Manual Construction
+
+```python
+env_tool = EnvTool(
+ env=my_env,
+ tool_name='execute_code',
+ description='Execute Python code in a sandbox.',
+ parameters={
+ 'type': 'object',
+ 'properties': {
+ 'code': {'type': 'string', 'description': 'Python code to execute.'},
+ },
+ 'required': ['code'],
+ },
+)
+```
+
+## OpenEnv
+
+`OpenEnv` adapts an [OpenEnv](https://github.com/OpenEnv) WebSocket-based environment server as a synchronous Twinkle `Env`.
+
+```python
+from twinkle_agentic.envs.openenv import OpenEnv
+
+env = OpenEnv(
+ base_url='http://localhost:8000',
+ env_cls='coding_env.CodingEnv', # Optional typed client
+ env_kwargs={'message_timeout_s': 30},
+ tool_schema=[...], # Optional tool definitions
+)
+```
+
+### Parameters
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| `base_url` | `str` | URL of the running OpenEnv server. |
+| `env_cls` | `str` or class | Dotted import path or class for a typed client. `None` uses `GenericEnvClient`. |
+| `env_kwargs` | `Dict` | Extra kwargs for the client constructor. |
+| `tool_schema` | `List[ToolInfo]` | Tool definitions exposed via `tools()`. |
+| `action_mapper` | `Callable` | Custom function to map `(tool_name, args)` to the action dict sent to the server. |
+
+### Usage with Rollout
+
+```python
+from twinkle_agentic.envs.openenv import OpenEnv
+from twinkle_agentic.envs.env_tool import EnvTool
+from twinkle_agentic.tools.tool_manager import ToolManager
+from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout
+
+# Set up environment
+env = OpenEnv(base_url='http://localhost:8000', tool_schema=[...])
+env.reset()
+
+# Bridge to ToolManager
+env_tools = EnvTool.from_env(env)
+manager = ToolManager(env_tools)
+
+# Use in rollout
+rollout = APIMultiTurnRollout(api=api, tool_manager=manager, max_turns=10)
+results = rollout(trajectories)
+```
+
+### Implementing a Custom Environment
+
+```python
+from twinkle_agentic.envs.base import Env, StepResult
+
+class CodeExecutionEnv(Env):
+
+ def reset(self, trajectory=None):
+ self._sandbox = create_sandbox()
+ return StepResult(observation='Sandbox ready.')
+
+ def step(self, tool_name, arguments):
+ code = arguments.get('code', '')
+ output = self._sandbox.run(code)
+ return StepResult(
+ observation=output,
+ reward=1.0 if 'error' not in output.lower() else 0.0,
+ done=False,
+ )
+
+ def tools(self):
+ return [{
+ 'type': 'function',
+ 'function': {
+ 'name': 'execute_code',
+ 'description': 'Run Python code.',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'code': {'type': 'string'},
+ },
+ },
+ },
+ }]
+
+ def close(self):
+ self._sandbox.cleanup()
+```
diff --git a/docs/source_en/Components/Agentic/Multi-Turn-Tool-Usage.md b/docs/source_en/Components/Agentic/Multi-Turn-Tool-Usage.md
new file mode 100644
index 000000000..06a962a33
--- /dev/null
+++ b/docs/source_en/Components/Agentic/Multi-Turn-Tool-Usage.md
@@ -0,0 +1,205 @@
+# Multi-Turn Tool Usage Guide
+
+This guide shows how to set up and run multi-turn agentic rollouts with tool use in Twinkle.
+
+## Architecture Overview
+
+The agentic rollout pipeline consists of four key components:
+
+- **Tool** — implements a specific capability (search, code execution, etc.)
+- **ToolManager** — registers tools and dispatches LLM tool calls
+- **Env** (optional) — RL environment that exposes tools via `EnvTool`
+- **Rollout** — drives the multi-turn conversation loop
+
+## Quick Start: API-based Rollout
+
+The simplest way to run a multi-turn tool-use rollout using an OpenAI-compatible API:
+
+```python
+from twinkle_agentic.protocol.openai import OpenAI
+from twinkle_agentic.tools.base import Tool
+from twinkle_agentic.tools.tool_manager import ToolManager
+from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout
+from twinkle.data_format.sampling import SamplingParams
+
+# 1. Define tools
+class WeatherTool(Tool):
+ def __call__(self, tool_name, arguments):
+ city = arguments.get('city', 'unknown')
+ return f'The weather in {city} is sunny, 25°C.'
+
+ def tool_info(self):
+ return {
+ 'type': 'function',
+ 'function': {
+ 'name': 'get_weather',
+ 'description': 'Get the current weather for a city.',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'city': {'type': 'string', 'description': 'City name.'},
+ },
+ 'required': ['city'],
+ },
+ },
+ }
+
+# 2. Set up ToolManager
+manager = ToolManager([WeatherTool()])
+
+# 3. Create API client
+api = OpenAI(model='qwen3.5-32b', base_url='http://localhost:8000/v1')
+
+# 4. Create rollout
+rollout = APIMultiTurnRollout(
+ api=api,
+ tool_manager=manager,
+ sampling_params=SamplingParams(temperature=0.7, max_tokens=2048),
+ max_turns=6,
+ concurrency=8,
+)
+
+# 5. Prepare trajectories
+trajectories = [
+ {
+ 'messages': [
+ {'role': 'user', 'content': "What's the weather like in Beijing?"},
+ ],
+ },
+]
+
+# 6. Run rollout
+results = rollout(trajectories)
+for r in results:
+ print(f"Turns: {r['turns']}, Stop: {r['stop_reason']}")
+ for msg in r['messages']:
+ print(f" [{msg['role']}] {msg.get('content', '')[:100]}")
+```
+
+## Training Integration: vLLM-based Rollout
+
+For RLHF training, use `MultiTurnRollout` which produces `input_ids` and `labels`:
+
+```python
+from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
+from twinkle.data_format.sampling import SamplingParams
+
+rollout = MultiTurnRollout(
+ sampler=vllm_sampler, # vLLMSampler instance
+ template=template, # Chat template
+ tool_manager=manager,
+ sampling_params=SamplingParams(temperature=0.7, max_tokens=4096),
+ max_turns=6,
+ max_trajectory_tokens=8192,
+ trace_dir='rollout_traces/',
+)
+
+# In GRPO training loop
+results = rollout(batch_trajectories)
+# results contain input_ids, labels, logprobs for training
+```
+
+## Using Environments as Tools
+
+Bridge an RL environment into the tool pipeline:
+
+```python
+from twinkle_agentic.envs.base import Env, StepResult
+from twinkle_agentic.envs.env_tool import EnvTool
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+# Define environment
+class CodeEnv(Env):
+ def step(self, tool_name, arguments):
+ code = arguments.get('code', '')
+ # Execute code in sandbox
+ result = execute_in_sandbox(code)
+ return StepResult(observation=result, reward=1.0, done=False)
+
+ def tools(self):
+ return [{
+ 'type': 'function',
+ 'function': {
+ 'name': 'run_python',
+ 'description': 'Execute Python code.',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'code': {'type': 'string'},
+ },
+ 'required': ['code'],
+ },
+ },
+ }]
+
+# Bridge Env -> Tool -> ToolManager
+env = CodeEnv()
+env_tools = EnvTool.from_env(env)
+manager = ToolManager(env_tools)
+
+# Use manager in rollout as usual
+rollout = APIMultiTurnRollout(api=api, tool_manager=manager, max_turns=10)
+```
+
+## Using OpenEnv Environments
+
+Connect to a remote OpenEnv WebSocket server:
+
+```python
+from twinkle_agentic.envs.openenv import OpenEnv
+from twinkle_agentic.envs.env_tool import EnvTool
+
+env = OpenEnv(
+ base_url='http://localhost:8000',
+ env_cls='coding_env.CodingEnv',
+ tool_schema=[{
+ 'type': 'function',
+ 'function': {
+ 'name': 'submit',
+ 'description': 'Submit code solution.',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'code': {'type': 'string'},
+ },
+ },
+ },
+ }],
+)
+
+env.reset()
+env_tools = EnvTool.from_env(env)
+manager = ToolManager(env_tools)
+```
+
+## Per-Trajectory Tool Managers
+
+For scenarios where each trajectory needs its own tool set (e.g., trajectory-bound state):
+
+```python
+# Create per-trajectory managers
+managers = []
+for traj in trajectories:
+ env = create_env_for(traj)
+ env_tools = EnvTool.from_env(env)
+ managers.append(ToolManager(env_tools))
+
+# Pass as a list (aligned 1:1 with trajectories)
+results = rollout(trajectories, tool_manager=managers)
+```
+
+## Trace Debugging
+
+Both rollout implementations support trace dumps for debugging:
+
+```python
+rollout = APIMultiTurnRollout(
+ api=api,
+ tool_manager=manager,
+ trace_dir='traces/',
+ trace_callback=lambda t: t['turns'] > 1, # Only store multi-turn
+ success_callback=lambda t: t.get('stop_reason') == 'stop',
+)
+```
+
+Trace files are saved as `{step}-{ok|fail}-{id}.json` with the full conversation and metadata.
diff --git a/docs/source_en/Components/Agentic/Preprocessor.md b/docs/source_en/Components/Agentic/Preprocessor.md
new file mode 100644
index 000000000..b646739c0
--- /dev/null
+++ b/docs/source_en/Components/Agentic/Preprocessor.md
@@ -0,0 +1,189 @@
+# Agentic Preprocessor
+
+The agentic preprocessor module provides a pipeline-based data quality filtering framework for multi-turn conversation datasets. It is designed for cleaning and filtering training data before RLHF / agentic fine-tuning.
+
+## QualityPreprocessor
+
+`QualityPreprocessor` is a thin pipeline runner that accepts a list of filter callables and runs them in sequence. Each step receives a list of rows, returns `(kept, dropped)`, and the pipeline logs per-step statistics.
+
+```python
+from twinkle_agentic.preprocessor import QualityPreprocessor, HardFilter, DeadLoopFilter
+
+pipeline = [
+ HardFilter(min_user_chars=10),
+ DeadLoopFilter(),
+]
+preprocessor = QualityPreprocessor(pipeline, dropped_log_path='dropped.jsonl')
+
+# rows is a dict of columns (Dataset.map format)
+cleaned = preprocessor(rows)
+```
+
+### Parameters
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| `pipeline` | `List[Callable]` | Ordered list of filter steps. Each step takes `List[Dict]` and returns `(kept, dropped)`. |
+| `dropped_log_path` | `str` | Optional JSONL file path for logging dropped rows with step name and reason. |
+
+## Built-in Filters
+
+### HardFilter
+
+Rule-based filter that removes trivially bad rows using deterministic rules. Supports multi-language detection (EN/ZH/JA/KO).
+
+```python
+from twinkle_agentic.preprocessor import HardFilter
+
+f = HardFilter(
+ min_user_chars=10, # Min chars for non-CJK user query
+ min_user_chars_cjk=6, # Min chars for CJK user query
+ min_assistant_chars_2turn=80, # Min assistant reply length (2-turn)
+ min_thinking_chars=200, # Min thinking chain length to exempt
+ system_deny_keywords=['hack', 'exploit'],
+ max_chars_per_round=50000,
+ max_total_chars=200000,
+ max_rounds=50,
+)
+```
+
+**Drop reasons:** `trivial_single_turn`, `shallow_reply`, `all_empty_assistant`, `system_deny_keyword`, `round_too_long`, `total_too_long`, `too_many_rounds`
+
+### DeadLoopFilter
+
+Detects assistant messages exhibiting hesitation or dead-loop patterns — repetitive self-corrections, cascading corrections, and high n-gram repetition.
+
+```python
+from twinkle_agentic.preprocessor import DeadLoopFilter
+
+f = DeadLoopFilter(
+ hesitation_density_threshold=7.0, # Markers per 1000 chars (response)
+ cascade_threshold=5, # Cascade markers in window
+ cascade_window=800, # Window size in chars
+ repetition_threshold=0.45, # N-gram repetition ratio
+ think_hesitation_density_threshold=15.0, # Laxer for blocks
+ think_repetition_threshold=0.65,
+)
+```
+
+Uses separate threshold profiles for `` reasoning blocks (laxer, free to ramble) and visible response (stricter).
+
+### DedupFilter
+
+Global longest-wins deduplication. The signature is derived from the first real user turn (head+tail) and the first assistant reply.
+
+```python
+from twinkle_agentic.preprocessor import DedupFilter
+
+f = DedupFilter(prefix_chars=100, asst_chars=100)
+kept, dropped = f(all_rows) # Must see entire dataset in one call
+```
+
+> **Note:** `DedupFilter` requires the full dataset in a single call. Do **not** place it inside `QualityPreprocessor` (which processes per-batch). Run it separately before or after the pipeline.
+
+### RefuseFilter
+
+Detects self-referential refusals in the first assistant reply (e.g., "I cannot help with that"). Multi-language pattern matching (EN/ZH/JA/KO).
+
+```python
+from twinkle_agentic.preprocessor import RefuseFilter
+
+f = RefuseFilter(check_window=600) # Only check first N chars
+```
+
+### TokenSoupFilter
+
+Detects garbled / token-soup output by checking for replacement characters, control characters, private-use Unicode, leaked special tokens, single-character repetition, and script chaos.
+
+```python
+from twinkle_agentic.preprocessor import TokenSoupFilter
+
+f = TokenSoupFilter(
+ replacement_char_ratio=0.02,
+ special_token_count=20,
+ script_chaos_threshold=0.55,
+)
+```
+
+### PIIPresidioFilter
+
+Multi-language PII detection and rewriting using Microsoft Presidio + spaCy NER + Faker. Detects and replaces personal identifiable information (names, emails, phone numbers, addresses, etc.).
+
+```python
+from twinkle_agentic.preprocessor import PIIPresidioFilter
+
+f = PIIPresidioFilter(languages=['en', 'zh'])
+```
+
+### IntentClassifier
+
+Heuristic intent classifier that tags each row with detected intents. Pluggable detector pipeline.
+
+```python
+from twinkle_agentic.preprocessor import IntentClassifier
+
+classifier = IntentClassifier()
+```
+
+**Intent categories:** `tool_call`, `code`, `math`, `complex_logic`, `reasoning`, `user_dissatisfaction`, `other`
+
+### ScoreFilter
+
+Pluggable scorer-based filter with built-in scorers for character-level metrics, semantic similarity, and code execution.
+
+```python
+from twinkle_agentic.preprocessor import ScoreFilter
+
+f = ScoreFilter()
+```
+
+**Built-in scorers:** `ChrMinScorer`, `SIFDScorer`, `PassNScorer`, `ParaphraseScorer`
+
+### ModelFilter
+
+Filters rows by model ID whitelist.
+
+```python
+from twinkle_agentic.preprocessor import ModelFilter
+
+f = ModelFilter(allowed_models=['qwen3.5-4b', 'qwen3.5-32b'])
+```
+
+### MessageNormalizer
+
+Three-pass message normalization: heartbeat stripping, tool-call rewriting, and consecutive same-role message merging.
+
+```python
+from twinkle_agentic.preprocessor import MessageNormalizer
+
+normalizer = MessageNormalizer()
+```
+
+## Complete Pipeline Example
+
+```python
+from twinkle_agentic.preprocessor import (
+ QualityPreprocessor,
+ HardFilter,
+ DeadLoopFilter,
+ RefuseFilter,
+ TokenSoupFilter,
+ MessageNormalizer,
+ DedupFilter,
+)
+
+# Step 1: Global dedup (must run on full dataset)
+dedup = DedupFilter()
+rows, _ = dedup(all_rows)
+
+# Step 2: Per-batch pipeline
+pipeline = [
+ HardFilter(min_user_chars=10, max_rounds=30),
+ DeadLoopFilter(),
+ RefuseFilter(),
+ TokenSoupFilter(),
+ MessageNormalizer(),
+]
+preprocessor = QualityPreprocessor(pipeline, dropped_log_path='dropped.jsonl')
+cleaned = preprocessor(rows)
+```
diff --git a/docs/source_en/Components/Agentic/Protocol.md b/docs/source_en/Components/Agentic/Protocol.md
new file mode 100644
index 000000000..a44c5c3a2
--- /dev/null
+++ b/docs/source_en/Components/Agentic/Protocol.md
@@ -0,0 +1,91 @@
+# Protocol
+
+The Protocol module provides an abstract LLM API client interface and its OpenAI-compatible implementation. It bridges Twinkle's `Trajectory` / `SamplingParams` data types with external LLM inference services.
+
+## API Base Class
+
+```python
+from abc import ABC, abstractmethod
+from twinkle.data_format import Trajectory
+from twinkle.data_format.message import Message
+from twinkle.data_format.sampling import SamplingParams
+
+class API(ABC):
+ """Abstract LLM API client: Trajectory + SamplingParams -> assistant Message(s)."""
+
+ @abstractmethod
+ def __call__(
+ self,
+ trajectory: Trajectory,
+ sampling_params: SamplingParams,
+ **kwargs,
+ ) -> Union[Message, List[Message]]:
+ raise NotImplementedError()
+```
+
+The `API` class defines a simple contract: given a conversation trajectory and sampling parameters, return one or more assistant messages.
+
+## OpenAI
+
+`OpenAI` is the built-in implementation that works with any endpoint speaking the `/v1/chat/completions` protocol (OpenAI, Azure OpenAI, vLLM, SGLang, Ollama, etc.).
+
+```python
+from twinkle_agentic.protocol.openai import OpenAI
+
+api = OpenAI(
+ model='qwen3.5-32b',
+ base_url='http://localhost:8000/v1',
+ api_key='EMPTY',
+)
+```
+
+### Parameters
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| `model` | `str` | Model name to pass in the API request. |
+| `api_key` | `str` | API key. Defaults to the `OPENAI_API_KEY` environment variable. |
+| `base_url` | `str` | Base URL of the API endpoint (e.g. `http://localhost:8000/v1`). |
+| `client_kwargs` | `Dict` | Extra keyword arguments forwarded to the `openai.OpenAI` client constructor. |
+
+### Usage
+
+```python
+from twinkle.data_format import Trajectory
+from twinkle.data_format.sampling import SamplingParams
+
+trajectory = {
+ 'messages': [
+ {'role': 'user', 'content': 'What is the capital of France?'},
+ ]
+}
+
+sp = SamplingParams(temperature=0.7, max_tokens=512)
+reply = api(trajectory, sp)
+# reply is a Message dict: {'role': 'assistant', 'content': '...'}
+```
+
+### Features
+
+- **Tool calls**: Automatically maps `trajectory['tools']` to the API request and parses structured `tool_calls` from the response.
+- **Reasoning content**: Preserves `reasoning_content` from models that support it (e.g., o1-style reasoning).
+- **Finish reason**: Surfaces `finish_reason` on the returned message so multi-turn drivers can detect length-cap truncation.
+- **Multi-sample**: When `sampling_params.num_samples > 1`, returns a list of messages (one per choice).
+
+### Custom API Client
+
+To integrate a non-OpenAI API, subclass `API`:
+
+```python
+from twinkle_agentic.protocol.base import API
+
+class MyCustomAPI(API):
+
+ def __call__(self, trajectory, sampling_params, **kwargs):
+ # Call your custom endpoint
+ response = my_llm_client.chat(
+ messages=trajectory['messages'],
+ temperature=sampling_params.temperature,
+ )
+ return {'role': 'assistant', 'content': response.text}
+```
diff --git a/docs/source_en/Components/Agentic/Rollout.md b/docs/source_en/Components/Agentic/Rollout.md
new file mode 100644
index 000000000..94b143454
--- /dev/null
+++ b/docs/source_en/Components/Agentic/Rollout.md
@@ -0,0 +1,140 @@
+# Multi-Turn Rollout
+
+The Rollout module provides multi-turn conversation rollout engines for agentic RLHF training. Two implementations are available: `MultiTurnRollout` for batched vLLM sampling and `APIMultiTurnRollout` for OpenAI-compatible API endpoints.
+
+## Rollout Base Class
+
+```python
+from abc import ABC, abstractmethod
+from twinkle.data_format import Trajectory
+
+class Rollout(ABC):
+
+ @abstractmethod
+ def __call__(self, trajectories: List[Trajectory], **kwargs) -> List[Trajectory]:
+ raise NotImplementedError()
+```
+
+All rollouts accept a list of trajectories and return the same number of trajectories with additional fields (`messages`, `turns`, `stop_reason`, `truncated`).
+
+## MultiTurnRollout
+
+Batched multi-turn rollout engine that uses a vLLM sampler for generation. All active trajectories are sampled in a single batched call per turn for maximum throughput.
+
+### Per-turn Loop
+
+1. Encode each trajectory into an `InputFeature` with a generation prompt
+2. Batch `sampler.sample(active_pifs)` — all live trajectories in parallel
+3. Check termination: `stop_reason == 'length'`, no tool calls, or max turns reached
+4. Dispatch tools via `ToolManager`, append tool responses
+5. Compute bridge tokens (tool turns + generation prompt) with `labels = -100`
+6. Repeat until all trajectories are done
+
+```python
+from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
+from twinkle_agentic.tools.tool_manager import ToolManager
+from twinkle.data_format.sampling import SamplingParams
+
+rollout = MultiTurnRollout(
+ sampler=vllm_sampler,
+ template=template,
+ tool_manager=tool_manager,
+ sampling_params=SamplingParams(temperature=0.7, max_tokens=4096),
+ max_turns=6,
+ max_trajectory_tokens=8192,
+ trace_dir='rollout_traces/',
+)
+
+# Run rollout
+results = rollout(trajectories)
+```
+
+### Parameters
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| `sampler` | Sampler | vLLM sampler instance for batched generation. |
+| `template` | `Template` | Chat template for encoding/decoding. |
+| `tool_manager` | `ToolManager` | Tool dispatcher. Can also be passed per-call. |
+| `sampling_params` | `SamplingParams` | Default sampling parameters. |
+| `max_turns` | `int` | Maximum number of turns per trajectory (default: 6). |
+| `max_trajectory_tokens` | `int` | Max total token length; exceeding truncates the trajectory. |
+| `trace_dir` | `str` | Directory for per-trajectory JSON trace dumps. |
+| `trace_callback` | `Callable` | Decides whether to store a trajectory trace. |
+| `success_callback` | `Callable` | Decides filename prefix (`ok-` vs `fail-`). |
+
+### Output Fields
+
+Each output trajectory dict includes:
+
+| Field | Type | Description |
+|-------|------|-------------|
+| `messages` | `List[Dict]` | Full conversation including tool turns. |
+| `input_ids` | `List[int]` | Token IDs of the full sequence. |
+| `labels` | `List[int]` | Training labels (`-100` for non-trainable tokens). |
+| `turns` | `int` | Number of turns performed. |
+| `stop_reason` | `str` | `'stop'` / `'length'` |
+| `truncated` | `bool` | Whether the trajectory was truncated. |
+| `logprobs` | `List` | Per-token log probabilities (if available). |
+
+### Ray Remote Support
+
+`MultiTurnRollout` is decorated with `@remote_class()`, enabling transparent deployment as a Ray actor:
+
+```python
+# The rollout can run as a Ray remote actor
+rollout_actor = MultiTurnRollout.remote(sampler=sampler, template=template, ...)
+results = ray.get(rollout_actor.__call__.remote(trajectories))
+```
+
+## APIMultiTurnRollout
+
+Multi-turn rollout over an OpenAI-compatible chat-completions API. Each trajectory runs independently in a thread pool for network concurrency.
+
+```python
+from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout
+from twinkle_agentic.protocol.openai import OpenAI
+
+api = OpenAI(model='qwen3.5-32b', base_url='http://localhost:8000/v1')
+
+rollout = APIMultiTurnRollout(
+ api=api,
+ tool_manager=tool_manager,
+ sampling_params=SamplingParams(temperature=0.7),
+ max_turns=6,
+ concurrency=8,
+ trace_dir='api_traces/',
+)
+
+results = rollout(trajectories)
+```
+
+### Parameters
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| `api` | `OpenAI` | OpenAI-compatible API client. |
+| `tool_manager` | `ToolManager` | Tool dispatcher (single or per-trajectory list). |
+| `sampling_params` | `SamplingParams` | Default sampling parameters. |
+| `max_turns` | `int` | Maximum turns per trajectory (default: 6). |
+| `concurrency` | `int` | Thread pool size for parallel API calls (default: 8). |
+| `extra_body` | `Dict` | Extra fields to include in API requests. |
+| `trace_dir` | `str` | Directory for trace dumps. |
+
+### Stop Reasons
+
+| Reason | Description |
+|--------|-------------|
+| `stop` | Assistant responded without tool calls (natural end). |
+| `length` | API returned `finish_reason='length'` (token limit). |
+| `max_turns` | Reached `max_turns` limit. |
+| `api_error` | API call or tool execution raised an exception. |
+
+## Choosing Between Rollouts
+
+| Feature | MultiTurnRollout | APIMultiTurnRollout |
+|---------|-----------------|---------------------|
+| **Backend** | vLLM sampler (local GPU) | OpenAI-compatible API |
+| **Training integration** | Produces `input_ids` / `labels` for GRPO | Messages only (for data collection) |
+| **Batching** | GPU-level batch parallelism | Network-level thread concurrency |
+| **Use case** | Online RLHF training loop | Offline data generation / evaluation |
diff --git a/docs/source_en/Components/Agentic/Tools.md b/docs/source_en/Components/Agentic/Tools.md
new file mode 100644
index 000000000..e53c7371f
--- /dev/null
+++ b/docs/source_en/Components/Agentic/Tools.md
@@ -0,0 +1,119 @@
+# Tools & ToolManager
+
+The Tools module provides an abstract tool interface and a central tool dispatcher (`ToolManager`) for agentic multi-turn rollouts. Tools follow the OpenAI function-calling schema for seamless integration with LLM tool-use capabilities.
+
+## Tool Base Class
+
+```python
+from abc import ABC, abstractmethod
+from twinkle.data_format import Tool as ToolInfo
+
+class Tool(ABC):
+
+ @abstractmethod
+ def __call__(self, tool_name: str, arguments: Dict[str, Any]) -> str:
+ """Execute the tool and return a string result."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def tool_info(self) -> ToolInfo:
+ """Return OpenAI-compatible tool schema."""
+ raise NotImplementedError
+```
+
+### Implementing a Custom Tool
+
+```python
+from twinkle_agentic.tools.base import Tool
+
+class SearchTool(Tool):
+
+ def __call__(self, tool_name: str, arguments: dict) -> str:
+ query = arguments.get('query', '')
+ # Perform search logic
+ return f'Search results for: {query}'
+
+ def tool_info(self):
+ return {
+ 'type': 'function',
+ 'function': {
+ 'name': 'search',
+ 'description': 'Search the web for information.',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'query': {
+ 'type': 'string',
+ 'description': 'The search query.',
+ },
+ },
+ 'required': ['query'],
+ },
+ },
+ }
+```
+
+## ToolManager
+
+`ToolManager` is a registry and dispatcher for tools. It resolves tool calls from the LLM's structured output and routes them to the correct tool implementation.
+
+```python
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+# Initialize with a list of Tool instances
+manager = ToolManager([search_tool, calculator_tool])
+
+# Or with a dict
+manager = ToolManager({'search': search_tool, 'calc': calculator_tool})
+
+# Or register dynamically
+manager = ToolManager()
+manager.register(search_tool)
+manager.register(calculator_tool)
+```
+
+### Key Methods
+
+| Method | Description |
+|--------|-------------|
+| `register(tool)` | Register a tool (name extracted from `tool_info()`). |
+| `unregister(name)` | Remove a tool by name. |
+| `names()` | List all registered tool names. |
+| `copy()` | Create a shallow copy of the manager. |
+| `tool_infos()` | Return a list of all tool schemas (for API requests). |
+| `__call__(tool_call)` | Dispatch a tool call and return the result string. |
+
+### Dispatching Tool Calls
+
+`ToolManager` accepts OpenAI-shaped tool call dicts:
+
+```python
+tool_call = {
+ 'id': 'call_1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'search',
+ 'arguments': '{"query": "Python tutorials"}',
+ },
+}
+
+result = manager(tool_call)
+# result: 'Search results for: Python tutorials'
+```
+
+**Error handling:** If the tool name is unknown, arguments are invalid JSON, or the tool raises an exception, `ToolManager` returns a descriptive error string instead of raising — this keeps the rollout loop running.
+
+### Integration with Rollout
+
+```python
+from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
+
+rollout = MultiTurnRollout(
+ sampler=sampler,
+ template=template,
+ tool_manager=manager, # Pass manager to rollout
+ max_turns=6,
+)
+```
+
+The rollout engine calls `manager(tool_call)` for each tool call generated by the model, and appends the result as a `{'role': 'tool', 'content': result}` message.
diff --git a/docs/source_en/Components/Agentic/index.rst b/docs/source_en/Components/Agentic/index.rst
new file mode 100644
index 000000000..802034366
--- /dev/null
+++ b/docs/source_en/Components/Agentic/index.rst
@@ -0,0 +1,11 @@
+Agentic
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Preprocessor.md
+ Protocol.md
+ Rollout.md
+ Tools.md
+ Envs.md
+ Multi-Turn-Tool-Usage.md
diff --git a/docs/source_en/Components/CLI/CLI.md b/docs/source_en/Components/CLI/CLI.md
new file mode 100644
index 000000000..9efe3c894
--- /dev/null
+++ b/docs/source_en/Components/CLI/CLI.md
@@ -0,0 +1,134 @@
+# CLI
+
+The CLI module provides a unified configuration system for Twinkle training scripts. It merges multiple configuration sources (environment variables, `.env` files, YAML configs, and command-line arguments) into a single `Args` dataclass with typed argument groups.
+
+## Resolution Order
+
+Configuration sources are applied in order (later wins):
+
+1. **Dataclass defaults** — sensible out-of-the-box values
+2. **`.env` file** — project-local overrides
+3. **Environment variables** — `TWINKLE_` prefix or bare keys
+4. **YAML config file** — `--config path/to/config.yaml`
+5. **CLI overrides** — `--key value` (highest priority)
+
+All keys are case-insensitive and dash/underscore equivalent.
+
+## Quick Start
+
+```python
+from twinkle.cli import CLI
+
+args = CLI.from_args()
+
+# Access typed groups
+print(args.model.model_id)
+print(args.training.max_steps)
+print(args.optimizer.learning_rate)
+
+# Or get dictionaries for component construction
+model_kwargs = args.get_model_args()
+optimizer_kwargs = args.get_optimizer_args()
+```
+
+## Argument Groups
+
+| Group | Class | Key Parameters |
+|:------|:------|:---------------|
+| model | `ModelArgs` | `model_id`, `mixed_precision`, `strategy`, `gradient_checkpointing` |
+| lora | `LoraArgs` | `use_lora`, `lora_r`, `lora_alpha`, `lora_target_modules` |
+| dataset | `DatasetArgs` | `dataset_id`, `subset_name`, `split`, `streaming` |
+| template | `TemplateArgs` | `template_cls`, `max_length`, `truncation_strategy`, `enable_thinking` |
+| training | `TrainingArgs` | `max_steps`, `batch_size`, `micro_batch_size`, `output_dir`, `save_steps` |
+| optimizer | `OptimizerArgs` | `optimizer_cls`, `learning_rate`, `weight_decay`, `max_grad_norm` |
+| scheduler | `SchedulerArgs` | `scheduler_cls`, `num_warmup_steps`, `t_max` |
+| loss | `LossArgs` | `loss_cls`, `epsilon`, `beta`, `sft_weight` |
+| sampler | `SamplerArgs` | `sampler_type`, `gpu_memory_utilization`, `tensor_parallel_size` |
+| sampling | `SamplingArgs` | `max_tokens`, `temperature`, `top_k`, `top_p`, `num_samples` |
+| infra | `InfraArgs` | `mode`, `nproc_per_node`, `model_gpus`, `sampler_gpus`, `dp_size` |
+| server | `ServerArgs` | `config`, `host`, `port`, `ray_namespace` |
+| rl | `RLArgs` | `num_generations`, `advantage_type`, `reward_fns` |
+| checkpoint | `CheckpointArgs` | `save_optimizer`, `merge_and_sync`, `platform` |
+
+## YAML Configuration
+
+```yaml
+# config.yaml
+model_id: ms://Qwen/Qwen3.5-4B
+mixed_precision: bf16
+strategy: accelerate
+
+use_lora: true
+lora_r: 16
+lora_alpha: 32
+
+dataset_id: ms://swift/self-cognition
+max_length: 4096
+
+batch_size: 8
+micro_batch_size: 2
+max_steps: 200
+learning_rate: 1e-5
+
+mode: ray
+nproc_per_node: 8
+model_gpus: 4
+sampler_gpus: 4
+```
+
+## Command-Line Usage
+
+```bash
+# Use with YAML config
+python train.py --config config.yaml
+
+# Override specific values
+python train.py --config config.yaml --learning_rate 5e-6 --max_steps 500
+
+# Boolean flags
+python train.py --use_lora --no_gradient_checkpointing
+
+# Without config file (all from CLI)
+python train.py --model_id ms://Qwen/Qwen3.5-4B --batch_size 4
+```
+
+## Environment Variables
+
+```bash
+# TWINKLE_ prefix
+export TWINKLE_MODEL_ID=ms://Qwen/Qwen3.5-4B
+export TWINKLE_LEARNING_RATE=1e-5
+
+# Or bare keys (when recognized)
+export MODEL_ID=ms://Qwen/Qwen3.5-4B
+```
+
+## Field Aliases
+
+Some fields support aliases for convenience:
+
+- `learning_rate` ↔ `lr`
+- `nproc_per_node` ↔ `num_gpus`
+- `max_tokens` ↔ `max_new_tokens`
+- `use_megatron=true` → `strategy=native_fsdp`
+
+## Custom Config Sources
+
+You can extend the CLI with custom configuration sources:
+
+```python
+from twinkle.cli.cli import ConfigSource, Args, ConfigResolver
+
+class RemoteConfigSource(ConfigSource):
+ def __init__(self, url: str):
+ self.url = url
+
+ def load(self) -> dict:
+ import requests
+ return requests.get(self.url).json()
+
+# Apply custom source
+args = Args()
+resolver = ConfigResolver(args)
+resolver.apply(RemoteConfigSource('http://config-server/my-config').load())
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/Gym/index.rst" b/docs/source_en/Components/CLI/index.rst
similarity index 76%
rename from "docs/source_zh/\347\273\204\344\273\266/Gym/index.rst"
rename to docs/source_en/Components/CLI/index.rst
index 85d941b97..cf59fa766 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/Gym/index.rst"
+++ b/docs/source_en/Components/CLI/index.rst
@@ -1,6 +1,6 @@
-Gym
+CLI
===============
.. toctree::
:maxdepth: 1
- Gym.md
+ CLI.md
diff --git a/docs/source_en/Components/Gym/Gym.md b/docs/source_en/Components/Gym/Gym.md
deleted file mode 100644
index 4db355b8a..000000000
--- a/docs/source_en/Components/Gym/Gym.md
+++ /dev/null
@@ -1,26 +0,0 @@
-# Gym
-
-The Gym component provides an interface for reinforcement learning environments in Twinkle.
-
-```python
-from twinkle.gym import Gym
-
-class CustomGym(Gym):
-
- def step(self, trajectories, **kwargs):
- """
- Execute one RL step: evaluate trajectories and return rewards.
-
- Args:
- trajectories: Model-generated trajectories to evaluate
- **kwargs: Additional arguments
-
- Returns:
- Reward values for each trajectory
- """
- ...
-```
-
-The Gym abstraction allows you to plug in custom RL environments that interact with the training loop. It decouples reward computation and environment interaction from the core training logic.
-
-> Gym is typically used in on-policy RL training where the environment needs to provide feedback on model-generated outputs.
diff --git a/docs/source_en/Components/Loss/InfoNCELoss.md b/docs/source_en/Components/Loss/InfoNCELoss.md
new file mode 100644
index 000000000..1c6cada5d
--- /dev/null
+++ b/docs/source_en/Components/Loss/InfoNCELoss.md
@@ -0,0 +1,68 @@
+# InfoNCE Loss
+
+The `InfonceLoss` implements contrastive learning with in-batch negatives and optional cross-rank gathering. It is designed for embedding/retrieval model training.
+
+## Usage
+
+```python
+from twinkle.loss import InfonceLoss
+
+loss_fn = InfonceLoss(
+ temperature=0.1,
+ use_batch=True, # Enable in-batch negatives
+ hard_negatives=7, # Fix negative count per sample
+ mask_fake_negative=True, # Mask false negatives
+ fake_neg_margin=0.1, # Margin for false negative detection
+)
+
+model.set_loss(loss_fn)
+```
+
+## Input Format
+
+Each sample is laid out as `anchor(1) + positive(1) + negatives(n)` in a flat embedding tensor. The `inputs['labels']` is a 1-D mask where `1` marks the start of each group.
+
+```
+embeddings: [a0, p0, n0_1, n0_2, a1, p1, n1_1, n1_2, ...]
+labels: [ 1, 0, 0, 0, 1, 0, 0, 0, ...]
+```
+
+## Parameters
+
+| Parameter | Type | Default | Description |
+|:----------|:-----|:--------|:------------|
+| `temperature` | float | 0.1 | Logit scaling factor |
+| `use_batch` | bool | True | Use cross-sample in-batch negatives |
+| `hard_negatives` | int | None | Fix per-sample negative count (truncate/upsample) |
+| `mask_fake_negative` | bool | False | Mask logits > positive + margin |
+| `fake_neg_margin` | float | 0.1 | Threshold for false negative masking |
+| `include_qq` | bool | False | Add query-query similarity block |
+| `include_dd` | bool | False | Add doc-doc similarity block |
+
+## Cross-Rank Gathering
+
+When `use_batch=True` and distributed training is active, embeddings are gathered from all DP ranks to maximize in-batch negative diversity. Only the local shard retains gradients.
+
+## Similarity Blocks
+
+The loss supports three similarity blocks for comprehensive contrastive learning:
+
+- **Q→D (default)**: Query to all documents — primary contrastive signal
+- **Q→Q** (`include_qq=True`): Query to all other queries — prevents query collapse
+- **D→D** (`include_dd=True`): Document to all other documents — Qwen3-Embedding style
+
+## Example: Embedding Training
+
+```python
+from twinkle.loss import InfonceLoss
+from twinkle.metric import EmbeddingMetric
+
+# Configure model for embedding
+model.set_loss(InfonceLoss(temperature=0.05, use_batch=True, include_qq=True))
+model.set_metric(EmbeddingMetric(device_mesh=mesh, process_group=pg))
+
+# Training loop
+for batch in dataloader:
+ model.forward_backward(batch)
+ model.clip_grad_and_step()
+```
diff --git a/docs/source_en/Components/Loss/index.rst b/docs/source_en/Components/Loss/index.rst
index dceaf20f0..c6f3cb2c9 100644
--- a/docs/source_en/Components/Loss/index.rst
+++ b/docs/source_en/Components/Loss/index.rst
@@ -3,6 +3,19 @@ Loss
.. toctree::
:maxdepth: 1
+ CrossEntropy.md
+ ChunkedCrossEntropy.md
+ DPOLoss.md
+ GKDLoss.md
+ GRPOLoss.md
+ InfoNCELoss.md
+ MSELoss.md
+ Building-Loss.md
+Loss
+===============
+.. toctree::
+ :maxdepth: 1
+
CrossEntropy.md
ChunkedCrossEntropy.md
DPOLoss.md
diff --git a/docs/source_en/Components/Metrics/EmbeddingMetric.md b/docs/source_en/Components/Metrics/EmbeddingMetric.md
new file mode 100644
index 000000000..e8cea06ba
--- /dev/null
+++ b/docs/source_en/Components/Metrics/EmbeddingMetric.md
@@ -0,0 +1,42 @@
+# EmbeddingMetric
+
+The `EmbeddingMetric` tracks embedding quality during contrastive (InfoNCE) training. It reports anchor-positive cosine similarity statistics and in-batch negative similarity.
+
+## Usage
+
+```python
+from twinkle.metric import EmbeddingMetric
+
+metric = EmbeddingMetric(device_mesh=device_mesh, process_group=process_group)
+
+# During training
+metric.accumulate(inputs, outputs)
+
+# At log interval
+results = metric.calculate()
+# results: {
+# 'pos_sim': '0.8523', # Mean anchor-positive cosine similarity
+# 'pos_sim_min': '0.7102', # Min across batch
+# 'pos_sim_max': '0.9451', # Max across batch
+# 'neg_sim': '0.2134', # Mean anchor-negative similarity
+# 'loss': '0.3412', # Average InfoNCE loss
+# 'grad_norm': '1.234567', # Gradient norm
+# }
+```
+
+## Reported Metrics
+
+| Metric | Description |
+|:-------|:------------|
+| `pos_sim` | Mean cosine similarity between anchors and their positives |
+| `pos_sim_min` | Minimum anchor-positive similarity in the batch |
+| `pos_sim_max` | Maximum anchor-positive similarity in the batch |
+| `neg_sim` | Mean similarity between anchors and other positives (in-batch negatives) |
+| `loss` | Average contrastive loss value |
+| `grad_norm` | Gradient norm (passed via kwargs) |
+
+## Cross-Rank Gathering
+
+`EmbeddingMetric` performs an `all_gather` to compute similarity statistics across all DP ranks, providing a global view of embedding quality even under data-parallel training.
+
+> This metric pairs with `InfonceLoss` for embedding/retrieval training tasks.
diff --git a/docs/source_en/Components/Metrics/GRPOMetric.md b/docs/source_en/Components/Metrics/GRPOMetric.md
new file mode 100644
index 000000000..cd4f11551
--- /dev/null
+++ b/docs/source_en/Components/Metrics/GRPOMetric.md
@@ -0,0 +1,66 @@
+# GRPOMetric
+
+The `GRPOMetric` tracks policy optimization diagnostics during GRPO training, including KL divergence, clipping rates, entropy, and log-probability statistics.
+
+## Usage
+
+```python
+from twinkle.metric import GRPOMetric
+
+metric = GRPOMetric(
+ device_mesh=device_mesh,
+ process_group=process_group,
+ epsilon=0.2, # PPO clip range
+ temperature=1.0, # Sampling temperature for logp rescaling
+ top_k_kl=10, # Track top-K high-KL tokens per step
+)
+
+# During training loop
+metric.accumulate(inputs, outputs, old_logps=old_logps, advantages=advantages)
+
+# At log interval
+results = metric.calculate()
+# results: {
+# 'train/policy_confidence': 0.85,
+# 'train/mean_new_logp': -1.23,
+# 'train/mean_old_logp': -1.30,
+# 'train/logp_diff_mean': 0.07,
+# 'train/approx_kl': 0.003,
+# 'train/token_kl_max': 0.15,
+# 'train/entropy': 2.1,
+# 'train/clip_ratio': 0.02,
+# 'train/clip_ratio_low': 0.01,
+# 'train/clip_ratio_high': 0.01,
+# }
+```
+
+## Reported Metrics
+
+| Metric | Description |
+|:-------|:------------|
+| `train/policy_confidence` | exp(mean_new_logp) — higher means model is more confident |
+| `train/mean_new_logp` | Average log-probability of generated tokens under current policy |
+| `train/mean_old_logp` | Average log-probability under reference policy |
+| `train/logp_diff_mean` | Mean (new - old) log-probability difference |
+| `train/approx_kl` | Schulman K3 estimator of KL(old \|\| new) |
+| `train/token_kl_max` | Maximum per-token KL across all ranks |
+| `train/token_ratio_max` | Maximum importance weight across all ranks |
+| `train/entropy` | Average token-level entropy |
+| `train/clip_ratio` | Fraction of tokens clipped (low + high) |
+| `train/clip_ratio_low` | Fraction clipped below (ratio < 1-ε, negative advantage) |
+| `train/clip_ratio_high` | Fraction clipped above (ratio > 1+ε, positive advantage) |
+
+## Variants
+
+- **`GSPOMetric`** — Computes clip rate at sequence level (geometric-mean ratio per sequence)
+- **`CISPOMetric`** — Unconditional clip rate (not gated by advantage sign)
+
+## Parameters
+
+| Parameter | Type | Default | Description |
+|:----------|:-----|:--------|:------------|
+| `epsilon` | float | 0.2 | Lower clip boundary |
+| `epsilon_high` | float | None | Upper clip boundary (defaults to epsilon) |
+| `temperature` | float | 1.0 | Rescale logps to T=1 before computing KL |
+| `top_k_kl` | int | 0 | If > 0, record top-K high-KL token details |
+| `ignore_index` | int | -100 | Label value to mask out |
diff --git a/docs/source_en/Components/Metrics/index.rst b/docs/source_en/Components/Metrics/index.rst
index 5d50e183b..68215482d 100644
--- a/docs/source_en/Components/Metrics/index.rst
+++ b/docs/source_en/Components/Metrics/index.rst
@@ -8,4 +8,6 @@ Metrics
Accuracy.md
CompletionRewardMetric.md
DPOMetric.md
+ GRPOMetric.md
+ EmbeddingMetric.md
Building-Metrics.md
diff --git a/docs/source_en/Components/Model/MultiLoraTransformersModel.md b/docs/source_en/Components/Model/MultiLoraTransformersModel.md
index c196f900b..0c78a6886 100644
--- a/docs/source_en/Components/Model/MultiLoraTransformersModel.md
+++ b/docs/source_en/Components/Model/MultiLoraTransformersModel.md
@@ -30,3 +30,48 @@ The reason for the existence of max_loras and max_r parameters is that Twinkle's
Because of this, the user's r must be less than or equal to the max_r configuration. During actual training, only part of the lora's rank will be used in the calculation.
MultiLoraTransformersModel supports the `@remote_class` annotation and supports device_mesh, which means it can run in Ray workers.
+
+## Tenant Lifecycle
+
+Under the hood, `MultiLoraTransformersModel` uses the `MultiLora` manager to handle tenant LoRA slots. The key APIs:
+
+### acquire_lora
+
+Claim an available LoRA slot for a tenant:
+
+```python
+adapter_name = model.multi_lora.acquire_lora('tenant_a', LoraConfig(r=16, lora_alpha=32))
+```
+
+- Raises `RuntimeError` if all slots are in use or `config.r > max_r`
+
+### release_lora
+
+Release a tenant's LoRA slot, resetting weights to initial state:
+
+```python
+model.multi_lora.release_lora('tenant_a')
+```
+
+### Context Manager
+
+Use `adapter()` for scoped activation:
+
+```python
+with model.multi_lora.adapter('tenant_a') as name:
+ output = model.forward(inputs)
+```
+
+### LoraTenant
+
+Each slot is tracked as a `LoraTenant` dataclass:
+
+```python
+@dataclass
+class LoraTenant:
+ index: int # Slot index (0..max_loras-1)
+ adapter_name: str # Internal name (e.g. "lora_0")
+ config: LoraConfig # Pre-allocated config (max_r)
+ tenant_adapter_name: str # User-facing tenant name (None if free)
+ tenant_config: LoraConfig # Tenant's actual config (None if free)
+```
diff --git a/docs/source_en/Components/Model/SupportedModels.md b/docs/source_en/Components/Model/SupportedModels.md
new file mode 100644
index 000000000..7cd9e8b4d
--- /dev/null
+++ b/docs/source_en/Components/Model/SupportedModels.md
@@ -0,0 +1,79 @@
+# Supported Models
+
+Twinkle supports any model compatible with HuggingFace Transformers or Megatron-LM. Below is a curated list of models tested with Twinkle.
+
+## Language Models
+
+| Model Family | Model IDs | Parameters | Features |
+|:-------------|:----------|:-----------|:---------|
+| Qwen 3.5 | `Qwen/Qwen3.5-0.6B` ~ `Qwen/Qwen3.5-235B-A22B` | 0.6B–235B | MoE, Thinking mode |
+| Qwen 2.5 | `Qwen/Qwen2.5-0.5B` ~ `Qwen/Qwen2.5-72B` | 0.5B–72B | Dense |
+| DeepSeek V4 | `deepseek-ai/DeepSeek-V4` | 685B MoE | Custom DSML encoding |
+| DeepSeek R1 | `deepseek-ai/DeepSeek-R1` | 685B MoE | Reasoning |
+| LLaMA 3 | `meta-llama/Llama-3.3-70B-Instruct` | 8B–70B | Dense |
+| Mistral | `mistralai/Mistral-7B-v0.3` | 7B | Dense |
+| Yi | `01-ai/Yi-1.5-34B` | 6B–34B | Dense |
+| GLM-4 | `THUDM/glm-4-9b-chat` | 9B | Dense |
+| InternLM 2.5 | `internlm/internlm2_5-7b-chat` | 7B–20B | Dense |
+
+## Vision-Language Models
+
+| Model Family | Model IDs | Features |
+|:-------------|:----------|:---------|
+| Qwen 3.5 VL | `Qwen/Qwen3.5-VL-3B` ~ `Qwen/Qwen3.5-VL-72B` | Image, Video |
+| Qwen 2.5 VL | `Qwen/Qwen2.5-VL-7B-Instruct` | Image, Video |
+| InternVL 2.5 | `OpenGVLab/InternVL2_5-8B` | Image |
+
+## Embedding Models
+
+| Model Family | Model IDs | Training Method |
+|:-------------|:----------|:----------------|
+| Qwen3 Embedding | `Qwen/Qwen3-Embedding-0.6B` | InfoNCE contrastive |
+| GTE | `thenlper/gte-large-zh` | InfoNCE contrastive |
+
+## Model Loading
+
+Models can be loaded from ModelScope or HuggingFace:
+
+```python
+from twinkle.model import TransformersModel
+
+# From ModelScope (ms:// prefix)
+model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B')
+
+# From HuggingFace (hf:// prefix)
+model = TransformersModel(model_id='hf://meta-llama/Llama-3.3-70B-Instruct')
+
+# Local path
+model = TransformersModel(model_id='/path/to/model')
+```
+
+## Framework Support
+
+| Framework | Class | Use Case |
+|:----------|:------|:---------|
+| Transformers | `TransformersModel` | General training (SFT, RLHF, DPO) |
+| Transformers + Multi-LoRA | `MultiLoraTransformersModel` | Multi-tenant training |
+| Megatron-LM | `MegatronModel` | Large-scale distributed pre-training |
+| Megatron + Multi-LoRA | `MultiLoraMegatronModel` | Large-scale multi-tenant |
+
+## Precision Support
+
+| Mode | Description |
+|:-----|:------------|
+| `bf16` | BFloat16 mixed precision (recommended for A100/H100) |
+| `fp16` | Float16 mixed precision (for older GPUs) |
+| `fp8` | FP8 precision (H100 with Transformer Engine) |
+| `no` | Full precision (debugging only) |
+
+## Parallelism Strategies
+
+| Strategy | Config Key | Description |
+|:---------|:-----------|:------------|
+| FSDP | `strategy=accelerate` | Accelerate-managed FSDP (default) |
+| Native FSDP | `strategy=native_fsdp` | PyTorch native FSDP |
+| Tensor Parallel | `tp_size` | Split layers across GPUs |
+| Pipeline Parallel | `pp_size` | Split model stages |
+| Data Parallel | `dp_size` | Replicate model, split data |
+| Sequence Parallel | `sequence_parallel` | Split long sequences |
+| Expert Parallel | `ep_size` | MoE expert distribution |
diff --git a/docs/source_en/Components/Model/index.rst b/docs/source_en/Components/Model/index.rst
index e0648f00f..4802cd0d3 100644
--- a/docs/source_en/Components/Model/index.rst
+++ b/docs/source_en/Components/Model/index.rst
@@ -8,3 +8,14 @@ Model
MultiLoraTransformersModel.md
MegatronModel.md
MultiLoraMegatronModel.md
+ SupportedModels.md
+Model
+===============
+.. toctree::
+ :maxdepth: 1
+
+ TwinkleModel.md
+ TransformersModel.md
+ MultiLoraTransformersModel.md
+ MegatronModel.md
+ MultiLoraMegatronModel.md
diff --git a/docs/source_en/Components/Notifier/Notifier.md b/docs/source_en/Components/Notifier/Notifier.md
new file mode 100644
index 000000000..c4e14b511
--- /dev/null
+++ b/docs/source_en/Components/Notifier/Notifier.md
@@ -0,0 +1,93 @@
+# Notifier
+
+The Notifier component provides a pluggable notification system for sending alerts during training. When exceptions occur or training events need attention, notifiers deliver messages to external channels (e.g., DingTalk webhooks).
+
+## Base Interface
+
+```python
+from twinkle.notifier import Notifier
+
+class Notifier:
+ def __call__(self, message: str):
+ """Send a notification message."""
+ ...
+
+ def to_dict(self) -> dict:
+ """Serialize for checkpoint/restore."""
+ ...
+
+ @classmethod
+ def from_dict(cls, data: dict) -> Notifier:
+ """Restore from serialized form."""
+ ...
+```
+
+## DingNotifier
+
+Sends notifications to DingTalk (钉钉) custom robot webhooks.
+
+```python
+from twinkle.notifier import DingNotifier
+
+notifier = DingNotifier(
+ ding_url='https://oapi.dingtalk.com/robot/send?access_token=xxx',
+ secret='SECxxxxxxx', # Optional: for signed robots
+ timeout=5.0,
+)
+
+# Send a message
+notifier("### Training Complete\n\n- Steps: 1000\n- Loss: 0.25")
+```
+
+**Parameters:**
+- `ding_url`: Full DingTalk webhook URL with access token
+- `secret`: Optional signing secret for signed-robot mode
+- `timeout`: HTTP request timeout in seconds (default: 5.0)
+
+Messages are sent as DingTalk **Markdown** format. The first heading line is extracted as the chat preview title.
+
+## Exception Notifications
+
+Twinkle provides automatic exception notification with deduplication:
+
+```python
+from twinkle.notifier.base import notify_exception
+
+# Automatically sends formatted exception info
+# Only one rank sends per unique exception (prevents flooding)
+try:
+ model.forward_backward(batch)
+except Exception as e:
+ notify_exception(notifier, context='forward_backward', exc=e, name='sft_train')
+```
+
+The notification includes:
+- Exception type and message
+- Full traceback
+- Runtime metadata (rank, PID, hostname)
+- Deduplication: only one notification per unique exception across all ranks
+
+## Custom Notifier
+
+Create custom notifiers by subclassing `Notifier`:
+
+```python
+from twinkle.notifier import Notifier
+
+class SlackNotifier(Notifier):
+ def __init__(self, webhook_url: str):
+ self.webhook_url = webhook_url
+
+ def __call__(self, message: str):
+ import requests
+ requests.post(self.webhook_url, json={'text': message})
+
+ def to_dict(self):
+ return {'class': 'SlackNotifier', 'webhook_url': self.webhook_url}
+
+ @classmethod
+ def _from_dict_impl(cls, data):
+ return cls(webhook_url=data['webhook_url'])
+```
+
+> Notifiers are registered automatically via `__init_subclass__`, so `Notifier.from_dict()` can restore any subclass by name.
diff --git a/docs/source_en/Components/Notifier/index.rst b/docs/source_en/Components/Notifier/index.rst
new file mode 100644
index 000000000..ff82117d4
--- /dev/null
+++ b/docs/source_en/Components/Notifier/index.rst
@@ -0,0 +1,6 @@
+Notifier
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Notifier.md
diff --git a/docs/source_en/Components/TUI/Auto-Research.md b/docs/source_en/Components/TUI/Auto-Research.md
new file mode 100644
index 000000000..bd27a9705
--- /dev/null
+++ b/docs/source_en/Components/TUI/Auto-Research.md
@@ -0,0 +1,313 @@
+# Auto-Research (TUI)
+
+Twinkle TUI is a terminal-based intelligent training assistant that lets you **control, monitor, and debug ML training through natural language**. It combines a chat-driven AI agent with real-time metrics visualization, log streaming, and an automated health monitor that can detect and fix training failures autonomously.
+
+## Architecture Overview
+
+```
+┌──────────────────────────────────────────────────────────┐
+│ TwinkleTUI (Textual App) │
+│ ┌──────────────────────────────────────────────────────┐ │
+│ │ StatusBar: state / run_id / model / step / progress │ │
+│ ├──────────────────────┬───────────────────────────────┤ │
+│ │ MetricsPanel │ LogPanel │ │
+│ │ (ASCII chart) │ (scrolling logs) │ │
+│ ├──────────────────────┤ │ │
+│ │ ChatPanel │ │ │
+│ │ (user <-> agent) │ │ │
+│ └──────────────────────┴───────────────────────────────┘ │
+│ │
+│ Background Services: │
+│ AgentLoop ─── LLM tool-calling loop │
+│ TrainingMonitor ─── periodic health check & auto-fix │
+│ MetricsPoller ─── incremental metrics reading │
+│ LogsPoller ─── incremental log tailing │
+│ SkillsLoader ─── async plugin loading │
+└──────────────────────────────────────────────────────────┘
+```
+
+## Installation & Launch
+
+TUI is part of the `twinkle-client` package:
+
+```bash
+pip install twinkle-client
+```
+
+### Command-Line Usage
+
+```bash
+# Basic launch (uses default local Ollama endpoint)
+twinkle-tui
+
+# Specify LLM backend
+twinkle-tui --llm-base-url http://localhost:11434/v1 --llm-model qwen3.5
+
+# Attach to an existing training run
+twinkle-tui --run-id my-grpo-run
+
+# Use a remote API (e.g., OpenAI-compatible)
+twinkle-tui --llm-base-url https://api.example.com/v1 --llm-api-key sk-xxx --llm-model gpt-4o
+
+# Enable debug logging
+twinkle-tui --verbose
+```
+
+Or run as a Python module:
+
+```bash
+python -m twinkle_client.tui
+```
+
+### CLI Options
+
+| Option | Env Var | Default | Description |
+|--------|---------|---------|-------------|
+| `--run-id`, `-r` | `TWINKLE_TUI_RUN_ID` | None | Attach to an existing training run |
+| `--llm-base-url` | `TWINKLE_LLM_BASE_URL` | `http://localhost:11434/v1` | LLM API base URL |
+| `--llm-model` | `TWINKLE_LLM_MODEL` | `qwen3.5` | LLM model name |
+| `--llm-api-key` | `TWINKLE_LLM_API_KEY` | `not-needed` | LLM API key |
+| `--verbose`, `-v` | `TWINKLE_TUI_VERBOSE` | `False` | Enable DEBUG logging |
+| `--version`, `-V` | — | — | Show version and exit |
+
+### Keyboard Shortcuts
+
+| Key | Action |
+|-----|--------|
+| `q` | Quit |
+| `Ctrl+P` | Toggle metrics panel |
+| `Ctrl+L` | Clear logs |
+
+## Chat Agent
+
+The core of TUI is an **LLM-powered tool-calling agent** (`AgentLoop`) that processes natural language commands through an OpenAI-compatible API. The agent maintains conversation history with automatic pruning (last 50 messages) and supports up to 10 tool-calling rounds per interaction.
+
+### What You Can Say
+
+**Training lifecycle:**
+- *"List my training runs"*
+- *"Start a new GRPO training with Qwen3.5-4B on gsm8k"*
+- *"Pause the current run"*
+- *"Resume training"*
+- *"Stop training"*
+
+**Server management:**
+- *"Start the server with Qwen3.5-4B and a Qwen3.5-72B sampler on 2 GPUs"*
+- *"Shut down the server"*
+- *"How many GPUs are available?"*
+
+**Monitoring & analysis:**
+- *"How is the training going?"*
+- *"Show me the reward-related metrics"*
+- *"Zoom into steps 100-200"*
+- *"Reset the chart view"*
+
+**Search:**
+- *"Search for math datasets"*
+- *"Find Qwen models on ModelScope"*
+
+### Available Tools
+
+The agent has access to 13 built-in tools:
+
+| Tool | Description |
+|------|-------------|
+| `list_training_runs` | List all training runs |
+| `get_training_status` | Get detailed status and recent metrics |
+| `start_server` | Start Ray cluster + Twinkle Server (idempotent) |
+| `shutdown_server` | Shut down server and release GPU resources |
+| `start_training` | Create and launch a new training run |
+| `select_run` | Switch monitoring to a different run |
+| `pause_training` | Pause training (SIGKILL, server retains state) |
+| `resume_training` | Resume by re-launching the client script |
+| `stop_training` | Stop training (SIGTERM, saves checkpoint) |
+| `update_script` | Update training script with version archiving |
+| `list_supported_models` | Query server for available models |
+| `search_datasets` | Search ModelScope for datasets |
+| `search_models` | Search ModelScope for models |
+| `zoom_metrics` | Adjust metrics chart view range |
+| `select_metrics` | Choose which metrics to display (max 4) |
+| `get_cluster_info` | Get GPU/cluster resource info |
+
+### Server Startup
+
+The `start_server` tool automates a multi-step pipeline:
+
+1. **GPU detection** — `nvidia-smi` hardware scan
+2. **GPU allocation** — partition GPUs between training model and samplers
+3. **Config generation** — auto-create `server_config.yaml`
+4. **Ray cluster startup** — multi-node GPU partitioning with isolated `CUDA_VISIBLE_DEVICES`
+5. **Server launch** — start Twinkle Server as background process
+6. **Health check** — poll `/api/v1/healthz` until ready
+
+Multi-model topology is supported: 1 training model + N sampler/teacher models.
+
+### Skills System
+
+TUI supports extensible skill plugins loaded from three sources:
+
+1. **Bundled skills** — shipped inside `twinkle_client/skills/bundled/`
+2. **User-local skills** — `~/.cache/twinkle/tui/skills/local/`
+3. **Community skills** — fetched from ModelScope (best-effort, 10s timeout)
+
+Skills are loaded asynchronously after startup and injected into the agent's system prompt. The agent is usable immediately even before skills finish loading.
+
+## Training Monitor (Auto-Fix)
+
+The `TrainingMonitor` is a background service that runs every **30 seconds**, collecting all available signals about the current training run and feeding them to the LLM for analysis.
+
+### Collected Signals
+
+- **Process status**: alive / dead / unknown
+- **output.log tail**: last 1500 chars (prioritizes tracebacks)
+- **Metrics**: recent entries + first-half vs second-half trend analysis
+- **Stall duration**: seconds since last metric was produced
+- **Current train.py**: full script source (for accurate fixes)
+
+### Decision Framework
+
+The LLM classifies each check into one of three actions:
+
+| Decision | When | Action |
+|----------|------|--------|
+| **LGTM** | Training progressing normally | No action |
+| **WARNING** | Loss plateau, reward hacking, KL explosion, etc. | Relay observation to user |
+| **FIX** | Script crashed, process dead with traceback | Auto-fix and restart |
+
+### Auto-Fix Pipeline
+
+When a FIX is needed:
+
+1. LLM outputs diagnosis + complete fixed script
+2. Monitor archives the old `train.py` as `train_v{N}.py`
+3. Writes the fixed script as the new `train.py`
+4. Re-launches training via `resume_training`
+5. Resets stall tracking for the new attempt
+
+Safety guardrails:
+- Max **3 auto-fix attempts** per run (prevents infinite retry loops)
+- Fix attempts are tracked per `run_id`
+- Snapshot deduplication avoids re-analyzing unchanged states
+
+## File-Based Connection
+
+TUI communicates with training processes through the local filesystem:
+
+```
+~/.cache/twinkle/{run_id}/
+├── meta.json — run metadata (model_id, config, status, pid)
+├── metrics.jsonl — one JSON object per step (incremental)
+├── output.log — combined stdout+stderr from training
+├── train.py — current active training script
+└── train_v{N}.py — archived previous script versions
+```
+
+### Training Control Model
+
+In Server Mode, the Twinkle Server retains all model/optimizer state in GPU memory:
+
+- **Pause** = kill client process (SIGKILL) — server state preserved
+- **Resume** = re-launch client script — seamlessly continues training
+- **Stop** = SIGTERM — triggers checkpoint saving then exits
+- **Shut down server** = releases GPU resources, **destroys** model state
+
+## TrainingRuntime (Script Integration)
+
+Training scripts use `TrainingRuntime` to integrate with TUI:
+
+```python
+from twinkle_client.tui.runtime import TrainingRuntime
+
+rt = TrainingRuntime(run_id='my-grpo-run')
+rt.start(model_id='Qwen/Qwen3.5-4B', config={'lr': 1e-5})
+rt.register_graceful_shutdown(model, dataloader)
+
+for step, batch in enumerate(dataloader):
+ # ... training logic ...
+ rt.log_metrics(step=step, loss=loss, reward=reward, grad_norm=gn, lr=lr)
+ rt.log(f'Completed step {step}, loss={loss:.4f}')
+
+rt.finish()
+```
+
+### Key Methods
+
+| Method | Description |
+|--------|-------------|
+| `start(model_id, config, script_path)` | Initialize run directory and metadata |
+| `log_metrics(**kwargs)` | Write metrics entry to `metrics.jsonl` |
+| `log(message)` | Print log message (captured as `output.log`) |
+| `get_resume_info()` | Get `last_step` for resuming from checkpoint |
+| `finish(status)` | Mark training as finished, close files |
+| `register_graceful_shutdown(model, dataloader)` | Register SIGTERM handler that saves checkpoint |
+
+### Resume Support
+
+`TrainingRuntime` automatically saves training progress to `meta.json` (throttled to every 5 seconds). Scripts can use `get_resume_info()` to resume from the last saved step:
+
+```python
+rt = TrainingRuntime(run_id='my-run')
+resume = rt.get_resume_info()
+global_step = resume['last_step']
+
+if global_step > 0:
+ dataloader.skip_consumed_samples(global_step * BATCH_SIZE)
+ print(f'Resuming from step {global_step}')
+```
+
+### Graceful Shutdown
+
+When `register_graceful_shutdown()` is called, a SIGTERM handler is installed that:
+
+1. Saves model checkpoint (LoRA weights + optimizer state)
+2. Saves dataloader position (`consumed_train_samples`)
+3. Logs the checkpoint path
+4. Marks training as `stopped` and exits
+
+## UI Panels
+
+### StatusBar
+
+Displays current training state at the top of the screen:
+
+- Training state icon (🚀 Training / ⏸ Paused / ✅ Done / ❌ Error)
+- Run ID
+- Model name
+- Current step
+- Progress bar with percentage
+
+### MetricsPanel
+
+Real-time ASCII chart rendered with `plotext`:
+
+- Plots up to 4 metrics simultaneously
+- Supports zoom (by step range and y-axis range)
+- Auto-selects first 3 available metrics if no selection
+- Hint bar shows hidden metrics that can be switched via agent
+- Retains up to 2000 data points
+
+### LogPanel
+
+Scrolling log viewer:
+
+- Strips ANSI escape sequences for clean display
+- Hard-wraps long lines to prevent overflow
+- Handles `\r` carriage returns from progress bars
+- Retains last 500 lines
+
+### ChatPanel
+
+Interactive chat interface:
+
+- User input with streaming agent responses
+- Throttled token flushing (80ms) for smooth display
+- Stream reset on tool-call detection
+- Supports Rich markup formatting
+
+## Logging
+
+All TUI logs are written to `./tui.log` (current working directory):
+
+- Rotated at 5MB with 3 backups
+- **No console output** — avoids corrupting Textual's alt-screen buffer
+- Use `--verbose` for DEBUG level logging
diff --git a/docs/source_en/Components/TUI/SkillProvider.md b/docs/source_en/Components/TUI/SkillProvider.md
new file mode 100644
index 000000000..d008cf978
--- /dev/null
+++ b/docs/source_en/Components/TUI/SkillProvider.md
@@ -0,0 +1,71 @@
+# SkillProvider
+
+The skill system allows Twinkle's TUI agent to dynamically load specialized knowledge from external sources (Git repos, APIs, local files) and inject them into the LLM's system prompt.
+
+## Architecture
+
+| Class | Role |
+|-------|------|
+| **Skill** | Dataclass holding a single skill's name, content, and source |
+| **SkillProvider** | Abstract base class for fetching skills from a source |
+| **SkillManager** | Orchestrates multiple providers, aggregates skills for prompt injection |
+
+## Skill Dataclass
+
+```python
+@dataclasses.dataclass
+class Skill:
+ name: str # Short identifier (typically filename without extension)
+ content: str # Full markdown content
+ source: str # Provider name + relative path for traceability
+```
+
+## Creating a Custom Provider
+
+Subclass `SkillProvider` and implement `name` and `fetch()`:
+
+```python
+from twinkle_client.skills.base import SkillProvider
+
+class MySkillProvider(SkillProvider):
+
+ @property
+ def name(self) -> str:
+ return 'my-skills'
+
+ async def fetch(self) -> None:
+ # Download/clone skill files to self.cache_dir
+ # e.g., git clone, API download, file copy
+ ...
+```
+
+The default `load_skills()` scans `self.cache_dir` for `.md` files (skipping README, LICENSE, etc.) and returns `Skill` objects.
+
+## SkillManager
+
+```python
+from twinkle_client.skills.manager import SkillManager
+
+manager = SkillManager()
+manager.register(my_provider)
+manager.register(another_provider)
+
+# Fetch and load all skills
+skills = await manager.load_all()
+
+# Format for LLM system prompt injection
+prompt_section = manager.format_for_prompt()
+```
+
+### Key Methods
+
+| Method | Description |
+|--------|-------------|
+| `register(provider)` | Add a skill provider |
+| `load_all()` | Fetch + load from all providers |
+| `format_for_prompt()` | Render skills as formatted text for system prompt |
+| `get_skill_names()` | List names of loaded skills |
+
+## Cache Directory
+
+By default, skills are cached at `~/.cache/twinkle/tui/skills//`. Override by passing `cache_dir` to the provider constructor.
diff --git a/docs/source_en/Components/TUI/index.rst b/docs/source_en/Components/TUI/index.rst
new file mode 100644
index 000000000..29cdad073
--- /dev/null
+++ b/docs/source_en/Components/TUI/index.rst
@@ -0,0 +1,7 @@
+TUI
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Auto-Research.md
+ SkillProvider.md
diff --git a/docs/source_en/Components/Task Processor/GRPOProcessor.md b/docs/source_en/Components/Task Processor/GRPOProcessor.md
deleted file mode 100644
index adff73c45..000000000
--- a/docs/source_en/Components/Task Processor/GRPOProcessor.md
+++ /dev/null
@@ -1,19 +0,0 @@
-# GRPOLossProcessor
-
-GRPOLossProcessor is a task processor wrapper designed for GRPO reinforcement learning training. It extends InputProcessor with GRPO-specific data preparation.
-
-```python
-from twinkle.processor import GRPOLossProcessor
-
-processor = GRPOLossProcessor(
- device_mesh=...,
- padding_free=False,
- framework='transformers',
-)
-
-model.set_processor(processor)
-```
-
-GRPOLossProcessor wraps the base `InputProcessor` and adds handling for GRPO-specific fields such as advantages, old log-probabilities, and reference log-probabilities that are required by the GRPO loss function.
-
-> For standard SFT tasks, use `InputProcessor` directly. Use `GRPOLossProcessor` when your training loop involves GRPO or its variants.
diff --git a/docs/source_en/Components/Task Processor/index.rst b/docs/source_en/Components/Task Processor/index.rst
index 1f20fdbca..1e9d600a4 100644
--- a/docs/source_en/Components/Task Processor/index.rst
+++ b/docs/source_en/Components/Task Processor/index.rst
@@ -4,4 +4,3 @@ Task Processor
:maxdepth: 1
InputProcessor.md
- GRPOProcessor.md
diff --git a/docs/source_en/Components/Template/DeepSeekV4Template.md b/docs/source_en/Components/Template/DeepSeekV4Template.md
new file mode 100644
index 000000000..bbd74928e
--- /dev/null
+++ b/docs/source_en/Components/Template/DeepSeekV4Template.md
@@ -0,0 +1,56 @@
+# DeepSeek-V4 Template
+
+The `DeepseekV4Template` provides native support for DeepSeek V4's custom chat template encoding, including its unique thinking mode, tool-call protocol, and multi-token special tokens.
+
+## Usage
+
+```python
+from twinkle.template import DeepseekV4Template
+
+template = DeepseekV4Template(
+ model_id='deepseek-ai/DeepSeek-V4',
+ enable_thinking=True,
+)
+```
+
+## Features
+
+- **Custom tokenizer wrapper**: Overrides `apply_chat_template` with DeepSeek V4's encoding protocol
+- **Thinking mode**: Supports `thinking` / `chat` modes with configurable reasoning effort
+- **Tool calls**: Native DSML (DeepSeek Markup Language) tool-call encoding
+- **Multi-token EOS**: Handles DeepSeek V4's multi-character special tokens
+
+## Thinking Modes
+
+```python
+# Enable deep thinking (reasoning mode)
+template = DeepseekV4Template(model_id='...', enable_thinking=True)
+
+# Control reasoning effort
+# 'max' or 'high' enables extended reasoning budget
+template.encode(messages, reasoning_effort='max')
+```
+
+## Tool Call Support
+
+DeepSeek V4 uses its own DSML protocol for structured function calling:
+
+```python
+messages = [
+ {'role': 'user', 'content': 'What is the weather in Shanghai?'},
+]
+tools = [
+ {'type': 'function', 'function': {'name': 'get_weather', 'parameters': {...}}}
+]
+
+features = template.encode(messages, tools=tools)
+```
+
+## Key Differences from Base Template
+
+| Feature | Base Template | DeepseekV4Template |
+|:--------|:-------------|:-------------------|
+| Chat template | HuggingFace native | Custom DSML encoding |
+| Thinking | `` tags | Native thinking mode toggle |
+| Tool calls | Hermes/Qwen format | DSML tool blocks |
+| EOS handling | Single token | Multi-token special markers |
diff --git a/docs/source_en/Components/Template/Template.md b/docs/source_en/Components/Template/Template.md
index 60962a33a..b9124412a 100644
--- a/docs/source_en/Components/Template/Template.md
+++ b/docs/source_en/Components/Template/Template.md
@@ -2,6 +2,64 @@
The template is a key component for converting Trajectory to InputFeature.
+```python
+class Template:
+
+ def __init__(self,
+ model_id: str,
+ use_chat_template: bool = True,
+ max_length: Optional[int] = 8192,
+ truncation_strategy: Literal['raise', 'left', 'right', 'split', 'delete'] = 'raise',
+ default_system: Optional[str] = None):
+ ...
+
+ def batch_encode(self, trajectories: Union[Dict[str, Any], List[Trajectory]]) -> List[InputFeature]:
+ # Batch encode samples
+ ...
+
+ def check(self, trajectory: Trajectory) -> Optional[Trajectory]:
+ # Encode one sample and return the original sample
+ # Generally used to check data reasonableness in RL algorithms like GRPO
+ ...
+
+ def batch_check(self, trajectories: List[Trajectory]) -> List[Optional[Trajectory]]:
+ # Batch check samples
+ ...
+
+ def decode(self, token_ids: List[int], **kwargs) -> str:
+ # Decode sample
+ ...
+
+ def batch_decode(self, token_ids: List[List[int]], **kwargs) -> List[str]:
+ # Batch decode samples
+ ...
+```
+
+- model_id: Model id containing tokenizer or processor
+- use_chat_template: Whether to use chat_template. If not used, it is generally a pre-training scenario
+- max_length: Maximum length of a single sample
+- truncation_strategy: How to handle the sample if it exceeds the maximum length
+ - raise: Throw an exception. Generally used for very precise dataset scenarios
+ - left: Remove tokens on the left to conform to max_length
+ - right: Remove tokens on the right to conform to max_length
+ - split: Split the oversized sample into multiple max_length chunks (not supported for multimodal, LazyDataset, or IterablePackingDataset)
+ - delete: Drop the entire sample if it exceeds max_length
+- default_system: If the dataset does not have a system message, use the default system
+
+> Template does not support using functions as replacements because it needs to support many functions internally. If you need to write a new Template, please inherit the `Template` class.
+> Generally speaking, using the Template base class is sufficient for pure text models. In the base class, we use tokenizer.apply_chat_template to encode the model, which is universal for general pure text models.
+
+# Template mapping
+
+Currently, the model-template mapping is simple:
+
+- Template class: Supported in all pure text LLMs.
+- DeepseekV4Template class: For DeepSeek V4, rewrites the chat template encoding logic, `encode_messages` is built into twinkle.
+- Qwen3_5Template class: For Qwen3.5 MLLMs.
+# Template
+
+The template is a key component for converting Trajectory to InputFeature.
+
```python
class Template:
diff --git a/docs/source_en/Components/Template/ToolCallParsers.md b/docs/source_en/Components/Template/ToolCallParsers.md
new file mode 100644
index 000000000..8d4e3f988
--- /dev/null
+++ b/docs/source_en/Components/Template/ToolCallParsers.md
@@ -0,0 +1,98 @@
+# Tool Call Parsers
+
+Twinkle's template system includes a modular tool-call parsing framework for training models with function calling capabilities.
+
+## Architecture
+
+```
+ToolCallRegistry
+├── HermesQwenParser — Hermes/Qwen style ...
+├── ReActParser — ReAct Thought/Action/Observation
+├── ClineParser — Cline XML-based tool calls
+└── VCPParser — VCP protocol
+```
+
+## ToolCallParser Interface
+
+```python
+from twinkle.template.tools import ToolCallParser
+
+class ToolCallParser(ABC):
+ name: str = ''
+ open_marker: str | None = None
+ close_marker: str | None = None
+
+ def detect(self, text: str) -> bool:
+ """Check if text contains this format's markup."""
+ ...
+
+ def parse(self, text: str) -> List[Dict[str, Any]]:
+ """Extract tool calls in OpenAI format."""
+ ...
+
+ def clean(self, text: str) -> str:
+ """Strip markup, return plain content."""
+ ...
+```
+
+## ToolCallRegistry
+
+The registry auto-discovers parsers and routes detection:
+
+```python
+from twinkle.template.tools import ToolCallRegistry
+
+# Detect which format a completion uses
+parser = ToolCallRegistry.detect_first(completion_text)
+if parser:
+ tool_calls = parser.parse(completion_text)
+ clean_text = parser.clean(completion_text)
+```
+
+## Built-in Parsers
+
+### HermesQwenParser
+
+Parses Hermes/Qwen-style function calls:
+
+```xml
+
+{"name": "get_weather", "arguments": {"city": "Shanghai"}}
+
+```
+
+### ReActParser
+
+Parses ReAct-style reasoning traces:
+
+```
+Thought: I need to check the weather
+Action: get_weather
+Action Input: {"city": "Shanghai"}
+Observation: ...
+```
+
+### ClineParser
+
+Parses Cline XML-based tool invocations with structured parameters.
+
+### VCPParser
+
+Parses VCP (Visual Code Protocol) tool calls.
+
+## Usage in Training
+
+Tool call parsers integrate with the Template during preprocessing:
+
+```python
+from twinkle.template import Template
+
+template = Template(
+ model_id='ms://Qwen/Qwen3.5-4B',
+ enable_thinking=True,
+)
+
+# Template automatically uses ToolCallRegistry for
+# tool-call aware tokenization during encoding
+features = template.encode(messages, tools=tool_definitions)
+```
diff --git a/docs/source_en/Components/Template/index.rst b/docs/source_en/Components/Template/index.rst
index cd5fddb42..c3d125f04 100644
--- a/docs/source_en/Components/Template/index.rst
+++ b/docs/source_en/Components/Template/index.rst
@@ -4,3 +4,11 @@ Template
:maxdepth: 1
Template.md
+ DeepSeekV4Template.md
+ ToolCallParsers.md
+Template
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Template.md
diff --git a/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md
index 169adb86e..b5f55a575 100644
--- a/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md
+++ b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md
@@ -40,6 +40,24 @@ class DeviceMesh:
It is recommended to use `from_sizes` to construct it.
+### Parameter Reference
+
+| Parameter | Description | Default |
+|-----------|-------------|---------|
+| `world_size` | Total number of processes | 1 |
+| `dp_size` | Data parallel degree | 1 |
+| `fsdp_size` | Fully Sharded Data Parallel degree | None |
+| `tp_size` | Tensor parallel degree | None |
+| `pp_size` | Pipeline parallel degree | None |
+| `ulysses_size` | Ulysses sequence parallel degree | None |
+| `cp_size` | Context parallel degree | None |
+| `ep_size` | Expert parallel degree (for MoE models) | None |
+| `etp_size` | Expert tensor parallel degree | None |
+| `ep_fsdp_size` | FSDP degree within each EP group | None |
+| `vpp_size` | Virtual pipeline parallel degree | None |
+| `device_type` | Device type (`cuda`, `npu`, etc.) | `cuda` |
+| `sequence_parallel` | Enable Megatron-style sequence parallel | False |
+
Let's give an example:
```python
diff --git a/docs/source_en/Components/Training Middleware/Expert-Parallel.md b/docs/source_en/Components/Training Middleware/Expert-Parallel.md
new file mode 100644
index 000000000..edd0ec03d
--- /dev/null
+++ b/docs/source_en/Components/Training Middleware/Expert-Parallel.md
@@ -0,0 +1,74 @@
+# Expert Parallel (EP)
+
+Expert Parallel distributes Mixture-of-Experts (MoE) model experts across multiple GPUs, allowing each rank to hold a subset of experts. This reduces per-GPU memory and enables training of large MoE models.
+
+## Overview
+
+| Concept | Description |
+|---------|-------------|
+| **ExpertParallelConfig** | Configuration dataclass controlling EP behavior |
+| **apply_expert_parallel()** | Entry point that shards experts and patches forward |
+| **shard_experts()** | Evenly splits experts across EP ranks |
+| **patch_forward()** | Replaces MoE block forward with EP-aware all-to-all communication |
+
+## Configuration
+
+```python
+from twinkle.model.transformers.moe.expert_parallel import ExpertParallelConfig
+
+config = ExpertParallelConfig(
+ enabled=True, # Enable expert parallel
+ router_dtype='fp32', # Router computation dtype: 'fp32', 'bf16', 'fp16'
+ keep_router_logits=True, # Return router logits alongside hidden states
+ ignore_shared_experts=False,# Skip shared expert computation (e.g. DeepSeek)
+ ep_size=None, # EP world size (consumed by TransformersModel)
+)
+```
+
+## Usage with DeviceMesh
+
+EP is activated by setting `ep_size` in `DeviceMesh.from_sizes()`. The framework automatically calls `apply_expert_parallel()` during model initialization.
+
+```python
+from twinkle.utils import DeviceMesh
+
+# 8 GPUs: 2-way EP × 4-way data parallel
+device_mesh = DeviceMesh.from_sizes(
+ world_size=8,
+ dp_size=4,
+ ep_size=2,
+)
+```
+
+For combined EP + FSDP sharding on the expert parameters:
+
+```python
+# 8 GPUs: 2-way EP with FSDP within each EP group
+device_mesh = DeviceMesh.from_sizes(
+ world_size=8,
+ dp_size=2,
+ ep_size=2,
+ ep_fsdp_size=2,
+)
+```
+
+## Communication Pattern
+
+The EP forward pass follows a 4-stage pipeline:
+
+1. **Preprocess** — compute per-expert token counts and split sizes
+2. **Token Pre-All2All** — permute tokens by expert assignment, then all-to-all exchange across EP ranks
+3. **Expert Compute** — each rank runs its local experts on received tokens
+4. **Token Post-All2All** — all-to-all exchange results back, unpermute and apply routing weights
+
+```
+Input tokens → Router → [preprocess] → [pre_all2all] → [local experts] → [post_all2all] → Output
+```
+
+## Requirements
+
+- `num_experts` must be divisible by `ep_size`
+- `torch.distributed` must be initialized
+- MoE blocks must define a `gate`/`router` module and `experts` (either `nn.ModuleList` or tensor-style `gate_up_proj`/`down_proj`)
+- Both ModuleList-style and tensor-style (fused) experts are supported
+- Shared experts (e.g. DeepSeek MoE) are handled automatically unless `ignore_shared_experts=True`
diff --git a/docs/source_en/Components/Training Middleware/Padding-Free.md b/docs/source_en/Components/Training Middleware/Padding-Free.md
new file mode 100644
index 000000000..44dd7ba14
--- /dev/null
+++ b/docs/source_en/Components/Training Middleware/Padding-Free.md
@@ -0,0 +1,52 @@
+# Padding-Free Training
+
+Padding-free (also called "packing") training eliminates wasted computation on padding tokens by concatenating multiple sequences into a single packed batch. Twinkle supports padding-free training for both standard attention and Qwen3.5's GatedDeltaNet linear attention.
+
+## How It Works
+
+Instead of padding all sequences to `max_length`, padding-free packs multiple sequences into one row and uses `position_ids` to track sequence boundaries. This avoids wasted FLOPs on padding tokens.
+
+```
+Standard: [tok tok tok PAD PAD PAD] [tok tok PAD PAD PAD PAD]
+Packed: [tok tok tok tok tok ...] ← no padding waste
+```
+
+## Usage
+
+Padding-free is enabled via `PackingDataset` or `IterablePackingDataset`:
+
+```python
+from twinkle.dataset import PackingDataset
+
+dataset = PackingDataset(
+ dataset=base_dataset,
+ max_length=8192,
+)
+```
+
+The dataset automatically packs sequences and generates correct `position_ids` with resets at sequence boundaries.
+
+## GatedDeltaNet Patch (Qwen3.5)
+
+Qwen3.5 uses a hybrid architecture mixing standard attention with GatedDeltaNet linear attention. The native GatedDeltaNet implementation does not reset its linear-attention state at packed sequence boundaries.
+
+`GatedDeltaNetPaddingFreePatch` fixes this by:
+
+1. Patching `Qwen3_5DecoderLayer.forward` to pass `cu_seq_lens_q` (cumulative sequence lengths) to linear attention layers
+2. Patching `Qwen3_5GatedDeltaNet.forward` to use flash-linear-attention kernels (`causal_conv1d`, `chunk_gated_delta_rule`) with `cu_seqlens` support
+
+The patch is applied automatically when padding-free is detected on Qwen3.5 models.
+
+### Requirements
+
+- `flash-linear-attention` package must be installed
+- Only needed for Qwen3.5 models with GatedDeltaNet layers
+- When sequence parallel is enabled, a separate `Qwen3_5GatedDeltaNetUlyssesPatch` is used instead
+
+## Attention Backend Requirements
+
+| Attention Backend | Padding-Free Support |
+|-------------------|---------------------|
+| FlashAttention2 | Fully supported |
+| SDPA | Supported (incompatible with sequence parallel) |
+| Eager | Not supported |
diff --git a/docs/source_en/Components/Training Middleware/Sequence-Parallel.md b/docs/source_en/Components/Training Middleware/Sequence-Parallel.md
new file mode 100644
index 000000000..d08b01d67
--- /dev/null
+++ b/docs/source_en/Components/Training Middleware/Sequence-Parallel.md
@@ -0,0 +1,68 @@
+# Sequence Parallel (SP)
+
+Sequence Parallel splits long sequences across multiple GPUs along the sequence dimension, enabling training with sequence lengths that exceed single-GPU memory. Twinkle implements Ulysses-style sequence parallel with optional derived ring attention.
+
+## Overview
+
+| Concept | Description |
+|---------|-------------|
+| **SequenceParallelConfig** | Configuration dataclass for SP |
+| **SequenceParallelStrategy** | Strategy class that wraps SP lifecycle |
+| **SequenceParallel** | Core implementation handling pad/split/gather |
+
+## Configuration
+
+```python
+from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelConfig
+
+config = SequenceParallelConfig(
+ enabled=True, # Enable sequence parallel
+ ulysses_size=None, # Ulysses SP degree (auto-derived from DeviceMesh if None)
+ gather_logits=True, # Gather logits after forward for loss computation
+)
+```
+
+## Usage with DeviceMesh
+
+SP is activated by setting `ulysses_size` in `DeviceMesh.from_sizes()`:
+
+```python
+from twinkle.utils import DeviceMesh
+
+# 8 GPUs: 4-way Ulysses SP × 2-way data parallel
+device_mesh = DeviceMesh.from_sizes(
+ world_size=8,
+ dp_size=2,
+ ulysses_size=4,
+)
+```
+
+## How It Works
+
+1. **Pad** — input sequences are padded to a length divisible by SP world size
+2. **Split** — padded inputs are evenly split across SP ranks along the sequence dimension
+3. **Distributed Attention** — FlashAttention2 is patched to perform Ulysses all-to-all communication before/after attention computation
+4. **Gather** — after forward, logits are gathered back to full sequence length for loss computation
+
+## Supported Attention Backends
+
+| Backend | Status |
+|---------|--------|
+| FlashAttention2 | Fully supported (including packed/padding-free sequences) |
+| SDPA | Supported (non-packed batches only) |
+| Derived Ring Attention | Supported with FlashAttention2 only (`rp_world_size > 1`) |
+
+## Qwen3.5 Linear Attention
+
+SP automatically detects Qwen3.5 GatedDeltaNet linear attention layers and applies the `Qwen3_5GatedDeltaNetUlyssesPatch` for correct sequence-parallel behavior on hybrid attention architectures.
+
+## MoE Auxiliary Loss
+
+For MoE models, SP automatically installs a forward hook that gathers router logits across SP ranks before auxiliary loss computation, ensuring correct load-balancing signals.
+
+## Key Constraints
+
+- `num_key_value_heads` must be divisible by `ulysses_size` (for Ulysses) or use ring attention fallback
+- Packed/padding-free batches require FlashAttention2
+- Derived ring attention requires `batch_size == 1` (packed format)
+- `torch.distributed` must be initialized
diff --git a/docs/source_en/Components/Training Middleware/TwinkleClient.md b/docs/source_en/Components/Training Middleware/TwinkleClient.md
new file mode 100644
index 000000000..18a1437db
--- /dev/null
+++ b/docs/source_en/Components/Training Middleware/TwinkleClient.md
@@ -0,0 +1,81 @@
+# TwinkleClient
+
+`TwinkleClient` is the Python client for interacting with the Twinkle REST API. It manages sessions, training runs, and checkpoints.
+
+## Initialization
+
+```python
+from twinkle_client.manager import TwinkleClient
+
+client = TwinkleClient(
+ base_url='http://localhost:8000', # Or TWINKLE_SERVER_URL env var
+ api_key='your-api-key', # Or TWINKLE_SERVER_TOKEN env var
+ route_prefix='/twinkle', # API route prefix
+ session_heartbeat_interval=10, # Heartbeat interval in seconds
+ session_metadata={'user': 'alice'}, # Optional session metadata
+)
+```
+
+On init, the client:
+1. Sets `base_url` and `api_key` into shared context (used by all client objects)
+2. Creates a server-side session
+3. Starts a background heartbeat thread to keep the session alive
+
+## Health Check
+
+```python
+is_healthy = client.health_check() # Returns True/False
+capabilities = client.get_server_capabilities() # Supported models
+```
+
+## Training Runs
+
+```python
+# List runs
+runs = client.list_training_runs(limit=20, offset=0)
+
+# List with pagination cursor
+runs, cursor = client.list_training_runs_with_cursor(limit=20)
+
+# Get specific run
+run = client.get_training_run(run_id='run_abc123')
+
+# Find by base model
+qwen_runs = client.find_training_run_by_model('Qwen/Qwen3.5-4B')
+```
+
+## Checkpoints
+
+```python
+# List checkpoints for a run
+checkpoints = client.list_checkpoints(run_id='run_abc123')
+
+# Get checkpoint path
+parsed = client.get_checkpoint_path(run_id, checkpoint_id)
+# parsed.path → filesystem path
+# parsed.twinkle_path → twinkle:// URI
+
+# Get latest checkpoint (useful for resume training)
+latest_path = client.get_latest_checkpoint_path(run_id)
+
+# Delete checkpoint
+client.delete_checkpoint(run_id, checkpoint_id)
+```
+
+## Capacity & Weights Info
+
+```python
+# LoRA capacity
+capacity = client.get_capacity_info()
+# capacity.max_loras, capacity.used_loras, capacity.free_loras
+
+# Weights metadata
+info = client.get_weights_info('twinkle://run_id/weights/checkpoint')
+# info.base_model, info.is_lora, info.lora_rank
+```
+
+## Cleanup
+
+```python
+client.close() # Stops heartbeat thread (also registered via atexit)
+```
diff --git a/docs/source_en/Components/Training Middleware/index.rst b/docs/source_en/Components/Training Middleware/index.rst
index 014dfdc66..b2dc3acee 100644
--- a/docs/source_en/Components/Training Middleware/index.rst
+++ b/docs/source_en/Components/Training Middleware/index.rst
@@ -4,4 +4,8 @@ Training Middleware
:maxdepth: 1
DeviceMesh-and-DeviceGroup.md
+ Expert-Parallel.md
+ Sequence-Parallel.md
+ Padding-Free.md
RemoteClass.md
+ TwinkleClient.md
diff --git a/docs/source_en/Usage Guide/Embedding-Training.md b/docs/source_en/Usage Guide/Embedding-Training.md
new file mode 100644
index 000000000..d86769c38
--- /dev/null
+++ b/docs/source_en/Usage Guide/Embedding-Training.md
@@ -0,0 +1,120 @@
+# Embedding Training
+
+Twinkle supports contrastive embedding model training with InfoNCE loss, in-batch negatives, and cross-rank gathering. This guide demonstrates how to train embedding models using Twinkle.
+
+---
+
+## Overview
+
+Embedding training in Twinkle uses the following core components:
+
+| Component | Role |
+|:----------|:-----|
+| `InfonceLoss` | Contrastive loss with in-batch negatives |
+| `EmbeddingMetric` | Tracks pos/neg similarity and loss |
+| `TransformersModel` | Trainable embedding model (with LoRA or full) |
+| `InputProcessor` | Processes anchor/positive pairs into features |
+
+### Data Format
+
+Each training sample consists of **(anchor, positive)** pairs. In the embedding feature tensor:
+
+```
+embeddings: [anchor_0, positive_0, anchor_1, positive_1, ...]
+labels: [ 1, 0, 1, 0, ...]
+```
+
+- `labels=1` marks the start of a new group (anchor)
+- `labels=0` marks positives/negatives within the group
+
+---
+
+## Basic Embedding Training
+
+A minimal embedding training script with DDP:
+
+```python
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.loss import InfonceLoss
+from twinkle.metric import EmbeddingMetric
+from twinkle.model import TransformersModel
+from twinkle.processor import InputProcessor
+from twinkle.template import Qwen3_5Template
+
+logger = get_logger()
+
+# --- Configuration ---
+MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
+MODEL_GPUS = 4
+BATCH_SIZE = 32
+LEARNING_RATE = 1e-5
+TEMPERATURE = 0.07
+EMB_MAX_LENGTH = 8192
+
+# --- Initialize ---
+device_groups = [
+ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
+]
+model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
+twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups)
+
+# --- Model ---
+model = TransformersModel(
+ model_id=MODEL_ID,
+ device_mesh=model_mesh,
+ remote_group='model',
+ ddp_config={'find_unused_parameters': True},
+)
+model.set_processor(InputProcessor)
+model.set_loss(InfonceLoss, temperature=TEMPERATURE, use_batch=True)
+model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE)
+model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=200,
+ num_training_steps=total_steps,
+)
+model.add_metric(EmbeddingMetric, is_training=True)
+
+# --- Template ---
+template = Qwen3_5Template(
+ model_id=MODEL_ID,
+ max_length=EMB_MAX_LENGTH,
+ enable_thinking=False,
+)
+
+# --- Training Loop ---
+for step, batch in enumerate(dataloader):
+ # batch: list of features with anchor/positive pairs
+ model.forward_backward(inputs=batch, task='embedding')
+ model.clip_grad_and_step(gradient_accumulation_steps=1)
+
+ if step % 10 == 0:
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Step {step}: {metric}')
+```
+
+### Key Parameters
+
+| Parameter | Recommended | Description |
+|:----------|:------------|:------------|
+| `temperature` | 0.05–0.1 | Lower = sharper contrast. 0.07 keeps gradients flowing until cosine > 0.75 |
+| `use_batch` | True | Enables cross-sample in-batch negatives for better efficiency |
+| `hard_negatives` | None or 7 | Fix negative count per sample; None uses all in-batch |
+| `find_unused_parameters` | True | Required for embedding models (only last hidden state contributes gradients) |
+
+---
+
+## Monitoring
+
+The `EmbeddingMetric` reports key training signals:
+
+| Metric | What it means |
+|:-------|:--------------|
+| `pos_sim` | Average anchor-positive cosine similarity (target: > 0.8) |
+| `neg_sim` | Average anchor-negative similarity (target: < 0.3) |
+| `loss` | InfoNCE loss value |
+| `grad_norm` | Gradient magnitude |
+
+Healthy training shows `pos_sim` rising and `neg_sim` stable or falling. If `pos_sim` saturates near 1.0, lower the temperature.
diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md
index 1c1a70fb6..707473910 100644
--- a/docs/source_en/Usage Guide/Quick-Start.md
+++ b/docs/source_en/Usage Guide/Quick-Start.md
@@ -156,6 +156,8 @@ if __name__ == '__main__':
In this training code, we constructed a dataset and loaded the Qwen/Qwen3.5-4B model, used LoRA with the all-linear approach, and completed one training run. In the logs, you can observe the process of loss gradually converging.
+> **Tip — Full-Parameter Training**: The example above uses LoRA for efficiency. To switch to full-parameter training, simply remove the `add_adapter_to_model` call (and the `from peft import LoraConfig` import). Everything else stays the same.
+
### torchrun
Twinkle supports running training in torchrun mode. In this scenario, Ray-related dependencies do not need to be installed.
@@ -471,7 +473,7 @@ python train.py
A major feature of Twinkle is support for multi-tenant mixed training. Specifically, multiple users can use a single base model for LoRA training, which can greatly reduce server-side deployment costs.
-Checkpoint resumption is also supported in client-server training. The recommended flow is to call `model.resume_from_checkpoint(resume_path)` to restore weights and optimizer state, then call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip consumed data. See [Twinkle-Client](./Server%20and%20Client/Twinkle-Client.md) and [self_cognition.py](../../../cookbook/client/twinkle/self_host/self_cognition.py).
+Checkpoint resumption is also supported in client-server training. The recommended flow is to call `model.resume_from_checkpoint(resume_path)` to restore weights and optimizer state, then call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip consumed data. See [Twinkle-Client](./Server%20and%20Client/Twinkle-Client.md) and [self_cognition.py](../../../cookbook/server_mode/twinkle/self_host/self_cognition.py).
Suppose we start a service using eight GPUs. First, we need to start the Ray cluster:
@@ -493,6 +495,8 @@ Next, start the server:
twinkle-server launch -c cookbook/client/server/transformer/server_config.yaml
```
+> For details on how to write `server_config.yaml`, see [Server Configuration](../Server%20and%20Client/Server.md).
+
The server will start three services: a sampler cluster, a model cluster, and a utility cluster.
Now you can perform client-side training:
diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst
index ef477f7fc..6128079c8 100644
--- a/docs/source_en/index.rst
+++ b/docs/source_en/index.rst
@@ -15,6 +15,7 @@ Twinkle DOCUMENTATION
Usage Guide/NPU-Support.md
Usage Guide/Train-as-a-Service.md
Usage Guide/Introduction-with-Qwen3.5.md
+ Usage Guide/Embedding-Training.md
.. toctree::
:maxdepth: 2
@@ -30,7 +31,6 @@ Twinkle DOCUMENTATION
Components/Sampler/index.rst
Components/Reward/index.rst
Components/Advantage/index.rst
- Components/Gym/index.rst
Components/Hub/index.rst
Components/Checkpoint Engine/index.rst
Components/Metrics/index.rst
@@ -41,6 +41,10 @@ Twinkle DOCUMENTATION
Components/Plugin/index.rst
Components/Kernel/index.rst
Components/Training Middleware/index.rst
+ Components/CLI/index.rst
+ Components/Notifier/index.rst
+ Components/Agentic/index.rst
+ Components/TUI/index.rst
Indices and tables
==================
diff --git a/docs/source_zh/index.rst b/docs/source_zh/index.rst
index 3d07d4b2a..6a7be7b5b 100644
--- a/docs/source_zh/index.rst
+++ b/docs/source_zh/index.rst
@@ -15,6 +15,7 @@ Twinkle DOCUMENTATION
使用指引/NPU的支持.md
使用指引/训练服务.md
使用指引/Qwen3.5最佳实践.md
+ 使用指引/Embedding训练.md
.. toctree::
:maxdepth: 2
@@ -30,7 +31,6 @@ Twinkle DOCUMENTATION
组件/采样器/index.rst
组件/奖励/index.rst
组件/优势/index.rst
- 组件/Gym/index.rst
组件/Hub/index.rst
组件/检查点引擎/index.rst
组件/指标/index.rst
@@ -41,6 +41,10 @@ Twinkle DOCUMENTATION
组件/组件化/index.rst
组件/Kernel/index.rst
组件/训练中间件/index.rst
+ 组件/CLI/index.rst
+ 组件/通知器/index.rst
+ 组件/Agentic/index.rst
+ 组件/TUI/index.rst
Indices and tables
==================
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Embedding\350\256\255\347\273\203.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Embedding\350\256\255\347\273\203.md"
new file mode 100644
index 000000000..94ea86ebe
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Embedding\350\256\255\347\273\203.md"
@@ -0,0 +1,120 @@
+# Embedding 模型训练
+
+Twinkle 支持基于 InfoNCE 损失的对比学习 Embedding 模型训练,内置 in-batch negatives 和跨 rank 聚合。本文介绍如何使用 Twinkle 训练 Embedding 模型。
+
+---
+
+## 概述
+
+Embedding 训练使用以下核心组件:
+
+| 组件 | 职责 |
+|:-----|:-----|
+| `InfonceLoss` | 对比损失,支持 in-batch negatives |
+| `EmbeddingMetric` | 追踪正/负对相似度和损失 |
+| `TransformersModel` | 可训练的 Embedding 模型(LoRA 或全参) |
+| `InputProcessor` | 将 anchor/positive 对处理为特征 |
+
+### 数据格式
+
+每个训练样本由 **(anchor, positive)** 对组成。在 Embedding 特征张量中:
+
+```
+embeddings: [anchor_0, positive_0, anchor_1, positive_1, ...]
+labels: [ 1, 0, 1, 0, ...]
+```
+
+- `labels=1` 标记新分组的起始位置(anchor)
+- `labels=0` 标记组内的 positive/negative
+
+---
+
+## 基础 Embedding 训练
+
+使用 DDP 的最小化 Embedding 训练脚本:
+
+```python
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.loss import InfonceLoss
+from twinkle.metric import EmbeddingMetric
+from twinkle.model import TransformersModel
+from twinkle.processor import InputProcessor
+from twinkle.template import Qwen3_5Template
+
+logger = get_logger()
+
+# --- 配置 ---
+MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
+MODEL_GPUS = 4
+BATCH_SIZE = 32
+LEARNING_RATE = 1e-5
+TEMPERATURE = 0.07
+EMB_MAX_LENGTH = 8192
+
+# --- 初始化 ---
+device_groups = [
+ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
+]
+model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
+twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups)
+
+# --- 模型 ---
+model = TransformersModel(
+ model_id=MODEL_ID,
+ device_mesh=model_mesh,
+ remote_group='model',
+ ddp_config={'find_unused_parameters': True},
+)
+model.set_processor(InputProcessor)
+model.set_loss(InfonceLoss, temperature=TEMPERATURE, use_batch=True)
+model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE)
+model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=200,
+ num_training_steps=total_steps,
+)
+model.add_metric(EmbeddingMetric, is_training=True)
+
+# --- 模板 ---
+template = Qwen3_5Template(
+ model_id=MODEL_ID,
+ max_length=EMB_MAX_LENGTH,
+ enable_thinking=False,
+)
+
+# --- 训练循环 ---
+for step, batch in enumerate(dataloader):
+ # batch: 包含 anchor/positive 对的特征列表
+ model.forward_backward(inputs=batch, task='embedding')
+ model.clip_grad_and_step(gradient_accumulation_steps=1)
+
+ if step % 10 == 0:
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Step {step}: {metric}')
+```
+
+### 关键参数
+
+| 参数 | 推荐值 | 说明 |
+|:----|:------|:-----|
+| `temperature` | 0.05–0.1 | 越低对比越尖锐;0.07 保持梯度流动直至 cosine > 0.75 |
+| `use_batch` | True | 启用跨样本 in-batch negatives 提升效率 |
+| `hard_negatives` | None 或 7 | 固定每样本负例数量;None 使用全部 in-batch |
+| `find_unused_parameters` | True | Embedding 模型必需(仅最后隐藏状态产生梯度) |
+
+---
+
+## 监控指标
+
+`EmbeddingMetric` 报告关键训练信号:
+
+| 指标 | 含义 |
+|:----|:-----|
+| `pos_sim` | anchor-positive 平均余弦相似度(目标 > 0.8) |
+| `neg_sim` | anchor-negative 平均相似度(目标 < 0.3) |
+| `loss` | InfoNCE 损失值 |
+| `grad_norm` | 梯度范数 |
+
+健康的训练表现为 `pos_sim` 持续上升、`neg_sim` 稳定或下降。如果 `pos_sim` 过早饱和至 1.0 附近,应降低 temperature。
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md"
index 39f6fe182..feaf7f4e8 100644
--- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md"
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md"
@@ -257,6 +257,19 @@ pip install torch_npu-2.7.1-cp311-cp311-linux_aarch64.whl
2. “待验证”功能可以尝试,但可能遇到兼容性问题
3. 遇到问题时,参考对应的示例代码进行配置
+## 示例代码
+
+Twinkle 在 NPU 上已验证的示例目前聚焦 Megatron smoke 路径;SFT 和 GRPO cookbook 示例暂无对应文件。
+
+### 远程训练(Tinker 协议)
+- **服务端配置**:[cookbook/remote/tinker/ascend/](https://github.com/modelscope/twinkle/tree/main/cookbook/remote/tinker/ascend)
+ - 提供 HTTP API 接口
+ - 支持远程训练和推理
+ - 适用于生产环境部署
+
+**运行示例**:
+暂无对应命令示例。
+
## 参考资源
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md"
index 19189948f..3bc5c4ba4 100644
--- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md"
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md"
@@ -157,6 +157,8 @@ if __name__ == '__main__':
在这个训练代码中,我们构造了一个数据集并拉起了Qwen/Qwen3.5-4B模型,使用all-linear方式加载了lora,并完成了一次训练。在日志中,可以看到loss逐步收敛的过程。
+> **提示 — 全参数训练**:上面的示例使用 LoRA 以提高效率。若要切换为全参数训练,只需移除 `add_adapter_to_model` 调用(以及 `from peft import LoraConfig` 导入),其余代码完全不变。
+
### torchrun
Twinkle 支持以 torchrun 模式运行训练。在这种场景下,不需要安装 Ray 相关的依赖。
@@ -470,7 +472,7 @@ python train.py
```
### 远程训练
-client-server 训练场景同样支持断点续训。推荐流程是调用 `model.resume_from_checkpoint(resume_path)` 恢复权重和优化器状态,再调用 `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` 跳过已消费数据。详细示例可参考 [Twinkle客户端](./服务端和客户端/Twinkle客户端.md) 和 [self_cognition.py](../../../cookbook/client/twinkle/self_host/self_cognition.py)。
+client-server 训练场景同样支持断点续训。推荐流程是调用 `model.resume_from_checkpoint(resume_path)` 恢复权重和优化器状态,再调用 `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` 跳过已消费数据。详细示例可参考 [Twinkle客户端](./服务端和客户端/Twinkle客户端.md) 和 [self_cognition.py](../../../cookbook/server_mode/twinkle/self_host/self_cognition.py)。
Twinkle 的一大特色是支持多租户用户混合训练。具体来说,多个用户可以使用一个基模进行 LoRA 训练,这样可以极大减小服务端部署成本。
@@ -494,6 +496,8 @@ CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0
twinkle-server launch -c cookbook/client/server/transformer/server_config.yaml
```
+> `server_config.yaml` 的编写方式详见 [服务端配置](../服务端和客户端/服务端.md)。
+
服务端会启动一个包含 Sampler 集群、模型集群、工具集群的三个服务。
下面可以进行client端训练:
diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Envs.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Envs.md"
new file mode 100644
index 000000000..675f05baf
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Envs.md"
@@ -0,0 +1,183 @@
+# 执行环境(Envs)
+
+Envs 模块提供了用于 Agentic 训练的 RL 执行环境抽象。环境可以在多轮 rollout 中交互式参与,也可以批量评估已完成的轨迹。
+
+## Env 基类
+
+```python
+from twinkle_agentic.envs.base import Env, StepResult
+
+class Env(ABC):
+
+ def reset(self, trajectory=None) -> StepResult:
+ """重置环境,开始新一轮。"""
+
+ @abstractmethod
+ def step(self, tool_name: str, arguments: dict) -> StepResult:
+ """执行单个动作,返回观测 + 奖励 + 完成标志。"""
+
+ def tools(self) -> List[ToolInfo]:
+ """返回此环境中可用的工具定义。"""
+
+ def evaluate(self, trajectories, **kwargs) -> List[float]:
+ """批量评估已完成的轨迹,返回奖励列表。"""
+
+ def close(self) -> None:
+ """释放资源。"""
+```
+
+### StepResult
+
+```python
+@dataclass
+class StepResult:
+ observation: str = '' # 动作执行后的环境观测
+ reward: float = 0.0 # 此步骤的标量奖励
+ done: bool = False # 是否终止
+ info: Dict[str, Any] = field(default_factory=dict) # 额外元数据
+```
+
+### 两种使用模式
+
+1. **交互模式**(多轮 rollout)—— 逐步执行:
+
+```python
+env = MyEnv()
+env.reset(trajectory)
+result = env.step('search', {'query': 'Python'})
+# ... 重复直到 result.done
+```
+
+2. **批量评估模式** —— 评估已完成的轨迹:
+
+```python
+rewards = env.evaluate(completed_trajectories)
+```
+
+## EnvTool
+
+`EnvTool` 将 `Env` 包装为 `Tool`,连接环境与 `ToolManager` 和 `MultiTurnRollout`。
+
+```python
+from twinkle_agentic.envs.env_tool import EnvTool
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+env = MyEnv()
+
+# 为环境中定义的每个工具创建一个 EnvTool
+env_tools = EnvTool.from_env(env)
+
+# 注册到 ToolManager
+manager = ToolManager(env_tools)
+```
+
+### 核心特性
+
+| 特性 | 说明 |
+|------|------|
+| `from_env(env)` | 工厂方法:为 `env.tools()` 中的每个工具创建一个 `EnvTool`。 |
+| `last_result` | 存储最近一次 `StepResult` 供调用方检查。 |
+| `done` | 属性:最后一步是否终止了回合。 |
+| `episode_reward` | 属性:来自 `info['episode_reward']` 的累计奖励。 |
+
+### 手动构造
+
+```python
+env_tool = EnvTool(
+ env=my_env,
+ tool_name='execute_code',
+ description='在沙箱中执行 Python 代码。',
+ parameters={
+ 'type': 'object',
+ 'properties': {
+ 'code': {'type': 'string', 'description': '要执行的 Python 代码。'},
+ },
+ 'required': ['code'],
+ },
+)
+```
+
+## OpenEnv
+
+`OpenEnv` 将基于 WebSocket 的 [OpenEnv](https://github.com/OpenEnv) 环境服务器适配为同步的 Twinkle `Env`。
+
+```python
+from twinkle_agentic.envs.openenv import OpenEnv
+
+env = OpenEnv(
+ base_url='http://localhost:8000',
+ env_cls='coding_env.CodingEnv', # 可选的类型化客户端
+ env_kwargs={'message_timeout_s': 30},
+ tool_schema=[...], # 可选的工具定义
+)
+```
+
+### 参数
+
+| 参数 | 类型 | 说明 |
+|------|------|------|
+| `base_url` | `str` | 运行中的 OpenEnv 服务器 URL。 |
+| `env_cls` | `str` 或 class | 类型化客户端的点分导入路径或类。`None` 使用 `GenericEnvClient`。 |
+| `env_kwargs` | `Dict` | 传递给客户端构造函数的额外参数。 |
+| `tool_schema` | `List[ToolInfo]` | 通过 `tools()` 暴露的工具定义。 |
+| `action_mapper` | `Callable` | 自定义函数,将 `(tool_name, args)` 映射为发送给服务器的动作字典。 |
+
+### 与 Rollout 集成使用
+
+```python
+from twinkle_agentic.envs.openenv import OpenEnv
+from twinkle_agentic.envs.env_tool import EnvTool
+from twinkle_agentic.tools.tool_manager import ToolManager
+from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout
+
+# 设置环境
+env = OpenEnv(base_url='http://localhost:8000', tool_schema=[...])
+env.reset()
+
+# 桥接到 ToolManager
+env_tools = EnvTool.from_env(env)
+manager = ToolManager(env_tools)
+
+# 在 rollout 中使用
+rollout = APIMultiTurnRollout(api=api, tool_manager=manager, max_turns=10)
+results = rollout(trajectories)
+```
+
+### 实现自定义环境
+
+```python
+from twinkle_agentic.envs.base import Env, StepResult
+
+class CodeExecutionEnv(Env):
+
+ def reset(self, trajectory=None):
+ self._sandbox = create_sandbox()
+ return StepResult(observation='沙箱已就绪。')
+
+ def step(self, tool_name, arguments):
+ code = arguments.get('code', '')
+ output = self._sandbox.run(code)
+ return StepResult(
+ observation=output,
+ reward=1.0 if 'error' not in output.lower() else 0.0,
+ done=False,
+ )
+
+ def tools(self):
+ return [{
+ 'type': 'function',
+ 'function': {
+ 'name': 'execute_code',
+ 'description': '运行 Python 代码。',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'code': {'type': 'string'},
+ },
+ },
+ },
+ }]
+
+ def close(self):
+ self._sandbox.cleanup()
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Multi-Turn-Tool-Usage.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Multi-Turn-Tool-Usage.md"
new file mode 100644
index 000000000..1c13bfbe0
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Multi-Turn-Tool-Usage.md"
@@ -0,0 +1,205 @@
+# 多轮工具使用指南
+
+本指南介绍如何在 Twinkle 中设置和运行带工具调用的多轮 Agentic rollout。
+
+## 架构概览
+
+Agentic rollout 管线由四个核心组件组成:
+
+- **Tool** —— 实现特定能力(搜索、代码执行等)
+- **ToolManager** —— 注册工具并分发 LLM 工具调用
+- **Env**(可选)—— RL 环境,通过 `EnvTool` 暴露工具
+- **Rollout** —— 驱动多轮对话循环
+
+## 快速开始:基于 API 的 Rollout
+
+使用 OpenAI 兼容 API 运行多轮工具使用 rollout 的最简方式:
+
+```python
+from twinkle_agentic.protocol.openai import OpenAI
+from twinkle_agentic.tools.base import Tool
+from twinkle_agentic.tools.tool_manager import ToolManager
+from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout
+from twinkle.data_format.sampling import SamplingParams
+
+# 1. 定义工具
+class WeatherTool(Tool):
+ def __call__(self, tool_name, arguments):
+ city = arguments.get('city', '未知')
+ return f'{city}的天气:晴,25°C。'
+
+ def tool_info(self):
+ return {
+ 'type': 'function',
+ 'function': {
+ 'name': 'get_weather',
+ 'description': '获取城市的当前天气。',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'city': {'type': 'string', 'description': '城市名称。'},
+ },
+ 'required': ['city'],
+ },
+ },
+ }
+
+# 2. 设置 ToolManager
+manager = ToolManager([WeatherTool()])
+
+# 3. 创建 API 客户端
+api = OpenAI(model='qwen3.5-32b', base_url='http://localhost:8000/v1')
+
+# 4. 创建 rollout
+rollout = APIMultiTurnRollout(
+ api=api,
+ tool_manager=manager,
+ sampling_params=SamplingParams(temperature=0.7, max_tokens=2048),
+ max_turns=6,
+ concurrency=8,
+)
+
+# 5. 准备轨迹
+trajectories = [
+ {
+ 'messages': [
+ {'role': 'user', 'content': '北京今天天气怎么样?'},
+ ],
+ },
+]
+
+# 6. 运行 rollout
+results = rollout(trajectories)
+for r in results:
+ print(f"轮次: {r['turns']}, 停止原因: {r['stop_reason']}")
+ for msg in r['messages']:
+ print(f" [{msg['role']}] {msg.get('content', '')[:100]}")
+```
+
+## 训练集成:基于 vLLM 的 Rollout
+
+用于 RLHF 训练时,使用 `MultiTurnRollout`,它会生成 `input_ids` 和 `labels`:
+
+```python
+from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
+from twinkle.data_format.sampling import SamplingParams
+
+rollout = MultiTurnRollout(
+ sampler=vllm_sampler, # vLLMSampler 实例
+ template=template, # 聊天模板
+ tool_manager=manager,
+ sampling_params=SamplingParams(temperature=0.7, max_tokens=4096),
+ max_turns=6,
+ max_trajectory_tokens=8192,
+ trace_dir='rollout_traces/',
+)
+
+# 在 GRPO 训练循环中
+results = rollout(batch_trajectories)
+# results 包含 input_ids、labels、logprobs 用于训练
+```
+
+## 将环境用作工具
+
+将 RL 环境桥接到工具管线中:
+
+```python
+from twinkle_agentic.envs.base import Env, StepResult
+from twinkle_agentic.envs.env_tool import EnvTool
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+# 定义环境
+class CodeEnv(Env):
+ def step(self, tool_name, arguments):
+ code = arguments.get('code', '')
+ # 在沙箱中执行代码
+ result = execute_in_sandbox(code)
+ return StepResult(observation=result, reward=1.0, done=False)
+
+ def tools(self):
+ return [{
+ 'type': 'function',
+ 'function': {
+ 'name': 'run_python',
+ 'description': '执行 Python 代码。',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'code': {'type': 'string'},
+ },
+ 'required': ['code'],
+ },
+ },
+ }]
+
+# 桥接 Env -> Tool -> ToolManager
+env = CodeEnv()
+env_tools = EnvTool.from_env(env)
+manager = ToolManager(env_tools)
+
+# 照常在 rollout 中使用 manager
+rollout = APIMultiTurnRollout(api=api, tool_manager=manager, max_turns=10)
+```
+
+## 使用 OpenEnv 环境
+
+连接远程 OpenEnv WebSocket 服务器:
+
+```python
+from twinkle_agentic.envs.openenv import OpenEnv
+from twinkle_agentic.envs.env_tool import EnvTool
+
+env = OpenEnv(
+ base_url='http://localhost:8000',
+ env_cls='coding_env.CodingEnv',
+ tool_schema=[{
+ 'type': 'function',
+ 'function': {
+ 'name': 'submit',
+ 'description': '提交代码解决方案。',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'code': {'type': 'string'},
+ },
+ },
+ },
+ }],
+)
+
+env.reset()
+env_tools = EnvTool.from_env(env)
+manager = ToolManager(env_tools)
+```
+
+## 每轨迹独立 ToolManager
+
+当每个轨迹需要独立工具集时(例如,轨迹绑定的状态):
+
+```python
+# 创建每轨迹的 manager
+managers = []
+for traj in trajectories:
+ env = create_env_for(traj)
+ env_tools = EnvTool.from_env(env)
+ managers.append(ToolManager(env_tools))
+
+# 传入列表(与轨迹 1:1 对齐)
+results = rollout(trajectories, tool_manager=managers)
+```
+
+## 跟踪调试
+
+两种 rollout 实现都支持跟踪文件输出用于调试:
+
+```python
+rollout = APIMultiTurnRollout(
+ api=api,
+ tool_manager=manager,
+ trace_dir='traces/',
+ trace_callback=lambda t: t['turns'] > 1, # 仅存储多轮对话
+ success_callback=lambda t: t.get('stop_reason') == 'stop',
+)
+```
+
+跟踪文件以 `{step}-{ok|fail}-{id}.json` 格式保存,包含完整对话和元数据。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Preprocessor.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Preprocessor.md"
new file mode 100644
index 000000000..a5730abc8
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Preprocessor.md"
@@ -0,0 +1,189 @@
+# Agentic 预处理器
+
+Agentic 预处理器模块提供了基于流水线的多轮对话数据质量过滤框架,用于 RLHF / Agentic 微调之前的训练数据清洗和过滤。
+
+## QualityPreprocessor
+
+`QualityPreprocessor` 是一个轻量级流水线运行器,接受过滤器列表并按顺序执行。每个步骤接收行列表,返回 `(kept, dropped)`,流水线会记录每步统计信息。
+
+```python
+from twinkle_agentic.preprocessor import QualityPreprocessor, HardFilter, DeadLoopFilter
+
+pipeline = [
+ HardFilter(min_user_chars=10),
+ DeadLoopFilter(),
+]
+preprocessor = QualityPreprocessor(pipeline, dropped_log_path='dropped.jsonl')
+
+# rows 是列格式的字典(Dataset.map 格式)
+cleaned = preprocessor(rows)
+```
+
+### 参数
+
+| 参数 | 类型 | 说明 |
+|------|------|------|
+| `pipeline` | `List[Callable]` | 有序的过滤步骤列表。每个步骤接收 `List[Dict]`,返回 `(kept, dropped)`。 |
+| `dropped_log_path` | `str` | 可选的 JSONL 文件路径,用于记录被丢弃的行及步骤名称和原因。 |
+
+## 内置过滤器
+
+### HardFilter
+
+基于硬规则的过滤器,使用确定性规则移除质量差的行。支持多语言检测(EN/ZH/JA/KO)。
+
+```python
+from twinkle_agentic.preprocessor import HardFilter
+
+f = HardFilter(
+ min_user_chars=10, # 非 CJK 用户查询最小字符数
+ min_user_chars_cjk=6, # CJK 用户查询最小字符数
+ min_assistant_chars_2turn=80, # 两轮对话中助手回复最小长度
+ min_thinking_chars=200, # 思考链最小长度(可豁免过滤)
+ system_deny_keywords=['hack', 'exploit'],
+ max_chars_per_round=50000,
+ max_total_chars=200000,
+ max_rounds=50,
+)
+```
+
+**丢弃原因:** `trivial_single_turn`(平凡单轮)、`shallow_reply`(浅回复)、`all_empty_assistant`(全空助手)、`system_deny_keyword`(系统拒绝关键词)、`round_too_long`(单轮过长)、`total_too_long`(总长过长)、`too_many_rounds`(轮次过多)
+
+### DeadLoopFilter
+
+检测助手消息中的犹豫/死循环模式——重复自我纠正、级联纠正和高 n-gram 重复。
+
+```python
+from twinkle_agentic.preprocessor import DeadLoopFilter
+
+f = DeadLoopFilter(
+ hesitation_density_threshold=7.0, # 每 1000 字符犹豫标记数(响应)
+ cascade_threshold=5, # 窗口内级联标记数
+ cascade_window=800, # 窗口大小(字符)
+ repetition_threshold=0.45, # N-gram 重复率
+ think_hesitation_density_threshold=15.0, # 块更宽松
+ think_repetition_threshold=0.65,
+)
+```
+
+对 `` 推理块使用更宽松的阈值(允许自由发散),对可见响应使用更严格的阈值。
+
+### DedupFilter
+
+全局最长优先去重。签名由第一个真实用户轮次(首尾)和第一个助手回复推导。
+
+```python
+from twinkle_agentic.preprocessor import DedupFilter
+
+f = DedupFilter(prefix_chars=100, asst_chars=100)
+kept, dropped = f(all_rows) # 必须在一次调用中传入整个数据集
+```
+
+> **注意:** `DedupFilter` 需要在单次调用中接收完整数据集。**不要**将它放入 `QualityPreprocessor` 中(后者按批处理)。请在流水线之前或之后单独运行。
+
+### RefuseFilter
+
+检测第一条助手回复中的自我引用式拒绝(如"我无法帮助您")。多语言模式匹配(EN/ZH/JA/KO)。
+
+```python
+from twinkle_agentic.preprocessor import RefuseFilter
+
+f = RefuseFilter(check_window=600) # 仅检查前 N 个字符
+```
+
+### TokenSoupFilter
+
+检测乱码/token-soup 输出,检查替换字符、控制字符、私用区 Unicode、泄漏的特殊 token、单字符重复和脚本混乱。
+
+```python
+from twinkle_agentic.preprocessor import TokenSoupFilter
+
+f = TokenSoupFilter(
+ replacement_char_ratio=0.02,
+ special_token_count=20,
+ script_chaos_threshold=0.55,
+)
+```
+
+### PIIPresidioFilter
+
+基于 Microsoft Presidio + spaCy NER + Faker 的多语言 PII 检测和重写。检测并替换个人身份信息(姓名、邮箱、电话号码、地址等)。
+
+```python
+from twinkle_agentic.preprocessor import PIIPresidioFilter
+
+f = PIIPresidioFilter(languages=['en', 'zh'])
+```
+
+### IntentClassifier
+
+启发式意图分类器,为每行标注检测到的意图。可插拔的检测器管线。
+
+```python
+from twinkle_agentic.preprocessor import IntentClassifier
+
+classifier = IntentClassifier()
+```
+
+**意图类别:** `tool_call`(工具调用)、`code`(代码)、`math`(数学)、`complex_logic`(复杂逻辑)、`reasoning`(推理)、`user_dissatisfaction`(用户不满)、`other`(其他)
+
+### ScoreFilter
+
+可插拔评分器过滤器,内置字符级指标、语义相似度和代码执行评分器。
+
+```python
+from twinkle_agentic.preprocessor import ScoreFilter
+
+f = ScoreFilter()
+```
+
+**内置评分器:** `ChrMinScorer`、`SIFDScorer`、`PassNScorer`、`ParaphraseScorer`
+
+### ModelFilter
+
+按模型 ID 白名单过滤行。
+
+```python
+from twinkle_agentic.preprocessor import ModelFilter
+
+f = ModelFilter(allowed_models=['qwen3.5-4b', 'qwen3.5-32b'])
+```
+
+### MessageNormalizer
+
+三遍消息规范化:心跳剥离、工具调用重写、连续同角色消息合并。
+
+```python
+from twinkle_agentic.preprocessor import MessageNormalizer
+
+normalizer = MessageNormalizer()
+```
+
+## 完整流水线示例
+
+```python
+from twinkle_agentic.preprocessor import (
+ QualityPreprocessor,
+ HardFilter,
+ DeadLoopFilter,
+ RefuseFilter,
+ TokenSoupFilter,
+ MessageNormalizer,
+ DedupFilter,
+)
+
+# 第一步:全局去重(必须在完整数据集上运行)
+dedup = DedupFilter()
+rows, _ = dedup(all_rows)
+
+# 第二步:按批流水线
+pipeline = [
+ HardFilter(min_user_chars=10, max_rounds=30),
+ DeadLoopFilter(),
+ RefuseFilter(),
+ TokenSoupFilter(),
+ MessageNormalizer(),
+]
+preprocessor = QualityPreprocessor(pipeline, dropped_log_path='dropped.jsonl')
+cleaned = preprocessor(rows)
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Protocol.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Protocol.md"
new file mode 100644
index 000000000..1e03092d6
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Protocol.md"
@@ -0,0 +1,91 @@
+# 协议(Protocol)
+
+Protocol 模块提供了抽象的 LLM API 客户端接口及其 OpenAI 兼容实现。它将 Twinkle 的 `Trajectory` / `SamplingParams` 数据类型与外部 LLM 推理服务连接起来。
+
+## API 基类
+
+```python
+from abc import ABC, abstractmethod
+from twinkle.data_format import Trajectory
+from twinkle.data_format.message import Message
+from twinkle.data_format.sampling import SamplingParams
+
+class API(ABC):
+ """抽象 LLM API 客户端:Trajectory + SamplingParams -> 助手 Message"""
+
+ @abstractmethod
+ def __call__(
+ self,
+ trajectory: Trajectory,
+ sampling_params: SamplingParams,
+ **kwargs,
+ ) -> Union[Message, List[Message]]:
+ raise NotImplementedError()
+```
+
+`API` 类定义了一个简单的契约:给定对话轨迹和采样参数,返回一条或多条助手消息。
+
+## OpenAI
+
+`OpenAI` 是内置实现,兼容任何支持 `/v1/chat/completions` 协议的端点(OpenAI、Azure OpenAI、vLLM、SGLang、Ollama 等)。
+
+```python
+from twinkle_agentic.protocol.openai import OpenAI
+
+api = OpenAI(
+ model='qwen3.5-32b',
+ base_url='http://localhost:8000/v1',
+ api_key='EMPTY',
+)
+```
+
+### 参数
+
+| 参数 | 类型 | 说明 |
+|------|------|------|
+| `model` | `str` | API 请求中传递的模型名称。 |
+| `api_key` | `str` | API 密钥。默认使用 `OPENAI_API_KEY` 环境变量。 |
+| `base_url` | `str` | API 端点的基础 URL(如 `http://localhost:8000/v1`)。 |
+| `client_kwargs` | `Dict` | 转发给 `openai.OpenAI` 客户端构造函数的额外关键字参数。 |
+
+### 使用方法
+
+```python
+from twinkle.data_format import Trajectory
+from twinkle.data_format.sampling import SamplingParams
+
+trajectory = {
+ 'messages': [
+ {'role': 'user', 'content': '法国的首都是什么?'},
+ ]
+}
+
+sp = SamplingParams(temperature=0.7, max_tokens=512)
+reply = api(trajectory, sp)
+# reply 是一个 Message 字典:{'role': 'assistant', 'content': '...'}
+```
+
+### 特性
+
+- **工具调用**:自动将 `trajectory['tools']` 映射到 API 请求,并解析响应中的结构化 `tool_calls`。
+- **推理内容**:保留支持推理的模型返回的 `reasoning_content`(如 o1 风格推理)。
+- **完成原因**:在返回消息中暴露 `finish_reason`,供多轮驱动器检测长度截断。
+- **多样本**:当 `sampling_params.num_samples > 1` 时,返回消息列表(每个 choice 一条)。
+
+### 自定义 API 客户端
+
+要集成非 OpenAI API,请继承 `API`:
+
+```python
+from twinkle_agentic.protocol.base import API
+
+class MyCustomAPI(API):
+
+ def __call__(self, trajectory, sampling_params, **kwargs):
+ # 调用自定义端点
+ response = my_llm_client.chat(
+ messages=trajectory['messages'],
+ temperature=sampling_params.temperature,
+ )
+ return {'role': 'assistant', 'content': response.text}
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Rollout.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Rollout.md"
new file mode 100644
index 000000000..b74c1e791
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Rollout.md"
@@ -0,0 +1,140 @@
+# 多轮 Rollout
+
+Rollout 模块提供了用于 Agentic RLHF 训练的多轮对话 rollout 引擎。包含两种实现:用于批量 vLLM 采样的 `MultiTurnRollout` 和用于 OpenAI 兼容 API 端点的 `APIMultiTurnRollout`。
+
+## Rollout 基类
+
+```python
+from abc import ABC, abstractmethod
+from twinkle.data_format import Trajectory
+
+class Rollout(ABC):
+
+ @abstractmethod
+ def __call__(self, trajectories: List[Trajectory], **kwargs) -> List[Trajectory]:
+ raise NotImplementedError()
+```
+
+所有 rollout 接受轨迹列表并返回相同数量的轨迹,附带额外字段(`messages`、`turns`、`stop_reason`、`truncated`)。
+
+## MultiTurnRollout
+
+批量多轮 rollout 引擎,使用 vLLM 采样器进行生成。每轮中所有活跃轨迹通过单次批量采样调用并行处理,最大化吞吐量。
+
+### 每轮循环
+
+1. 将每个轨迹编码为带生成提示的 `InputFeature`
+2. 批量调用 `sampler.sample(active_pifs)` —— 所有活跃轨迹并行
+3. 检查终止条件:`stop_reason == 'length'`、无工具调用、或达到最大轮次
+4. 通过 `ToolManager` 分发工具调用,追加工具响应
+5. 计算桥接 token(工具轮次 + 生成提示),设置 `labels = -100`
+6. 重复直到所有轨迹完成
+
+```python
+from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
+from twinkle_agentic.tools.tool_manager import ToolManager
+from twinkle.data_format.sampling import SamplingParams
+
+rollout = MultiTurnRollout(
+ sampler=vllm_sampler,
+ template=template,
+ tool_manager=tool_manager,
+ sampling_params=SamplingParams(temperature=0.7, max_tokens=4096),
+ max_turns=6,
+ max_trajectory_tokens=8192,
+ trace_dir='rollout_traces/',
+)
+
+# 运行 rollout
+results = rollout(trajectories)
+```
+
+### 参数
+
+| 参数 | 类型 | 说明 |
+|------|------|------|
+| `sampler` | Sampler | 用于批量生成的 vLLM 采样器实例。 |
+| `template` | `Template` | 用于编码/解码的聊天模板。 |
+| `tool_manager` | `ToolManager` | 工具分发器。也可以按调用传入。 |
+| `sampling_params` | `SamplingParams` | 默认采样参数。 |
+| `max_turns` | `int` | 每个轨迹的最大轮次(默认:6)。 |
+| `max_trajectory_tokens` | `int` | 最大总 token 长度;超出则截断轨迹。 |
+| `trace_dir` | `str` | 每轨迹 JSON 跟踪文件的目录。 |
+| `trace_callback` | `Callable` | 决定是否存储轨迹跟踪。 |
+| `success_callback` | `Callable` | 决定文件名前缀(`ok-` 或 `fail-`)。 |
+
+### 输出字段
+
+每个输出轨迹字典包含:
+
+| 字段 | 类型 | 说明 |
+|------|------|------|
+| `messages` | `List[Dict]` | 包含工具轮次的完整对话。 |
+| `input_ids` | `List[int]` | 完整序列的 token ID。 |
+| `labels` | `List[int]` | 训练标签(非可训练 token 为 `-100`)。 |
+| `turns` | `int` | 执行的轮次数。 |
+| `stop_reason` | `str` | `'stop'` / `'length'` |
+| `truncated` | `bool` | 轨迹是否被截断。 |
+| `logprobs` | `List` | 每 token 的对数概率(如有)。 |
+
+### Ray 远程支持
+
+`MultiTurnRollout` 使用 `@remote_class()` 装饰器,支持作为 Ray actor 透明部署:
+
+```python
+# rollout 可以作为 Ray 远程 actor 运行
+rollout_actor = MultiTurnRollout.remote(sampler=sampler, template=template, ...)
+results = ray.get(rollout_actor.__call__.remote(trajectories))
+```
+
+## APIMultiTurnRollout
+
+通过 OpenAI 兼容 chat-completions API 进行多轮 rollout。每个轨迹在线程池中独立运行,实现网络并发。
+
+```python
+from twinkle_agentic.rollout.api_multi_turn import APIMultiTurnRollout
+from twinkle_agentic.protocol.openai import OpenAI
+
+api = OpenAI(model='qwen3.5-32b', base_url='http://localhost:8000/v1')
+
+rollout = APIMultiTurnRollout(
+ api=api,
+ tool_manager=tool_manager,
+ sampling_params=SamplingParams(temperature=0.7),
+ max_turns=6,
+ concurrency=8,
+ trace_dir='api_traces/',
+)
+
+results = rollout(trajectories)
+```
+
+### 参数
+
+| 参数 | 类型 | 说明 |
+|------|------|------|
+| `api` | `OpenAI` | OpenAI 兼容 API 客户端。 |
+| `tool_manager` | `ToolManager` | 工具分发器(单个或按轨迹的列表)。 |
+| `sampling_params` | `SamplingParams` | 默认采样参数。 |
+| `max_turns` | `int` | 每轨迹最大轮次(默认:6)。 |
+| `concurrency` | `int` | 并行 API 调用的线程池大小(默认:8)。 |
+| `extra_body` | `Dict` | API 请求中附加的额外字段。 |
+| `trace_dir` | `str` | 跟踪文件目录。 |
+
+### 停止原因
+
+| 原因 | 说明 |
+|------|------|
+| `stop` | 助手回复未包含工具调用(自然结束)。 |
+| `length` | API 返回 `finish_reason='length'`(token 限制)。 |
+| `max_turns` | 达到 `max_turns` 限制。 |
+| `api_error` | API 调用或工具执行抛出异常。 |
+
+## 选择建议
+
+| 特性 | MultiTurnRollout | APIMultiTurnRollout |
+|------|-----------------|---------------------|
+| **后端** | vLLM 采样器(本地 GPU) | OpenAI 兼容 API |
+| **训练集成** | 生成 `input_ids` / `labels` 用于 GRPO | 仅消息(用于数据收集) |
+| **批处理** | GPU 级别批量并行 | 网络级别线程并发 |
+| **用例** | 在线 RLHF 训练循环 | 离线数据生成 / 评估 |
diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/Tools.md" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Tools.md"
new file mode 100644
index 000000000..122b75a14
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/Tools.md"
@@ -0,0 +1,119 @@
+# 工具与 ToolManager
+
+Tools 模块提供了抽象工具接口和中央工具分发器(`ToolManager`),用于 Agentic 多轮 rollout。工具遵循 OpenAI function-calling schema,与 LLM 工具调用能力无缝集成。
+
+## Tool 基类
+
+```python
+from abc import ABC, abstractmethod
+from twinkle.data_format import Tool as ToolInfo
+
+class Tool(ABC):
+
+ @abstractmethod
+ def __call__(self, tool_name: str, arguments: Dict[str, Any]) -> str:
+ """执行工具并返回字符串结果。"""
+ raise NotImplementedError
+
+ @abstractmethod
+ def tool_info(self) -> ToolInfo:
+ """返回 OpenAI 兼容的工具 schema。"""
+ raise NotImplementedError
+```
+
+### 实现自定义工具
+
+```python
+from twinkle_agentic.tools.base import Tool
+
+class SearchTool(Tool):
+
+ def __call__(self, tool_name: str, arguments: dict) -> str:
+ query = arguments.get('query', '')
+ # 执行搜索逻辑
+ return f'搜索结果:{query}'
+
+ def tool_info(self):
+ return {
+ 'type': 'function',
+ 'function': {
+ 'name': 'search',
+ 'description': '搜索网络信息。',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'query': {
+ 'type': 'string',
+ 'description': '搜索查询。',
+ },
+ },
+ 'required': ['query'],
+ },
+ },
+ }
+```
+
+## ToolManager
+
+`ToolManager` 是工具的注册中心和分发器。它解析 LLM 结构化输出中的工具调用,并路由到正确的工具实现。
+
+```python
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+# 通过 Tool 实例列表初始化
+manager = ToolManager([search_tool, calculator_tool])
+
+# 或通过字典初始化
+manager = ToolManager({'search': search_tool, 'calc': calculator_tool})
+
+# 或动态注册
+manager = ToolManager()
+manager.register(search_tool)
+manager.register(calculator_tool)
+```
+
+### 核心方法
+
+| 方法 | 说明 |
+|------|------|
+| `register(tool)` | 注册工具(名称从 `tool_info()` 提取)。 |
+| `unregister(name)` | 按名称移除工具。 |
+| `names()` | 列出所有已注册的工具名称。 |
+| `copy()` | 创建管理器的浅拷贝。 |
+| `tool_infos()` | 返回所有工具 schema 列表(用于 API 请求)。 |
+| `__call__(tool_call)` | 分发工具调用并返回结果字符串。 |
+
+### 分发工具调用
+
+`ToolManager` 接受 OpenAI 格式的工具调用字典:
+
+```python
+tool_call = {
+ 'id': 'call_1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'search',
+ 'arguments': '{"query": "Python 教程"}',
+ },
+}
+
+result = manager(tool_call)
+# result: '搜索结果:Python 教程'
+```
+
+**错误处理:** 如果工具名未知、参数是无效 JSON 或工具抛出异常,`ToolManager` 返回描述性错误字符串而不是抛出异常——这保证了 rollout 循环的持续运行。
+
+### 与 Rollout 集成
+
+```python
+from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
+
+rollout = MultiTurnRollout(
+ sampler=sampler,
+ template=template,
+ tool_manager=manager, # 传入工具管理器
+ max_turns=6,
+)
+```
+
+Rollout 引擎对模型生成的每个工具调用执行 `manager(tool_call)`,并将结果作为 `{'role': 'tool', 'content': result}` 消息追加。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/Agentic/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/Agentic/index.rst"
new file mode 100644
index 000000000..802034366
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/Agentic/index.rst"
@@ -0,0 +1,11 @@
+Agentic
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Preprocessor.md
+ Protocol.md
+ Rollout.md
+ Tools.md
+ Envs.md
+ Multi-Turn-Tool-Usage.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/CLI/CLI.md" "b/docs/source_zh/\347\273\204\344\273\266/CLI/CLI.md"
new file mode 100644
index 000000000..e21c5ea44
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/CLI/CLI.md"
@@ -0,0 +1,134 @@
+# CLI 命令行配置
+
+CLI 模块为 Twinkle 训练脚本提供统一的配置系统。它将多种配置来源(环境变量、`.env` 文件、YAML 配置、命令行参数)合并到一个带类型的 `Args` 数据类中。
+
+## 配置优先级
+
+配置按以下顺序应用(后者覆盖前者):
+
+1. **数据类默认值** — 开箱即用
+2. **`.env` 文件** — 项目本地配置
+3. **环境变量** — `TWINKLE_` 前缀或裸键名
+4. **YAML 配置文件** — `--config path/to/config.yaml`
+5. **命令行参数** — `--key value`(最高优先级)
+
+所有键名不区分大小写,横杠和下划线等价。
+
+## 快速开始
+
+```python
+from twinkle.cli import CLI
+
+args = CLI.from_args()
+
+# 访问类型化的参数组
+print(args.model.model_id)
+print(args.training.max_steps)
+print(args.optimizer.learning_rate)
+
+# 或获取字典用于组件构造
+model_kwargs = args.get_model_args()
+optimizer_kwargs = args.get_optimizer_args()
+```
+
+## 参数组
+
+| 分组 | 类名 | 关键参数 |
+|:-----|:-----|:---------|
+| model | `ModelArgs` | `model_id`, `mixed_precision`, `strategy`, `gradient_checkpointing` |
+| lora | `LoraArgs` | `use_lora`, `lora_r`, `lora_alpha`, `lora_target_modules` |
+| dataset | `DatasetArgs` | `dataset_id`, `subset_name`, `split`, `streaming` |
+| template | `TemplateArgs` | `template_cls`, `max_length`, `truncation_strategy`, `enable_thinking` |
+| training | `TrainingArgs` | `max_steps`, `batch_size`, `micro_batch_size`, `output_dir`, `save_steps` |
+| optimizer | `OptimizerArgs` | `optimizer_cls`, `learning_rate`, `weight_decay`, `max_grad_norm` |
+| scheduler | `SchedulerArgs` | `scheduler_cls`, `num_warmup_steps`, `t_max` |
+| loss | `LossArgs` | `loss_cls`, `epsilon`, `beta`, `sft_weight` |
+| sampler | `SamplerArgs` | `sampler_type`, `gpu_memory_utilization`, `tensor_parallel_size` |
+| sampling | `SamplingArgs` | `max_tokens`, `temperature`, `top_k`, `top_p`, `num_samples` |
+| infra | `InfraArgs` | `mode`, `nproc_per_node`, `model_gpus`, `sampler_gpus`, `dp_size` |
+| server | `ServerArgs` | `config`, `host`, `port`, `ray_namespace` |
+| rl | `RLArgs` | `num_generations`, `advantage_type`, `reward_fns` |
+| checkpoint | `CheckpointArgs` | `save_optimizer`, `merge_and_sync`, `platform` |
+
+## YAML 配置示例
+
+```yaml
+# config.yaml
+model_id: ms://Qwen/Qwen3.5-4B
+mixed_precision: bf16
+strategy: accelerate
+
+use_lora: true
+lora_r: 16
+lora_alpha: 32
+
+dataset_id: ms://swift/self-cognition
+max_length: 4096
+
+batch_size: 8
+micro_batch_size: 2
+max_steps: 200
+learning_rate: 1e-5
+
+mode: ray
+nproc_per_node: 8
+model_gpus: 4
+sampler_gpus: 4
+```
+
+## 命令行用法
+
+```bash
+# 使用 YAML 配置
+python train.py --config config.yaml
+
+# 覆盖特定值
+python train.py --config config.yaml --learning_rate 5e-6 --max_steps 500
+
+# 布尔标志
+python train.py --use_lora --no_gradient_checkpointing
+
+# 无配置文件(全部从命令行指定)
+python train.py --model_id ms://Qwen/Qwen3.5-4B --batch_size 4
+```
+
+## 环境变量
+
+```bash
+# TWINKLE_ 前缀
+export TWINKLE_MODEL_ID=ms://Qwen/Qwen3.5-4B
+export TWINKLE_LEARNING_RATE=1e-5
+
+# 或裸键名(当能识别时)
+export MODEL_ID=ms://Qwen/Qwen3.5-4B
+```
+
+## 字段别名
+
+部分字段支持别名:
+
+- `learning_rate` ↔ `lr`
+- `nproc_per_node` ↔ `num_gpus`
+- `max_tokens` ↔ `max_new_tokens`
+- `use_megatron=true` → `strategy=native_fsdp`
+
+## 自定义配置源
+
+你可以通过自定义配置源扩展 CLI:
+
+```python
+from twinkle.cli.cli import ConfigSource, Args, ConfigResolver
+
+class RemoteConfigSource(ConfigSource):
+ def __init__(self, url: str):
+ self.url = url
+
+ def load(self) -> dict:
+ import requests
+ return requests.get(self.url).json()
+
+# 应用自定义配置源
+args = Args()
+resolver = ConfigResolver(args)
+resolver.apply(RemoteConfigSource('http://config-server/my-config').load())
+```
diff --git a/docs/source_en/Components/Gym/index.rst "b/docs/source_zh/\347\273\204\344\273\266/CLI/index.rst"
similarity index 76%
rename from docs/source_en/Components/Gym/index.rst
rename to "docs/source_zh/\347\273\204\344\273\266/CLI/index.rst"
index 85d941b97..cf59fa766 100644
--- a/docs/source_en/Components/Gym/index.rst
+++ "b/docs/source_zh/\347\273\204\344\273\266/CLI/index.rst"
@@ -1,6 +1,6 @@
-Gym
+CLI
===============
.. toctree::
:maxdepth: 1
- Gym.md
+ CLI.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" "b/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md"
deleted file mode 100644
index 63dc87aa7..000000000
--- "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md"
+++ /dev/null
@@ -1,26 +0,0 @@
-# Gym
-
-Gym 组件为 Twinkle 中的强化学习环境提供接口。
-
-```python
-from twinkle.gym import Gym
-
-class CustomGym(Gym):
-
- def step(self, trajectories, **kwargs):
- """
- 执行一个 RL 步骤:评估轨迹并返回奖励。
-
- Args:
- trajectories: 模型生成的待评估轨迹
- **kwargs: 额外参数
-
- Returns:
- 每个轨迹的奖励值
- """
- ...
-```
-
-Gym 抽象允许你插入自定义 RL 环境与训练循环交互。它将奖励计算和环境交互与核心训练逻辑解耦。
-
-> Gym 通常用于在线策略 RL 训练中,环境需要对模型生成的输出提供反馈。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/TUI/Auto-Research.md" "b/docs/source_zh/\347\273\204\344\273\266/TUI/Auto-Research.md"
new file mode 100644
index 000000000..9624f5ae7
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/TUI/Auto-Research.md"
@@ -0,0 +1,313 @@
+# Auto-Research (TUI)
+
+Twinkle TUI 是一个基于终端的智能训练助手,支持通过**自然语言控制、监控和调试 ML 训练**。它将聊天驱动的 AI 代理与实时指标可视化、日志流、以及自动化健康监控器相结合,能够自主检测并修复训练故障。
+
+## 架构概览
+
+```
+┌──────────────────────────────────────────────────────────┐
+│ TwinkleTUI (Textual 应用) │
+│ ┌──────────────────────────────────────────────────────┐ │
+│ │ StatusBar: 状态 / run_id / 模型 / step / 进度条 │ │
+│ ├──────────────────────┬───────────────────────────────┤ │
+│ │ MetricsPanel │ LogPanel │ │
+│ │ (ASCII 图表) │ (滚动日志) │ │
+│ ├──────────────────────┤ │ │
+│ │ ChatPanel │ │ │
+│ │ (用户 <-> 代理) │ │ │
+│ └──────────────────────┴───────────────────────────────┘ │
+│ │
+│ 后台服务: │
+│ AgentLoop ─── LLM 工具调用循环 │
+│ TrainingMonitor ─── 定期健康检查与自动修复 │
+│ MetricsPoller ─── 增量指标读取 │
+│ LogsPoller ─── 增量日志尾读 │
+│ SkillsLoader ─── 异步插件加载 │
+└──────────────────────────────────────────────────────────┘
+```
+
+## 安装与启动
+
+TUI 是 `twinkle-client` 包的一部分:
+
+```bash
+pip install twinkle-client
+```
+
+### 命令行用法
+
+```bash
+# 基本启动(使用默认本地 Ollama 端点)
+twinkle-tui
+
+# 指定 LLM 后端
+twinkle-tui --llm-base-url http://localhost:11434/v1 --llm-model qwen3.5
+
+# 连接到已有训练运行
+twinkle-tui --run-id my-grpo-run
+
+# 使用远程 API(如 OpenAI 兼容接口)
+twinkle-tui --llm-base-url https://api.example.com/v1 --llm-api-key sk-xxx --llm-model gpt-4o
+
+# 启用调试日志
+twinkle-tui --verbose
+```
+
+也可作为 Python 模块运行:
+
+```bash
+python -m twinkle_client.tui
+```
+
+### CLI 参数
+
+| 参数 | 环境变量 | 默认值 | 说明 |
+|------|---------|--------|------|
+| `--run-id`, `-r` | `TWINKLE_TUI_RUN_ID` | None | 连接到已有训练运行 |
+| `--llm-base-url` | `TWINKLE_LLM_BASE_URL` | `http://localhost:11434/v1` | LLM API 基础 URL |
+| `--llm-model` | `TWINKLE_LLM_MODEL` | `qwen3.5` | LLM 模型名称 |
+| `--llm-api-key` | `TWINKLE_LLM_API_KEY` | `not-needed` | LLM API 密钥 |
+| `--verbose`, `-v` | `TWINKLE_TUI_VERBOSE` | `False` | 启用 DEBUG 日志 |
+| `--version`, `-V` | — | — | 显示版本并退出 |
+
+### 快捷键
+
+| 按键 | 操作 |
+|------|------|
+| `q` | 退出 |
+| `Ctrl+P` | 切换指标面板 |
+| `Ctrl+L` | 清空日志 |
+
+## 聊天代理
+
+TUI 的核心是一个 **LLM 驱动的工具调用代理**(`AgentLoop`),通过 OpenAI 兼容 API 处理自然语言命令。代理维护对话历史并自动修剪(保留最近 50 条消息),每次交互最多支持 10 轮工具调用。
+
+### 你可以这样说
+
+**训练生命周期:**
+- *"列出我的训练运行"*
+- *"用 Qwen3.5-4B 在 gsm8k 上启动一个新的 GRPO 训练"*
+- *"暂停当前运行"*
+- *"恢复训练"*
+- *"停止训练"*
+
+**服务器管理:**
+- *"启动服务器,使用 Qwen3.5-4B 和一个 2 卡的 Qwen3.5-72B 采样器"*
+- *"关闭服务器"*
+- *"有多少 GPU 可用?"*
+
+**监控与分析:**
+- *"训练进展如何?"*
+- *"显示 reward 相关的指标"*
+- *"放大到 step 100-200"*
+- *"重置图表视图"*
+
+**搜索:**
+- *"搜索数学数据集"*
+- *"在 ModelScope 上查找 Qwen 模型"*
+
+### 可用工具
+
+代理内置 13 个工具:
+
+| 工具 | 说明 |
+|------|------|
+| `list_training_runs` | 列出所有训练运行 |
+| `get_training_status` | 获取详细状态和最近指标 |
+| `start_server` | 启动 Ray 集群 + Twinkle Server(幂等) |
+| `shutdown_server` | 关闭服务器并释放 GPU 资源 |
+| `start_training` | 创建并启动新的训练运行 |
+| `select_run` | 切换监控到另一个运行 |
+| `pause_training` | 暂停训练(SIGKILL,服务器保留状态) |
+| `resume_training` | 通过重新启动客户端脚本恢复训练 |
+| `stop_training` | 停止训练(SIGTERM,保存检查点) |
+| `update_script` | 更新训练脚本(带版本归档) |
+| `list_supported_models` | 查询服务器支持的模型 |
+| `search_datasets` | 在 ModelScope 搜索数据集 |
+| `search_models` | 在 ModelScope 搜索模型 |
+| `zoom_metrics` | 调整指标图表视图范围 |
+| `select_metrics` | 选择显示哪些指标(最多 4 个) |
+| `get_cluster_info` | 获取 GPU/集群资源信息 |
+
+### 服务器启动
+
+`start_server` 工具自动化一个多步骤流程:
+
+1. **GPU 检测** — `nvidia-smi` 硬件扫描
+2. **GPU 分配** — 在训练模型和采样器之间分配 GPU
+3. **配置生成** — 自动创建 `server_config.yaml`
+4. **Ray 集群启动** — 多节点 GPU 分区,隔离 `CUDA_VISIBLE_DEVICES`
+5. **服务器启动** — 作为后台进程启动 Twinkle Server
+6. **健康检查** — 轮询 `/api/v1/healthz` 直到就绪
+
+支持多模型拓扑:1 个训练模型 + N 个采样器/教师模型。
+
+### Skills 系统
+
+TUI 支持从三个来源加载可扩展的技能插件:
+
+1. **内置技能** — 包含在 `twinkle_client/skills/bundled/` 中
+2. **用户本地技能** — `~/.cache/twinkle/tui/skills/local/`
+3. **社区技能** — 从 ModelScope 获取(尽力而为,10 秒超时)
+
+技能在启动后异步加载并注入代理的系统提示词中。代理在技能加载完成前即可使用。
+
+## 训练监控器(自动修复)
+
+`TrainingMonitor` 是一个后台服务,每 **30 秒**运行一次,收集当前训练运行的所有可用信号,并提交给 LLM 进行分析。
+
+### 收集的信号
+
+- **进程状态**:alive / dead / unknown
+- **output.log 尾部**:最后 1500 个字符(优先提取 traceback)
+- **指标**:最近条目 + 前半段 vs 后半段趋势分析
+- **停滞时长**:自最后一次产生指标以来的秒数
+- **当前 train.py**:完整脚本源码(用于精确修复)
+
+### 决策框架
+
+LLM 将每次检查分类为三种操作之一:
+
+| 决策 | 触发条件 | 执行动作 |
+|------|---------|---------|
+| **LGTM** | 训练正常推进 | 无操作 |
+| **WARNING** | Loss 平台期、reward hacking、KL 爆炸等 | 向用户报告观察结果 |
+| **FIX** | 脚本崩溃、进程死亡并有 traceback | 自动修复并重启 |
+
+### 自动修复流程
+
+当需要 FIX 时:
+
+1. LLM 输出诊断 + 完整修复脚本
+2. 监控器将旧 `train.py` 归档为 `train_v{N}.py`
+3. 将修复脚本写为新的 `train.py`
+4. 通过 `resume_training` 重新启动训练
+5. 重置停滞追踪
+
+安全保障:
+- 每个运行最多 **3 次自动修复尝试**(防止无限重试循环)
+- 修复尝试按 `run_id` 追踪
+- 快照去重避免对未变化状态的重复分析
+
+## 基于文件的连接层
+
+TUI 通过本地文件系统与训练进程通信:
+
+```
+~/.cache/twinkle/{run_id}/
+├── meta.json — 运行元数据(model_id、config、status、pid)
+├── metrics.jsonl — 每步一个 JSON 对象(增量)
+├── output.log — 训练的 stdout+stderr 合并输出
+├── train.py — 当前活动训练脚本
+└── train_v{N}.py — 归档的历史脚本版本
+```
+
+### 训练控制模型
+
+在 Server 模式下,Twinkle Server 将所有模型/优化器状态保留在 GPU 内存中:
+
+- **暂停** = 杀死客户端进程 (SIGKILL) — 服务器状态保留
+- **恢复** = 重新启动客户端脚本 — 无缝继续训练
+- **停止** = SIGTERM — 触发检查点保存后退出
+- **关闭服务器** = 释放 GPU 资源,**销毁**模型状态
+
+## TrainingRuntime(脚本集成)
+
+训练脚本使用 `TrainingRuntime` 与 TUI 集成:
+
+```python
+from twinkle_client.tui.runtime import TrainingRuntime
+
+rt = TrainingRuntime(run_id='my-grpo-run')
+rt.start(model_id='Qwen/Qwen3.5-4B', config={'lr': 1e-5})
+rt.register_graceful_shutdown(model, dataloader)
+
+for step, batch in enumerate(dataloader):
+ # ... 训练逻辑 ...
+ rt.log_metrics(step=step, loss=loss, reward=reward, grad_norm=gn, lr=lr)
+ rt.log(f'Completed step {step}, loss={loss:.4f}')
+
+rt.finish()
+```
+
+### 核心方法
+
+| 方法 | 说明 |
+|------|------|
+| `start(model_id, config, script_path)` | 初始化运行目录和元数据 |
+| `log_metrics(**kwargs)` | 向 `metrics.jsonl` 写入指标条目 |
+| `log(message)` | 打印日志消息(被捕获为 `output.log`) |
+| `get_resume_info()` | 获取 `last_step` 用于从检查点恢复 |
+| `finish(status)` | 标记训练完成,关闭文件 |
+| `register_graceful_shutdown(model, dataloader)` | 注册 SIGTERM 处理器以保存检查点 |
+
+### 断点续训支持
+
+`TrainingRuntime` 自动将训练进度保存到 `meta.json`(每 5 秒节流写入一次)。脚本可以使用 `get_resume_info()` 从上次保存的步数恢复:
+
+```python
+rt = TrainingRuntime(run_id='my-run')
+resume = rt.get_resume_info()
+global_step = resume['last_step']
+
+if global_step > 0:
+ dataloader.skip_consumed_samples(global_step * BATCH_SIZE)
+ print(f'从 step {global_step} 恢复训练')
+```
+
+### 优雅关停
+
+调用 `register_graceful_shutdown()` 后,会安装一个 SIGTERM 处理器:
+
+1. 保存模型检查点(LoRA 权重 + 优化器状态)
+2. 保存数据加载器位置(`consumed_train_samples`)
+3. 记录检查点路径
+4. 标记训练为 `stopped` 并退出
+
+## UI 面板
+
+### StatusBar(状态栏)
+
+显示在屏幕顶部的当前训练状态:
+
+- 训练状态图标(🚀 训练中 / ⏸ 已暂停 / ✅ 已完成 / ❌ 错误)
+- Run ID
+- 模型名称
+- 当前步数
+- 百分比进度条
+
+### MetricsPanel(指标面板)
+
+使用 `plotext` 渲染的实时 ASCII 图表:
+
+- 同时绘制最多 4 个指标
+- 支持缩放(按步数范围和 y 轴范围)
+- 未选择时自动显示前 3 个可用指标
+- 提示栏显示可通过代理切换的隐藏指标
+- 保留最多 2000 个数据点
+
+### LogPanel(日志面板)
+
+滚动日志查看器:
+
+- 自动剥离 ANSI 转义序列
+- 硬换行长行以防止溢出
+- 处理进度条的 `\r` 回车符
+- 保留最后 500 行
+
+### ChatPanel(聊天面板)
+
+交互式聊天界面:
+
+- 用户输入,流式代理响应
+- 节流令牌刷新(80ms)确保平滑显示
+- 工具调用检测时流重置
+- 支持 Rich 标记格式
+
+## 日志记录
+
+所有 TUI 日志写入 `./tui.log`(当前工作目录):
+
+- 5MB 时轮转,保留 3 个备份
+- **无控制台输出** — 避免破坏 Textual 的 alt-screen 缓冲区
+- 使用 `--verbose` 启用 DEBUG 级别日志
diff --git "a/docs/source_zh/\347\273\204\344\273\266/TUI/SkillProvider\346\212\200\350\203\275\347\263\273\347\273\237.md" "b/docs/source_zh/\347\273\204\344\273\266/TUI/SkillProvider\346\212\200\350\203\275\347\263\273\347\273\237.md"
new file mode 100644
index 000000000..11637331e
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/TUI/SkillProvider\346\212\200\350\203\275\347\263\273\347\273\237.md"
@@ -0,0 +1,71 @@
+# SkillProvider 技能系统
+
+技能系统允许 Twinkle 的 TUI 智能体从外部来源(Git 仓库、API、本地文件)动态加载专业知识,并注入到 LLM 的系统提示词中。
+
+## 架构
+
+| 类 | 角色 |
+|----|------|
+| **Skill** | 持有单个技能名称、内容和来源的数据类 |
+| **SkillProvider** | 从数据源获取技能的抽象基类 |
+| **SkillManager** | 编排多个 Provider,聚合技能用于提示词注入 |
+
+## Skill 数据类
+
+```python
+@dataclasses.dataclass
+class Skill:
+ name: str # 简短标识符(通常为文件名去除扩展名)
+ content: str # 完整的 Markdown 内容
+ source: str # Provider 名称 + 相对路径,用于可追溯性
+```
+
+## 创建自定义 Provider
+
+继承 `SkillProvider` 并实现 `name` 和 `fetch()`:
+
+```python
+from twinkle_client.skills.base import SkillProvider
+
+class MySkillProvider(SkillProvider):
+
+ @property
+ def name(self) -> str:
+ return 'my-skills'
+
+ async def fetch(self) -> None:
+ # 将技能文件下载/克隆到 self.cache_dir
+ # 例如:git clone、API 下载、文件拷贝
+ ...
+```
+
+默认的 `load_skills()` 会扫描 `self.cache_dir` 中的 `.md` 文件(跳过 README、LICENSE 等),返回 `Skill` 对象。
+
+## SkillManager
+
+```python
+from twinkle_client.skills.manager import SkillManager
+
+manager = SkillManager()
+manager.register(my_provider)
+manager.register(another_provider)
+
+# 拉取并加载所有技能
+skills = await manager.load_all()
+
+# 格式化为 LLM 系统提示词注入内容
+prompt_section = manager.format_for_prompt()
+```
+
+### 关键方法
+
+| 方法 | 说明 |
+|------|------|
+| `register(provider)` | 添加技能 Provider |
+| `load_all()` | 从所有 Provider 拉取并加载 |
+| `format_for_prompt()` | 将技能渲染为系统提示词格式 |
+| `get_skill_names()` | 列出已加载技能名称 |
+
+## 缓存目录
+
+默认缓存在 `~/.cache/twinkle/tui/skills//`。可通过向 Provider 构造函数传入 `cache_dir` 参数覆盖。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/TUI/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/TUI/index.rst"
new file mode 100644
index 000000000..32ec8dc40
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/TUI/index.rst"
@@ -0,0 +1,7 @@
+TUI
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Auto-Research.md
+ SkillProvider技能系统.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/GRPOProcessor.md" "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/GRPOProcessor.md"
deleted file mode 100644
index afb8f0948..000000000
--- "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/GRPOProcessor.md"
+++ /dev/null
@@ -1,19 +0,0 @@
-# GRPOLossProcessor
-
-GRPOLossProcessor 是专为 GRPO 强化学习训练设计的任务处理器包装器。它在 InputProcessor 基础上扩展了 GRPO 特有的数据准备功能。
-
-```python
-from twinkle.processor import GRPOLossProcessor
-
-processor = GRPOLossProcessor(
- device_mesh=...,
- padding_free=False,
- framework='transformers',
-)
-
-model.set_processor(processor)
-```
-
-GRPOLossProcessor 包装了基础 `InputProcessor`,并添加了 GRPO 特有字段的处理,如优势值、旧对数概率和参考对数概率,这些是 GRPO 损失函数所需要的。
-
-> 对于标准 SFT 任务,直接使用 `InputProcessor`。当训练循环涉及 GRPO 或其变体时,使用 `GRPOLossProcessor`。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst"
index 1eb839f0e..a2c88eaf4 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst"
@@ -4,4 +4,3 @@
:maxdepth: 1
InputProcessor.md
- GRPOProcessor.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/EmbeddingMetric.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/EmbeddingMetric.md"
new file mode 100644
index 000000000..ab770498c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/EmbeddingMetric.md"
@@ -0,0 +1,31 @@
+# EmbeddingMetric
+
+`EmbeddingMetric` 跟踪对比学习(InfoNCE)训练中的嵌入质量,报告锚点-正样本余弦相似度和批内负样本相似度。
+
+## 使用方法
+
+```python
+from twinkle.metric import EmbeddingMetric
+
+metric = EmbeddingMetric(device_mesh=device_mesh, process_group=process_group)
+
+# 训练中
+metric.accumulate(inputs, outputs)
+
+# 日志间隔时
+results = metric.calculate()
+# results: {'pos_sim': '0.8523', 'neg_sim': '0.2134', 'loss': '0.3412', ...}
+```
+
+## 输出指标
+
+| 指标 | 说明 |
+|:-----|:-----|
+| `pos_sim` | 锚点与正样本的平均余弦相似度 |
+| `pos_sim_min` | 批内最小正样本相似度 |
+| `pos_sim_max` | 批内最大正样本相似度 |
+| `neg_sim` | 锚点与其他正样本(批内负样本)的平均相似度 |
+| `loss` | 平均对比损失值 |
+| `grad_norm` | 梯度范数 |
+
+> 此指标与 `InfonceLoss` 配合使用,适用于嵌入/检索模型训练。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/GRPOMetric.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/GRPOMetric.md"
new file mode 100644
index 000000000..434dc17c2
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/GRPOMetric.md"
@@ -0,0 +1,39 @@
+# GRPOMetric
+
+`GRPOMetric` 跟踪 GRPO 训练中的策略优化诊断指标,包括 KL 散度、裁剪率、熵和对数概率统计。
+
+## 使用方法
+
+```python
+from twinkle.metric import GRPOMetric
+
+metric = GRPOMetric(
+ device_mesh=device_mesh,
+ process_group=process_group,
+ epsilon=0.2, # PPO 裁剪范围
+ temperature=1.0, # 用于 logp 重缩放的采样温度
+ top_k_kl=10, # 每步记录 top-K 高 KL token
+)
+
+# 训练循环中
+metric.accumulate(inputs, outputs, old_logps=old_logps, advantages=advantages)
+
+# 日志间隔时
+results = metric.calculate()
+```
+
+## 输出指标
+
+| 指标 | 说明 |
+|:-----|:-----|
+| `train/policy_confidence` | exp(mean_new_logp) — 越高表示模型越自信 |
+| `train/mean_new_logp` | 当前策略下生成 token 的平均对数概率 |
+| `train/mean_old_logp` | 参考策略下的平均对数概率 |
+| `train/approx_kl` | Schulman K3 KL 估计器 |
+| `train/entropy` | 平均 token 级熵 |
+| `train/clip_ratio` | 被裁剪的 token 比例 |
+
+## 变体
+
+- **`GSPOMetric`** — 序列级裁剪率(几何平均比率)
+- **`CISPOMetric`** — 无条件裁剪率(不按优势符号门控)
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst"
index 6e03f97cf..d5ba804be 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst"
@@ -3,6 +3,19 @@
.. toctree::
:maxdepth: 1
+ TrainMetric.md
+ LossMetric.md
+ Accuracy.md
+ CompletionRewardMetric.md
+ DPOMetric.md
+ GRPOMetric.md
+ EmbeddingMetric.md
+ 构建指标.md
+指标
+===============
+.. toctree::
+ :maxdepth: 1
+
TrainMetric.md
LossMetric.md
Accuracy.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/InfoNCELoss.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/InfoNCELoss.md"
new file mode 100644
index 000000000..f8fbaa1be
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/InfoNCELoss.md"
@@ -0,0 +1,68 @@
+# InfoNCE 损失
+
+`InfonceLoss` 实现带批内负样本和可选跨 rank 聚合的对比学习损失,用于嵌入/检索模型训练。
+
+## 使用方法
+
+```python
+from twinkle.loss import InfonceLoss
+
+loss_fn = InfonceLoss(
+ temperature=0.1,
+ use_batch=True, # 启用批内负样本
+ hard_negatives=7, # 固定每样本负样本数
+ mask_fake_negative=True, # 遮蔽假负样本
+ fake_neg_margin=0.1, # 假负样本检测阈值
+)
+
+model.set_loss(loss_fn)
+```
+
+## 输入格式
+
+每个样本按 `锚点(1) + 正样本(1) + 负样本(n)` 排列。`inputs['labels']` 是一维掩码,`1` 标记每组的起始位置。
+
+```
+embeddings: [a0, p0, n0_1, n0_2, a1, p1, n1_1, n1_2, ...]
+labels: [ 1, 0, 0, 0, 1, 0, 0, 0, ...]
+```
+
+## 参数
+
+| 参数 | 类型 | 默认值 | 说明 |
+|:-----|:-----|:-------|:-----|
+| `temperature` | float | 0.1 | 相似度缩放因子 |
+| `use_batch` | bool | True | 使用跨样本批内负样本 |
+| `hard_negatives` | int | None | 固定每样本负样本数(截断/上采样)|
+| `mask_fake_negative` | bool | False | 遮蔽高于 positive + margin 的 logit |
+| `fake_neg_margin` | float | 0.1 | 假负样本遮蔽阈值 |
+| `include_qq` | bool | False | 添加 query-query 相似度块 |
+| `include_dd` | bool | False | 添加 doc-doc 相似度块 |
+
+## 跨 Rank 聚合
+
+当 `use_batch=True` 且分布式训练激活时,嵌入会从所有 DP rank 聚合以最大化批内负样本多样性。仅本地分片保留梯度。
+
+## 相似度块
+
+该损失支持三种相似度块,提供全面的对比学习信号:
+
+- **Q→D(默认)**:Query 到所有 Document — 主要对比信号
+- **Q→Q**(`include_qq=True`):Query 到其他所有 Query — 防止 query 坍缩
+- **D→D**(`include_dd=True`):Document 到其他所有 Document — Qwen3-Embedding 风格
+
+## 示例:Embedding 训练
+
+```python
+from twinkle.loss import InfonceLoss
+from twinkle.metric import EmbeddingMetric
+
+# 配置 Embedding 模型
+model.set_loss(InfonceLoss(temperature=0.05, use_batch=True, include_qq=True))
+model.set_metric(EmbeddingMetric(device_mesh=mesh, process_group=pg))
+
+# 训练循环
+for batch in dataloader:
+ model.forward_backward(batch)
+ model.clip_grad_and_step()
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst"
index ea813f56f..0a2a890cf 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst"
@@ -3,6 +3,19 @@
.. toctree::
:maxdepth: 1
+ CrossEntropy.md
+ ChunkedCrossEntropy.md
+ DPOLoss.md
+ GKDLoss.md
+ GRPOLoss.md
+ InfoNCELoss.md
+ MSELoss.md
+ 构建损失.md
+损失
+===============
+.. toctree::
+ :maxdepth: 1
+
CrossEntropy.md
ChunkedCrossEntropy.md
DPOLoss.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md"
index 4017aea7e..5ae13f739 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md"
@@ -30,3 +30,48 @@ class MultiLoraTransformersModel:
正因如此,用户的r必须要小于等于max_r的配置,在实际训练时仅会使用lora的部分rank参与计算。
MultiLoraTransformersModel支持`@remote_class`注解,并且支持device_mesh,这意味着它可以运行在ray的worker中。
+
+## 租户生命周期
+
+底层使用 `MultiLora` 管理器来处理租户 LoRA 槽位。关键 API:
+
+### acquire_lora
+
+为租户获取一个可用的 LoRA 槽位:
+
+```python
+adapter_name = model.multi_lora.acquire_lora('tenant_a', LoraConfig(r=16, lora_alpha=32))
+```
+
+- 如果所有槽位已被占用或 `config.r > max_r`,则抛出 `RuntimeError`
+
+### release_lora
+
+释放租户的 LoRA 槽位,权重重置为初始状态:
+
+```python
+model.multi_lora.release_lora('tenant_a')
+```
+
+### 上下文管理器
+
+使用 `adapter()` 进行作用域激活:
+
+```python
+with model.multi_lora.adapter('tenant_a') as name:
+ output = model.forward(inputs)
+```
+
+### LoraTenant
+
+每个槽位以 `LoraTenant` 数据类追踪:
+
+```python
+@dataclass
+class LoraTenant:
+ index: int # 槽位索引 (0..max_loras-1)
+ adapter_name: str # 内部名称(如 "lora_0")
+ config: LoraConfig # 预分配配置(max_r)
+ tenant_adapter_name: str # 面向用户的租户名(空闲时为 None)
+ tenant_config: LoraConfig # 租户实际配置(空闲时为 None)
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/SupportedModels.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/SupportedModels.md"
new file mode 100644
index 000000000..bfbb03ea0
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/SupportedModels.md"
@@ -0,0 +1,77 @@
+# 支持的模型
+
+Twinkle 支持任何兼容 HuggingFace Transformers 或 Megatron-LM 的模型。以下是经过测试的模型列表。
+
+## 语言模型
+
+| 模型系列 | 模型 ID | 参数量 | 特性 |
+|:---------|:--------|:-------|:-----|
+| Qwen 3.5 | `Qwen/Qwen3.5-0.6B` ~ `Qwen/Qwen3.5-235B-A22B` | 0.6B–235B | MoE、思考模式 |
+| Qwen 2.5 | `Qwen/Qwen2.5-0.5B` ~ `Qwen/Qwen2.5-72B` | 0.5B–72B | Dense |
+| DeepSeek V4 | `deepseek-ai/DeepSeek-V4` | 685B MoE | 自定义 DSML 编码 |
+| DeepSeek R1 | `deepseek-ai/DeepSeek-R1` | 685B MoE | 推理 |
+| LLaMA 3 | `meta-llama/Llama-3.3-70B-Instruct` | 8B–70B | Dense |
+| Mistral | `mistralai/Mistral-7B-v0.3` | 7B | Dense |
+| Yi | `01-ai/Yi-1.5-34B` | 6B–34B | Dense |
+| GLM-4 | `THUDM/glm-4-9b-chat` | 9B | Dense |
+| InternLM 2.5 | `internlm/internlm2_5-7b-chat` | 7B–20B | Dense |
+
+## 视觉语言模型
+
+| 模型系列 | 模型 ID | 特性 |
+|:---------|:--------|:-----|
+| Qwen 3.5 VL | `Qwen/Qwen3.5-VL-3B` ~ `Qwen/Qwen3.5-VL-72B` | 图片、视频 |
+| Qwen 2.5 VL | `Qwen/Qwen2.5-VL-7B-Instruct` | 图片、视频 |
+| InternVL 2.5 | `OpenGVLab/InternVL2_5-8B` | 图片 |
+
+## 嵌入模型
+
+| 模型系列 | 模型 ID | 训练方法 |
+|:---------|:--------|:---------|
+| Qwen3 Embedding | `Qwen/Qwen3-Embedding-0.6B` | InfoNCE 对比学习 |
+| GTE | `thenlper/gte-large-zh` | InfoNCE 对比学习 |
+
+## 模型加载
+
+```python
+from twinkle.model import TransformersModel
+
+# 从 ModelScope 加载(ms:// 前缀)
+model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B')
+
+# 从 HuggingFace 加载(hf:// 前缀)
+model = TransformersModel(model_id='hf://meta-llama/Llama-3.3-70B-Instruct')
+
+# 本地路径
+model = TransformersModel(model_id='/path/to/model')
+```
+
+## 框架支持
+
+| 框架 | 类名 | 适用场景 |
+|:-----|:-----|:---------|
+| Transformers | `TransformersModel` | 通用训练(SFT、RLHF、DPO)|
+| Transformers + Multi-LoRA | `MultiLoraTransformersModel` | 多租户训练 |
+| Megatron-LM | `MegatronModel` | 大规模分布式预训练 |
+| Megatron + Multi-LoRA | `MultiLoraMegatronModel` | 大规模多租户 |
+
+## 精度支持
+
+| 模式 | 说明 |
+|:-----|:-----|
+| `bf16` | BFloat16 混合精度(推荐 A100/H100)|
+| `fp16` | Float16 混合精度(适用于旧 GPU)|
+| `fp8` | FP8 精度(H100 + Transformer Engine)|
+| `no` | 全精度(仅用于调试)|
+
+## 并行策略
+
+| 策略 | 配置键 | 说明 |
+|:-----|:-------|:-----|
+| FSDP | `strategy=accelerate` | Accelerate 管理的 FSDP(默认)|
+| 原生 FSDP | `strategy=native_fsdp` | PyTorch 原生 FSDP |
+| 张量并行 | `tp_size` | 跨 GPU 切分层 |
+| 流水线并行 | `pp_size` | 切分模型阶段 |
+| 数据并行 | `dp_size` | 复制模型,切分数据 |
+| 序列并行 | `sequence_parallel` | 切分长序列 |
+| 专家并行 | `ep_size` | MoE 专家分布 |
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst"
index 713ea35c6..d20155bd7 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst"
@@ -8,3 +8,14 @@
MultiLoraTransformersModel.md
MegatronModel.md
MultiLoraMegatronModel.md
+ SupportedModels.md
+模型
+===============
+.. toctree::
+ :maxdepth: 1
+
+ TwinkleModel.md
+ TransformersModel.md
+ MultiLoraTransformersModel.md
+ MegatronModel.md
+ MultiLoraMegatronModel.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/DeepSeekV4Template.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/DeepSeekV4Template.md"
new file mode 100644
index 000000000..053b51051
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/DeepSeekV4Template.md"
@@ -0,0 +1,30 @@
+# DeepSeek-V4 模板
+
+`DeepseekV4Template` 为 DeepSeek V4 提供原生支持,包括其独特的思考模式、工具调用协议和多 token 特殊标记。
+
+## 使用方法
+
+```python
+from twinkle.template import DeepseekV4Template
+
+template = DeepseekV4Template(
+ model_id='deepseek-ai/DeepSeek-V4',
+ enable_thinking=True,
+)
+```
+
+## 特性
+
+- **自定义 tokenizer 包装**:用 DeepSeek V4 的编码协议覆盖 `apply_chat_template`
+- **思考模式**:支持 `thinking` / `chat` 模式切换
+- **工具调用**:原生 DSML 工具调用编码
+- **多 token EOS**:处理 DeepSeek V4 的多字符特殊标记
+
+## 与基础模板的区别
+
+| 特性 | 基础模板 | DeepseekV4Template |
+|:-----|:---------|:-------------------|
+| Chat 模板 | HuggingFace 原生 | 自定义 DSML 编码 |
+| 思考模式 | `` 标签 | 原生思考模式开关 |
+| 工具调用 | Hermes/Qwen 格式 | DSML 工具块 |
+| EOS 处理 | 单 token | 多 token 特殊标记 |
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md"
index 364275b64..c3f5918e1 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md"
@@ -9,7 +9,7 @@ class Template:
model_id: str,
use_chat_template: bool = True,
max_length: Optional[int] = 8192,
- truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise',
+ truncation_strategy: Literal['raise', 'left', 'right', 'split', 'delete'] = 'raise',
default_system: Optional[str] = None):
...
@@ -42,7 +42,9 @@ class Template:
- raise: 抛出异常。一般用于非常精确的数据集场景
- left: 移除左边的 token,使其符合 max_length
- right: 移除右边的 token,使其符合 max_length
- - default_system: 如果数据集没有 system,则使用默认 system
+ - split: 将超长样本切分为多个 max_length 的片段(不支持多模态、LazyDataset、IterablePackingDataset)
+ - delete: 直接丢弃超长样本
+- default_system: 如果数据集没有 system,则使用默认 system
> Template 不支持使用函数来代替,因为其内部要支持的功能较多。如果需要编写新的 Template,请继承 `Template` 类。
> 一般来说,纯文本模型使用 Template 基类就足够了,在基类中我们使用了 tokenizer.apply_chat_template 来编码模型,对一般的纯文本模型是通用的。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/ToolCallParsers.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/ToolCallParsers.md"
new file mode 100644
index 000000000..9d52be1c7
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/ToolCallParsers.md"
@@ -0,0 +1,53 @@
+# 工具调用解析器
+
+Twinkle 的模板系统包含模块化的工具调用解析框架,用于训练具有函数调用能力的模型。
+
+## 架构
+
+```
+ToolCallRegistry
+├── HermesQwenParser — Hermes/Qwen 风格 ...
+├── ReActParser — ReAct Thought/Action/Observation
+├── ClineParser — Cline XML 工具调用
+└── VCPParser — VCP 协议
+```
+
+## ToolCallParser 接口
+
+```python
+from twinkle.template.tools import ToolCallParser
+
+class ToolCallParser(ABC):
+ name: str = ''
+
+ def detect(self, text: str) -> bool:
+ """检查文本是否包含此格式的标记"""
+
+ def parse(self, text: str) -> List[Dict[str, Any]]:
+ """提取 OpenAI 格式的工具调用"""
+
+ def clean(self, text: str) -> str:
+ """去除标记,返回纯内容"""
+```
+
+## ToolCallRegistry
+
+注册表自动发现解析器并路由检测:
+
+```python
+from twinkle.template.tools import ToolCallRegistry
+
+# 检测补全使用了哪种格式
+parser = ToolCallRegistry.detect_first(completion_text)
+if parser:
+ tool_calls = parser.parse(completion_text)
+```
+
+## 内置解析器
+
+| 解析器 | 格式说明 |
+|:-------|:---------|
+| HermesQwenParser | `{"name": "...", "arguments": {...}}` |
+| ReActParser | Thought/Action/Action Input/Observation |
+| ClineParser | Cline XML 结构化参数 |
+| VCPParser | Visual Code Protocol |
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst"
index 9ab4c887b..840adf497 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst"
@@ -4,3 +4,11 @@
:maxdepth: 1
Template.md
+ DeepSeekV4Template.md
+ ToolCallParsers.md
+模板
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Template.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md"
index 00ec1f308..5842d51dd 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md"
@@ -40,6 +40,96 @@ class DeviceMesh:
推荐使用 `from_sizes` 来构造它。
+### 参数参考
+
+| 参数 | 说明 | 默认值 |
+|------|------|--------|
+| `world_size` | 总进程数 | 1 |
+| `dp_size` | 数据并行度 | 1 |
+| `fsdp_size` | 全分片数据并行度 | None |
+| `tp_size` | 张量并行度 | None |
+| `pp_size` | 流水线并行度 | None |
+| `ulysses_size` | Ulysses 序列并行度 | None |
+| `cp_size` | 上下文并行度 | None |
+| `ep_size` | 专家并行度(MoE 模型)| None |
+| `etp_size` | 专家张量并行度 | None |
+| `ep_fsdp_size` | 每个 EP 组内的 FSDP 度 | None |
+| `vpp_size` | 虚拟流水线并行度 | None |
+| `device_type` | 设备类型(`cuda`、`npu` 等)| `cuda` |
+| `sequence_parallel` | 启用 Megatron 风格序列并行 | False |
+
+我们举一个例子:
+
+```python
+sampler_device_mesh = DeviceMesh.from_sizes(dp_size=4)
+actor_device_mesh = DeviceMesh.from_sizes(dp_size=2, pp_size=2, tp_size=2)
+
+dataloader = DataLoader(...)
+sampler = vLLMSampler(..., device_mesh=sampler_device_mesh, remote_group=...)
+actor = MegatronModel(..., device_mesh=actor_device_mesh, remote_group=...)
+
+for data in dataloader:
+ sampler_output = sampler.sample(data)
+ input_data = [seq.new_input_feature for response in sampler_output for seq in response.sequences]
+ ...
+ model_output = actor.forward(input_data)
+```
+
+我们以上面的伪代码来分析数据传递情况。
+
+dataloader 取出数据 -> 按照 dp_size=4 分发给 sampler -> 按照 dp_size=4 收集数据 -> 按照 dp_size=2 分发给模型 -> 按照 dp_size=2 收集输出
+
+通过 DeviceMesh,可以将数据流平顺地在各个 group 和组件之间流转起来。
+
+数据的分发判断由 DeviceMesh 的 `get_slice` 方法执行:
+
+```python
+batch[device_mesh.get_slice(len(batch))]
+```
+
+get_slice 会根据当前 rank,计算出当前 worker 属于哪个 dp 组,并获取对应的数据。该过程发生在 DataLoader 的 DeviceMeshSampler 中,同样发生在 remote_class 的 dispatch 和 collect 中。
+# DeviceMesh/DeviceGroup
+
+这两个类用于表达硬件资源分配和网络拓扑,Twinkle 的数据分发和收集也依赖它们。
+
+## DeviceGroup
+
+```python
+@dataclass
+class DeviceGroup:
+ name: str
+ ranks: Union[List[int], int]
+ device_type: str
+ visible_devices: Optional[str] = None # Optional: explicitly set visible devices (e.g., "8,9")
+ gpus_per_worker: int = 1
+```
+
+- name: 资源组名
+- ranks: 占用硬件列表,如果是CPU资源仅支持int类型
+- device_type: 硬件类型,例如 GPU/CPU/NPU 等
+- visible_devices: 可见资源列表,用于希望仅使用部分 rank 的硬件的情况
+- gpus_per_worker: 每个 worker 占用多少硬件
+
+如果训练 RL,开发者可以构造多个这样的组,并将对应的模型、采样器分配进入其中。
+
+## DeviceMesh
+
+DeviceMesh 承载了组件拓扑、分布式并行信息,这个类会在组件内传递,用于数据分发和数据收集。
+
+```python
+@dataclass
+class DeviceMesh:
+ ...
+
+ @staticmethod
+ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, tp_size: int = None,
+ pp_size: int = None, ulysses_size: int = None, cp_size: int = None, ep_size: int = None,
+ etp_size: int = None,vpp_size: int = None, device_type: str = 'cuda', sequence_parallel: bool = False) -> "DeviceMesh":
+ ...
+```
+
+推荐使用 `from_sizes` 来构造它。
+
我们举一个例子:
```python
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/Padding-Free\350\256\255\347\273\203.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/Padding-Free\350\256\255\347\273\203.md"
new file mode 100644
index 000000000..8bd783cff
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/Padding-Free\350\256\255\347\273\203.md"
@@ -0,0 +1,52 @@
+# Padding-Free 训练
+
+Padding-free(也称为"打包")训练通过将多个序列拼接到一个打包批次中,消除了对 padding token 的无效计算。Twinkle 支持标准注意力和 Qwen3.5 GatedDeltaNet 线性注意力的 padding-free 训练。
+
+## 工作原理
+
+不同于将所有序列填充到 `max_length`,padding-free 将多个序列打包到一行中,并使用 `position_ids` 跟踪序列边界,从而避免在 padding token 上浪费算力。
+
+```
+标准方式: [tok tok tok PAD PAD PAD] [tok tok PAD PAD PAD PAD]
+打包方式: [tok tok tok tok tok ...] ← 无 padding 浪费
+```
+
+## 使用方式
+
+通过 `PackingDataset` 或 `IterablePackingDataset` 启用:
+
+```python
+from twinkle.dataset import PackingDataset
+
+dataset = PackingDataset(
+ dataset=base_dataset,
+ max_length=8192,
+)
+```
+
+数据集会自动打包序列并生成正确的 `position_ids`,在序列边界处重置。
+
+## GatedDeltaNet 补丁(Qwen3.5)
+
+Qwen3.5 使用混合架构,融合了标准注意力和 GatedDeltaNet 线性注意力。原生 GatedDeltaNet 实现不会在打包序列边界处重置线性注意力状态。
+
+`GatedDeltaNetPaddingFreePatch` 通过以下方式修复:
+
+1. Patch `Qwen3_5DecoderLayer.forward`,将 `cu_seq_lens_q`(累积序列长度)传递给线性注意力层
+2. Patch `Qwen3_5GatedDeltaNet.forward`,使用支持 `cu_seqlens` 的 flash-linear-attention 内核(`causal_conv1d`、`chunk_gated_delta_rule`)
+
+在 Qwen3.5 模型上检测到 padding-free 时,补丁会自动应用。
+
+### 要求
+
+- 需安装 `flash-linear-attention` 包
+- 仅适用于含 GatedDeltaNet 层的 Qwen3.5 模型
+- 启用序列并行时,会使用 `Qwen3_5GatedDeltaNetUlyssesPatch` 替代
+
+## 注意力后端要求
+
+| 注意力后端 | Padding-Free 支持 |
+|-----------|-------------------|
+| FlashAttention2 | 完全支持 |
+| SDPA | 支持(不兼容序列并行) |
+| Eager | 不支持 |
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/TwinkleClient\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/TwinkleClient\345\256\242\346\210\267\347\253\257.md"
new file mode 100644
index 000000000..327e2e40c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/TwinkleClient\345\256\242\346\210\267\347\253\257.md"
@@ -0,0 +1,81 @@
+# TwinkleClient 客户端
+
+`TwinkleClient` 是与 Twinkle REST API 交互的 Python 客户端,管理会话、训练任务和检查点。
+
+## 初始化
+
+```python
+from twinkle_client.manager import TwinkleClient
+
+client = TwinkleClient(
+ base_url='http://localhost:8000', # 或 TWINKLE_SERVER_URL 环境变量
+ api_key='your-api-key', # 或 TWINKLE_SERVER_TOKEN 环境变量
+ route_prefix='/twinkle', # API 路由前缀
+ session_heartbeat_interval=10, # 心跳间隔(秒)
+ session_metadata={'user': 'alice'}, # 可选的会话元数据
+)
+```
+
+初始化时客户端会:
+1. 将 `base_url` 和 `api_key` 设置到共享上下文(所有客户端对象自动使用)
+2. 创建服务端会话
+3. 启动后台心跳线程保持会话活跃
+
+## 健康检查
+
+```python
+is_healthy = client.health_check() # 返回 True/False
+capabilities = client.get_server_capabilities() # 支持的模型
+```
+
+## 训练任务
+
+```python
+# 列出训练任务
+runs = client.list_training_runs(limit=20, offset=0)
+
+# 带分页游标列出
+runs, cursor = client.list_training_runs_with_cursor(limit=20)
+
+# 获取特定任务
+run = client.get_training_run(run_id='run_abc123')
+
+# 按基础模型查找
+qwen_runs = client.find_training_run_by_model('Qwen/Qwen3.5-4B')
+```
+
+## 检查点
+
+```python
+# 列出训练任务的检查点
+checkpoints = client.list_checkpoints(run_id='run_abc123')
+
+# 获取检查点路径
+parsed = client.get_checkpoint_path(run_id, checkpoint_id)
+# parsed.path → 文件系统路径
+# parsed.twinkle_path → twinkle:// URI
+
+# 获取最新检查点(用于恢复训练)
+latest_path = client.get_latest_checkpoint_path(run_id)
+
+# 删除检查点
+client.delete_checkpoint(run_id, checkpoint_id)
+```
+
+## 容量与权重信息
+
+```python
+# LoRA 容量
+capacity = client.get_capacity_info()
+# capacity.max_loras, capacity.used_loras, capacity.free_loras
+
+# 权重元数据
+info = client.get_weights_info('twinkle://run_id/weights/checkpoint')
+# info.base_model, info.is_lora, info.lora_rank
+```
+
+## 清理
+
+```python
+client.close() # 停止心跳线程(也通过 atexit 自动注册)
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst"
index 7174ce690..377098988 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst"
@@ -4,4 +4,8 @@
:maxdepth: 1
DeviceMesh和DeviceGroup.md
+ 专家并行.md
+ 序列并行.md
+ Padding-Free训练.md
RemoteClass.md
+ TwinkleClient客户端.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\344\270\223\345\256\266\345\271\266\350\241\214.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\344\270\223\345\256\266\345\271\266\350\241\214.md"
new file mode 100644
index 000000000..a0112249b
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\344\270\223\345\256\266\345\271\266\350\241\214.md"
@@ -0,0 +1,73 @@
+# 专家并行 (EP)
+
+专家并行将混合专家模型(MoE)的专家分布到多个 GPU 上,每个 rank 只持有部分专家。这降低了单卡显存占用,使大规模 MoE 模型的训练成为可能。
+
+## 概览
+
+| 概念 | 说明 |
+|------|------|
+| **ExpertParallelConfig** | 控制 EP 行为的配置数据类 |
+| **apply_expert_parallel()** | 入口函数,负责分片专家并替换前向传播 |
+| **shard_experts()** | 将专家均匀分配到各 EP rank |
+| **patch_forward()** | 将 MoE block 的 forward 替换为带 all-to-all 通信的 EP 版本 |
+
+## 配置
+
+```python
+from twinkle.model.transformers.moe.expert_parallel import ExpertParallelConfig
+
+config = ExpertParallelConfig(
+ enabled=True, # 启用专家并行
+ router_dtype='fp32', # 路由计算精度:'fp32', 'bf16', 'fp16'
+ keep_router_logits=True, # 在输出中保留路由 logits
+ ignore_shared_experts=False,# 跳过共享专家计算(如 DeepSeek)
+ ep_size=None, # EP 并行度(由 TransformersModel 使用)
+)
+```
+
+## 配合 DeviceMesh 使用
+
+在 `DeviceMesh.from_sizes()` 中设置 `ep_size` 即可激活 EP。框架会在模型初始化时自动调用 `apply_expert_parallel()`。
+
+```python
+from twinkle.utils import DeviceMesh
+
+# 8 卡:2 路 EP × 4 路数据并行
+device_mesh = DeviceMesh.from_sizes(
+ world_size=8,
+ dp_size=4,
+ ep_size=2,
+)
+```
+
+EP + FSDP 组合分片:
+
+```python
+# 8 卡:2 路 EP,每个 EP 组内 2 路 FSDP
+device_mesh = DeviceMesh.from_sizes(
+ world_size=8,
+ dp_size=2,
+ ep_size=2,
+ ep_fsdp_size=2,
+)
+```
+
+## 通信模式
+
+EP 前向传播遵循 4 阶段流水线:
+
+1. **预处理** — 计算每个专家的 token 数量和分割大小
+2. **Token Pre-All2All** — 按专家分配排列 token,然后在 EP rank 间执行 all-to-all 交换
+3. **专家计算** — 每个 rank 在接收到的 token 上运行本地专家
+4. **Token Post-All2All** — all-to-all 交换结果,反排列并应用路由权重
+
+```
+输入 token → 路由器 → [预处理] → [pre_all2all] → [本地专家] → [post_all2all] → 输出
+```
+
+## 要求
+
+- `num_experts` 必须能被 `ep_size` 整除
+- `torch.distributed` 必须已初始化
+- MoE block 必须定义 `gate`/`router` 模块和 `experts`(支持 `nn.ModuleList` 或张量形式的 `gate_up_proj`/`down_proj`)
+- 共享专家(如 DeepSeek MoE)会自动处理,除非设置 `ignore_shared_experts=True`
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\345\272\217\345\210\227\345\271\266\350\241\214.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\345\272\217\345\210\227\345\271\266\350\241\214.md"
new file mode 100644
index 000000000..e1d997188
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/\345\272\217\345\210\227\345\271\266\350\241\214.md"
@@ -0,0 +1,68 @@
+# 序列并行 (SP)
+
+序列并行沿序列维度将长序列分割到多个 GPU 上,使训练能处理超出单卡显存的序列长度。Twinkle 实现了 Ulysses 风格的序列并行,并可选地支持派生环形注意力。
+
+## 概览
+
+| 概念 | 说明 |
+|------|------|
+| **SequenceParallelConfig** | SP 配置数据类 |
+| **SequenceParallelStrategy** | 封装 SP 生命周期的策略类 |
+| **SequenceParallel** | 核心实现,处理填充/分割/聚合 |
+
+## 配置
+
+```python
+from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelConfig
+
+config = SequenceParallelConfig(
+ enabled=True, # 启用序列并行
+ ulysses_size=None, # Ulysses SP 并行度(若为 None 则从 DeviceMesh 自动推导)
+ gather_logits=True, # 前向后聚合 logits 用于损失计算
+)
+```
+
+## 配合 DeviceMesh 使用
+
+在 `DeviceMesh.from_sizes()` 中设置 `ulysses_size` 即可激活 SP:
+
+```python
+from twinkle.utils import DeviceMesh
+
+# 8 卡:4 路 Ulysses SP × 2 路数据并行
+device_mesh = DeviceMesh.from_sizes(
+ world_size=8,
+ dp_size=2,
+ ulysses_size=4,
+)
+```
+
+## 工作原理
+
+1. **填充** — 输入序列被填充到可被 SP 并行度整除的长度
+2. **分割** — 填充后的输入沿序列维度均匀分配到各 SP rank
+3. **分布式注意力** — FlashAttention2 被 patch 为在注意力计算前后执行 Ulysses all-to-all 通信
+4. **聚合** — 前向传播后,logits 被聚合回完整序列长度用于损失计算
+
+## 支持的注意力后端
+
+| 后端 | 状态 |
+|------|------|
+| FlashAttention2 | 完全支持(包括打包/padding-free 序列)|
+| SDPA | 支持(仅非打包批次)|
+| 派生环形注意力 | 仅支持 FlashAttention2(`rp_world_size > 1`)|
+
+## Qwen3.5 线性注意力
+
+SP 自动检测 Qwen3.5 GatedDeltaNet 线性注意力层,并应用 `Qwen3_5GatedDeltaNetUlyssesPatch`,确保混合注意力架构下序列并行的正确性。
+
+## MoE 辅助损失
+
+对于 MoE 模型,SP 自动安装前向 hook,在计算辅助损失前跨 SP rank 聚合路由 logits,确保负载均衡信号的正确性。
+
+## 关键约束
+
+- `num_key_value_heads` 必须能被 `ulysses_size` 整除(Ulysses 模式),否则回退到环形注意力
+- 打包/padding-free 批次需要 FlashAttention2
+- 派生环形注意力要求 `batch_size == 1`(打包格式)
+- `torch.distributed` 必须已初始化
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/index.rst"
new file mode 100644
index 000000000..8b3692d51
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/index.rst"
@@ -0,0 +1,6 @@
+通知器
+===============
+.. toctree::
+ :maxdepth: 1
+
+ 通知器.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/\351\200\232\347\237\245\345\231\250.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/\351\200\232\347\237\245\345\231\250.md"
new file mode 100644
index 000000000..823e0d459
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\351\200\232\347\237\245\345\231\250/\351\200\232\347\237\245\345\231\250.md"
@@ -0,0 +1,93 @@
+# 通知器
+
+通知器组件提供可插拔的通知系统,用于在训练过程中发送告警。当异常发生或训练事件需要关注时,通知器将消息投递到外部渠道(如钉钉 Webhook)。
+
+## 基础接口
+
+```python
+from twinkle.notifier import Notifier
+
+class Notifier:
+ def __call__(self, message: str):
+ """发送通知消息"""
+ ...
+
+ def to_dict(self) -> dict:
+ """序列化(用于 checkpoint 保存/恢复)"""
+ ...
+
+ @classmethod
+ def from_dict(cls, data: dict) -> Notifier:
+ """从序列化数据恢复"""
+ ...
+```
+
+## DingNotifier(钉钉通知)
+
+向钉钉自定义机器人 Webhook 发送通知。
+
+```python
+from twinkle.notifier import DingNotifier
+
+notifier = DingNotifier(
+ ding_url='https://oapi.dingtalk.com/robot/send?access_token=xxx',
+ secret='SECxxxxxxx', # 可选:签名模式
+ timeout=5.0,
+)
+
+# 发送消息
+notifier("### 训练完成\n\n- Steps: 1000\n- Loss: 0.25")
+```
+
+**参数:**
+- `ding_url`:完整的钉钉 Webhook URL(含 access_token)
+- `secret`:可选签名密钥(签名模式机器人)
+- `timeout`:HTTP 请求超时时间,单位秒(默认 5.0)
+
+消息以钉钉 **Markdown** 格式发送。第一个标题行自动提取为聊天预览标题。
+
+## 异常通知
+
+Twinkle 提供带去重的自动异常通知:
+
+```python
+from twinkle.notifier.base import notify_exception
+
+# 自动发送格式化的异常信息
+# 每个唯一异常只有一个 rank 发送(防止消息洪泛)
+try:
+ model.forward_backward(batch)
+except Exception as e:
+ notify_exception(notifier, context='forward_backward', exc=e, name='sft_train')
+```
+
+通知包含:
+- 异常类型和消息
+- 完整堆栈跟踪
+- 运行时元数据(rank、PID、主机名)
+- 去重:所有 rank 中每个唯一异常只发一条通知
+
+## 自定义通知器
+
+继承 `Notifier` 创建自定义通知器:
+
+```python
+from twinkle.notifier import Notifier
+
+class SlackNotifier(Notifier):
+ def __init__(self, webhook_url: str):
+ self.webhook_url = webhook_url
+
+ def __call__(self, message: str):
+ import requests
+ requests.post(self.webhook_url, json={'text': message})
+
+ def to_dict(self):
+ return {'class': 'SlackNotifier', 'webhook_url': self.webhook_url}
+
+ @classmethod
+ def _from_dict_impl(cls, data):
+ return cls(webhook_url=data['webhook_url'])
+```
+
+> 通知器通过 `__init_subclass__` 自动注册,因此 `Notifier.from_dict()` 可以按类名恢复任何子类。
diff --git a/new_feature.txt b/new_feature.txt
new file mode 100644
index 000000000..39ce0d763
--- /dev/null
+++ b/new_feature.txt
@@ -0,0 +1,204 @@
+# Native ML LLM Control
+
+基于 Twinkle Server Mode 架构,实现 LLM Agent 驱动的 TUI 训练控制系统。面向零基础开发者,通过自然语言对话完成模型训练的全生命周期管理。
+
+
+## 一、整体架构
+
+采用"无状态客户端 + 有状态服务端"的 Server Mode 架构:
+
+- **Server 端(Ray 集群)**:模型权重、LoRA adapter、optimizer 状态、LR scheduler 全部驻留在 GPU 内存
+- **Client 端(TUI + 训练脚本)**:完全无状态,仅负责数据加载、训练循环逻辑、指标上报
+- **核心特性**:杀死 client = 暂停(server 保留全部状态),重启 client = 恢复(零成本继续训练)
+
+支持两种运行模式:
+1. **本地自建**:启动本地 Ray 集群 + Twinkle Server,需评估 GPU 资源和 DeviceMesh
+2. **线上云服务**:连接 ModelScope 托管服务(`http://www.modelscope.cn/twinkle`),无需本地 GPU
+
+
+## 二、TUI 界面(基于 Textual 框架)
+
+TUI 采用 Grid 布局,包含四个核心面板:
+
+1. **状态栏(StatusBar)**:顶部横跨,显示当前训练 run_id、状态、步数进度
+2. **指标面板(MetricsPanel)**:左上区域,绘制 loss / reward / grad_norm 等指标曲线,支持自然语言控制放大缩小和还原
+3. **对话面板(ChatPanel)**:左下区域,用户与 LLM agent 的对话界面,支持 UTF-8 中文输入输出
+4. **日志面板(LogPanel)**:右侧纵向,滚动显示训练运行日志
+
+快捷键:`q` 退出、`Ctrl+P` 切换指标面板、`Ctrl+L` 清空日志。
+
+关闭 TUI 后训练不中断(Server 端状态不受影响),重新打开 TUI 可继续监控。
+
+
+## 三、LLM Agent 系统
+
+### 3.1 对话式 Agent(AgentLoop)
+
+用户通过 Chat 面板与 Agent 对话,Agent 通过 tool_call 执行训练管理操作:
+
+| 工具 | 功能 |
+|------|------|
+| `list_training_runs` | 列出所有活跃和历史训练任务 |
+| `get_training_status` | 获取指定 run 的状态和近期指标 |
+| `pause_training` | 暂停训练(SIGKILL client,server 保留状态) |
+| `resume_training` | 恢复训练(重启 client 脚本) |
+| `stop_training` | 优雅停止(SIGTERM,脚本自动保存 checkpoint + dataloader 位置后退出) |
+| `list_supported_models` | 查询 Server 支持的模型列表(本地/云端) |
+| `search_datasets` | 在 ModelScope 搜索数据集 |
+| `search_models` | 在 ModelScope 搜索模型 |
+| `zoom_metrics` | 自然语言控制指标图表缩放 |
+
+### 3.2 自动监控(TrainingMonitor)
+
+后台 LLM 定期(默认 30 秒)读取 metrics 和 logs,进行趋势分析:
+
+- 检测异常:loss 突增/NaN、reward 停滞、gradient 爆炸/消失、KL 散度过大、entropy 坍塌
+- 主动建议:调整学习率、更换 reward 组合、增加 num_generations
+- 通过 Chat 面板推送诊断报告(`[Monitor] ...`)
+- 无硬编码规则,所有分析由 LLM 推理完成
+
+### 3.3 Skills 可扩展框架
+
+通过 `SkillManager` + `ModelScopeSkillProvider` 加载技能文档,为 Agent 提供领域知识:
+
+- `skills/twinkle-training.md`:训练脚本编写指导(1260+ 行)
+- `skills/autoresearch.md`:自动化研究实验设计(256 行)
+- 支持从 ModelScope 远程加载社区共享 Skills
+
+
+## 四、训练控制机制
+
+### 4.1 训练进程管理
+
+| 操作 | 信号 | 行为 | 恢复方式 |
+|------|------|------|----------|
+| 暂停 | SIGKILL | 立即杀死 client | 重启同一脚本(adapter_name 相同即可继续) |
+| 停止 | SIGTERM | 脚本保存 checkpoint + dataloader 位置后退出 | `model.resume_from_checkpoint()` + `dataloader.resume_from_checkpoint()` |
+| 修改超参 | SIGKILL → 编辑 → 重启 | 新配置生效,optimizer 状态保留 | 使用相同 adapter_name |
+| 重置训练 | 使用新 adapter_name | 全新开始 | 旧 adapter 按 `adapter_timeout` 自动清理 |
+
+### 4.2 优雅退出(SIGTERM)
+
+每个训练脚本必须注册 graceful shutdown handler:
+
+```python
+rt = TrainingRuntime(run_id='my-exp')
+rt.register_graceful_shutdown(model, dataloader)
+```
+
+收到 SIGTERM 后自动执行:
+1. 保存模型 checkpoint(含 optimizer 状态)
+2. 记录 dataloader 已消费的样本数(`consumed_train_samples`)
+3. 写入 `rt.finish(status='stopped')`
+4. 安全退出
+
+
+## 五、数据通信(TUI ↔ 训练脚本)
+
+训练脚本通过 `TrainingRuntime` 写入本地 JSONL 文件,TUI 通过 `LocalConnection` 读取:
+
+```
+~/.cache/twinkle/{run_id}/
+├── meta.json # 运行元信息(model_id、config、status、pid、script_path、script_version)
+├── train.py # 当前活跃版本(始终是最新的)
+├── train_v1.py # 归档:第1版脚本(出错的原始版本)
+├── train_v2.py # 归档:第2版脚本(如果也有问题)
+├── metrics.jsonl # 每步一行 JSON(step, loss, reward, grad_norm, lr, ...)
+└── logs.jsonl # 事件日志(timestamp + message)
+```
+
+**脚本命名与版本管理规则:**
+- `train.py` 始终是当前活跃版本,`resume_training` 只执行它
+- 当 Agent 修复脚本时,旧版自动归档为 `train_v{N}.py`(保留完整修改历史)
+- `meta.json` 中 `script_version` 字段记录当前版本号
+- `run_id` 由用户定义(如 `'grpo-gsm8k'`、`'sft-self-cognition'`)
+- 同一 `run_id` 下可多次修改脚本:server 端 adapter 状态不变,只有 client 逻辑更新
+
+**脚本更新流程(Agent 自动执行):**
+1. 脚本出错停止 → Agent 读取 logs/metrics 诊断问题
+2. Agent 调用 `update_script(run_id, new_code)` → 旧 `train.py` 归档为 `train_v{N}.py`,新代码写入 `train.py`
+3. Agent 调用 `resume_training(run_id)` → 重新执行最新的 `train.py`
+
+**TUI 通过 PID 进行进程控制:**
+- `meta.json` 记录 `pid`(进程 ID)
+- 暂停 = `os.kill(pid, SIGKILL)`
+- 停止 = `os.kill(pid, SIGTERM)` → 脚本优雅保存 checkpoint
+- 恢复 = `subprocess.Popen(['python', script_path])` → 新 PID 写回 meta
+
+TUI 支持增量读取(tail 模式),避免大文件全量加载。
+
+
+## 六、训练前规划(Pre-Training Planning)
+
+Agent 在编写训练脚本前,必须完成以下评估(本地模式):
+
+1. **集群资源评估**:GPU 数量、型号、显存(`nvidia-smi` / `ray status`)
+2. **模型显存估算**:LoRA training ≈ model weights + 20% overhead
+3. **DeviceMesh 设计**:根据 GPU 数决定 model vs sampler 分配(决策树)
+4. **训练时间预估**:`total_steps × time_per_step`
+5. **数据集搜索**:通过 ModelScope API 或 `search_datasets` 工具
+6. **模型选择**:根据任务类型和资源约束推荐
+
+云服务模式下可跳过 1-4,直接进入数据集和模型选择。
+
+`list_supported_models` 工具用于查询 Server 实际支持的模型列表,避免选择不可用模型。
+
+
+## 七、Skills 文档内容
+
+### twinkle-training.md(训练脚本编写指导)
+
+覆盖以下内容:
+- Pre-Training Planning 完整流程
+- Ray 集群配置(DeviceGroup / DeviceMesh / initialize)
+- 模型后端(Transformers / Megatron)初始化
+- Dataset 加载、Template 编码、Preprocessor 使用
+- 所有训练方式示例:SFT、GRPO、DPO、GKD、PT
+- Server Mode 完整说明(本地自建 + 云服务两种模式)
+- Cloud Service Mode(ModelScope 托管,两种客户端 API 对比)
+- Sampler 配置与权重同步
+- MultiTurnRollout 多轮对话采样
+- TUI 集成:TrainingRuntime、指标上报、优雅退出
+- 实验管理文件夹规范
+
+### autoresearch.md(自动化研究实验设计)
+
+指导 Agent 如何:
+- 分析用户需求,选取合适的训练方法
+- 根据资源约束选择模型规模
+- 配置超参数(SFT / GRPO / DPO 默认值和调优建议)
+- 设计多阶段 Pipeline(数据清洗 → SFT → GRPO/DPO)
+- 编写数据清洗和转换流程
+- 组织实验输出文件夹
+
+
+## 八、训练脚本规范
+
+所有训练脚本必须:
+
+1. 使用 **Server Mode 语法**:`twinkle_client`(模型操作)+ `twinkle`(数据处理)
+2. 连接到运行中的 Twinkle Server(本地或云端)
+3. 注册 SIGTERM graceful shutdown handler
+4. 通过 `TrainingRuntime` 上报所有可用指标(loss, reward, grad_norm, lr, ...)
+5. 每个实验独立文件夹,包含 `plan.md`、`config.yaml`、`train.py`、`train.sh`
+
+两种客户端 API:
+- **Twinkle 原生**:`init_twinkle_client()` → `MultiLoraTransformersModel` → `forward_backward()` → `clip_grad_and_step()`
+- **Tinker 兼容**:`init_tinker_client()` → `ServiceClient` → `create_lora_training_client()` → `forward_backward()` → `optim_step()`
+
+
+## 九、插件化架构与组件约束
+
+Twinkle 是一个插件化框架,所有核心组件(loss、preprocessor、metric、sampler 等)均以插件形式注册,本地模式下支持用户编写新组件。
+
+### 本地自建模式
+- 完全可扩展:可编写自定义 loss、preprocessor、metric、reward function
+- 通过装饰器注册(如 `@register_loss('MyLoss')`),然后在训练脚本中以字符串名引用
+
+### 线上云服务模式(ModelScope 托管)
+- **安全限制**:不支持传入类、函数对象或 pickle 序列化
+- **只能使用已注册的内置组件**,以字符串名引用
+- 内置 Loss:`CrossEntropyLoss`、`GRPOLoss`、`DPOLoss`、`GKDLoss`
+- 内置 Preprocessor:`SFTPreprocessor`、`RLPreprocessor`、`DPOPreprocessor`
+
+Agent 编写脚本时必须判断目标环境:本地可用自定义组件,云端只能用内置名称。
diff --git a/pyproject.toml b/pyproject.toml
index 5daa14fe9..484dbc9e8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,24 +16,28 @@ dependencies = [
"transformers",
"typer>=0.9.0",
"pyzmq",
+ "accelerate",
+ "torch>=2.6.0,<3.0.0",
]
[project.scripts]
twinkle-server = "twinkle.server.cli:main"
+twinkle-tui = "twinkle_client.tui:main"
[project.optional-dependencies]
-transformers = [
- "accelerate",
- "torch>=2.6.0,<3.0.0",
- "torchvision",
-]
-kernels = ["kernels"]
megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]", "mcore_bridge"]
-vllm = ["vllm>=0.11"]
-ray = ["ray[serve]"]
-datajuicer = ["py-data-juicer"]
-tinker = ["tinker==0.14.0"]
-test = ["hypothesis>=6.0", "pytest", "pytest-asyncio"]
+data = ["py-data-juicer"]
+rl = [
+ "vllm>=0.11",
+ "ray[serve]"
+]
+client = [
+ "textual>=1.0.0",
+ "plotext>=5.2.0",
+ "openai>=1.0.0",
+ "httpx>=0.25.0",
+ "tinker==0.16.1",
+]
server = [
"redis>=5.0",
"psutil>=5.9.0",
@@ -43,6 +47,11 @@ server = [
"opentelemetry-exporter-otlp",
"opentelemetry-instrumentation-logging",
]
+test = [
+ "hypothesis>=6.0",
+ "pytest",
+ "pytest-asyncio"
+]
docs = [
"sphinx>=5.3.0,<6.0.0",
"docutils>=0.16.0,<0.17.0",
@@ -68,3 +77,6 @@ build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
+
+[tool.setuptools.package-data]
+"twinkle_client.skills.bundled" = ["*.md"]
diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py
index cde5c519d..3860d2840 100644
--- a/src/twinkle/checkpoint_engine/manager.py
+++ b/src/twinkle/checkpoint_engine/manager.py
@@ -1,6 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py
-import time
from typing import List, Optional
from twinkle import Platform, get_logger
diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py
index e2e5d94d5..8dc15c926 100644
--- a/src/twinkle/checkpoint_engine/mixin.py
+++ b/src/twinkle/checkpoint_engine/mixin.py
@@ -1,5 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import os
from twinkle import Platform, remote_function
from twinkle.checkpoint_engine.base import CheckpointEngine
diff --git a/src/twinkle/cli/cli.py b/src/twinkle/cli/cli.py
index 1730887f2..085ad7ec6 100644
--- a/src/twinkle/cli/cli.py
+++ b/src/twinkle/cli/cli.py
@@ -1,12 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-from __future__ import annotations
-
import os
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from pathlib import Path
-from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union
+from typing import Any, Iterator, Literal
+
# ────────────────────────────────────────────────────────────────────────────────
# Arg group dataclasses
@@ -46,6 +45,7 @@ class LoraArgs:
lora_dropout: float = 0.05
lora_target_modules: list[str] | None = None
adapter_name: str = 'default'
+ lora_path: str | None = None
@dataclass
@@ -84,6 +84,7 @@ class TrainingArgs:
log_interval: int = 10
eval_interval: int | None = None
eval_samples: int | None = None
+ train_samples: int | None = None
resume_from_checkpoint: str | None = None
resume_only_model: bool = False
ignore_data_skip: bool = False
@@ -117,9 +118,11 @@ class SchedulerArgs:
@dataclass
class LossArgs:
loss_cls: str = 'CrossEntropyLoss'
+ loss_type: str = 'sigmoid'
epsilon: float = 0.2
epsilon_high: float | None = None
- beta: float = 0.0
+ beta: float = 0.1
+ sft_weight: float = 1.0
entropy_coef: float = 0.0
ignore_index: int = -100
@@ -155,6 +158,7 @@ class InfraArgs:
ncpu_proc_per_node: int = 8
model_gpus: int | None = None
sampler_gpus: int | None = None
+ ref_model_gpus: int | None = None
world_size: int | None = None
dp_size: int | None = None
fsdp_size: int | None = None
@@ -185,6 +189,11 @@ class RLArgs:
advantage_type: str = 'GRPOAdvantage'
advantage_scale: Literal['group', 'batch', 'none'] = 'group'
reward_fns: list[str] | None = None
+ student_model_id: str | None = None
+ teacher_model_id: str | None = None
+ gkd_beta: float = 0.5
+ gkd_temperature: float = 1.0
+ gkd_topk: int = 64
@dataclass
@@ -192,6 +201,7 @@ class CheckpointArgs:
save_optimizer: bool = True
merge_and_sync: bool = True
platform: str = 'GPU'
+ lora_sync_dir: str | None = None
# ────────────────────────────────────────────────────────────────────────────────
@@ -243,7 +253,7 @@ def _resolve_path(self) -> Path | None:
class EnvVarSource(ConfigSource):
"""Reads os.environ; recognizes TWINKLE_ prefix and any key known to the registry."""
- def __init__(self, registry: ConfigRegistry):
+ def __init__(self, registry: 'ConfigRegistry'):
self._registry = registry
def load(self) -> dict[str, str]:
diff --git a/src/twinkle/data_format/output.py b/src/twinkle/data_format/output.py
index 763ef246f..596252fb6 100644
--- a/src/twinkle/data_format/output.py
+++ b/src/twinkle/data_format/output.py
@@ -20,11 +20,13 @@ class ModelOutput(TypedDict, total=False):
loss: The loss calculated by the model.
logps: The log-probabilities of correct tokens by the model.
num_tokens: The token denominator associated with ``loss``.
+ embeddings: The embeddings output by the model, used be embedding task.
"""
logits: Optional[OutputType]
loss: Optional[OutputType]
logps: Optional[OutputType]
num_tokens: Optional[OutputType]
+ embeddings: Optional[OutputType]
class LossOutput(TypedDict, total=False):
diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py
index 1d5fe07c6..01ff0377d 100644
--- a/src/twinkle/data_format/sampling.py
+++ b/src/twinkle/data_format/sampling.py
@@ -1,5 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py
index c392d56cf..408c8d4b4 100644
--- a/src/twinkle/dataloader/dataloader.py
+++ b/src/twinkle/dataloader/dataloader.py
@@ -146,7 +146,7 @@ def _tracking_iter(self, inner):
def skip_consumed_samples(self, consumed_train_samples: int) -> None:
from torch.utils.data import IterableDataset
- if isinstance(self.dataset, IterableDataset):
+ if isinstance(self.dataset, IterableDataset) or consumed_train_samples is None or consumed_train_samples <= 0:
warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.')
self._skip_samples = 0
return
@@ -164,6 +164,7 @@ def resume_from_checkpoint(self, consumed_train_samples, **kwargs):
@remote_function()
def get_state(self) -> dict:
+ """The dataloader state for saving."""
return {'consumed_train_samples': self._consumed_train_samples}
def _rebuild_sampler_stack(self):
diff --git a/src/twinkle/dataset/iterable_packing_dataset.py b/src/twinkle/dataset/iterable_packing_dataset.py
index ca7c6fbd8..ab1d3a982 100644
--- a/src/twinkle/dataset/iterable_packing_dataset.py
+++ b/src/twinkle/dataset/iterable_packing_dataset.py
@@ -88,10 +88,27 @@ def _fetch_data_out_queue(self, last_res, num_samples):
last_res += res
return last_res
- @staticmethod
- def _cyclic_iter(iterable):
- while True:
- yield from iterable
+ def _write_through_iter(self, iterable):
+ """Yields from iterable, meanwhile, save it to disk if needed.
+ Saving is needed when you are using several datasets at a time.
+ """
+ if not self.cyclic:
+ for row in iterable:
+ self._write_through(row)
+ yield row
+ return
+ else:
+ first_pass = True
+ while True:
+ empty = True
+ for row in iterable:
+ empty = False
+ if first_pass:
+ self._write_through(row)
+ yield row
+ if empty:
+ return
+ first_pass = False
@remote_function()
def __iter__(self):
@@ -102,10 +119,7 @@ def __iter__(self):
except StopIteration:
return
- if self.cyclic:
- iterator = self._cyclic_iter(self.dataset)
- else:
- iterator = iter(self.dataset)
+ iterator = self._write_through_iter(self.dataset)
data = []
max_length = self.template.max_length or 2048
while True:
diff --git a/src/twinkle/dataset/packing_dataset.py b/src/twinkle/dataset/packing_dataset.py
index fa4acbd57..ada9498b8 100644
--- a/src/twinkle/dataset/packing_dataset.py
+++ b/src/twinkle/dataset/packing_dataset.py
@@ -114,6 +114,8 @@ def __getitem__(self, index):
assert self._packed_called, 'Call `pack_dataset()` first before index the sample.'
sequence = self.packed_idx[index]
rows = [self.dataset[i] for i in sequence]
+ for row in rows:
+ self._write_through(row)
output = {}
for key in rows[0]:
output[key] = [r[key] for r in rows]
diff --git a/src/twinkle/gym/__init__.py b/src/twinkle/gym/__init__.py
deleted file mode 100644
index 44b0771bb..000000000
--- a/src/twinkle/gym/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Copyright (c) ModelScope Contributors. All rights reserved.
-from .base import Gym
diff --git a/src/twinkle/gym/base.py b/src/twinkle/gym/base.py
deleted file mode 100644
index aca798093..000000000
--- a/src/twinkle/gym/base.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# Copyright (c) ModelScope Contributors. All rights reserved.
-
-
-class Gym:
-
- def __init__(self):
- pass
-
- def step(self):
- pass
diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py
index 83e10d132..a2760c900 100644
--- a/src/twinkle/infra/__init__.py
+++ b/src/twinkle/infra/__init__.py
@@ -1,11 +1,9 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import functools
import inspect
-import itertools
import json
import numpy as np
import os
-import random
import sys
from typing import Any, Callable, List, Literal, Optional, TypeVar, Union
@@ -59,7 +57,7 @@ def _tag_exc(exc: BaseException, caller: Optional[str]) -> None:
prefix = f'[twinkle driver caller: {caller}] '
exc.args = (prefix + str(exc.args[0]), *exc.args[1:]) if exc.args else (prefix.rstrip(), )
exc._twinkle_caller_augmented = True
- except Exception: # noqa: BLE001
+ except Exception: # noqa
pass
@@ -404,6 +402,7 @@ def dispatch_func(arg, n):
return result
elif dispatch == 'slice_dp':
+ assert device_mesh is not None
# split by dp. each worker in one ep will receive the same argument
result = []
# if device_mesh is not None:
@@ -420,14 +419,6 @@ def dispatch_func(arg, n):
import torch
if isinstance(arg, list) or isinstance(arg, torch.Tensor):
_args = []
- if device_mesh is None:
- total = len(arg)
- chunk = max(1, (total + n - 1) // n)
- for i in range(n):
- start = i * chunk
- end = min(total, start + chunk)
- _args.append(arg[start:end])
- return _args
for i in range(n):
_args.append(arg[device_mesh.get_slice(
len(arg), device_mesh.get_data_rank_from_global_rank(i * _rank_stride))])
@@ -696,11 +687,12 @@ def __next__(_self):
return decorator
-def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callable] = 'slice',
+def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp', 'last_pp_first'], Callable] = 'slice',
execute: Literal['first', 'peer', 'all'] = 'all',
collect: Union[Literal['none', 'flatten', 'mean', 'sum', 'first', 'last_pp'], Callable] = 'none',
sync: bool = False,
- lazy_collect: Optional[bool] = None):
+ lazy_collect: Optional[bool] = None,
+ timeout: Optional[float] = None):
"""Patch each method called from remote(which class should be decorated with `remote_class`) with this decorator.
Args:
@@ -726,6 +718,7 @@ def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callabl
sync: If True, use synchronous execution (execute_all_sync) instead of async.
Required for methods with NCCL collective operations (e.g., Megatron forward_backward).
lazy_collect: Do lazy collect, this boolean value decides whether this function needs lazy collect. If setting to None, it will follow the global setting.
+ timeout: Timeout in seconds for ray.get() when collecting results. Instance attribute ``_ray_get_timeout`` overrides this.
""" # noqa
def decorator(func: Callable[..., T1]) -> Callable[..., T1]:
@@ -773,7 +766,9 @@ def wrapper(self, *args, **kwargs) -> T1:
result = execute_method(func.__name__, _workers_and_args)
# This is a result future, call it to get the actual result
- result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result, device_mesh)
+ _rgt = getattr(self, '_ray_get_timeout', None) or timeout
+ result_func = RayHelper.do_get_and_collect_func(
+ _collect_func, collect, result, device_mesh, timeout=_rgt)
_local_lazy_collect = _lazy_collect
if func.__name__ == '__iter__':
# return self
@@ -803,18 +798,13 @@ def wrapper(self, *args, **kwargs) -> T1:
# And this is user independent, only decided by the code.
_local_lazy_collect = self._lazy_collect
if _local_lazy_collect:
- # Wrap the deferred collector so that exceptions
- # raised when the caller later materializes the
- # result also trigger the notifier. Attributes
- # (``_futures`` etc.) on the original collector
- # are preserved for downstream code paths.
_orig_result_func = result_func
@functools.wraps(_orig_result_func)
def _notifying_result_func(*rargs, **rkwargs):
try:
return _orig_result_func(*rargs, **rkwargs)
- except Exception as _e: # noqa: BLE001
+ except Exception as _e: # noqa
_tag_exc(_e, _caller)
notify_exception(_notifier, _ctx, _e, _name)
raise
diff --git a/src/twinkle/infra/_ray/ray_helper.py b/src/twinkle/infra/_ray/ray_helper.py
index 0d8908a35..5cd792c3a 100644
--- a/src/twinkle/infra/_ray/ray_helper.py
+++ b/src/twinkle/infra/_ray/ray_helper.py
@@ -161,18 +161,20 @@ def get_node_address():
return ip, port
@staticmethod
- def do_get_and_collect_func(collect_func: Callable, method: Union[str, Callable], futures, device_mesh):
+ def do_get_and_collect_func(collect_func: Callable, method: Union[str, Callable], futures, device_mesh,
+ timeout=None):
"""Return a callable to collect results in the workers."""
class LazyCollect:
- def __init__(self, futures, method, collect_func, device_mesh):
+ def __init__(self, futures, method, collect_func, device_mesh, timeout=None):
self._futures = futures
self._method = method
self._collect_func = collect_func
self._is_lazy_collect = True
self.device_mesh = device_mesh
self._result = None # Cache collected results
+ self._timeout = timeout
def _get_result(self):
"""Internal method to lazily collect and cache results"""
@@ -181,7 +183,7 @@ def _get_result(self):
result = []
for future in self._futures:
if isinstance(future, ray.ObjectRef):
- result.append(ray.get(future))
+ result.append(ray.get(future, timeout=self._timeout))
else:
result.append(future)
self._result = self._collect_func(self._method, result, device_mesh=self.device_mesh)
@@ -199,7 +201,7 @@ def __len__(self):
"""Support len() function"""
return len(self._get_result())
- return LazyCollect(futures, method, collect_func, device_mesh)
+ return LazyCollect(futures, method, collect_func, device_mesh, timeout=timeout)
@staticmethod
def do_get_and_collect(args, kwargs):
diff --git a/src/twinkle/loss/chunked_cross_entropy.py b/src/twinkle/loss/chunked_cross_entropy.py
index 22d3d4077..061ca2168 100644
--- a/src/twinkle/loss/chunked_cross_entropy.py
+++ b/src/twinkle/loss/chunked_cross_entropy.py
@@ -1,63 +1,159 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import math
-from typing import Any
-
-from ..data_format import LossOutput
+from twinkle.data_format import LossOutput
from .base import Loss
+# Lazily-built singleton autograd.Function, so we neither pay the
+# class-construction cost on every forward nor force a top-level torch import.
+_CHUNKED_CE_FUNC = None
+
+
+def _get_chunked_ce_func():
+ global _CHUNKED_CE_FUNC
+ if _CHUNKED_CE_FUNC is not None:
+ return _CHUNKED_CE_FUNC
+
+ import torch
+ import torch.nn.functional as F
+
+ class _ChunkedCrossEntropyFunc(torch.autograd.Function):
+ """Chunked CE that materialises log_softmax(B, V) only one chunk at a time.
+
+ Forward returns a scalar loss; backward writes per-token gradients into
+ a freshly allocated `grad_logits` tensor (the input `logits` is never
+ mutated). Mathematically equivalent to ``CrossEntropyLoss`` in the same
+ package; ``chunk_size`` only controls the memory/throughput trade-off.
+ """
+
+ @staticmethod
+ def forward(ctx, logits, labels, chunk_size, ignore_index, reduction, dft):
+ ctx.save_for_backward(logits, labels)
+ ctx.chunk_size = chunk_size
+ ctx.ignore_index = ignore_index
+ ctx.reduction = reduction
+ ctx.dft = dft
+
+ n = logits.shape[0]
+ # Use fp32 accumulators so we don't lose precision when summing
+ # over many tokens under fp16/bf16 autocast (matches cross_entropy.py).
+ total_loss = logits.new_zeros((), dtype=torch.float32)
+ total_count = logits.new_zeros((), dtype=torch.float32)
+
+ for start in range(0, n, chunk_size):
+ end = min(start + chunk_size, n)
+ logits_chunk = logits[start:end]
+ labels_chunk = labels[start:end]
+ mask = (labels_chunk != ignore_index).float()
+
+ logps = F.log_softmax(logits_chunk, dim=-1).gather(
+ -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1)
+ per_token = -logps * logps.exp() if dft else -logps
+
+ total_loss = total_loss + (per_token * mask).sum()
+ total_count = total_count + mask.sum()
+
+ ctx.num_tokens = total_count.detach()
+ if reduction == 'mean':
+ return total_loss / total_count.clamp(min=1)
+ return total_loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ logits, labels = ctx.saved_tensors
+ chunk_size = ctx.chunk_size
+ ignore_index = ctx.ignore_index
+ reduction = ctx.reduction
+ dft = ctx.dft
+
+ if reduction == 'mean':
+ scale = grad_output / ctx.num_tokens.clamp(min=1)
+ else:
+ scale = grad_output
+
+ grad_logits = torch.empty_like(logits)
+ n = logits.shape[0]
+
+ for start in range(0, n, chunk_size):
+ end = min(start + chunk_size, n)
+ logits_chunk = logits[start:end].detach().requires_grad_(True)
+ labels_chunk = labels[start:end]
+ mask = (labels_chunk != ignore_index).float()
+
+ with torch.enable_grad():
+ logps = F.log_softmax(logits_chunk, dim=-1).gather(
+ -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1)
+ per_token = -logps * logps.exp() if dft else -logps
+ loss_chunk = (per_token * mask).sum()
+
+ grad_chunk = torch.autograd.grad(loss_chunk, logits_chunk, retain_graph=False)[0]
+ grad_logits[start:end] = grad_chunk * scale
+
+ # logits, labels, chunk_size, ignore_index, reduction, dft
+ return grad_logits, None, None, None, None, None
+
+ _CHUNKED_CE_FUNC = _ChunkedCrossEntropyFunc
+ return _CHUNKED_CE_FUNC
+
class ChunkedCrossEntropyLoss(Loss):
- """TODO untested code"""
+ """CE loss that chunks the (B, V) softmax to bound peak memory.
+
+ Drop-in replacement for :class:`CrossEntropyLoss` when ``outputs['logits']``
+ is large (e.g. long sequence x big vocab). Behaviour matches that loss
+ bit-for-bit; ``chunk_size`` only affects memory/throughput.
+
+ Args:
+ chunk_size: How many rows of ``logits`` to process per chunk.
+ ignore_index: Label id treated as padding (excluded from loss).
+ reduction: ``'mean'`` or ``'sum'``; matches ``CrossEntropyLoss``.
+ dft: If True, use DFT weighting ``-p*log(p)`` (arxiv 2508.05629).
+ """
- def __init__(self, chunk_size):
+ require_logits = True
+ # We chunk the (B, V) softmax ourselves; tell upstream not to materialise
+ # `logps` (which would already pay the full memory cost we're trying to
+ # avoid). The `_loss_from_logps` fast path is kept only for the rare case
+ # where someone explicitly hands us pre-computed logps.
+ require_logps = False
+
+ def __init__(self,
+ chunk_size: int,
+ ignore_index: int = -100,
+ reduction: str = 'mean',
+ dft: bool = False,
+ **kwargs):
+ super().__init__()
+ assert chunk_size > 0, 'chunk_size must be positive'
+ assert reduction in ('mean', 'sum'), f"reduction must be 'mean' or 'sum', got {reduction!r}"
self.chunk_size = chunk_size
+ self.ignore_index = ignore_index
+ self.reduction = reduction
+ self.dft = dft
def __call__(self, inputs, outputs, **kwargs):
- import torch
-
- class ChunkedCrossEntropyLossFunc(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, logits, labels, chunk_size):
- import torch
- ctx.save_for_backward(logits, labels)
- ctx.chunk_size = chunk_size
-
- losses = []
- for i in range(math.ceil(logits.shape[0] / chunk_size)):
- l_start = i * chunk_size
- l_end = min((i + 1) * chunk_size, logits.shape[0])
- logits_chunk = logits[l_start:l_end]
- labels_chunk = labels[l_start:l_end]
- loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
- loss_chunk = loss_fct(logits_chunk, labels_chunk)
- losses.append(loss_chunk)
- del logits_chunk
- del labels_chunk
- all_losses = torch.cat(losses)
- return all_losses
-
- @staticmethod
- def backward(ctx: Any, *grad_outputs: Any):
- import torch
- logits, labels = ctx.saved_tensors
- chunk_size = ctx.chunk_size
-
- for i in range(math.ceil(logits.shape[0] / chunk_size)):
- l_start = i * chunk_size
- l_end = min((i + 1) * chunk_size, logits.shape[0])
- logits_chunk = logits[l_start:l_end].detach().requires_grad_(True)
- labels_chunk = labels[l_start:l_end]
- loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
- with torch.enable_grad():
- loss_chunk = loss_fct(logits_chunk, labels_chunk)
- grad_output_chunk = grad_outputs[0][l_start:l_end]
- _loss_chunk = (loss_chunk * grad_output_chunk).sum()
- grad_chunk = torch.autograd.grad(_loss_chunk, logits_chunk, retain_graph=False)[0]
- logits[l_start:l_end] = grad_chunk
-
- return logits, None, None
+ labels = inputs['labels']
+ logps = outputs.get('logps')
+
+ # Fast path: if logps is already gathered upstream, chunking the
+ # softmax is moot — fall back to the same scalar formula as
+ # CrossEntropyLoss to keep behaviour identical.
+ if logps is not None:
+ return self._loss_from_logps(labels, logps)
logits = outputs['logits']
- labels = inputs['labels']
- return LossOutput(loss=ChunkedCrossEntropyLossFunc.apply(logits, labels, self.chunk_size), num_tokens=0)
+ labels = labels.view(-1)
+ logits = logits.view(-1, logits.shape[-1])
+
+ func = _get_chunked_ce_func()
+ loss = func.apply(logits, labels, self.chunk_size, self.ignore_index, self.reduction, self.dft)
+
+ if self.reduction == 'mean':
+ return LossOutput(loss=loss, num_tokens=0)
+ num_tokens = (labels != self.ignore_index).float().sum().clamp(min=1)
+ return LossOutput(loss=loss, num_tokens=num_tokens)
+
+ def _loss_from_logps(self, labels, logps):
+ mask = (labels != self.ignore_index).float()
+ per_token = -logps * logps.exp() if self.dft else -logps
+ if self.reduction == 'mean':
+ return LossOutput(loss=(per_token * mask).sum() / mask.sum().clamp(min=1), num_tokens=0)
+ return LossOutput(loss=(per_token * mask).sum(), num_tokens=mask.sum().clamp(min=1))
diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py
index fe526ab46..d53019513 100644
--- a/src/twinkle/loss/dpo.py
+++ b/src/twinkle/loss/dpo.py
@@ -7,7 +7,7 @@
(https://arxiv.org/abs/2305.18290)
"""
from typing import TYPE_CHECKING, Dict, List, Optional, Union
-
+import math
from twinkle.data_format import LossOutput
from twinkle.loss.base import Loss
from twinkle.utils.torch_utils import selective_log_softmax
@@ -132,6 +132,12 @@ def __init__(
**kwargs,
):
super().__init__(ignore_index=ignore_index)
+ if loss_type not in ('sigmoid', 'hinge', 'ipo', 'kto_pair'):
+ raise ValueError(f'Unknown loss_type: {loss_type}')
+ if label_smoothing > 0 and loss_type != 'sigmoid':
+ raise ValueError(
+ f'label_smoothing > 0 is only defined for loss_type="sigmoid", '
+ f'got loss_type="{loss_type}". Set label_smoothing=0.0 or switch to sigmoid.')
self.beta = beta
self.label_smoothing = label_smoothing
self.loss_type = loss_type
@@ -217,6 +223,11 @@ def _compute_dpo_loss(
if self.loss_type == 'sigmoid':
# Standard DPO loss: -log(sigmoid(beta * margin))
losses = -F.logsigmoid(logits)
+ # Apply label smoothing (only meaningful here: Bradley-Terry soft labels).
+ if self.label_smoothing > 0:
+ # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected
+ smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference
+ losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses
elif self.loss_type == 'hinge':
# Hinge loss variant
losses = torch.relu(1 - logits)
@@ -234,12 +245,6 @@ def _compute_dpo_loss(
else:
raise ValueError(f'Unknown loss_type: {self.loss_type}')
- # Apply label smoothing if specified
- if self.label_smoothing > 0:
- # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected
- smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference
- losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses
-
return losses.mean()
def __call__(
@@ -321,7 +326,8 @@ def __call__(
reference_chosen_logps = torch.zeros_like(policy_chosen_logps)
reference_rejected_logps = torch.zeros_like(policy_rejected_logps)
else:
- return LossOutput(loss=torch.tensor(0.0, device=chosen_logps.device), num_tokens=0)
+ zero = (policy_chosen_logps.sum() + policy_rejected_logps.sum()) * 0.0
+ return LossOutput(loss=zero, num_tokens=0)
# Compute DPO loss
dpo_loss = self._compute_dpo_loss(
@@ -535,11 +541,23 @@ def __call__(
# Odds ratio: log(odds_chosen / odds_rejected)
# log_odds = log(p/(1-p)) = log(p) - log(1-p)
- # Compute entirely in log-space to avoid exp() underflow:
- # log(p) = avg_logps (already in log-space)
- # log(1-p) = log1p(-exp(avg_logps)) (numerically stable via log1p)
- log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps))
- log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps))
+ # Compute log(1-p) = log(1 - exp(avg_logp)) numerically stably:
+ # - For x > -log(2): log(-expm1(x)) (avoids log(0) when p → 1)
+ # - For x ≤ -log(2): log1p(-exp(x)) (avoids cancellation when p → 0)
+ # ``avg_logp ∈ (-∞, 0]`` so the threshold partitions the safe regime.
+ log_two = math.log(2.0)
+
+ def _log1mexp(x: 'torch.Tensor') -> 'torch.Tensor':
+ # Clamp at a tiny negative to keep both branches well-defined when p≈1.
+ x_safe = torch.clamp(x, max=-1e-7)
+ return torch.where(
+ x_safe > -log_two,
+ torch.log(-torch.expm1(x_safe)),
+ torch.log1p(-torch.exp(x_safe)),
+ )
+
+ log_odds_chosen = chosen_avg_logps - _log1mexp(chosen_avg_logps)
+ log_odds_rejected = rejected_avg_logps - _log1mexp(rejected_avg_logps)
# ORPO odds ratio loss
odds_ratio = log_odds_chosen - log_odds_rejected
diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py
index 3f7db4bfb..7c198ad02 100644
--- a/src/twinkle/loss/gkd.py
+++ b/src/twinkle/loss/gkd.py
@@ -41,6 +41,10 @@ def __init__(
chunk_size: int = 512,
**kwargs,
):
+ if not (0.0 <= beta <= 1.0):
+ raise ValueError(f'beta must be in [0, 1], got {beta}')
+ if temperature <= 0:
+ raise ValueError(f'temperature must be > 0, got {temperature}')
self.beta = beta
self.temperature = temperature
self.ignore_index = ignore_index
@@ -94,6 +98,7 @@ def __call__(
labels=labels,
beta=self.beta,
temperature=self.temperature,
+ ignore_index=self.ignore_index,
chunk_size=self.chunk_size,
topk=topk,
teacher_topk_logprobs=teacher_topk_logprobs,
@@ -108,6 +113,7 @@ def _generalized_jsd_loss(
labels=None,
beta: float = 0.5,
temperature: float = 1.0,
+ ignore_index: int = -100,
chunk_size: int = 512,
topk: Optional[int] = None,
teacher_topk_logprobs=None,
@@ -164,7 +170,7 @@ def _generalized_jsd_loss(
# ── Mask valid (response) tokens ──────────────────────────────────────
if labels is not None:
- mask = labels != -100 # ignore_index is always -100 per convention
+ mask = labels != ignore_index
# Vocab-size mismatch (e.g. Qwen2.5-VL-3B vs 7B): pad the smaller side
# so both distributions are defined over the same token set.
stu_dim = student_logits.shape[-1]
@@ -178,12 +184,15 @@ def _generalized_jsd_loss(
student_logits = student_logits[mask] # [num_valid, vocab/topk]
teacher_logits = teacher_logits[mask]
num_valid = mask.sum()
+ # ``[mask]`` already created fresh storage, so in-place divide is safe
+ # and avoids an extra [num_valid, V] allocation.
+ student_logits.div_(temperature)
+ teacher_logits.div_(temperature)
else:
- student_logits = student_logits.view(-1, student_logits.size(-1))
- teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1))
+ # Keep logits, may be an infer scenario
+ student_logits = student_logits.reshape(-1, student_logits.size(-1)) / temperature
+ teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) / temperature
num_valid = student_logits.size(0)
- student_logits.div_(temperature)
- teacher_logits.div_(temperature)
if num_valid == 0:
return student_logits.new_zeros(())
diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py
index 4bb71216c..781b22060 100644
--- a/src/twinkle/loss/grpo.py
+++ b/src/twinkle/loss/grpo.py
@@ -42,18 +42,6 @@ def __init__(
self.require_entropy = entropy_coef > 0.0
self.ignore_index = ignore_index
- def _compute_loss_mask(self, labels: 'torch.Tensor') -> 'torch.Tensor':
- """
- Compute loss mask from labels.
-
- Args:
- labels: [batch, seq_len] target token ids, -100 for ignored positions
-
- Returns:
- mask: [batch, seq_len] float tensor, 1.0 for valid positions, 0.0 for ignored
- """
- return (labels != self.ignore_index).float()
-
def _compute_log_importance_weights(
self,
per_token_logps: 'torch.Tensor',
@@ -165,10 +153,13 @@ def _pad_and_align_to_batch(
return data # Already aligned
if data.dim() == 1:
data = data.unsqueeze(1)
- if data.shape[1] == 1: # Scalars
- result = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device)
- result[mask] = data[mask.any(dim=1).nonzero(as_tuple=True)[0].repeat_interleave(mask.sum(dim=1)), 0]
- return result
+ if data.shape[1] == 1:
+ assert data.shape[0] == batch_size, (
+ f'scalar broadcast expects data.shape[0]==batch_size, '
+ f'got data.shape={tuple(data.shape)} mask.shape={(batch_size, seq_len)}')
+ fill = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device)
+ expanded = data.expand(batch_size, seq_len)
+ return torch.where(mask, expanded, fill)
data = [data[i] for i in range(batch_size)] # To list
# Handle list (scalars or sequences)
@@ -276,10 +267,12 @@ def __call__(
)
# GRPO loss is ill-defined without advantages (e.g. ref-logps-only forward,
- # or eval/validation forwards). Return a zero loss so the forward still
- # flows through cleanly and callers can harvest outputs['logps'] freely.
+ # or eval/validation forwards). Return a zero loss that still flows through
+ # autograd so DDP/FSDP do not see unused params, and callers can harvest
+ # outputs['logps'] freely.
if advantages is None:
- return LossOutput(loss=torch.zeros((), device=device, dtype=logps.dtype), num_tokens=0)
+ zero = logps.sum() * 0.0
+ return LossOutput(loss=zero, num_tokens=0)
advantages = self._pad_and_align_to_batch(
advantages,
diff --git a/src/twinkle/loss/infonce.py b/src/twinkle/loss/infonce.py
index 68d14840c..c356bd64c 100644
--- a/src/twinkle/loss/infonce.py
+++ b/src/twinkle/loss/infonce.py
@@ -13,8 +13,6 @@
import numpy as np
import torch
import torch.distributed as dist
-import torch.nn.functional as F
-from enum import Enum
from torch import nn
from typing import Optional
@@ -22,15 +20,6 @@
from .base import Loss
-# Borrowed from sentence_transformers.
-class SiameseDistanceMetric(Enum):
- """Distance metrics available to the pairwise contrastive losses."""
-
- EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa
- MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa
- COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa
-
-
def _extract_sentences(outputs) -> torch.Tensor:
"""Return [B, D] sentence embeddings from postprocess_tensor_sp output.
@@ -119,6 +108,11 @@ def __init__(
process_group=None,
**kwargs,
):
+ if mask_fake_negative and fake_neg_margin <= 0:
+ raise ValueError(
+ f'fake_neg_margin must be > 0 when mask_fake_negative=True, got {fake_neg_margin}. '
+ 'A non-positive margin would mask out the positive itself or every above-positive '
+ 'logit indiscriminately, collapsing the contrastive signal.')
self.temperature = temperature
self.use_batch = use_batch
self.hard_negatives = hard_negatives
@@ -129,7 +123,13 @@ def __init__(
self.process_group = process_group
def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor):
- """All-gather embeddings & labels across DP ranks; only local shard keeps grad."""
+ """All-gather embeddings & labels across DP ranks; only local shard keeps grad.
+
+ NCCL ``all_gather`` requires every rank to send the *same* tensor size. Under
+ ``slice_dp`` dispatch the per-rank batch is uneven (``divmod`` splits), so we
+ pad each rank to the global max along dim-0, do an equal-sized all_gather,
+ then strip padding back. Only the local shard retains gradients.
+ """
if not (dist.is_available() and dist.is_initialized()):
return sentences, labels
world_size = dist.get_world_size(group=self.process_group)
@@ -137,24 +137,40 @@ def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor):
return sentences, labels
rank = dist.get_rank(group=self.process_group)
- # variable per-rank shapes require communicating shape first
- local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long)
- shapes = [torch.empty_like(local_shape) for _ in range(world_size)]
- dist.all_gather(shapes, local_shape, group=self.process_group)
- all_sentences = [sentences.new_empty(shape.tolist()) for shape in shapes]
- dist.all_gather(all_sentences, sentences.contiguous(), group=self.process_group)
-
- local_label_shape = labels.new_tensor(labels.shape, dtype=torch.long)
- label_shapes = [torch.empty_like(local_label_shape) for _ in range(world_size)]
- dist.all_gather(label_shapes, local_label_shape, group=self.process_group)
- all_labels = [labels.new_empty(shape.tolist()) for shape in label_shapes]
- dist.all_gather(all_labels, labels.contiguous(), group=self.process_group)
-
- # keep the local shard differentiable; detach others
- all_sentences[rank] = sentences
+ # ``labels`` is a 1-D mask aligned to ``sentences`` along dim-0, so they
+ # share the same per-rank size. Gather sizes once and reuse for both.
+ assert sentences.shape[0] == labels.shape[0], (
+ f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}')
+ local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long)
+ sizes = [torch.empty_like(local_n) for _ in range(world_size)]
+ dist.all_gather(sizes, local_n, group=self.process_group)
+ sizes_int = [int(s.item()) for s in sizes]
+ max_n = max(sizes_int)
+
+ def _pad_gather(tensor: torch.Tensor):
+ if tensor.shape[0] < max_n:
+ pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:])
+ padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0)
+ else:
+ padded = tensor
+ buffers = [torch.empty_like(padded) for _ in range(world_size)]
+ dist.all_gather(buffers, padded.contiguous(), group=self.process_group)
+ return buffers
+
+ sent_buffers = _pad_gather(sentences)
+ label_buffers = _pad_gather(labels)
+
+ # Strip padding; keep local shard differentiable, detach others.
+ all_sentences = []
+ all_labels = []
for idx in range(world_size):
- if idx != rank:
- all_sentences[idx] = all_sentences[idx].detach()
+ n = sizes_int[idx]
+ if idx == rank:
+ all_sentences.append(sentences)
+ all_labels.append(labels)
+ else:
+ all_sentences.append(sent_buffers[idx][:n].detach())
+ all_labels.append(label_buffers[idx][:n])
return torch.cat(all_sentences, dim=0), torch.cat(all_labels, dim=0)
def __call__(self, inputs, outputs, **kwargs) -> LossOutput:
diff --git a/src/twinkle/metric/accuracy.py b/src/twinkle/metric/accuracy.py
index b3034c57a..4dfb01198 100644
--- a/src/twinkle/metric/accuracy.py
+++ b/src/twinkle/metric/accuracy.py
@@ -1,5 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import numpy as np
from typing import List, Union
from ..data_format import InputFeature, ModelOutput
diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py
index b203d255e..024cb0473 100644
--- a/src/twinkle/metric/dpo.py
+++ b/src/twinkle/metric/dpo.py
@@ -131,6 +131,12 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
ref_outputs = kwargs.get('ref_outputs')
if ref_outputs is not None:
ref_logps = ref_outputs.get('logps')
+ if ref_logps is not None:
+ if isinstance(ref_logps, list):
+ if len(ref_logps) == 0:
+ ref_logps = None
+ else:
+ ref_logps = pad_and_stack_tensors(ref_logps)
if ref_logps is not None:
# Align ref_logps to match labels shape (handles different seq lengths)
ref_logps = self._align_logps(ref_logps, labels.shape, labels.device, logps.dtype)
diff --git a/src/twinkle/metric/embedding.py b/src/twinkle/metric/embedding.py
index 9fb3aed8c..8b3681031 100644
--- a/src/twinkle/metric/embedding.py
+++ b/src/twinkle/metric/embedding.py
@@ -1,7 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import torch
-import torch.distributed as dist
-import torch.nn.functional as F
from typing import List, Union
from twinkle.data_format import InputFeature, ModelOutput
@@ -32,6 +29,9 @@ def reset(self):
self.grad_norm = 0.0
def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs):
+ import torch
+ import torch.distributed as dist
+ import torch.nn.functional as F
sentences = outputs.get('embeddings')
if sentences is None:
sentences = outputs.get('logits')
@@ -44,22 +44,34 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
inputs = [inputs]
labels = torch.cat([inp['labels'].view(-1) for inp in inputs], dim=0)
- # Gather embeddings and labels across DP for in-batch stats
+ # Gather embeddings and labels across DP for in-batch stats.
+ # NCCL ``all_gather`` requires every rank to send the same tensor size,
+ # but ``slice_dp`` dispatch (``divmod`` split) can leave per-rank dim-0
+ # uneven. Pad to the global max along dim-0, gather, then strip padding.
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
world_size = dist.get_world_size()
- local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long)
- shapes = [torch.empty_like(local_shape) for _ in range(world_size)]
- dist.all_gather(shapes, local_shape)
- all_sentences = [sentences.new_empty(s.tolist()) for s in shapes]
- dist.all_gather(all_sentences, sentences.contiguous())
- sentences = torch.cat(all_sentences, dim=0)
-
- local_lshape = labels.new_tensor(labels.shape, dtype=torch.long)
- lshapes = [torch.empty_like(local_lshape) for _ in range(world_size)]
- dist.all_gather(lshapes, local_lshape)
- all_labels = [labels.new_empty(s.tolist()) for s in lshapes]
- dist.all_gather(all_labels, labels.contiguous())
- labels = torch.cat(all_labels, dim=0)
+ assert sentences.shape[0] == labels.shape[0], (
+ f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}')
+ local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long)
+ sizes = [torch.empty_like(local_n) for _ in range(world_size)]
+ dist.all_gather(sizes, local_n)
+ sizes_int = [int(s.item()) for s in sizes]
+ max_n = max(sizes_int)
+
+ def _pad_gather(tensor: 'torch.Tensor') -> 'List[torch.Tensor]':
+ if tensor.shape[0] < max_n:
+ pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:])
+ padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0)
+ else:
+ padded = tensor
+ buffers = [torch.empty_like(padded) for _ in range(world_size)]
+ dist.all_gather(buffers, padded.contiguous())
+ return buffers
+
+ sent_buffers = _pad_gather(sentences)
+ label_buffers = _pad_gather(labels)
+ sentences = torch.cat([sent_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0)
+ labels = torch.cat([label_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0)
anchor_idx = torch.nonzero(labels, as_tuple=False).squeeze(-1)
if anchor_idx.numel() == 0:
diff --git a/src/twinkle/metric/grpo.py b/src/twinkle/metric/grpo.py
index 06e082eeb..e2797b1ec 100644
--- a/src/twinkle/metric/grpo.py
+++ b/src/twinkle/metric/grpo.py
@@ -3,9 +3,12 @@
from typing import Any, Dict, List, Optional, Union
from twinkle.data_format import InputFeature, ModelOutput
+from twinkle.utils import get_logger
from twinkle.utils.transformers_utils import align_logps_to_mask
from .base import Metric
+logger = get_logger()
+
class GRPOMetric(Metric):
@@ -254,6 +257,11 @@ def accumulate(
if len(seq_lens) == 1:
merged = torch.cat(label_tensors, dim=0)
inputs_list = [{'labels': merged}]
+ else:
+ logger.warning(
+ f'GRPOMetric: logps is a single tensor but inputs_list has '
+ f'{len(inputs_list)} mb with mismatched seq_lens={sorted(seq_lens)}. '
+ f'Only mb[0] will be accumulated; check the model forward path.')
flat_old: Optional[List] = None
if old_logps is not None and isinstance(old_logps, (list, tuple)):
@@ -284,7 +292,17 @@ def accumulate(
# Uncommon: aligned global tensor. Only honour when it
# exactly matches the single-mb shape; otherwise drop.
import torch as _torch # noqa: F811
- old_slice = old_logps if (_torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape) else None
+ if _torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape:
+ old_slice = old_logps
+ else:
+ if mb_idx == 0:
+ # Warn once per accumulate call (not per mb) to avoid log spam.
+ old_shape = tuple(old_logps.shape) if _torch.is_tensor(old_logps) else 'unknown'
+ logger.warning(
+ f'GRPOMetric: old_logps shape {old_shape} does not match '
+ f'logps_mb shape {tuple(logps_mb.shape)}; ratio/kl metrics will '
+ f'be skipped for this step.')
+ old_slice = None
else:
old_slice = None
diff --git a/src/twinkle/metric/train_metric.py b/src/twinkle/metric/train_metric.py
index da82a8783..8d785c38b 100644
--- a/src/twinkle/metric/train_metric.py
+++ b/src/twinkle/metric/train_metric.py
@@ -2,7 +2,7 @@
import time
from typing import List, Union
-from ..data_format import InputFeature, ModelOutput
+from twinkle.data_format import InputFeature, ModelOutput
from .base import Metric
diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py
index a4d4ea064..8ea00d696 100644
--- a/src/twinkle/model/base.py
+++ b/src/twinkle/model/base.py
@@ -1,8 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
from abc import ABC, abstractmethod
-from datetime import timedelta
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
from twinkle import Platform, torch_util
from twinkle.data_format import InputFeature, ModelOutput
diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py
index a5ea3fc56..60b45f774 100644
--- a/src/twinkle/model/megatron/megatron.py
+++ b/src/twinkle/model/megatron/megatron.py
@@ -420,7 +420,7 @@ def forward_step_func(data_iterator, model):
embeddings = output_tensor
elif labels is not None and is_last_pp:
_loss_require_logps = getattr(_loss_instance, 'require_logps', True)
- _loss_require_entropy = (hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy)
+ _loss_require_entropy = getattr(_loss_instance, 'require_entropy', True)
_packed = batch.get('packed_seq_params')
cu_seqlens_q = getattr(_packed, 'cu_seqlens_q', None) if _packed is not None else None
if _loss_require_logps:
@@ -446,7 +446,7 @@ def forward_step_func(data_iterator, model):
_outputs = {'logps': logps}
if entropies is not None:
_outputs['entropies'] = entropies
- if hasattr(_loss_instance, 'require_logits') and _loss_instance.require_logits:
+ if getattr(_loss_instance, 'require_logits', False):
_outputs['logits'] = output_tensor
batch, _outputs = processor.unpack_packed_sequences(batch, _outputs)
logps = _outputs['logps']
@@ -990,7 +990,9 @@ def _get_rng_state() -> 'ShardedObject':
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
- 'cuda_rng_state': torch.cuda.get_rng_state(),
+ # Backend-agnostic device RNG (CUDA / NPU / MPS); key kept as
+ # 'cuda_rng_state' for backward compatibility with existing checkpoints.
+ 'cuda_rng_state': Platform.get_device_rng_state(),
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states(),
}
rng_state_list = [rng_state]
@@ -1112,7 +1114,7 @@ def _save_mcore_optimizer(
with open(tracker_path, 'w') as f:
f.write(str(iteration))
- logging.getLogger(__name__).info(f'Saved mcore optimizer state at iteration {iteration} '
+ logger.info(f'Saved mcore optimizer state at iteration {iteration} '
f'to {checkpoint_dir}')
def _load_mcore_optimizer(
@@ -1139,7 +1141,7 @@ def _load_mcore_optimizer(
)
iteration = self._read_iteration(tracker_path)
if iteration == 0:
- logging.getLogger(__name__).warning(f'No checkpoint found in {checkpoint_dir}')
+ logger.warning(f'No checkpoint found in {checkpoint_dir}')
return
iter_dir = os.path.join(checkpoint_dir, f'iter_{iteration:07d}')
@@ -1201,7 +1203,9 @@ def _load_mcore_optimizer(
random.setstate(rng['random_rng_state'])
np.random.set_state(rng['np_rng_state'])
torch.set_rng_state(rng['torch_rng_state'])
- torch.cuda.set_rng_state(rng['cuda_rng_state'])
+ # Backend-agnostic restore: tolerates ckpt produced on different backend
+ # (returns None) and avoids hard-coded torch.cuda which crashes on NPU.
+ Platform.set_device_rng_state(rng.get('cuda_rng_state'))
tensor_parallel.get_cuda_rng_tracker().set_states(rng['rng_tracker_states'], )
# Restore iteration counter.
@@ -1211,26 +1215,26 @@ def _load_mcore_optimizer(
if dist.is_initialized():
dist.barrier()
- logging.getLogger(__name__).info(f'Resumed from mcore checkpoint at iteration {iteration} '
+ logger.info(f'Resumed from mcore checkpoint at iteration {iteration} '
f'from {checkpoint_dir}')
@staticmethod
def _read_iteration(tracker_path: str) -> int:
- if not os.path.exists(tracker_path):
- return 0
- with open(tracker_path) as f:
- iteration = int(f.read().strip())
+ # All ranks must enter the all_reduce together; missing tracker on some
+ # ranks (e.g. NFS lag, partial mount) must NOT short-circuit, otherwise
+ # the remaining ranks hang at the collective. Treat missing as 0 and
+ # let MAX reduction recover the canonical iteration from any rank that
+ # successfully read the file.
+ iteration = 0
+ if os.path.exists(tracker_path):
+ with open(tracker_path) as f:
+ iteration = int(f.read().strip())
if torch.distributed.is_initialized():
- iters_cuda = torch.tensor(
- [iteration],
- dtype=torch.long,
- device='cuda',
- )
- torch.distributed.all_reduce(
- iters_cuda,
- op=torch.distributed.ReduceOp.MAX,
- )
- iteration = iters_cuda[0].item()
+ # Use Platform.get_local_device() to stay backend-agnostic
+ # (CUDA / NPU / MPS); 'cuda' would crash on NPU.
+ iters_dev = torch.tensor([iteration], dtype=torch.long, device=Platform.get_local_device())
+ torch.distributed.all_reduce(iters_dev, op=torch.distributed.ReduceOp.MAX)
+ iteration = int(iters_dev[0].item())
return iteration
def _merge_lora_adapters(self, adapter_name: str = 'default'):
@@ -1256,7 +1260,7 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non
For distributed training:
- All PP ranks participate in export (each has different layers)
- - Only DP rank 0 actually writes to disk
+ - Only global rank 0 actually writes shared config files
- Uses barrier for synchronization
For LoRA training:
@@ -1264,12 +1268,9 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non
"""
# Check if this is LoRA training
is_peft_format = (adapter_name != _default_adapter_name)
+ is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0
- # Create output directory on rank 0 only
- from megatron.core import parallel_state as mpu
- dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0
-
- if dp_rank == 0:
+ if is_global_zero:
os.makedirs(output_dir, exist_ok=True)
# Synchronize before saving
@@ -1281,8 +1282,8 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non
self.strategy.bridge.save_weights(
model, output_dir, peft_format=is_peft_format, adapter_name=adapter_name, converter=lora_converter)
- # Save config on rank 0 only
- if dp_rank == 0:
+ # Save config on global rank 0 only (avoid concurrent writers).
+ if is_global_zero:
self.hf_config.save_pretrained(output_dir)
if isinstance(model[0], PeftModel):
config = model[0].peft_config[adapter_name]
@@ -1291,11 +1292,13 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non
model[0].peft_config[adapter_name].save_pretrained(output_dir)
config.target_modules = target_modules
+ if dist.is_initialized():
+ dist.barrier()
+
def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None):
"""Save in Megatron checkpoint format."""
+ is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0
os.makedirs(output_dir, exist_ok=True)
- from megatron.core import parallel_state as mpu
- dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0
state_dict = self._get_trainable_parameters(adapter_name)
cpu_state_dict = {}
for k, v in state_dict.items():
@@ -1311,13 +1314,18 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert
rank = dist.get_rank() if dist.is_initialized() else 0
checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt')
torch.save(cpu_state_dict, checkpoint_path)
- # Save config on rank 0 only
+ # Save shared config on global rank 0 only (avoid concurrent writers).
model = self.strategy.unwrap_model(self.model)
- if dp_rank == 0:
+ if is_global_zero:
self.hf_config.save_pretrained(output_dir)
if isinstance(model[0], PeftModel):
model[0].peft_config[adapter_name].save_pretrained(output_dir)
+ # Finalize barrier: ensure all ranks finish writing model_rank*.pt
+ # before the caller proceeds (e.g. uploading / loading the ckpt).
+ if dist.is_initialized():
+ dist.barrier()
+
def _save_tokenizer(self, output_dir: str, **kwargs):
from twinkle.utils import is_last_rank
if not is_last_rank():
diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py
index 2dd6b7a53..78981b888 100644
--- a/src/twinkle/model/megatron/multi_lora_megatron.py
+++ b/src/twinkle/model/megatron/multi_lora_megatron.py
@@ -15,7 +15,7 @@
from transformers import AutoConfig, PretrainedConfig
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
-from twinkle import DeviceMesh, remote_class, remote_function, requires, template, torch_util
+from twinkle import DeviceMesh, Platform, remote_class, remote_function, requires, template, torch_util
from twinkle.data_format import InputFeature, Trajectory
from twinkle.hub import HubOperation
from twinkle.infra import collect_tensor_dict
@@ -26,6 +26,9 @@
from ._mindspeed_runtime import ensure_mindspeed_adaptor_patched
from .megatron import MegatronModel
from .strategy import MegatronStrategy
+from twinkle.utils import get_logger
+
+logger = get_logger()
@remote_class(execute='all')
@@ -221,8 +224,11 @@ def _save_local_training_rng_state():
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
}
- if torch.cuda.is_available():
- rng_state['cuda_rng_state'] = torch.cuda.get_rng_state()
+ # Backend-agnostic device RNG capture (CUDA / NPU / MPS). Key is kept as
+ # 'cuda_rng_state' for backward compatibility with existing checkpoints.
+ device_rng = Platform.get_device_rng_state()
+ if device_rng is not None:
+ rng_state['cuda_rng_state'] = device_rng
rng_state['rng_tracker_states'] = tensor_parallel.get_cuda_rng_tracker().get_states()
return rng_state
@@ -233,8 +239,10 @@ def _load_local_training_rng_state(rng_state):
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
- if 'cuda_rng_state' in rng_state and torch.cuda.is_available():
- torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
+ # Backend-agnostic device RNG restore: tolerates ckpt produced on different
+ # backend (key absent or None) and avoids hard-coded torch.cuda on NPU.
+ if 'cuda_rng_state' in rng_state:
+ Platform.set_device_rng_state(rng_state['cuda_rng_state'])
tensor_parallel.get_cuda_rng_tracker().set_states(rng_state['rng_tracker_states'])
def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kwargs):
@@ -251,19 +259,35 @@ def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kw
torch.save(state_dict, self._rank_local_optimizer_path(checkpoint_dir))
+ if dist.is_initialized():
+ dist.barrier()
+
def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '', **kwargs):
no_load_optim = kwargs.pop('no_load_optim', False)
- no_load_rng = kwargs.pop('no_load_rng', False)
+ no_load_rng = kwargs.pop('no_load_rng', True)
optimizer_config = self.optimizer_group.get(adapter_name)
state_dict = torch.load(self._rank_local_optimizer_path(checkpoint_dir), map_location='cpu', weights_only=False)
if not no_load_optim and optimizer_config is not None:
if optimizer_config.optimizer is not None and 'optimizer' in state_dict:
optimizer_config.optimizer.load_state_dict(state_dict['optimizer'])
+ device = Platform.get_local_device()
+ for group_state in optimizer_config.optimizer.state.values():
+ if not isinstance(group_state, dict):
+ continue
+ for k, v in group_state.items():
+ if isinstance(v, torch.Tensor):
+ group_state[k] = v.to(device)
if optimizer_config.lr_scheduler is not None and 'opt_param_scheduler' in state_dict:
optimizer_config.lr_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
+ # RNG state is intentionally not restored in multi-tenant mode:
+ # restoring the global RNG would silently affect other active tenants'
+ # dropout / initialization behaviour.
if not no_load_rng and 'rng_state' in state_dict:
- self._load_local_training_rng_state(state_dict['rng_state'])
+ logger.warning(
+ 'Skipping RNG state restoration in multi-tenant mode. '
+ 'Global RNG is shared across tenants; restoring it would '
+ 'affect other active adapters.')
if optimizer_config is not None and 'iteration' in state_dict:
optimizer_config.cur_step = state_dict['iteration']
@@ -354,6 +378,11 @@ def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **k
self._check_adapter_valid(adapter_name)
trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
+ if not os.path.isfile(trainer_state_path):
+ raise FileNotFoundError(
+ f'trainer_state.json not found in {checkpoint_dir}. '
+ f'Ensure the checkpoint was saved with save_optimizer=True.')
+
with open(trainer_state_path) as f:
trainer_state = json.load(f)
diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py
index 819014eb8..1bd809025 100644
--- a/src/twinkle/model/megatron/strategy/megatron.py
+++ b/src/twinkle/model/megatron/strategy/megatron.py
@@ -48,7 +48,6 @@ def __init__(
ddp_config: Dict[str, Any] = None,
**kwargs,
):
- import torch.distributed as dist
from megatron.core import mpu
self.device_mesh = device_mesh
self.use_distributed_optimizer = use_distributed_optimizer
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index d0434991f..018f4c494 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -34,10 +34,8 @@ def __init__(
parallelism_config = self._parallelism_config_from_device_mesh(device_mesh)
fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init)
- kwargs_handlers = []
- kwargs_handlers.append(
- InitProcessGroupKwargs(
- timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200')))))
+ kwargs_handlers = [InitProcessGroupKwargs(
+ timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))]
if ddp_config is not None:
from accelerate import DistributedDataParallelKwargs
ddp_config = DistributedDataParallelKwargs(**ddp_config)
@@ -131,8 +129,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di
return fsdp_plugin
def wrap_model(self, model, *args):
- result = self.accelerator.prepare(model, *args)
- return result
+ return self.accelerator.prepare(model, *args)
def unwrap_model(self, model):
return self.accelerator.unwrap_model(model, keep_torch_compile=False)
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 61733d7dc..cac375263 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -414,8 +414,8 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec
inputs = optimizer_config.template.batch_encode(inputs) # noqa
processor: InputProcessor = optimizer_config.processor
loss_instance = optimizer_config.loss_instance
- loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits)
- loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy)
+ loss_require_logits = getattr(loss_instance, 'require_logits', False)
+ loss_require_entropy = getattr(loss_instance, 'require_entropy', False)
loss_require_logps = getattr(loss_instance, 'require_logps', True)
assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding'
inputs: Dict[str, Any] = processor(
@@ -490,8 +490,8 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T
processor: InputProcessor = optimizer_config.processor
assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding'
loss_instance = optimizer_config.loss_instance
- loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits)
- loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy)
+ loss_require_logits = getattr(loss_instance, 'require_logits', False)
+ loss_require_entropy = getattr(loss_instance, 'require_entropy', False)
loss_require_logps = getattr(loss_instance, 'require_logps', True)
inputs: Dict[str, Any] = processor(
inputs,
@@ -929,7 +929,6 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int
if optimizer_config.cur_step % interval != 0:
return
model = self.strategy.unwrap_model(self.model)
- processed_state_dict = {}
save_kwargs = {}
if adapter_name == _default_adapter_name:
# Full model save
diff --git a/src/twinkle/notifier/__init__.py b/src/twinkle/notifier/__init__.py
index 329cb6f1d..067db71a1 100644
--- a/src/twinkle/notifier/__init__.py
+++ b/src/twinkle/notifier/__init__.py
@@ -1,2 +1,3 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
from .base import Notifier, notify_exception
from .ding_notifier import DingNotifier
diff --git a/src/twinkle/notifier/base.py b/src/twinkle/notifier/base.py
index a83903b53..6f50ca659 100644
--- a/src/twinkle/notifier/base.py
+++ b/src/twinkle/notifier/base.py
@@ -1,3 +1,4 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
import os
from typing import Dict, Optional
@@ -66,7 +67,7 @@ def notify_exception(notifier: Notifier, context: str, exc: BaseException, name:
if not _try_claim_notify_slot(exc, context, name):
try:
setattr(exc, '_twinkle_notified', True)
- except Exception: # noqa: BLE001
+ except Exception: # noqa
pass
return
diff --git a/src/twinkle/notifier/ding_notifier.py b/src/twinkle/notifier/ding_notifier.py
index fe102d8a5..fc535edd7 100644
--- a/src/twinkle/notifier/ding_notifier.py
+++ b/src/twinkle/notifier/ding_notifier.py
@@ -1,3 +1,4 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
import base64
import hashlib
import hmac
diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py
index 97065fba1..39d3257be 100644
--- a/src/twinkle/preprocessor/llm.py
+++ b/src/twinkle/preprocessor/llm.py
@@ -48,9 +48,9 @@ def preprocess(self, row) -> Trajectory:
class SelfCognitionProcessor(Preprocessor):
- def __init__(self, model_name, model_author):
- self.model_name = model_name
- self.model_author = model_author
+ def __init__(self, model_name=None, model_author=None):
+ self.model_name = model_name or 'twinkle robot'
+ self.model_author = model_author or 'twinkle lab'
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
rows = self.map_col_to_row(rows)
diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py
index 8709d98ab..14a0e206b 100644
--- a/src/twinkle/processor/base.py
+++ b/src/twinkle/processor/base.py
@@ -42,7 +42,6 @@ class InputProcessor:
'video_grid_thw': 0,
'input_features': 0.0,
'feature_attention_mask': 0,
- 'mm_token_type_ids': 0,
}
# VLM fields to concatenate (not pad) in batch
@@ -108,8 +107,12 @@ def to_tensor(_input):
# so tensor ops like labels != ignore_index or .to(device) would fail without this.
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
- elif (isinstance(value, list) and isinstance(value[0],
- (int, float, np.number))) or key == 'position_ids':
+ elif isinstance(value, list) and len(value) > 0 and isinstance(
+ value[0], (int, float, np.number)):
+ value = torch.tensor(value)
+ elif key == 'position_ids' and not isinstance(value, torch.Tensor):
+ if value is None:
+ continue
value = torch.tensor(value)
elif (isinstance(value, list)) and key in ('completion_mask', 'mm_token_type_ids'):
value = torch.tensor(value)
@@ -284,7 +287,9 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso
return input_tensor
if cp_size > 1:
- position_ids_f = position_ids.flatten()
+ pos_for_cu = position_ids[:1] if position_ids.dim() >= 2 and position_ids.shape[0] > 1 \
+ else position_ids
+ position_ids_f = pos_for_cu.flatten()
indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32)
cu_seqlens = torch.cat([
indices_q[position_ids_f == 0],
@@ -354,8 +359,11 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di
view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] //
(2 * cp_size), *inputs.shape[dim + 1:])
val = val.view(view_shape)
- index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu',
- pin_memory=True).cuda(non_blocking=True)
+ index = torch.tensor(
+ [cp_rank, (2 * cp_size - cp_rank - 1)],
+ device=inputs.device,
+ dtype=torch.long,
+ )
val = val.index_select(dim, index)
view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:])
new_inputs.append(val.view(view_shape))
@@ -402,17 +410,18 @@ def prepare_transformers_padding_free_patch(self, inputs: List[InputFeature], **
if not padding_free or bool(kwargs.get('enable_sp', False)):
return inputs
- from twinkle.patch import apply_patch
- from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch
-
- apply_patch(
- model,
- GatedDeltaNetPaddingFreePatch,
- hf_config=kwargs.get('hf_config'),
- enable_sp=False,
- )
if not getattr(model, '_twinkle_gdn_padding_free_patched', False):
- return inputs
+ from twinkle.patch import apply_patch
+ from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch
+
+ apply_patch(
+ model,
+ GatedDeltaNetPaddingFreePatch,
+ hf_config=kwargs.get('hf_config'),
+ enable_sp=False,
+ )
+ if not getattr(model, '_twinkle_gdn_padding_free_patched', False):
+ return inputs
for _inp in inputs:
position_ids = _inp.get('position_ids')
@@ -631,15 +640,27 @@ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeat
output = {}
_keys = [
'input_ids',
- 'input_embeddings',
+ 'inputs_embeds',
'attention_mask',
'position_ids',
'labels',
'completion_mask',
+ 'cu_seq_lens_q',
+ 'cu_seq_lens_k',
+ 'cu_seqlens_q',
+ 'cu_seqlens_kv',
+ 'max_length_q',
+ 'max_length_k',
+ 'packed_seq_params',
] + list(InputProcessor.VLM_CONCAT_FIELDS)
for key in list(_input.keys()):
- if key in _keys:
- output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key]
+ if key not in _keys:
+ continue
+ value = _input[key]
+ if isinstance(value, torch.Tensor) or not isinstance(value, (list, np.ndarray)):
+ output[key] = value
+ else:
+ output[key] = np.array(value)
results.append(InputFeature(**output))
return results
@@ -694,7 +715,8 @@ def is_mm_position_ids(position_ids):
result[key] = self._create_4d_attention_mask(values)
elif key == 'position_ids' and is_mm_position_ids(values[0]):
result[key] = InputProcessor._pad_sequence(values, self.padding_map[key], self.padding_side)
- result[key] = result[key].reshape(values[0].shape[0], len(values), -1)
+ num_axes = values[0].shape[0]
+ result[key] = result[key].reshape(len(values), num_axes, -1).permute(1, 0, 2).contiguous()
elif isinstance(values[0], torch.Tensor):
result[key] = InputProcessor._pad_sequence(values, self.padding_map[key], self.padding_side)
if result[key].dim() == 1:
@@ -776,6 +798,5 @@ def postprocess_tensor_cp(self, tensor, cu_seqlens=None):
if self.device_mesh.cp_world_size <= 1:
return tensor
from megatron.core import parallel_state as mpu
-
from twinkle.utils.torch_utils import gather_cp_load_balanced
return gather_cp_load_balanced(tensor, mpu.get_context_parallel_group(), seq_dim=1, cu_seqlens=cu_seqlens)
diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
index 32dc1ca50..0433ef5a8 100644
--- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
+++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
@@ -1,24 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-"""vLLM-based sampler using VLLMEngine (AsyncLLM).
-
-Device Configuration:
- vLLMSampler automatically detects the number of available GPUs from
- CUDA_VISIBLE_DEVICES environment variable (set by twinkle's ResourceManager)
- and configures vLLM's tensor_parallel_size accordingly.
-
- To use tensor parallelism, configure DeviceGroup with gpus_per_worker > 1:
-
- # DP2 with TP2 (4 GPUs total, 2 workers, each with 2 GPUs)
- DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=2)
-
- # TP4 (4 GPUs, 1 worker with all 4 GPUs)
- DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=4)
-
-Data Flow:
- When multiple vLLMSampler workers exist (DP > 1):
- - Data is dispatched via dispatch='slice_dp' (each worker gets a slice)
- - Results are collected via collect='flatten' (merged into single list)
-"""
import asyncio
import atexit
import numpy as np
diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py
index 326a6916f..be24c7824 100644
--- a/src/twinkle/server/state/backend/factory.py
+++ b/src/twinkle/server/state/backend/factory.py
@@ -1,12 +1,11 @@
"""Backend factory for creating StateBackend instances based on configuration."""
from __future__ import annotations
-import logging
-
from twinkle.server.config.persistence import PersistenceConfig
+from twinkle.utils import get_logger
from .base import StateBackend
-logger = logging.getLogger(__name__)
+logger = get_logger()
def create_backend(config: PersistenceConfig | None = None) -> StateBackend:
diff --git a/src/twinkle/server/state/base.py b/src/twinkle/server/state/base.py
index d931ce336..8cb055ae4 100644
--- a/src/twinkle/server/state/base.py
+++ b/src/twinkle/server/state/base.py
@@ -1,7 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from __future__ import annotations
-import logging
import time
from abc import ABC, abstractmethod
from datetime import datetime, timezone
@@ -9,9 +8,10 @@
from typing import Generic, TypeVar
from twinkle.server.state.backend.base import StateBackend
+from twinkle.utils import get_logger
T = TypeVar('T', bound=BaseModel)
-logger = logging.getLogger(__name__)
+logger = get_logger()
class BaseManager(ABC, Generic[T]):
diff --git a/src/twinkle/server/state/session_manager.py b/src/twinkle/server/state/session_manager.py
index 442019ea4..1bc901b0c 100644
--- a/src/twinkle/server/state/session_manager.py
+++ b/src/twinkle/server/state/session_manager.py
@@ -2,14 +2,14 @@
from __future__ import annotations
import functools
-import logging
import time
+from twinkle.utils import get_logger
from .backend.base import ConcurrencyError, StateBackend
from .base import BaseManager
from .models import SessionRecord
-logger = logging.getLogger(__name__)
+logger = get_logger()
def _session_touch_transform(existing: dict | None, *, now: float) -> dict | None:
diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py
index 77212c757..059f301fb 100644
--- a/src/twinkle/server/telemetry/provider.py
+++ b/src/twinkle/server/telemetry/provider.py
@@ -16,8 +16,9 @@
from typing import Any
from twinkle.server.config.telemetry import TelemetryConfig
+from twinkle.utils import get_logger
-logger = logging.getLogger(__name__)
+logger = get_logger()
# Loggers belonging to the OTLP transport stack. Their own records must never
# be routed back through the OTLP LoggingHandler: an exporter error logged
diff --git a/src/twinkle/server/telemetry/worker_init.py b/src/twinkle/server/telemetry/worker_init.py
index 997f2e140..40edc6628 100644
--- a/src/twinkle/server/telemetry/worker_init.py
+++ b/src/twinkle/server/telemetry/worker_init.py
@@ -7,10 +7,11 @@
"""
from __future__ import annotations
-import logging
import os
-logger = logging.getLogger(__name__)
+from twinkle.utils import get_logger
+
+logger = get_logger()
_worker_initialized = False
diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py
index 6c4bdddd2..168456eab 100644
--- a/src/twinkle/template/__init__.py
+++ b/src/twinkle/template/__init__.py
@@ -2,3 +2,4 @@
from .base import Template
from .deepseek_v4 import DeepseekV4Template
from .qwen3_5_vl import Qwen3_5Template
+from .tools import ToolCallParser, ToolCallRegistry, ClineParser, HermesQwenParser, ReActParser, VCPParser
diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py
index 26c2e4f26..c1e8f069f 100644
--- a/src/twinkle/template/base.py
+++ b/src/twinkle/template/base.py
@@ -32,6 +32,19 @@ class Template:
video_placeholder: str = '