-
Notifications
You must be signed in to change notification settings - Fork 845
Description
您好,我在复现 activation_beacon 的预训练时,对当前训练 loss 的定义和训练现象有一个疑问,想请教一下这是否符合作者的预期。
我目前使用的是类似 examples 中的预训练配置,预训练数据为普通 LM 数据(如 RedPajama),采用的模型为Qwen2-7B-Instruct
· 观察到一个现象:
在预训练过程中,loss 从一开始大约 2.13,训练到后面也基本还是在 2.04 左右波动,没有看到特别明显的下降。
· 我对代码的理解:
训练时最终 loss 是对所有有效 token 的 cross-entropy 做平均,beacon token 本身不参与 loss。
在 interleave + full-coverage 下,同一个窗口内部的大量 raw token 仍然可见,因此很多 token 的预测主要还是依赖局部 raw context,比较简单。
真正更能体现 beacon 是否学到作用的位置,应该是那些需要跨窗口利用前一窗信息的位置;比如在 window = stride 时,前一窗 raw KV 不再保留,主要只能依赖 beacon memory,此时更能检验beacon的作用。
· 思考:
真正“依赖 beacon memory”的 token 在整条序列中只占一部分,但训练 loss 是对所有有效 token 统一平均,大量主要靠局部 raw context 就能预测好的 token,会主导整体 loss,这样整体 loss 可能更多反映的是“普通 LM token 的平均预测难度”,而不是 beacon memory 是否真的学好了。这是否也是我在预训练时看到 loss 从 2.13 到 2.04 左右、下降不明显的原因之一?
· 请教几个问题:
1、我的这个理解是否正确?如果确实是预期设计,想请教一下这样设计的原因
2、预训练阶段 loss 从大约 2.13 训练到 2.0 左右、整体下降不明显,这种现象是否正常?
3、其实2.04左右的 Loss 也接近原始模型(Qwen2-7B-Instruct)在训练集上的效果,这是否侧面印证了训练Loss被大量依靠局部 raw context 预测的 token稀释?
4、作者是否尝试过更聚焦 beacon 作用的位置指标或目标,例如:
单独统计跨窗口位置的 loss
对更依赖 beacon memory 的 token 加权
单独汇报一项 beacon-related loss 作为监控指标
期待您的回复!
附上预训练脚本
output_name=beacon-qwen2-pretrain torchrun --nproc_per_node 8 $DDP -m main.train \ --output_dir data/outputs/$output_name \ --model_name_or_path /dataset/common/tzh/model/Qwen2-7B-Instruct \ --train_data /dataset/common/tzh/dataset/long-llm/redpajama/train.json \ --min_length 2400 \ --max_length 20000 \ --group_by_stride strict \ --enable_beacon \ --beacon_window 2048 \ --beacon_stride 2048 \ --beacon_attn full-coverage \ --beacon_attend_prev True \ --beacon_sink_size 0 \ --beacon_ratio 2 4 8 16 32 \ --beacon_ratio_mix step-random \ --beacon_param q k v \ --beacon_pos interleave \ --attn_impl flash_attention_2 \ --gradient_checkpointing \ --use_reentrant False \ --save_only_model \ --save_strategy epoch \ --evaluation_strategy steps \ --num_train_epochs 1 \ --logging_steps 50 \ --bf16 \ --deepspeed data/deepspeed/stage2.json
附上部分预训练输出:
{'loss': 2.1343, 'grad_norm': 0.2808859348297119, 'learning_rate': 4.979447262550289e-05, 'epoch': 0.0}
{'loss': 2.1084, 'grad_norm': 0.226092129945755, 'learning_rate': 4.957582648242085e-05, 'epoch': 0.01}
{'loss': 2.1167, 'grad_norm': 0.1432613581418991, 'learning_rate': 4.9357180339338815e-05, 'epoch': 0.01}
{'loss': 1.9753, 'grad_norm': 0.16457320749759674, 'learning_rate': 4.913853419625678e-05, 'epoch': 0.02}
{'loss': 2.0601, 'grad_norm': 0.12638713419437408, 'learning_rate': 4.891988805317474e-05, 'epoch': 0.02}
{'loss': 2.0375, 'grad_norm': 0.14955464005470276, 'learning_rate': 4.8701241910092706e-05, 'epoch': 0.03}
{'loss': 2.0465, 'grad_norm': 0.14118915796279907, 'learning_rate': 4.848259576701068e-05, 'epoch': 0.03}
{'loss': 2.0381, 'grad_norm': 0.1377200037240982, 'learning_rate': 4.8263949623928634e-05, 'epoch': 0.03}
{'loss': 2.024, 'grad_norm': 0.10688850283622742, 'learning_rate': 4.80453034808466e-05, 'epoch': 0.04}
{'loss': 2.0387, 'grad_norm': 0.11750514060258865, 'learning_rate': 4.782665733776456e-05, 'epoch': 0.04}
{'loss': 2.0662, 'grad_norm': 0.13407647609710693, 'learning_rate': 4.7608011194682526e-05, 'epoch': 0.05}
{'loss': 2.0861, 'grad_norm': 0.14363867044448853, 'learning_rate': 4.738936505160049e-05, 'epoch': 0.05}
{'loss': 2.0667, 'grad_norm': 0.1525307148694992, 'learning_rate': 4.717071890851846e-05, 'epoch': 0.06}
{'loss': 2.0437, 'grad_norm': 0.1380089372396469, 'learning_rate': 4.6952072765436424e-05, 'epoch': 0.06}
{'loss': 2.0895, 'grad_norm': 0.1160527840256691, 'learning_rate': 4.673342662235438e-05, 'epoch': 0.07}
{'loss': 1.9908, 'grad_norm': 0.12899190187454224, 'learning_rate': 4.6514780479272345e-05, 'epoch': 0.07}
{'loss': 2.0517, 'grad_norm': 0.12548774480819702, 'learning_rate': 4.629613433619031e-05, 'epoch': 0.07}
......
{'loss': 2.0505, 'grad_norm': 0.11952044069766998, 'learning_rate': 2.129613433619031e-06, 'epoch': 0.96}
{'loss': 1.9744, 'grad_norm': 0.11724891513586044, 'learning_rate': 1.9109672905369953e-06, 'epoch': 0.96}
{'loss': 2.052, 'grad_norm': 0.11619514226913452, 'learning_rate': 1.6923211474549588e-06, 'epoch': 0.97}
{'loss': 2.0327, 'grad_norm': 0.13094069063663483, 'learning_rate': 1.473675004372923e-06, 'epoch': 0.97}
{'loss': 2.056, 'grad_norm': 0.13238118588924408, 'learning_rate': 1.2550288612908868e-06, 'epoch': 0.97}
{'loss': 2.0527, 'grad_norm': 0.14477798342704773, 'learning_rate': 1.036382718208851e-06, 'epoch': 0.98}
{'loss': 1.9486, 'grad_norm': 0.17454515397548676, 'learning_rate': 8.177365751268147e-07, 'epoch': 0.98}
{'loss': 2.0118, 'grad_norm': 0.1428622007369995, 'learning_rate': 5.990904320447787e-07, 'epoch': 0.99}
{'loss': 2.0553, 'grad_norm': 0.15292023122310638, 'learning_rate': 3.8044428896274276e-07, 'epoch': 0.99}
{'loss': 2.0557, 'grad_norm': 0.14297017455101013, 'learning_rate': 1.6179814588070666e-07, 'epoch': 1.0}
{'train_runtime': 27943.4629, 'train_samples_per_second': 3.275, 'train_steps_per_second': 0.409, 'train_loss': 2.041090958266043, 'epoch': 1.0}