Skip to content

Activation Beacon,关于Loss定义的疑问:平均 token loss 是否会稀释真正依赖 beacon memory 的位置? #1565

@Strandding

Description

@Strandding

您好,我在复现 activation_beacon 的预训练时,对当前训练 loss 的定义和训练现象有一个疑问,想请教一下这是否符合作者的预期。

我目前使用的是类似 examples 中的预训练配置,预训练数据为普通 LM 数据(如 RedPajama),采用的模型为Qwen2-7B-Instruct

· 观察到一个现象:
在预训练过程中,loss 从一开始大约 2.13,训练到后面也基本还是在 2.04 左右波动,没有看到特别明显的下降。

· 我对代码的理解:
训练时最终 loss 是对所有有效 token 的 cross-entropy 做平均,beacon token 本身不参与 loss。
在 interleave + full-coverage 下,同一个窗口内部的大量 raw token 仍然可见,因此很多 token 的预测主要还是依赖局部 raw context,比较简单。
真正更能体现 beacon 是否学到作用的位置,应该是那些需要跨窗口利用前一窗信息的位置;比如在 window = stride 时,前一窗 raw KV 不再保留,主要只能依赖 beacon memory,此时更能检验beacon的作用。

· 思考:
真正“依赖 beacon memory”的 token 在整条序列中只占一部分,但训练 loss 是对所有有效 token 统一平均,大量主要靠局部 raw context 就能预测好的 token,会主导整体 loss,这样整体 loss 可能更多反映的是“普通 LM token 的平均预测难度”,而不是 beacon memory 是否真的学好了。这是否也是我在预训练时看到 loss 从 2.13 到 2.04 左右、下降不明显的原因之一?

· 请教几个问题:
1、我的这个理解是否正确?如果确实是预期设计,想请教一下这样设计的原因
2、预训练阶段 loss 从大约 2.13 训练到 2.0 左右、整体下降不明显,这种现象是否正常?
3、其实2.04左右的 Loss 也接近原始模型(Qwen2-7B-Instruct)在训练集上的效果,这是否侧面印证了训练Loss被大量依靠局部 raw context 预测的 token稀释?
4、作者是否尝试过更聚焦 beacon 作用的位置指标或目标,例如:
单独统计跨窗口位置的 loss
对更依赖 beacon memory 的 token 加权
单独汇报一项 beacon-related loss 作为监控指标

期待您的回复!

附上预训练脚本
output_name=beacon-qwen2-pretrain torchrun --nproc_per_node 8 $DDP -m main.train \ --output_dir data/outputs/$output_name \ --model_name_or_path /dataset/common/tzh/model/Qwen2-7B-Instruct \ --train_data /dataset/common/tzh/dataset/long-llm/redpajama/train.json \ --min_length 2400 \ --max_length 20000 \ --group_by_stride strict \ --enable_beacon \ --beacon_window 2048 \ --beacon_stride 2048 \ --beacon_attn full-coverage \ --beacon_attend_prev True \ --beacon_sink_size 0 \ --beacon_ratio 2 4 8 16 32 \ --beacon_ratio_mix step-random \ --beacon_param q k v \ --beacon_pos interleave \ --attn_impl flash_attention_2 \ --gradient_checkpointing \ --use_reentrant False \ --save_only_model \ --save_strategy epoch \ --evaluation_strategy steps \ --num_train_epochs 1 \ --logging_steps 50 \ --bf16 \ --deepspeed data/deepspeed/stage2.json

附上部分预训练输出:
{'loss': 2.1343, 'grad_norm': 0.2808859348297119, 'learning_rate': 4.979447262550289e-05, 'epoch': 0.0}
{'loss': 2.1084, 'grad_norm': 0.226092129945755, 'learning_rate': 4.957582648242085e-05, 'epoch': 0.01}
{'loss': 2.1167, 'grad_norm': 0.1432613581418991, 'learning_rate': 4.9357180339338815e-05, 'epoch': 0.01}
{'loss': 1.9753, 'grad_norm': 0.16457320749759674, 'learning_rate': 4.913853419625678e-05, 'epoch': 0.02}
{'loss': 2.0601, 'grad_norm': 0.12638713419437408, 'learning_rate': 4.891988805317474e-05, 'epoch': 0.02}
{'loss': 2.0375, 'grad_norm': 0.14955464005470276, 'learning_rate': 4.8701241910092706e-05, 'epoch': 0.03}
{'loss': 2.0465, 'grad_norm': 0.14118915796279907, 'learning_rate': 4.848259576701068e-05, 'epoch': 0.03}
{'loss': 2.0381, 'grad_norm': 0.1377200037240982, 'learning_rate': 4.8263949623928634e-05, 'epoch': 0.03}
{'loss': 2.024, 'grad_norm': 0.10688850283622742, 'learning_rate': 4.80453034808466e-05, 'epoch': 0.04}
{'loss': 2.0387, 'grad_norm': 0.11750514060258865, 'learning_rate': 4.782665733776456e-05, 'epoch': 0.04}
{'loss': 2.0662, 'grad_norm': 0.13407647609710693, 'learning_rate': 4.7608011194682526e-05, 'epoch': 0.05}
{'loss': 2.0861, 'grad_norm': 0.14363867044448853, 'learning_rate': 4.738936505160049e-05, 'epoch': 0.05}
{'loss': 2.0667, 'grad_norm': 0.1525307148694992, 'learning_rate': 4.717071890851846e-05, 'epoch': 0.06}
{'loss': 2.0437, 'grad_norm': 0.1380089372396469, 'learning_rate': 4.6952072765436424e-05, 'epoch': 0.06}
{'loss': 2.0895, 'grad_norm': 0.1160527840256691, 'learning_rate': 4.673342662235438e-05, 'epoch': 0.07}
{'loss': 1.9908, 'grad_norm': 0.12899190187454224, 'learning_rate': 4.6514780479272345e-05, 'epoch': 0.07}
{'loss': 2.0517, 'grad_norm': 0.12548774480819702, 'learning_rate': 4.629613433619031e-05, 'epoch': 0.07}

......

{'loss': 2.0505, 'grad_norm': 0.11952044069766998, 'learning_rate': 2.129613433619031e-06, 'epoch': 0.96}
{'loss': 1.9744, 'grad_norm': 0.11724891513586044, 'learning_rate': 1.9109672905369953e-06, 'epoch': 0.96}
{'loss': 2.052, 'grad_norm': 0.11619514226913452, 'learning_rate': 1.6923211474549588e-06, 'epoch': 0.97}
{'loss': 2.0327, 'grad_norm': 0.13094069063663483, 'learning_rate': 1.473675004372923e-06, 'epoch': 0.97}
{'loss': 2.056, 'grad_norm': 0.13238118588924408, 'learning_rate': 1.2550288612908868e-06, 'epoch': 0.97}
{'loss': 2.0527, 'grad_norm': 0.14477798342704773, 'learning_rate': 1.036382718208851e-06, 'epoch': 0.98}
{'loss': 1.9486, 'grad_norm': 0.17454515397548676, 'learning_rate': 8.177365751268147e-07, 'epoch': 0.98}
{'loss': 2.0118, 'grad_norm': 0.1428622007369995, 'learning_rate': 5.990904320447787e-07, 'epoch': 0.99}
{'loss': 2.0553, 'grad_norm': 0.15292023122310638, 'learning_rate': 3.8044428896274276e-07, 'epoch': 0.99}
{'loss': 2.0557, 'grad_norm': 0.14297017455101013, 'learning_rate': 1.6179814588070666e-07, 'epoch': 1.0}
{'train_runtime': 27943.4629, 'train_samples_per_second': 3.275, 'train_steps_per_second': 0.409, 'train_loss': 2.041090958266043, 'epoch': 1.0}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions