Skip to content

zhaoyingjun/Tiny-R2

Repository files navigation

Tiny-R2: A Hybrid Architecture Combining SWA, CSA, HCA, mHC and DSMoE Under the DeepSeek V4 Design Paradigm

ๆจกๅž‹็ป“ๆž„ benchimark loss

Tiny-R2 ๆจกๅž‹ๆžถๆž„ไธŽ่ฎญ็ปƒๆต็จ‹ๆ–‡ๆกฃ


๐Ÿ“‹ ็›ฎๅฝ•

  1. ้กน็›ฎๆฆ‚่ฟฐ
  2. ๆจกๅž‹ๆžถๆž„ๆ€ป่งˆ
  3. ๆ ธๅฟƒ็ป„ไปถ่ฏฆ่งฃ
  4. ่ฎญ็ปƒๆต็จ‹
  5. ๅ…ณ้”ฎๆŠ€ๆœฏ็‰นๆ€ง
  6. ้™„ๅฝ•๏ผšๅ›พ่กจ็ดขๅผ•

้กน็›ฎๆฆ‚่ฟฐ

Tiny-R2 ๆ˜ฏไธ€ไธชไปฅๅฟซ้€Ÿๅค็ŽฐDeepSeekV4/R2ไธบ็›ฎๆ ‡็š„้กน็›ฎ๏ผŒ็›ฎๅ‰ๅทฒ็ปๅฎž็Žฐๅฆ‚ไธ‹็š„ๆžถๆž„๏ผš

  • ็จ€็–ๆณจๆ„ๅŠ›ๆœบๅˆถ (HCA-CSA Hybrid Attention)
  • ไธ“ๅฎถๆททๅˆๆจกๅž‹ (DeepSeek MoE)
  • ่ถ…่ฟžๆŽฅๆŠ€ๆœฏ (Hyper-Connections)
  • ๅŒไผ˜ๅŒ–ๅ™จ็ญ–็•ฅ (Muon + AdamW)
  • ๆ”ฏๆŒOPDๅŽ่ฎญ็ปƒ (On-policy distillation)

ๆจกๅž‹ๆžถๆž„ๆ€ป่งˆ

ๆจกๅž‹็ป“ๆž„

ๅฟซ้€ŸๅฏๅŠจ

2.1 ๅฎ‰่ฃ…ไพ่ต–

pip install tiktoken datasets transformers huggingface_hub
pip install git+https://github.com/KellerJordan/Muon
pip install --upgrade transformers
hf auth login --force

2.2 ๅฏๅŠจ่ฎญ็ปƒ

2.2.1ๆ”ฏๆŒ้‡‡็”จAgent่ฟ›่กŒ่‡ชไธป่ง‚ๅฏŸๅ’Œ่ฎญ็ปƒ่ฐƒๆ•ดlrใ€clip่ถ…ๅ‚๏ผŒไปฅๅฎž็Žฐๆ›ดๅŠ ็จณๅฎš็š„ๆ™บ่ƒฝๅŒ–่ฎญ็ปƒ;้ป˜่ฎคไธๅผ€ๅฏ๏ผ›ๅผ€ๅฏๅŽ้œ€่ฆไฝฟ็”จgemini็š„api key

python train.py --n_layer 6 --n_embd 1536 --hc 'True' --mhc 'True' --n_experts 8 --max_iters 10000 --attention_types 'Sparse' --batch_size 8 --ctx_len 2048 --hf_dataset 'karpathy/climbmix-400b-shuffle' --resume True --save_best_only True

2.2.2 ่ฎพ็ฝฎ --use_agent_observeๅผ€ๅฏAgentๆ™บ่ƒฝๅŒ–่ฎญ็ปƒไพ›่ƒฝ๏ผŒ้œ€่ฆๅกซๅ…ฅไฝ ็š„geimini็š„api key

python train.py --n_layer 6 --n_embd 1536 --hc 'True' --mhc 'True' --n_experts 8 --max_iters 10000 --attention_types 'Sparse' --batch_size 8 --ctx_len 2048 --hf_dataset 'karpathy/climbmix-400b-shuffle' --resume True --save_best_only True --use_agent_observe True --gemini_api_key "your gemini apikey"

2.3 ้ชŒ่ฏๆจกๅž‹่ฎญ็ปƒๆ•ˆๆžœPPL

python evaluate.py --checkpoint checkpoints/best_model_step_xxx.pt 

2.4 ๅฏๅŠจOPDๅœจ็บฟ่’ธ้ฆ

python opd_train.py --batch_size 4 --ctx_len 2048 --hf_teacher_model Qwen/Qwen3.5-9B --student_ckpt "./opd_checkpoints/best_model_step_0.pt" --tokenizer_name Qwen/Qwen3.5-9B --dataset mmlu_pro
     

ๆ ธๅฟƒ็ป„ไปถ่ฏฆ่งฃ

3.1 ๆณจๆ„ๅŠ›ๆœบๅˆถ

Tiny-R2 ๆ”ฏๆŒFullAttentionไธŽSparse Attentionๆณจๆ„ๅŠ›็ฑปๅž‹๏ผŒ้€š่ฟ‡้…็ฝฎ attention_types ็ตๆดปๅˆ‡ๆข๏ผš

3.1.1 CausalSelfAttention (Full Attention)

ๆ ‡ๅ‡†็š„ๅ› ๆžœ่‡ชๆณจๆ„ๅŠ›ๆœบๅˆถ๏ผš

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        # Projections: Q, K, V from single linear
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        
        # Value residual connections
        self.v_residual = config.v_residual
        self.lamb1 = nn.Parameter(torch.tensor(0.5))
        self.lamb2 = nn.Parameter(torch.tensor(0.5))
        
        # Flash Attention support
        self.flash = hasattr(F, "scaled_dot_product_attention")

ๅ…ณ้”ฎ็‰นๆ€ง๏ผš

  • ไฝฟ็”จ Flash Attention ๅŠ ้€Ÿ๏ผˆๅฆ‚ๆžœๅฏ็”จ๏ผ‰
  • ๆ”ฏๆŒ Value Residual Connections
  • ๆ ‡ๅ‡†็š„ๅ› ๆžœๆŽฉ็ 

3.1.2 HCA-CSA Hybrid Attention

็ป“ๅˆ HCA)ๅ’ŒCSA ็š„ๆททๅˆๆณจๆ„ๅŠ›ๆœบๅˆถใ€‚

ไธ‰็ง่ฟ่กŒๆจกๅผ๏ผš

ๆจกๅผ ๅˆ†ๆ”ฏ้…็ฝฎ ่ฏดๆ˜Ž
HCA [1, 0, 0] ่ถ…็บงๅŽ‹็ผฉๅˆ†ๆ”ฏ
SWA [0, 0, 1] ๆป‘ๅŠจ็ช—ๅฃๅˆ†ๆ”ฏ
CSA [0, 1, 0] ๅŽ‹็ผฉ + ้€‰ๆ‹ฉๅˆ†ๆ”ฏ

3.2 HCA-CSA ๆททๅˆๆณจๆ„ๅŠ›

HCA-CSA ๆ˜ฏ Tiny-R2 ็š„ๆ ธๅฟƒๅˆ›ๆ–ฐไน‹ไธ€๏ผŒ้€š่ฟ‡ไธ‰ไธชๅนถ่กŒๅˆ†ๆ”ฏๅฎž็Žฐ้ซ˜ๆ•ˆ็š„็จ€็–ๆณจๆ„ๅŠ›่ฎก็ฎ—ใ€‚

ๆžถๆž„ๆต็จ‹

                       Input x
                          โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Query Preparation (HCA style)                                โ”‚
โ”‚   compress_q โ†’ q_norm โ†’ decompress_q โ†’ RoPE โ†’ Query         โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                          โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚   Branch 1       โ”‚   Branch 2       โ”‚   Branch 3           โ”‚
โ”‚   Compression    โ”‚   Selection      โ”‚   Sliding Window     โ”‚
โ”‚   (HCA)          โ”‚   (CSA)          โ”‚   (SWA)              โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ compress_kv      โ”‚ importance_score โ”‚ window_k/v           โ”‚
โ”‚ kv_norm          โ”‚ topk selection   โ”‚ sliding_window       โ”‚
โ”‚ decompress_k/v   โ”‚ selection_k/v    โ”‚ RoPE                 โ”‚
โ”‚ k_rope           โ”‚ RoPE             โ”‚                      โ”‚
โ”‚ K/V Recombine    โ”‚ K/V Selected     โ”‚ K/V Window           โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    โ†“                    โ†“                    โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Attention Computation                                        โ”‚
โ”‚   Attention 1: (Q @ K1.T) @ V1                               โ”‚
โ”‚   Attention 2: (Q @ K2.T) @ V2                               โ”‚
โ”‚   Attention 3: (Q @ K3.T) @ V3                               โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    โ†“
branch_gate (Linear + Softmax) โ†’ Weighted Sum
    โ†“
proj (Linear) โ†’ res_dropout โ†’ Output

ๅ…ณ้”ฎๅ‚ๆ•ฐ

# HCA ๅ‚ๆ•ฐ
self.v_head_dim = 32
self.kv_lora_rank = 32
self.q_lora_rank = 3 * self.kv_lora_rank
self.rope_head_dim = 64
self.nope_head_dim = 32

# CSA ๅ‚ๆ•ฐ
self.block_size = config.block_size      # TokenๅŽ‹็ผฉๅ—ๅคงๅฐ
self.window_size = config.window_size    # ๆป‘ๅŠจ็ช—ๅฃๅคงๅฐ
self.num_tokens_to_keep = config.num_tokens_to_keep  # ้€‰ๆ‹ฉไฟ็•™็š„tokenๆ•ฐ

3.3 ๅ‰้ฆˆ็ฝ‘็ปœไธŽ MoE

3.3.1 MLP

ๆ ‡ๅ‡†็š„ๅ‰้ฆˆ็ฝ‘็ปœ๏ผŒไฝฟ็”จ ReLUยฒ ๆฟ€ๆดปๅ‡ฝๆ•ฐ๏ผš

class MLP(nn.Module):
    def __init__(self):
        self.c_fc = nn.Linear(n_embd, 4 * n_embd)
        self.c_proj = nn.Linear(4 * n_embd, n_embd)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()  # ReLU squared
        x = self.c_proj(x)
        return x

3.3.2 DSMoE (DeepSeek Mixture of Experts)

DeepSeek ้ฃŽๆ ผ็š„ไธ“ๅฎถๆททๅˆๆจกๅž‹๏ผš

Input x [B, T, C]
    โ†“
Gate Network (Linear + UnitCenteredNoise)
    โ†“
Softmax โ†’ Top-k Selection
    โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Expert Networks                          โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
โ”‚  โ”‚ Shared Exp 0 โ”‚  โ”‚ Expert 1 โ”‚  โ”‚ Expert 2 โ”‚  โ”‚  ...   โ”‚ โ”‚
โ”‚  โ”‚ (Always On)  โ”‚  โ”‚ (Top-k)  โ”‚  โ”‚ (Top-k)  โ”‚  โ”‚ (Top-k)โ”‚ โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    โ†“
Weighted Sum of Expert Outputs
    โ†“
Output [B, T, C]

ๅ…ณ้”ฎ็‰นๆ€ง๏ผš

็‰นๆ€ง ่ฏดๆ˜Ž
Shared Expert ๅง‹็ปˆๆฟ€ๆดป็š„ๅ…ฑไบซไธ“ๅฎถ๏ผŒๆไพ›็จณๅฎšๆ€ง
Routed Experts Top-k ้€‰ๆ‹ฉ็š„่ทฏ็”ฑไธ“ๅฎถ
Load Balance Loss ้˜ฒๆญขไธ“ๅฎถๅดฉๆบƒ็š„่ดŸ่ฝฝๅ‡่กกๆŸๅคฑ
Expert Bias ๅฏๅญฆไน ็š„ไธ“ๅฎถๅ็ฝฎ๏ผŒ็”จไบŽ่ทฏ็”ฑไผ˜ๅŒ–
UnitCenteredNoise ่ฎญ็ปƒๆ—ถๆทปๅŠ ๅ™ชๅฃฐไปฅๅขžๅŠ ๆŽข็ดข

Load Balance Loss ่ฎก็ฎ—๏ผš

def moe_load_balance_loss(router_weights, num_experts):
    load = router_weights.sum(dim=0)
    load = load / load.sum()
    ideal = torch.full_like(load, 1.0 / num_experts)
    loss = num_experts * torch.sum((load - ideal) ** 2)
    return loss

3.4 Hyper-Connections

Hyper-Connections ๆ˜ฏ Tiny-R2 ็š„ๅฆไธ€ๅคงๅˆ›ๆ–ฐ๏ผŒ้€š่ฟ‡ๅคšๆต่ทฏ็”ฑๆœบๅˆถๅขžๅผบไฟกๆฏๆตๅŠจใ€‚

ๆ ธๅฟƒๆฆ‚ๅฟต๏ผš

# ๅˆๅง‹ๅŒ– Hyper-Connections
self.init_hc, self.expand_stream, self.reduce_stream = \
    get_init_and_expand_reduce_stream_functions(
        config.hc_num_streams,
        num_fracs=config.hc_num_fracs,
        disable=config.hc_disable,
    )

# ๅœจๆฏไธช Block ไธญไฝฟ็”จ
self.hc_attn = init_hc(
    dim=config.n_embd,
    branch=self.attn_branch,
    layer_index=index * 2,
    mhc=config.mhc,
    sinkhorn_iters=config.sinkhorn_iters,
    sinkhorn_tau=config.sinkhorn_tau,
)

ๅ…ณ้”ฎๅ‚ๆ•ฐ๏ผš

ๅ‚ๆ•ฐ ่ฏดๆ˜Ž
hc_num_streams ่ถ…่ฟžๆŽฅๆตๆ•ฐ้‡
hc_num_fracs ๅˆ†ๆฎตๆ•ฐ้‡
mhc ๅคš่ถ…่ฟžๆŽฅ้…็ฝฎ
sinkhorn_iters Sinkhorn ็ฎ—ๆณ•่ฟญไปฃๆฌกๆ•ฐ
sinkhorn_tau Sinkhorn ๆธฉๅบฆๅ‚ๆ•ฐ

่ฎญ็ปƒๆต็จ‹

4.1 ๅˆๅง‹ๅŒ–้˜ถๆฎต

Parse Arguments โ†’ Update Config โ†’ Init WandB โ†’ Setup Distributed โ†’ Setup AMP

4.2 ๆ•ฐๆฎๅ‡†ๅค‡

Load HF Dataset (flytech/python-codes-25k)
    โ†“
Init GPT2 Tokenizer
    โ†“
Create TokenBuffer

TokenBuffer ๅŠŸ่ƒฝ๏ผš

  • ๆตๅผ่ฏปๅ– HuggingFace ๆ•ฐๆฎ้›†
  • ๅŠจๆ€ๅกซๅ…… token buffer
  • ็”Ÿๆˆ่ฟž็ปญ็š„ token batch

4.3 ๆจกๅž‹ๅˆๅง‹ๅŒ–

Create Transformer
    โ†“
Configure Optimizers (Muon + AdamW)
    โ†“
Create LR Scheduler (Warmup + Cosine)

4.4 ่ฎญ็ปƒๅพช็Žฏ

For iter in range(max_iters):
    โ”‚
    โ”œโ”€โ”€ For step in grad_accum_steps:
    โ”‚       โ”œโ”€โ”€ Get Batch (TokenBuffer)
    โ”‚       โ”œโ”€โ”€ Forward Pass (model)
    โ”‚       โ”œโ”€โ”€ Backward Pass (scaler.scale)
    โ”‚       โ””โ”€โ”€ Collect Router Weights
    โ”‚
    โ”œโ”€โ”€ Gradient Clipping (clip_grad_norm_)
    โ”œโ”€โ”€ Optimizer Steps (Muon + AdamW)
    โ”œโ”€โ”€ Update Scaler (scaler.update)
    โ”œโ”€โ”€ LR Scheduler Step
    โ”œโ”€โ”€ Update Expert Biases (load balancing)
    โ””โ”€โ”€ Log Metrics (WandB)

4.5 ่ฏ„ไผฐไธŽไฟๅญ˜

If iter % eval_interval == 0:
    โ”œโ”€โ”€ Estimate Loss (eval mode)
    โ”œโ”€โ”€ Save Checkpoint (if val_loss < 5.27)
    โ””โ”€โ”€ Log to WandB

4.6 ไผ˜ๅŒ–ๅ™จ้…็ฝฎ

Tiny-R2 ไฝฟ็”จๅŒไผ˜ๅŒ–ๅ™จ็ญ–็•ฅ๏ผš

def configure_optimizers(self, weight_decay, learning_rate, device):
    muon_params = []    # โ‰ฅ2D parameters in blocks
    adamw_params = []   # Other parameters
    
    for name, param in self.named_parameters():
        if 'blocks' in name and param.ndim >= 2:
            muon_params.append(param)
        else:
            adamw_params.append(param)
    
    return [
        Muon(muon_params, lr=0.02, momentum=0.95),
        torch.optim.AdamW(adamw_params, lr=learning_rate, 
                          betas=(0.90, 0.95), weight_decay=weight_decay)
    ]

ๅ…ณ้”ฎๆŠ€ๆœฏ็‰นๆ€ง

5.1 ๆณจๆ„ๅŠ›ๆœบๅˆถๅฏนๆฏ”

็‰นๆ€ง CausalSelfAttention HCA-CSA Hybrid
่ฎก็ฎ—ๅคๆ‚ๅบฆ O(nยฒ) O(n) ~ O(n log n)
ๅ†…ๅญ˜ไฝฟ็”จ ้ซ˜ ไฝŽ
้€‚็”จๅœบๆ™ฏ ็Ÿญๅบๅˆ— ้•ฟๅบๅˆ—
ๅˆ†ๆ”ฏๆ•ฐ้‡ 1 3 (ๅฏ้…็ฝฎ)

5.2 FFN ็ฑปๅž‹ๅฏนๆฏ”

็‰นๆ€ง MLP DSMoE
ๅ‚ๆ•ฐ้‡ ๅ›บๅฎš ๅ…ฑไบซ + ่ทฏ็”ฑ
่ฎก็ฎ—้‡ ๅ›บๅฎš ็จ€็–ๆฟ€ๆดป
่กจ่พพ่ƒฝๅŠ› ๆ ‡ๅ‡† ๆ›ดๅผบ
่ฎญ็ปƒ็จณๅฎšๆ€ง ้ซ˜ ้œ€่ฆ่ดŸ่ฝฝๅ‡่กก

5.3 ๆ ธๅฟƒ้…็ฝฎๅ‚ๆ•ฐ

# ๆจกๅž‹ๆžถๆž„
n_embd = 512        # ๅตŒๅ…ฅ็ปดๅบฆ
n_head = 8          # ๆณจๆ„ๅŠ›ๅคดๆ•ฐ
n_layer = 8         # ๅฑ‚ๆ•ฐ
n_experts = 8       # ไธ“ๅฎถๆ•ฐ้‡
num_exp = 2         # ๆฏtokenๆฟ€ๆดป็š„ไธ“ๅฎถๆ•ฐ

# ๆณจๆ„ๅŠ›้…็ฝฎ
attention_types = ["FULL", "Spares", ...]  # ๆฏๅฑ‚ๆณจๆ„ๅŠ›็ฑปๅž‹
attention_mode = ["FULL", "SWA", "CSA"]    # ็จ€็–ๆณจๆ„ๅŠ›ๆจกๅผ

# Hyper-Connections
hc = True           # ๅฏ็”จ่ถ…่ฟžๆŽฅ
hc_num_streams = 4  # ๆตๆ•ฐ้‡

# ่ฎญ็ปƒ
batch_size = 32
ctx_len = 512       # ไธŠไธ‹ๆ–‡้•ฟๅบฆ
lr = 1e-3
warmup_iters = 1000
max_iters = 100000

้™„ๅฝ•๏ผšๅ›พ่กจ็ดขๅผ•

ๆœฌๆ–‡ๆกฃ้…ๅฅ—ๅ›พ่กจไฟๅญ˜ๅœจ /mnt/okcomputer/output/ ็›ฎๅฝ•๏ผš

ๆ–‡ไปถๅ ่ฏดๆ˜Ž
model_architecture.png ๆจกๅž‹ๆ•ดไฝ“ๆžถๆž„ๅ›พ
loss.png ่ฎญ็ปƒ 20ไบฟTokens็š„lossๆ”ถๆ•›ๅ›พ
benchmark.png wikitext-103 benchmarkๅ›พ

ๅ‚่€ƒ่ต„ๆ–™


About

Tiny-R2: A hybrid architecture integrating SWA, CSA, HCA, mHC, and DSMoE under the DeepSeek V4 design paradigm, enabling single-GPU OPD post-training.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors